common.go 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package shell
  2. import (
  3. "errors"
  4. "sync"
  5. )
  6. var (
  7. // Default maximum parallelization/concurrency for commands supporting it.
  8. DefaultMaxParallelization = 10
  9. )
  10. // ErrorWaitGroup implements a goroutine wait group which aggregates errors, if any.
  11. type ErrorWaitGroup struct {
  12. maxConcurrency int
  13. wg *sync.WaitGroup
  14. wgSem chan bool
  15. errors []error
  16. errorsMu sync.Mutex
  17. }
  18. type ErrorWaitGroupTask func() error
  19. func NewErrorWaitGroup(maxConcurrency int) *ErrorWaitGroup {
  20. if maxConcurrency <= 0 {
  21. // no concurrency = one task at the time
  22. maxConcurrency = 1
  23. }
  24. return &ErrorWaitGroup{
  25. maxConcurrency: maxConcurrency,
  26. wg: &sync.WaitGroup{},
  27. wgSem: make(chan bool, maxConcurrency),
  28. }
  29. }
  30. // Reset restarts an ErrorWaitGroup, keeping original settings. Errors and pending goroutines, if any, are flushed.
  31. func (ewg *ErrorWaitGroup) Reset() {
  32. close(ewg.wgSem)
  33. ewg.wg = &sync.WaitGroup{}
  34. ewg.wgSem = make(chan bool, ewg.maxConcurrency)
  35. ewg.errors = nil
  36. }
  37. // Add queues an ErrorWaitGroupTask to be executed as a goroutine.
  38. func (ewg *ErrorWaitGroup) Add(f ErrorWaitGroupTask) {
  39. if ewg.maxConcurrency <= 1 {
  40. // keep run order deterministic when parallelization is off
  41. ewg.errors = append(ewg.errors, f())
  42. return
  43. }
  44. ewg.wg.Add(1)
  45. go func() {
  46. ewg.wgSem <- true
  47. err := f()
  48. ewg.errorsMu.Lock()
  49. ewg.errors = append(ewg.errors, err)
  50. ewg.errorsMu.Unlock()
  51. <-ewg.wgSem
  52. ewg.wg.Done()
  53. }()
  54. }
  55. // Wait sleeps until all ErrorWaitGroupTasks are completed, then returns errors for them.
  56. func (ewg *ErrorWaitGroup) Wait() error {
  57. ewg.wg.Wait()
  58. return errors.Join(ewg.errors...)
  59. }