||
- package ipc
- import (
- "context"
- "encoding/binary"
- "fmt"
- "net"
- "sync"
- "time"
- "github.com/sirupsen/logrus"
- "github.com/vmihailenco/msgpack/v5"
- )
- // Client provides IPC communication with the Rust RDMA engine
- type Client struct {
- socketPath string
- conn net.Conn
- mu sync.RWMutex
- logger *logrus.Logger
- connected bool
- }
- // NewClient creates a new IPC client
- func NewClient(socketPath string, logger *logrus.Logger) *Client {
- if logger == nil {
- logger = logrus.New()
- logger.SetLevel(logrus.InfoLevel)
- }
- return &Client{
- socketPath: socketPath,
- logger: logger,
- }
- }
- // Connect establishes connection to the Rust RDMA engine
- func (c *Client) Connect(ctx context.Context) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.connected {
- return nil
- }
- c.logger.WithField("socket", c.socketPath).Info("🔗 Connecting to Rust RDMA engine")
- dialer := &net.Dialer{}
- conn, err := dialer.DialContext(ctx, "unix", c.socketPath)
- if err != nil {
- c.logger.WithError(err).Error("❌ Failed to connect to RDMA engine")
- return fmt.Errorf("failed to connect to RDMA engine at %s: %w", c.socketPath, err)
- }
- c.conn = conn
- c.connected = true
- c.logger.Info("✅ Connected to Rust RDMA engine")
- return nil
- }
- // Disconnect closes the connection
- func (c *Client) Disconnect() {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.conn != nil {
- c.conn.Close()
- c.conn = nil
- c.connected = false
- c.logger.Info("🔌 Disconnected from Rust RDMA engine")
- }
- }
- // IsConnected returns connection status
- func (c *Client) IsConnected() bool {
- c.mu.RLock()
- defer c.mu.RUnlock()
- return c.connected
- }
- // SendMessage sends an IPC message and waits for response
- func (c *Client) SendMessage(ctx context.Context, msg *IpcMessage) (*IpcMessage, error) {
- c.mu.RLock()
- conn := c.conn
- connected := c.connected
- c.mu.RUnlock()
- if !connected || conn == nil {
- return nil, fmt.Errorf("not connected to RDMA engine")
- }
- // Set write timeout
- if deadline, ok := ctx.Deadline(); ok {
- conn.SetWriteDeadline(deadline)
- } else {
- conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
- }
- c.logger.WithField("type", msg.Type).Debug("📤 Sending message to Rust engine")
- // Serialize message with MessagePack
- data, err := msgpack.Marshal(msg)
- if err != nil {
- c.logger.WithError(err).Error("❌ Failed to marshal message")
- return nil, fmt.Errorf("failed to marshal message: %w", err)
- }
- // Send message length (4 bytes) + message data
- lengthBytes := make([]byte, 4)
- binary.LittleEndian.PutUint32(lengthBytes, uint32(len(data)))
- if _, err := conn.Write(lengthBytes); err != nil {
- c.logger.WithError(err).Error("❌ Failed to send message length")
- return nil, fmt.Errorf("failed to send message length: %w", err)
- }
- if _, err := conn.Write(data); err != nil {
- c.logger.WithError(err).Error("❌ Failed to send message data")
- return nil, fmt.Errorf("failed to send message data: %w", err)
- }
- c.logger.WithFields(logrus.Fields{
- "type": msg.Type,
- "size": len(data),
- }).Debug("📤 Message sent successfully")
- // Read response
- return c.readResponse(ctx, conn)
- }
- // readResponse reads and deserializes the response message
- func (c *Client) readResponse(ctx context.Context, conn net.Conn) (*IpcMessage, error) {
- // Set read timeout
- if deadline, ok := ctx.Deadline(); ok {
- conn.SetReadDeadline(deadline)
- } else {
- conn.SetReadDeadline(time.Now().Add(30 * time.Second))
- }
- // Read message length (4 bytes)
- lengthBytes := make([]byte, 4)
- if _, err := conn.Read(lengthBytes); err != nil {
- c.logger.WithError(err).Error("❌ Failed to read response length")
- return nil, fmt.Errorf("failed to read response length: %w", err)
- }
- length := binary.LittleEndian.Uint32(lengthBytes)
- if length > 64*1024*1024 { // 64MB sanity check
- c.logger.WithField("length", length).Error("❌ Response message too large")
- return nil, fmt.Errorf("response message too large: %d bytes", length)
- }
- // Read message data
- data := make([]byte, length)
- if _, err := conn.Read(data); err != nil {
- c.logger.WithError(err).Error("❌ Failed to read response data")
- return nil, fmt.Errorf("failed to read response data: %w", err)
- }
- c.logger.WithField("size", length).Debug("📥 Response received")
- // Deserialize with MessagePack
- var response IpcMessage
- if err := msgpack.Unmarshal(data, &response); err != nil {
- c.logger.WithError(err).Error("❌ Failed to unmarshal response")
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
- c.logger.WithField("type", response.Type).Debug("📥 Response deserialized successfully")
- return &response, nil
- }
- // High-level convenience methods
- // Ping sends a ping message to test connectivity
- func (c *Client) Ping(ctx context.Context, clientID *string) (*PongResponse, error) {
- msg := NewPingMessage(clientID)
- response, err := c.SendMessage(ctx, msg)
- if err != nil {
- return nil, err
- }
- if response.Type == MsgError {
- errorData, err := msgpack.Marshal(response.Data)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
- }
- var errorResp ErrorResponse
- if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
- return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
- }
- return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
- }
- if response.Type != MsgPong {
- return nil, fmt.Errorf("unexpected response type: %s", response.Type)
- }
- // Convert response data to PongResponse
- pongData, err := msgpack.Marshal(response.Data)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal pong data: %w", err)
- }
- var pong PongResponse
- if err := msgpack.Unmarshal(pongData, &pong); err != nil {
- return nil, fmt.Errorf("failed to unmarshal pong response: %w", err)
- }
- return &pong, nil
- }
- // GetCapabilities requests engine capabilities
- func (c *Client) GetCapabilities(ctx context.Context, clientID *string) (*GetCapabilitiesResponse, error) {
- msg := NewGetCapabilitiesMessage(clientID)
- response, err := c.SendMessage(ctx, msg)
- if err != nil {
- return nil, err
- }
- if response.Type == MsgError {
- errorData, err := msgpack.Marshal(response.Data)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
- }
- var errorResp ErrorResponse
- if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
- return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
- }
- return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
- }
- if response.Type != MsgGetCapabilitiesResponse {
- return nil, fmt.Errorf("unexpected response type: %s", response.Type)
- }
- // Convert response data to GetCapabilitiesResponse
- capsData, err := msgpack.Marshal(response.Data)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal capabilities data: %w", err)
- }
- var caps GetCapabilitiesResponse
- if err := msgpack.Unmarshal(capsData, &caps); err != nil {
- return nil, fmt.Errorf("failed to unmarshal capabilities response: %w", err)
- }
- return &caps, nil
- }
- // StartRead initiates an RDMA read operation
- func (c *Client) StartRead(ctx context.Context, req *StartReadRequest) (*StartReadResponse, error) {
- msg := NewStartReadMessage(req)
- response, err := c.SendMessage(ctx, msg)
- if err != nil {
- return nil, err
- }
- if response.Type == MsgError {
- errorData, err := msgpack.Marshal(response.Data)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
- }
- var errorResp ErrorResponse
- if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
- return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
- }
- return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
- }
- if response.Type != MsgStartReadResponse {
- return nil, fmt.Errorf("unexpected response type: %s", response.Type)
- }
- // Convert response data to StartReadResponse
- startData, err := msgpack.Marshal(response.Data)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal start read data: %w", err)
- }
- var startResp StartReadResponse
- if err := msgpack.Unmarshal(startData, &startResp); err != nil {
- return nil, fmt.Errorf("failed to unmarshal start read response: %w", err)
- }
- return &startResp, nil
- }
- // CompleteRead completes an RDMA read operation
- func (c *Client) CompleteRead(ctx context.Context, sessionID string, success bool, bytesTransferred uint64, clientCrc *uint32) (*CompleteReadResponse, error) {
- msg := NewCompleteReadMessage(sessionID, success, bytesTransferred, clientCrc, nil)
- response, err := c.SendMessage(ctx, msg)
- if err != nil {
- return nil, err
- }
- if response.Type == MsgError {
- errorData, err := msgpack.Marshal(response.Data)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal engine error data: %w", err)
- }
- var errorResp ErrorResponse
- if err := msgpack.Unmarshal(errorData, &errorResp); err != nil {
- return nil, fmt.Errorf("failed to unmarshal engine error response: %w", err)
- }
- return nil, fmt.Errorf("engine error: %s - %s", errorResp.Code, errorResp.Message)
- }
- if response.Type != MsgCompleteReadResponse {
- return nil, fmt.Errorf("unexpected response type: %s", response.Type)
- }
- // Convert response data to CompleteReadResponse
- completeData, err := msgpack.Marshal(response.Data)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal complete read data: %w", err)
- }
- var completeResp CompleteReadResponse
- if err := msgpack.Unmarshal(completeData, &completeResp); err != nil {
- return nil, fmt.Errorf("failed to unmarshal complete read response: %w", err)
- }
- return &completeResp, nil
- }
|