task.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. package tasks
  2. import (
  3. "context"
  4. "fmt"
  5. "sync"
  6. "time"
  7. "github.com/seaweedfs/seaweedfs/weed/glog"
  8. "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
  9. "github.com/seaweedfs/seaweedfs/weed/worker/types"
  10. )
  11. // BaseTask provides common functionality for all tasks
  12. type BaseTask struct {
  13. taskType types.TaskType
  14. taskID string
  15. progress float64
  16. cancelled bool
  17. mutex sync.RWMutex
  18. startTime time.Time
  19. estimatedDuration time.Duration
  20. logger TaskLogger
  21. loggerConfig TaskLoggerConfig
  22. progressCallback func(float64, string) // Callback function for progress updates
  23. currentStage string // Current stage description
  24. }
  25. // NewBaseTask creates a new base task
  26. func NewBaseTask(taskType types.TaskType) *BaseTask {
  27. return &BaseTask{
  28. taskType: taskType,
  29. progress: 0.0,
  30. cancelled: false,
  31. loggerConfig: DefaultTaskLoggerConfig(),
  32. }
  33. }
  34. // NewBaseTaskWithLogger creates a new base task with custom logger configuration
  35. func NewBaseTaskWithLogger(taskType types.TaskType, loggerConfig TaskLoggerConfig) *BaseTask {
  36. return &BaseTask{
  37. taskType: taskType,
  38. progress: 0.0,
  39. cancelled: false,
  40. loggerConfig: loggerConfig,
  41. }
  42. }
  43. // InitializeLogger initializes the task logger with task details
  44. func (t *BaseTask) InitializeLogger(taskID string, workerID string, params types.TaskParams) error {
  45. return t.InitializeTaskLogger(taskID, workerID, params)
  46. }
  47. // InitializeTaskLogger initializes the task logger with task details (LoggerProvider interface)
  48. func (t *BaseTask) InitializeTaskLogger(taskID string, workerID string, params types.TaskParams) error {
  49. t.mutex.Lock()
  50. defer t.mutex.Unlock()
  51. t.taskID = taskID
  52. logger, err := NewTaskLogger(taskID, t.taskType, workerID, params, t.loggerConfig)
  53. if err != nil {
  54. return fmt.Errorf("failed to initialize task logger: %w", err)
  55. }
  56. t.logger = logger
  57. t.logger.Info("BaseTask initialized for task %s (type: %s)", taskID, t.taskType)
  58. return nil
  59. }
  60. // Type returns the task type
  61. func (t *BaseTask) Type() types.TaskType {
  62. return t.taskType
  63. }
  64. // GetProgress returns the current progress (0.0 to 100.0)
  65. func (t *BaseTask) GetProgress() float64 {
  66. t.mutex.RLock()
  67. defer t.mutex.RUnlock()
  68. return t.progress
  69. }
  70. // SetProgress sets the current progress and logs it
  71. func (t *BaseTask) SetProgress(progress float64) {
  72. t.mutex.Lock()
  73. if progress < 0 {
  74. progress = 0
  75. }
  76. if progress > 100 {
  77. progress = 100
  78. }
  79. oldProgress := t.progress
  80. callback := t.progressCallback
  81. stage := t.currentStage
  82. t.progress = progress
  83. t.mutex.Unlock()
  84. // Log progress change
  85. if t.logger != nil && progress != oldProgress {
  86. message := stage
  87. if message == "" {
  88. message = fmt.Sprintf("Progress updated from %.1f%% to %.1f%%", oldProgress, progress)
  89. }
  90. t.logger.LogProgress(progress, message)
  91. }
  92. // Call progress callback if set
  93. if callback != nil && progress != oldProgress {
  94. callback(progress, stage)
  95. }
  96. }
  97. // SetProgressWithStage sets the current progress with a stage description
  98. func (t *BaseTask) SetProgressWithStage(progress float64, stage string) {
  99. t.mutex.Lock()
  100. if progress < 0 {
  101. progress = 0
  102. }
  103. if progress > 100 {
  104. progress = 100
  105. }
  106. callback := t.progressCallback
  107. t.progress = progress
  108. t.currentStage = stage
  109. t.mutex.Unlock()
  110. // Log progress change
  111. if t.logger != nil {
  112. t.logger.LogProgress(progress, stage)
  113. }
  114. // Call progress callback if set
  115. if callback != nil {
  116. callback(progress, stage)
  117. }
  118. }
  119. // SetCurrentStage sets the current stage description
  120. func (t *BaseTask) SetCurrentStage(stage string) {
  121. t.mutex.Lock()
  122. defer t.mutex.Unlock()
  123. t.currentStage = stage
  124. }
  125. // GetCurrentStage returns the current stage description
  126. func (t *BaseTask) GetCurrentStage() string {
  127. t.mutex.RLock()
  128. defer t.mutex.RUnlock()
  129. return t.currentStage
  130. }
  131. // Cancel cancels the task
  132. func (t *BaseTask) Cancel() error {
  133. t.mutex.Lock()
  134. defer t.mutex.Unlock()
  135. if t.cancelled {
  136. return nil
  137. }
  138. t.cancelled = true
  139. if t.logger != nil {
  140. t.logger.LogStatus("cancelled", "Task cancelled by request")
  141. t.logger.Warning("Task %s was cancelled", t.taskID)
  142. }
  143. return nil
  144. }
  145. // IsCancelled returns whether the task is cancelled
  146. func (t *BaseTask) IsCancelled() bool {
  147. t.mutex.RLock()
  148. defer t.mutex.RUnlock()
  149. return t.cancelled
  150. }
  151. // SetStartTime sets the task start time
  152. func (t *BaseTask) SetStartTime(startTime time.Time) {
  153. t.mutex.Lock()
  154. defer t.mutex.Unlock()
  155. t.startTime = startTime
  156. if t.logger != nil {
  157. t.logger.LogStatus("running", fmt.Sprintf("Task started at %s", startTime.Format(time.RFC3339)))
  158. }
  159. }
  160. // GetStartTime returns the task start time
  161. func (t *BaseTask) GetStartTime() time.Time {
  162. t.mutex.RLock()
  163. defer t.mutex.RUnlock()
  164. return t.startTime
  165. }
  166. // SetEstimatedDuration sets the estimated duration
  167. func (t *BaseTask) SetEstimatedDuration(duration time.Duration) {
  168. t.mutex.Lock()
  169. defer t.mutex.Unlock()
  170. t.estimatedDuration = duration
  171. if t.logger != nil {
  172. t.logger.LogWithFields("INFO", "Estimated duration set", map[string]interface{}{
  173. "estimated_duration": duration.String(),
  174. "estimated_seconds": duration.Seconds(),
  175. })
  176. }
  177. }
  178. // GetEstimatedDuration returns the estimated duration
  179. func (t *BaseTask) GetEstimatedDuration() time.Duration {
  180. t.mutex.RLock()
  181. defer t.mutex.RUnlock()
  182. return t.estimatedDuration
  183. }
  184. // SetProgressCallback sets the progress callback function
  185. func (t *BaseTask) SetProgressCallback(callback func(float64, string)) {
  186. t.mutex.Lock()
  187. defer t.mutex.Unlock()
  188. t.progressCallback = callback
  189. }
  190. // SetLoggerConfig sets the logger configuration for this task
  191. func (t *BaseTask) SetLoggerConfig(config TaskLoggerConfig) {
  192. t.mutex.Lock()
  193. defer t.mutex.Unlock()
  194. t.loggerConfig = config
  195. }
  196. // GetLogger returns the task logger
  197. func (t *BaseTask) GetLogger() TaskLogger {
  198. t.mutex.RLock()
  199. defer t.mutex.RUnlock()
  200. return t.logger
  201. }
  202. // GetTaskLogger returns the task logger (LoggerProvider interface)
  203. func (t *BaseTask) GetTaskLogger() TaskLogger {
  204. t.mutex.RLock()
  205. defer t.mutex.RUnlock()
  206. return t.logger
  207. }
  208. // LogInfo logs an info message
  209. func (t *BaseTask) LogInfo(message string, args ...interface{}) {
  210. if t.logger != nil {
  211. t.logger.Info(message, args...)
  212. }
  213. }
  214. // LogWarning logs a warning message
  215. func (t *BaseTask) LogWarning(message string, args ...interface{}) {
  216. if t.logger != nil {
  217. t.logger.Warning(message, args...)
  218. }
  219. }
  220. // LogError logs an error message
  221. func (t *BaseTask) LogError(message string, args ...interface{}) {
  222. if t.logger != nil {
  223. t.logger.Error(message, args...)
  224. }
  225. }
  226. // LogDebug logs a debug message
  227. func (t *BaseTask) LogDebug(message string, args ...interface{}) {
  228. if t.logger != nil {
  229. t.logger.Debug(message, args...)
  230. }
  231. }
  232. // LogWithFields logs a message with structured fields
  233. func (t *BaseTask) LogWithFields(level string, message string, fields map[string]interface{}) {
  234. if t.logger != nil {
  235. t.logger.LogWithFields(level, message, fields)
  236. }
  237. }
  238. // FinishTask finalizes the task and closes the logger
  239. func (t *BaseTask) FinishTask(success bool, errorMsg string) error {
  240. if t.logger != nil {
  241. if success {
  242. t.logger.LogStatus("completed", "Task completed successfully")
  243. t.logger.Info("Task %s finished successfully", t.taskID)
  244. } else {
  245. t.logger.LogStatus("failed", fmt.Sprintf("Task failed: %s", errorMsg))
  246. t.logger.Error("Task %s failed: %s", t.taskID, errorMsg)
  247. }
  248. // Close logger
  249. if err := t.logger.Close(); err != nil {
  250. glog.Errorf("Failed to close task logger: %v", err)
  251. }
  252. }
  253. return nil
  254. }
  255. // ExecuteTask is a wrapper that handles common task execution logic with logging
  256. func (t *BaseTask) ExecuteTask(ctx context.Context, params types.TaskParams, executor func(context.Context, types.TaskParams) error) error {
  257. // Initialize logger if not already done
  258. if t.logger == nil {
  259. // Generate a temporary task ID if none provided
  260. if t.taskID == "" {
  261. t.taskID = fmt.Sprintf("task_%d", time.Now().UnixNano())
  262. }
  263. workerID := "unknown"
  264. if err := t.InitializeLogger(t.taskID, workerID, params); err != nil {
  265. glog.Warningf("Failed to initialize task logger: %v", err)
  266. }
  267. }
  268. t.SetStartTime(time.Now())
  269. t.SetProgress(0)
  270. if t.logger != nil {
  271. t.logger.LogWithFields("INFO", "Task execution started", map[string]interface{}{
  272. "volume_id": params.VolumeID,
  273. "server": getServerFromSources(params.TypedParams.Sources),
  274. "collection": params.Collection,
  275. })
  276. }
  277. // Create a context that can be cancelled
  278. ctx, cancel := context.WithCancel(ctx)
  279. defer cancel()
  280. // Monitor for cancellation
  281. go func() {
  282. for !t.IsCancelled() {
  283. select {
  284. case <-ctx.Done():
  285. return
  286. case <-time.After(time.Second):
  287. // Check cancellation every second
  288. }
  289. }
  290. t.LogWarning("Task cancellation detected, cancelling context")
  291. cancel()
  292. }()
  293. // Execute the actual task
  294. t.LogInfo("Starting task executor")
  295. err := executor(ctx, params)
  296. if err != nil {
  297. t.LogError("Task executor failed: %v", err)
  298. t.FinishTask(false, err.Error())
  299. return err
  300. }
  301. if t.IsCancelled() {
  302. t.LogWarning("Task was cancelled during execution")
  303. t.FinishTask(false, "cancelled")
  304. return context.Canceled
  305. }
  306. t.SetProgress(100)
  307. t.LogInfo("Task executor completed successfully")
  308. t.FinishTask(true, "")
  309. return nil
  310. }
  311. // UnsupportedTaskTypeError represents an error for unsupported task types
  312. type UnsupportedTaskTypeError struct {
  313. TaskType types.TaskType
  314. }
  315. func (e *UnsupportedTaskTypeError) Error() string {
  316. return "unsupported task type: " + string(e.TaskType)
  317. }
  318. // BaseTaskFactory provides common functionality for task factories
  319. type BaseTaskFactory struct {
  320. taskType types.TaskType
  321. capabilities []string
  322. description string
  323. }
  324. // NewBaseTaskFactory creates a new base task factory
  325. func NewBaseTaskFactory(taskType types.TaskType, capabilities []string, description string) *BaseTaskFactory {
  326. return &BaseTaskFactory{
  327. taskType: taskType,
  328. capabilities: capabilities,
  329. description: description,
  330. }
  331. }
  332. // Capabilities returns the capabilities required for this task type
  333. func (f *BaseTaskFactory) Capabilities() []string {
  334. return f.capabilities
  335. }
  336. // Description returns the description of this task type
  337. func (f *BaseTaskFactory) Description() string {
  338. return f.description
  339. }
  340. // ValidateParams validates task parameters
  341. func ValidateParams(params types.TaskParams, requiredFields ...string) error {
  342. for _, field := range requiredFields {
  343. switch field {
  344. case "volume_id":
  345. if params.VolumeID == 0 {
  346. return &ValidationError{Field: field, Message: "volume_id is required"}
  347. }
  348. case "server":
  349. if len(params.TypedParams.Sources) == 0 {
  350. return &ValidationError{Field: field, Message: "server is required"}
  351. }
  352. case "collection":
  353. if params.Collection == "" {
  354. return &ValidationError{Field: field, Message: "collection is required"}
  355. }
  356. }
  357. }
  358. return nil
  359. }
  360. // ValidationError represents a parameter validation error
  361. type ValidationError struct {
  362. Field string
  363. Message string
  364. }
  365. func (e *ValidationError) Error() string {
  366. return e.Field + ": " + e.Message
  367. }
  368. // getServerFromSources extracts the server address from unified sources
  369. func getServerFromSources(sources []*worker_pb.TaskSource) string {
  370. if len(sources) > 0 {
  371. return sources[0].Node
  372. }
  373. return ""
  374. }