worker_grpc_server.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. package dash
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net"
  7. "sync"
  8. "time"
  9. "github.com/seaweedfs/seaweedfs/weed/glog"
  10. "github.com/seaweedfs/seaweedfs/weed/pb"
  11. "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
  12. "github.com/seaweedfs/seaweedfs/weed/security"
  13. "github.com/seaweedfs/seaweedfs/weed/util"
  14. "google.golang.org/grpc"
  15. "google.golang.org/grpc/peer"
  16. )
  17. // WorkerGrpcServer implements the WorkerService gRPC interface
  18. type WorkerGrpcServer struct {
  19. worker_pb.UnimplementedWorkerServiceServer
  20. adminServer *AdminServer
  21. // Worker connection management
  22. connections map[string]*WorkerConnection
  23. connMutex sync.RWMutex
  24. // Log request correlation
  25. pendingLogRequests map[string]*LogRequestContext
  26. logRequestsMutex sync.RWMutex
  27. // gRPC server
  28. grpcServer *grpc.Server
  29. listener net.Listener
  30. running bool
  31. stopChan chan struct{}
  32. }
  33. // LogRequestContext tracks pending log requests
  34. type LogRequestContext struct {
  35. TaskID string
  36. WorkerID string
  37. ResponseCh chan *worker_pb.TaskLogResponse
  38. Timeout time.Time
  39. }
  40. // WorkerConnection represents an active worker connection
  41. type WorkerConnection struct {
  42. workerID string
  43. stream worker_pb.WorkerService_WorkerStreamServer
  44. lastSeen time.Time
  45. capabilities []MaintenanceTaskType
  46. address string
  47. maxConcurrent int32
  48. outgoing chan *worker_pb.AdminMessage
  49. ctx context.Context
  50. cancel context.CancelFunc
  51. }
  52. // NewWorkerGrpcServer creates a new gRPC server for worker connections
  53. func NewWorkerGrpcServer(adminServer *AdminServer) *WorkerGrpcServer {
  54. return &WorkerGrpcServer{
  55. adminServer: adminServer,
  56. connections: make(map[string]*WorkerConnection),
  57. pendingLogRequests: make(map[string]*LogRequestContext),
  58. stopChan: make(chan struct{}),
  59. }
  60. }
  61. // StartWithTLS starts the gRPC server on the specified port with optional TLS
  62. func (s *WorkerGrpcServer) StartWithTLS(port int) error {
  63. if s.running {
  64. return fmt.Errorf("worker gRPC server is already running")
  65. }
  66. // Create listener
  67. listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
  68. if err != nil {
  69. return fmt.Errorf("failed to listen on port %d: %v", port, err)
  70. }
  71. // Create gRPC server with optional TLS
  72. grpcServer := pb.NewGrpcServer(security.LoadServerTLS(util.GetViper(), "grpc.admin"))
  73. worker_pb.RegisterWorkerServiceServer(grpcServer, s)
  74. s.grpcServer = grpcServer
  75. s.listener = listener
  76. s.running = true
  77. // Start cleanup routine
  78. go s.cleanupRoutine()
  79. // Start serving in a goroutine
  80. go func() {
  81. if err := s.grpcServer.Serve(listener); err != nil {
  82. if s.running {
  83. glog.Errorf("Worker gRPC server error: %v", err)
  84. }
  85. }
  86. }()
  87. return nil
  88. }
  89. // Stop stops the gRPC server
  90. func (s *WorkerGrpcServer) Stop() error {
  91. if !s.running {
  92. return nil
  93. }
  94. s.running = false
  95. close(s.stopChan)
  96. // Close all worker connections
  97. s.connMutex.Lock()
  98. for _, conn := range s.connections {
  99. conn.cancel()
  100. close(conn.outgoing)
  101. }
  102. s.connections = make(map[string]*WorkerConnection)
  103. s.connMutex.Unlock()
  104. // Stop gRPC server
  105. if s.grpcServer != nil {
  106. s.grpcServer.GracefulStop()
  107. }
  108. // Close listener
  109. if s.listener != nil {
  110. s.listener.Close()
  111. }
  112. glog.Infof("Worker gRPC server stopped")
  113. return nil
  114. }
  115. // WorkerStream handles bidirectional communication with workers
  116. func (s *WorkerGrpcServer) WorkerStream(stream worker_pb.WorkerService_WorkerStreamServer) error {
  117. ctx := stream.Context()
  118. // get client address
  119. address := findClientAddress(ctx)
  120. // Wait for initial registration message
  121. msg, err := stream.Recv()
  122. if err != nil {
  123. return fmt.Errorf("failed to receive registration message: %w", err)
  124. }
  125. registration := msg.GetRegistration()
  126. if registration == nil {
  127. return fmt.Errorf("first message must be registration")
  128. }
  129. registration.Address = address
  130. workerID := registration.WorkerId
  131. if workerID == "" {
  132. return fmt.Errorf("worker ID cannot be empty")
  133. }
  134. glog.Infof("Worker %s connecting from %s", workerID, registration.Address)
  135. // Create worker connection
  136. connCtx, connCancel := context.WithCancel(ctx)
  137. conn := &WorkerConnection{
  138. workerID: workerID,
  139. stream: stream,
  140. lastSeen: time.Now(),
  141. address: registration.Address,
  142. maxConcurrent: registration.MaxConcurrent,
  143. outgoing: make(chan *worker_pb.AdminMessage, 100),
  144. ctx: connCtx,
  145. cancel: connCancel,
  146. }
  147. // Convert capabilities
  148. capabilities := make([]MaintenanceTaskType, len(registration.Capabilities))
  149. for i, cap := range registration.Capabilities {
  150. capabilities[i] = MaintenanceTaskType(cap)
  151. }
  152. conn.capabilities = capabilities
  153. // Register connection
  154. s.connMutex.Lock()
  155. s.connections[workerID] = conn
  156. s.connMutex.Unlock()
  157. // Register worker with maintenance manager
  158. s.registerWorkerWithManager(conn)
  159. // Send registration response
  160. regResponse := &worker_pb.AdminMessage{
  161. Timestamp: time.Now().Unix(),
  162. Message: &worker_pb.AdminMessage_RegistrationResponse{
  163. RegistrationResponse: &worker_pb.RegistrationResponse{
  164. Success: true,
  165. Message: "Worker registered successfully",
  166. },
  167. },
  168. }
  169. select {
  170. case conn.outgoing <- regResponse:
  171. case <-time.After(5 * time.Second):
  172. glog.Errorf("Failed to send registration response to worker %s", workerID)
  173. }
  174. // Start outgoing message handler
  175. go s.handleOutgoingMessages(conn)
  176. // Handle incoming messages
  177. for {
  178. select {
  179. case <-ctx.Done():
  180. glog.Infof("Worker %s connection closed: %v", workerID, ctx.Err())
  181. s.unregisterWorker(workerID)
  182. return nil
  183. case <-connCtx.Done():
  184. glog.Infof("Worker %s connection cancelled", workerID)
  185. s.unregisterWorker(workerID)
  186. return nil
  187. default:
  188. }
  189. msg, err := stream.Recv()
  190. if err != nil {
  191. if err == io.EOF {
  192. glog.Infof("Worker %s disconnected", workerID)
  193. } else {
  194. glog.Errorf("Error receiving from worker %s: %v", workerID, err)
  195. }
  196. s.unregisterWorker(workerID)
  197. return err
  198. }
  199. conn.lastSeen = time.Now()
  200. s.handleWorkerMessage(conn, msg)
  201. }
  202. }
  203. // handleOutgoingMessages sends messages to worker
  204. func (s *WorkerGrpcServer) handleOutgoingMessages(conn *WorkerConnection) {
  205. for {
  206. select {
  207. case <-conn.ctx.Done():
  208. return
  209. case msg, ok := <-conn.outgoing:
  210. if !ok {
  211. return
  212. }
  213. if err := conn.stream.Send(msg); err != nil {
  214. glog.Errorf("Failed to send message to worker %s: %v", conn.workerID, err)
  215. conn.cancel()
  216. return
  217. }
  218. }
  219. }
  220. }
  221. // handleWorkerMessage processes incoming messages from workers
  222. func (s *WorkerGrpcServer) handleWorkerMessage(conn *WorkerConnection, msg *worker_pb.WorkerMessage) {
  223. workerID := conn.workerID
  224. switch m := msg.Message.(type) {
  225. case *worker_pb.WorkerMessage_Heartbeat:
  226. s.handleHeartbeat(conn, m.Heartbeat)
  227. case *worker_pb.WorkerMessage_TaskRequest:
  228. s.handleTaskRequest(conn, m.TaskRequest)
  229. case *worker_pb.WorkerMessage_TaskUpdate:
  230. s.handleTaskUpdate(conn, m.TaskUpdate)
  231. case *worker_pb.WorkerMessage_TaskComplete:
  232. s.handleTaskCompletion(conn, m.TaskComplete)
  233. case *worker_pb.WorkerMessage_TaskLogResponse:
  234. s.handleTaskLogResponse(conn, m.TaskLogResponse)
  235. case *worker_pb.WorkerMessage_Shutdown:
  236. glog.Infof("Worker %s shutting down: %s", workerID, m.Shutdown.Reason)
  237. s.unregisterWorker(workerID)
  238. default:
  239. glog.Warningf("Unknown message type from worker %s", workerID)
  240. }
  241. }
  242. // registerWorkerWithManager registers the worker with the maintenance manager
  243. func (s *WorkerGrpcServer) registerWorkerWithManager(conn *WorkerConnection) {
  244. if s.adminServer.maintenanceManager == nil {
  245. return
  246. }
  247. worker := &MaintenanceWorker{
  248. ID: conn.workerID,
  249. Address: conn.address,
  250. LastHeartbeat: time.Now(),
  251. Status: "active",
  252. Capabilities: conn.capabilities,
  253. MaxConcurrent: int(conn.maxConcurrent),
  254. CurrentLoad: 0,
  255. }
  256. s.adminServer.maintenanceManager.RegisterWorker(worker)
  257. glog.V(1).Infof("Registered worker %s with maintenance manager", conn.workerID)
  258. }
  259. // handleHeartbeat processes heartbeat messages
  260. func (s *WorkerGrpcServer) handleHeartbeat(conn *WorkerConnection, heartbeat *worker_pb.WorkerHeartbeat) {
  261. if s.adminServer.maintenanceManager != nil {
  262. s.adminServer.maintenanceManager.UpdateWorkerHeartbeat(conn.workerID)
  263. }
  264. // Send heartbeat response
  265. response := &worker_pb.AdminMessage{
  266. Timestamp: time.Now().Unix(),
  267. Message: &worker_pb.AdminMessage_HeartbeatResponse{
  268. HeartbeatResponse: &worker_pb.HeartbeatResponse{
  269. Success: true,
  270. Message: "Heartbeat acknowledged",
  271. },
  272. },
  273. }
  274. select {
  275. case conn.outgoing <- response:
  276. case <-time.After(time.Second):
  277. glog.Warningf("Failed to send heartbeat response to worker %s", conn.workerID)
  278. }
  279. }
  280. // handleTaskRequest processes task requests from workers
  281. func (s *WorkerGrpcServer) handleTaskRequest(conn *WorkerConnection, request *worker_pb.TaskRequest) {
  282. // glog.Infof("DEBUG handleTaskRequest: Worker %s requesting tasks with capabilities %v", conn.workerID, conn.capabilities)
  283. if s.adminServer.maintenanceManager == nil {
  284. glog.Infof("DEBUG handleTaskRequest: maintenance manager is nil")
  285. return
  286. }
  287. // Get next task from maintenance manager
  288. task := s.adminServer.maintenanceManager.GetNextTask(conn.workerID, conn.capabilities)
  289. // glog.Infof("DEBUG handleTaskRequest: GetNextTask returned task: %v", task != nil)
  290. if task != nil {
  291. glog.Infof("DEBUG handleTaskRequest: Assigning task %s (type: %s) to worker %s", task.ID, task.Type, conn.workerID)
  292. // Use typed params directly - master client should already be configured in the params
  293. var taskParams *worker_pb.TaskParams
  294. if task.TypedParams != nil {
  295. taskParams = task.TypedParams
  296. } else {
  297. // Create basic params if none exist
  298. taskParams = &worker_pb.TaskParams{
  299. VolumeId: task.VolumeID,
  300. Collection: task.Collection,
  301. Sources: []*worker_pb.TaskSource{
  302. {
  303. Node: task.Server,
  304. VolumeId: task.VolumeID,
  305. },
  306. },
  307. }
  308. }
  309. // Send task assignment
  310. assignment := &worker_pb.AdminMessage{
  311. Timestamp: time.Now().Unix(),
  312. Message: &worker_pb.AdminMessage_TaskAssignment{
  313. TaskAssignment: &worker_pb.TaskAssignment{
  314. TaskId: task.ID,
  315. TaskType: string(task.Type),
  316. Params: taskParams,
  317. Priority: int32(task.Priority),
  318. CreatedTime: time.Now().Unix(),
  319. },
  320. },
  321. }
  322. select {
  323. case conn.outgoing <- assignment:
  324. glog.Infof("DEBUG handleTaskRequest: Successfully assigned task %s to worker %s", task.ID, conn.workerID)
  325. case <-time.After(time.Second):
  326. glog.Warningf("Failed to send task assignment to worker %s", conn.workerID)
  327. }
  328. } else {
  329. // glog.Infof("DEBUG handleTaskRequest: No tasks available for worker %s", conn.workerID)
  330. }
  331. }
  332. // handleTaskUpdate processes task progress updates
  333. func (s *WorkerGrpcServer) handleTaskUpdate(conn *WorkerConnection, update *worker_pb.TaskUpdate) {
  334. if s.adminServer.maintenanceManager != nil {
  335. s.adminServer.maintenanceManager.UpdateTaskProgress(update.TaskId, float64(update.Progress))
  336. glog.V(3).Infof("Updated task %s progress: %.1f%%", update.TaskId, update.Progress)
  337. }
  338. }
  339. // handleTaskCompletion processes task completion notifications
  340. func (s *WorkerGrpcServer) handleTaskCompletion(conn *WorkerConnection, completion *worker_pb.TaskComplete) {
  341. if s.adminServer.maintenanceManager != nil {
  342. errorMsg := ""
  343. if !completion.Success {
  344. errorMsg = completion.ErrorMessage
  345. }
  346. s.adminServer.maintenanceManager.CompleteTask(completion.TaskId, errorMsg)
  347. if completion.Success {
  348. glog.V(1).Infof("Worker %s completed task %s successfully", conn.workerID, completion.TaskId)
  349. } else {
  350. glog.Errorf("Worker %s failed task %s: %s", conn.workerID, completion.TaskId, completion.ErrorMessage)
  351. }
  352. }
  353. }
  354. // handleTaskLogResponse processes task log responses from workers
  355. func (s *WorkerGrpcServer) handleTaskLogResponse(conn *WorkerConnection, response *worker_pb.TaskLogResponse) {
  356. requestKey := fmt.Sprintf("%s:%s", response.WorkerId, response.TaskId)
  357. s.logRequestsMutex.RLock()
  358. requestContext, exists := s.pendingLogRequests[requestKey]
  359. s.logRequestsMutex.RUnlock()
  360. if !exists {
  361. glog.Warningf("Received unexpected log response for task %s from worker %s", response.TaskId, response.WorkerId)
  362. return
  363. }
  364. glog.V(1).Infof("Received log response for task %s from worker %s: %d entries", response.TaskId, response.WorkerId, len(response.LogEntries))
  365. // Send response to waiting channel
  366. select {
  367. case requestContext.ResponseCh <- response:
  368. // Response delivered successfully
  369. case <-time.After(time.Second):
  370. glog.Warningf("Failed to deliver log response for task %s from worker %s: timeout", response.TaskId, response.WorkerId)
  371. }
  372. // Clean up the pending request
  373. s.logRequestsMutex.Lock()
  374. delete(s.pendingLogRequests, requestKey)
  375. s.logRequestsMutex.Unlock()
  376. }
  377. // unregisterWorker removes a worker connection
  378. func (s *WorkerGrpcServer) unregisterWorker(workerID string) {
  379. s.connMutex.Lock()
  380. if conn, exists := s.connections[workerID]; exists {
  381. conn.cancel()
  382. close(conn.outgoing)
  383. delete(s.connections, workerID)
  384. }
  385. s.connMutex.Unlock()
  386. glog.V(1).Infof("Unregistered worker %s", workerID)
  387. }
  388. // cleanupRoutine periodically cleans up stale connections
  389. func (s *WorkerGrpcServer) cleanupRoutine() {
  390. ticker := time.NewTicker(30 * time.Second)
  391. defer ticker.Stop()
  392. for {
  393. select {
  394. case <-s.stopChan:
  395. return
  396. case <-ticker.C:
  397. s.cleanupStaleConnections()
  398. }
  399. }
  400. }
  401. // cleanupStaleConnections removes connections that haven't been seen recently
  402. func (s *WorkerGrpcServer) cleanupStaleConnections() {
  403. cutoff := time.Now().Add(-2 * time.Minute)
  404. s.connMutex.Lock()
  405. defer s.connMutex.Unlock()
  406. for workerID, conn := range s.connections {
  407. if conn.lastSeen.Before(cutoff) {
  408. glog.Warningf("Cleaning up stale worker connection: %s", workerID)
  409. conn.cancel()
  410. close(conn.outgoing)
  411. delete(s.connections, workerID)
  412. }
  413. }
  414. }
  415. // GetConnectedWorkers returns a list of currently connected workers
  416. func (s *WorkerGrpcServer) GetConnectedWorkers() []string {
  417. s.connMutex.RLock()
  418. defer s.connMutex.RUnlock()
  419. workers := make([]string, 0, len(s.connections))
  420. for workerID := range s.connections {
  421. workers = append(workers, workerID)
  422. }
  423. return workers
  424. }
  425. // RequestTaskLogs requests execution logs from a worker for a specific task
  426. func (s *WorkerGrpcServer) RequestTaskLogs(workerID, taskID string, maxEntries int32, logLevel string) ([]*worker_pb.TaskLogEntry, error) {
  427. s.connMutex.RLock()
  428. conn, exists := s.connections[workerID]
  429. s.connMutex.RUnlock()
  430. if !exists {
  431. return nil, fmt.Errorf("worker %s is not connected", workerID)
  432. }
  433. // Create response channel for this request
  434. responseCh := make(chan *worker_pb.TaskLogResponse, 1)
  435. requestKey := fmt.Sprintf("%s:%s", workerID, taskID)
  436. // Register pending request
  437. requestContext := &LogRequestContext{
  438. TaskID: taskID,
  439. WorkerID: workerID,
  440. ResponseCh: responseCh,
  441. Timeout: time.Now().Add(10 * time.Second),
  442. }
  443. s.logRequestsMutex.Lock()
  444. s.pendingLogRequests[requestKey] = requestContext
  445. s.logRequestsMutex.Unlock()
  446. // Create log request message
  447. logRequest := &worker_pb.AdminMessage{
  448. AdminId: "admin-server",
  449. Timestamp: time.Now().Unix(),
  450. Message: &worker_pb.AdminMessage_TaskLogRequest{
  451. TaskLogRequest: &worker_pb.TaskLogRequest{
  452. TaskId: taskID,
  453. WorkerId: workerID,
  454. IncludeMetadata: true,
  455. MaxEntries: maxEntries,
  456. LogLevel: logLevel,
  457. },
  458. },
  459. }
  460. // Send the request through the worker's outgoing channel
  461. select {
  462. case conn.outgoing <- logRequest:
  463. glog.V(1).Infof("Log request sent to worker %s for task %s", workerID, taskID)
  464. case <-time.After(5 * time.Second):
  465. // Clean up pending request on timeout
  466. s.logRequestsMutex.Lock()
  467. delete(s.pendingLogRequests, requestKey)
  468. s.logRequestsMutex.Unlock()
  469. return nil, fmt.Errorf("timeout sending log request to worker %s", workerID)
  470. }
  471. // Wait for response
  472. select {
  473. case response := <-responseCh:
  474. if !response.Success {
  475. return nil, fmt.Errorf("worker log request failed: %s", response.ErrorMessage)
  476. }
  477. glog.V(1).Infof("Received %d log entries for task %s from worker %s", len(response.LogEntries), taskID, workerID)
  478. return response.LogEntries, nil
  479. case <-time.After(10 * time.Second):
  480. // Clean up pending request on timeout
  481. s.logRequestsMutex.Lock()
  482. delete(s.pendingLogRequests, requestKey)
  483. s.logRequestsMutex.Unlock()
  484. return nil, fmt.Errorf("timeout waiting for log response from worker %s", workerID)
  485. }
  486. }
  487. // RequestTaskLogsFromAllWorkers requests logs for a task from all connected workers
  488. func (s *WorkerGrpcServer) RequestTaskLogsFromAllWorkers(taskID string, maxEntries int32, logLevel string) (map[string][]*worker_pb.TaskLogEntry, error) {
  489. s.connMutex.RLock()
  490. workerIDs := make([]string, 0, len(s.connections))
  491. for workerID := range s.connections {
  492. workerIDs = append(workerIDs, workerID)
  493. }
  494. s.connMutex.RUnlock()
  495. results := make(map[string][]*worker_pb.TaskLogEntry)
  496. for _, workerID := range workerIDs {
  497. logs, err := s.RequestTaskLogs(workerID, taskID, maxEntries, logLevel)
  498. if err != nil {
  499. glog.V(1).Infof("Failed to get logs from worker %s for task %s: %v", workerID, taskID, err)
  500. // Store empty result with error information for debugging
  501. results[workerID+"_error"] = []*worker_pb.TaskLogEntry{
  502. {
  503. Timestamp: time.Now().Unix(),
  504. Level: "ERROR",
  505. Message: fmt.Sprintf("Failed to retrieve logs from worker %s: %v", workerID, err),
  506. Fields: map[string]string{"source": "admin"},
  507. },
  508. }
  509. continue
  510. }
  511. if len(logs) > 0 {
  512. results[workerID] = logs
  513. } else {
  514. glog.V(2).Infof("No logs found for task %s on worker %s", taskID, workerID)
  515. }
  516. }
  517. return results, nil
  518. }
  519. // convertTaskParameters converts task parameters to protobuf format
  520. func convertTaskParameters(params map[string]interface{}) map[string]string {
  521. result := make(map[string]string)
  522. for key, value := range params {
  523. result[key] = fmt.Sprintf("%v", value)
  524. }
  525. return result
  526. }
  527. func findClientAddress(ctx context.Context) string {
  528. // fmt.Printf("FromContext %+v\n", ctx)
  529. pr, ok := peer.FromContext(ctx)
  530. if !ok {
  531. glog.Error("failed to get peer from ctx")
  532. return ""
  533. }
  534. if pr.Addr == net.Addr(nil) {
  535. glog.Error("failed to get peer address")
  536. return ""
  537. }
  538. return pr.Addr.String()
  539. }