@@ -44,6 +44,7 @@ type TaskManager struct {
44
44
MaxRetries int // MaxRetries is the maximum number of retries
45
45
limiter * rate.Limiter // limiter is a rate limiter that limits the number of tasks that can be executed at once
46
46
wg sync.WaitGroup // wg is a wait group that waits for all tasks to finish
47
+ running sync.WaitGroup // running is a wait group that waits for all running tasks to finish
47
48
mutex sync.RWMutex // mutex protects the task handling
48
49
quit chan struct {} // quit is a channel to signal all goroutines to stop
49
50
ctx context.Context // ctx is the context for the task manager
@@ -102,6 +103,7 @@ func NewTaskManager(maxWorkers int, maxTasks int, tasksPerSecond float64, timeou
102
103
MaxRetries : maxRetries ,
103
104
limiter : rate .NewLimiter (rate .Limit (tasksPerSecond ), maxTasks ),
104
105
wg : sync.WaitGroup {},
106
+ running : sync.WaitGroup {},
105
107
mutex : sync.RWMutex {},
106
108
quit : make (chan struct {}),
107
109
ctx : ctx ,
@@ -282,37 +284,23 @@ func (tm *TaskManager) RegisterTasks(ctx context.Context, tasks ...Task) {
282
284
283
285
// Wait waits for all tasks to complete or for the timeout to elapse
284
286
func (tm * TaskManager ) Wait (timeout time.Duration ) {
285
- timer := time .NewTimer (timeout )
286
- defer timer .Stop ()
287
-
288
- // flag to indicate if any tasks have been executed
289
- executed := false
287
+ done := make (chan struct {})
288
+ go func () {
289
+ tm .wg .Wait () // Wait for all tasks to be started
290
+ tm .running .Wait () // Wait for all running tasks to finish
291
+ close (done )
292
+ }()
290
293
291
- for {
292
- select {
293
- case <- tm .quit :
294
- // task manager has been closed, cancel all tasks
295
- tm .CancelAll ()
296
- close (tm .Results )
297
- close (tm .ctx .Value (ctxKeyCancelled {}).(chan Task ))
298
-
299
- case <- timer .C :
300
- // timeout reached, cancel all tasks
301
- tm .CancelAll ()
302
- default :
303
- // check if any tasks have been executed
304
- if ! executed {
305
- // no tasks have been executed, return immediately
306
- return
307
- }
308
- // wait for all tasks to finish
309
- // tm.scheduler.Wait()
310
- tm .wg .Wait ()
311
- // close the results and cancelled channels
312
- close (tm .Results )
313
- close (tm .ctx .Value (ctxKeyCancelled {}).(chan Task ))
314
- }
294
+ select {
295
+ case <- done :
296
+ // All tasks have finished
297
+ case <- time .After (timeout ):
298
+ // Timeout reached before all tasks finished
315
299
}
300
+
301
+ // close the results and cancelled channels
302
+ close (tm .Results )
303
+ close (tm .ctx .Value (ctxKeyCancelled {}).(chan Task ))
316
304
}
317
305
318
306
// Close stops the task manager and waits for all tasks to finish
@@ -378,11 +366,35 @@ func (tm *TaskManager) GetActiveTasks() int {
378
366
return int (tm .limiter .Limit ()) - tm .limiter .Burst ()
379
367
}
380
368
381
- // GetResults gets the results channel
382
- func (tm * TaskManager ) GetResults () <- chan Result {
369
+ // StreamResults streams the results channel
370
+ func (tm * TaskManager ) StreamResults () <- chan Result {
383
371
return tm .Results
384
372
}
385
373
374
+ // GetResults gets the results channel
375
+ func (tm * TaskManager ) GetResults () []Result {
376
+ results := make ([]Result , 0 )
377
+
378
+ // Create a done channel to signal when all tasks have finished
379
+ done := make (chan struct {})
380
+
381
+ // Start a goroutine to read from the Results channel
382
+ go func () {
383
+ for result := range tm .Results {
384
+ results = append (results , result )
385
+ }
386
+ close (done )
387
+ }()
388
+
389
+ // Wait for all tasks to finish
390
+ tm .Wait (tm .Timeout )
391
+
392
+ // Wait for the results goroutine to finish
393
+ <- done
394
+
395
+ return results
396
+ }
397
+
386
398
// GetCancelled gets the cancelled tasks channel
387
399
func (tm * TaskManager ) GetCancelled () <- chan Task {
388
400
return tm .ctx .Value (ctxKeyCancelled {}).(chan Task )
@@ -532,6 +544,9 @@ func (tm *TaskManager) ExecuteTask(id uuid.UUID, timeout time.Duration) (interfa
532
544
return nil , ErrTaskCancelled
533
545
default :
534
546
task .setStarted ()
547
+ // Increment the running wait group when the task starts
548
+ tm .running .Add (1 )
549
+ defer tm .running .Done ()
535
550
536
551
// execute the task
537
552
result , err := task .Fn ()
0 commit comments