client.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. package ipc
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "fmt"
  6. "net"
  7. "sync"
  8. "time"
  9. "github.com/sirupsen/logrus"
  10. "github.com/vmihailenco/msgpack/v5"
  11. )
  12. // Client provides IPC communication with the Rust RDMA engine
  13. type Client struct {
  14. socketPath string
  15. conn net.Conn
  16. mu sync.RWMutex
  17. logger *logrus.Logger
  18. connected bool
  19. }
  20. // NewClient creates a new IPC client
  21. func NewClient(socketPath string, logger *logrus.Logger) *Client {
  22. if logger == nil {
  23. logger = logrus.New()
  24. logger.SetLevel(logrus.InfoLevel)
  25. }
  26. return &Client{
  27. socketPath: socketPath,
  28. logger: logger,
  29. }
  30. }
  31. // Connect establishes connection to the Rust RDMA engine
  32. func (c *Client) Connect(ctx context.Context) error {
  33. c.mu.Lock()
  34. defer c.mu.Unlock()
  35. if c.connected {
  36. return nil
  37. }
  38. c.logger.WithField("socket", c.socketPath).Info("🔗 Connecting to Rust RDMA engine")
  39. dialer := &net.Dialer{}
  40. conn, err := dialer.DialContext(ctx, "unix", c.socketPath)
  41. if err != nil {
  42. c.logger.WithError(err).Error("❌ Failed to connect to RDMA engine")
  43. return fmt.Errorf("failed to connect to RDMA engine at %s: %w", c.socketPath, err)
  44. }
  45. c.conn = conn
  46. c.connected = true
  47. c.logger.Info("✅ Connected to Rust RDMA engine")
  48. return nil
  49. }
  50. // Disconnect closes the connection
  51. func (c *Client) Disconnect() {
  52. c.mu.Lock()
  53. defer c.mu.Unlock()
  54. if c.conn != nil {
  55. c.conn.Close()
  56. c.conn = nil
  57. c.connected = false
  58. c.logger.Info("🔌 Disconnected from Rust RDMA engine")
  59. }
  60. }
  61. // IsConnected returns connection status
  62. func (c *Client) IsConnected() bool {
  63. c.mu.RLock()
  64. defer c.mu.RUnlock()
  65. return c.connected
  66. }
  67. // SendMessage sends an IPC message and waits for response
  68. func (c *Client) SendMessage(ctx context.Context, msg *IpcMessage) (*IpcMessage, error) {
  69. c.mu.RLock()
  70. conn := c.conn
  71. connected := c.connected
  72. c.mu.RUnlock()
  73. if !connected || conn == nil {
  74. return nil, fmt.Errorf("not connected to RDMA engine")
  75. }
  76. // Set write timeout
  77. if deadline, ok := ctx.Deadline(); ok {
  78. conn.SetWriteDeadline(deadline)
  79. } else {
  80. conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
  81. }
  82. c.logger.WithField("type", msg.Type).Debug("📤 Sending message to Rust engine")
  83. // Serialize message with MessagePack
  84. data, err := msgpack.Marshal(msg)
  85. if err != nil {
  86. c.logger.WithError(err).Error("❌ Failed to marshal message")
  87. return nil, fmt.Errorf("failed to marshal message: %w", err)
  88. }
  89. // Send message length (4 bytes) + message data
  90. lengthBytes := make([]byte, 4)
  91. binary.LittleEndian.PutUint32(lengthBytes, uint32(len(data)))
  92. if _, err := conn.Write(lengthBytes); err != nil {
  93. c.logger.WithError(err).Error("❌ Failed to send message length")
  94. return nil, fmt.Errorf("failed to send message length: %w", err)
  95. }
  96. if _, err := conn.Write(data); err != nil {
  97. c.logger.WithError(err).Error("❌ Failed to send message data")
  98. return nil, fmt.Errorf("failed to send message data: %w", err)
  99. }
  100. c.logger.WithFields(logrus.Fields{
  101. "type": msg.Type,
  102. "size": len(data),
  103. }).Debug("📤 Message sent successfully")
  104. // Read response
  105. return c.readResponse(ctx, conn)
  106. }
  107. // readResponse reads and deserializes the response message
  108. func (c *Client) readResponse(ctx context.Context, conn net.Conn) (*IpcMessage, error) {
  109. // Set read timeout
  110. if deadline, ok := ctx.Deadline(); ok {
  111. conn.SetReadDeadline(deadline)
  112. } else {
  113. conn.SetReadDeadline(time.Now().Add(30 * time.Second))
  114. }
  115. // Read message length (4 bytes)
  116. lengthBytes := make([]byte, 4)
  117. if _, err := conn.Read(lengthBytes); err != nil {
  118. c.logger.WithError(err).Error("❌ Failed to read response length")
  119. return nil, fmt.Errorf("failed to read response length: %w", err)
  120. }
  121. length := binary.LittleEndian.Uint32(lengthBytes)
  122. if length > 64*1024*1024 { // 64MB sanity check
  123. c.logger.WithField("length", length).Error("❌ Response message too large")
  124. return nil, fmt.Errorf("response message too large: %d bytes", length)
  125. }
  126. // Read message data
  127. data := make([]byte, length)
  128. if _, err := conn.Read(data); err != nil {
  129. c.logger.WithError(err).Error("❌ Failed to read response data")
  130. return nil, fmt.Errorf("failed to read response data: %w", err)
  131. }
  132. c.logger.WithField("size", length).Debug("📥 Response received")
  133. // Deserialize with MessagePack
  134. var response IpcMessage
  135. if err := msgpack.Unmarshal(data, &response); err != nil {
  136. c.logger.WithError(err).Error("❌ Failed to unmarshal response")
  137. return nil, fmt.Errorf("failed to unmarshal response: %w", err)
  138. }
  139. c.logger.WithField("type", response.Type).Debug("📥 Response deserialized successfully")
  140. return &response, nil
  141. }
  142. // High-level convenience methods
  143. // Ping sends a ping message to test connectivity
  144. func (c *Client) Ping(ctx context.Context, clientID *string) (*PongResponse, error) {
  145. msg := NewPingMessage(clientID)
  146. response, err := c.SendMessage(ctx, msg)
  147. if err != nil {
  148. return nil, err
  149. }
  150. if response.Type == MsgError {
  151. errorData, err := msgpack.Marshal(response.Data)
  152. if err != nil {
  153. return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
  154. }
  155. var errorResp ErrorResponse
  156. if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
  157. return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
  158. }
  159. return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
  160. }
  161. if response.Type != MsgPong {
  162. return nil, fmt.Errorf("unexpected response type: %s", response.Type)
  163. }
  164. // Convert response data to PongResponse
  165. pongData, err := msgpack.Marshal(response.Data)
  166. if err != nil {
  167. return nil, fmt.Errorf("failed to marshal pong data: %w", err)
  168. }
  169. var pong PongResponse
  170. if err := msgpack.Unmarshal(pongData, &pong); err != nil {
  171. return nil, fmt.Errorf("failed to unmarshal pong response: %w", err)
  172. }
  173. return &pong, nil
  174. }
  175. // GetCapabilities requests engine capabilities
  176. func (c *Client) GetCapabilities(ctx context.Context, clientID *string) (*GetCapabilitiesResponse, error) {
  177. msg := NewGetCapabilitiesMessage(clientID)
  178. response, err := c.SendMessage(ctx, msg)
  179. if err != nil {
  180. return nil, err
  181. }
  182. if response.Type == MsgError {
  183. errorData, err := msgpack.Marshal(response.Data)
  184. if err != nil {
  185. return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
  186. }
  187. var errorResp ErrorResponse
  188. if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
  189. return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
  190. }
  191. return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
  192. }
  193. if response.Type != MsgGetCapabilitiesResponse {
  194. return nil, fmt.Errorf("unexpected response type: %s", response.Type)
  195. }
  196. // Convert response data to GetCapabilitiesResponse
  197. capsData, err := msgpack.Marshal(response.Data)
  198. if err != nil {
  199. return nil, fmt.Errorf("failed to marshal capabilities data: %w", err)
  200. }
  201. var caps GetCapabilitiesResponse
  202. if err := msgpack.Unmarshal(capsData, &caps); err != nil {
  203. return nil, fmt.Errorf("failed to unmarshal capabilities response: %w", err)
  204. }
  205. return &caps, nil
  206. }
  207. // StartRead initiates an RDMA read operation
  208. func (c *Client) StartRead(ctx context.Context, req *StartReadRequest) (*StartReadResponse, error) {
  209. msg := NewStartReadMessage(req)
  210. response, err := c.SendMessage(ctx, msg)
  211. if err != nil {
  212. return nil, err
  213. }
  214. if response.Type == MsgError {
  215. errorData, err := msgpack.Marshal(response.Data)
  216. if err != nil {
  217. return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
  218. }
  219. var errorResp ErrorResponse
  220. if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
  221. return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
  222. }
  223. return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
  224. }
  225. if response.Type != MsgStartReadResponse {
  226. return nil, fmt.Errorf("unexpected response type: %s", response.Type)
  227. }
  228. // Convert response data to StartReadResponse
  229. startData, err := msgpack.Marshal(response.Data)
  230. if err != nil {
  231. return nil, fmt.Errorf("failed to marshal start read data: %w", err)
  232. }
  233. var startResp StartReadResponse
  234. if err := msgpack.Unmarshal(startData, &startResp); err != nil {
  235. return nil, fmt.Errorf("failed to unmarshal start read response: %w", err)
  236. }
  237. return &startResp, nil
  238. }
  239. // CompleteRead completes an RDMA read operation
  240. func (c *Client) CompleteRead(ctx context.Context, sessionID string, success bool, bytesTransferred uint64, clientCrc *uint32) (*CompleteReadResponse, error) {
  241. msg := NewCompleteReadMessage(sessionID, success, bytesTransferred, clientCrc, nil)
  242. response, err := c.SendMessage(ctx, msg)
  243. if err != nil {
  244. return nil, err
  245. }
  246. if response.Type == MsgError {
  247. errorData, err := msgpack.Marshal(response.Data)
  248. if err != nil {
  249. return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
  250. }
  251. var errorResp ErrorResponse
  252. if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
  253. return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
  254. }
  255. return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
  256. }
  257. if response.Type != MsgCompleteReadResponse {
  258. return nil, fmt.Errorf("unexpected response type: %s", response.Type)
  259. }
  260. // Convert response data to CompleteReadResponse
  261. completeData, err := msgpack.Marshal(response.Data)
  262. if err != nil {
  263. return nil, fmt.Errorf("failed to marshal complete read data: %w", err)
  264. }
  265. var completeResp CompleteReadResponse
  266. if err := msgpack.Unmarshal(completeData, &completeResp); err != nil {
  267. return nil, fmt.Errorf("failed to unmarshal complete read response: %w", err)
  268. }
  269. return &completeResp, nil
  270. }