Skip to content

Commit 8566fd8

Browse files
author
Francesco Cosentino
committed
Added GetResults and StreamResults
1 parent 3687f27 commit 8566fd8

File tree

5 files changed

+63
-39
lines changed

5 files changed

+63
-39
lines changed

examples/test/test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ func main() {
1616
// create a new task manager
1717
tm := worker.NewTaskManager(4, 10, 5, time.Second*30, time.Second*30, 3)
1818
// close the task manager
19+
// defer tm.Close()
1920

2021
// register and execute 10 tasks in a separate goroutine
2122
go func() {
@@ -34,7 +35,8 @@ func main() {
3435
log.Fatal(error)
3536
}
3637
emptyFile.Close()
37-
time.Sleep(time.Second)
38+
time.Sleep(time.Millisecond * 100)
39+
3840
return fmt.Sprintf("** task number %v with id %s executed", j, id), err
3941
},
4042
Retries: 10,
@@ -61,7 +63,7 @@ func main() {
6163
log.Fatal(error)
6264
}
6365
emptyFile.Close()
64-
time.Sleep(time.Second)
66+
time.Sleep(time.Millisecond * 100)
6567
return fmt.Sprintf("**** task number %v with id %s executed", j, id), err
6668
},
6769
}
@@ -76,8 +78,8 @@ func main() {
7678
// tm.Close()
7779

7880
// wait for the tasks to finish and print the results
79-
for result := range tm.GetResults() {
80-
fmt.Println(result)
81+
for id, result := range tm.GetResults() {
82+
fmt.Println(id, result)
8183
}
8284

8385
}

middleware/logger.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,13 @@ func (mw *loggerMiddleware) GetActiveTasks() int {
131131
return mw.next.GetActiveTasks()
132132
}
133133

134+
// StreamResults streams the results channel
135+
func (mw *loggerMiddleware) StreamResults() <-chan worker.Result {
136+
return mw.next.StreamResults()
137+
}
138+
134139
// GetResults returns the results channel
135-
func (mw *loggerMiddleware) GetResults() <-chan worker.Result {
140+
func (mw *loggerMiddleware) GetResults() []worker.Result {
136141
return mw.next.GetResults()
137142
}
138143

service.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ type Service interface {
2727
CancelTask(id uuid.UUID)
2828
// GetActiveTasks returns the number of active tasks
2929
GetActiveTasks() int
30+
// StreamResults streams the `Result` channel
31+
StreamResults() <-chan Result
3032
// GetResults retruns the `Result` channel
31-
GetResults() <-chan Result
33+
GetResults() []Result
3234
// GetCancelled gets the cancelled tasks channel
3335
GetCancelled() <-chan Task
3436
// GetTask gets a task by its ID

tests/worker_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestTaskManager_Start(t *testing.T) {
4747
}
4848
tm.RegisterTask(context.Background(), task)
4949

50-
res := <-tm.GetResults()
50+
res := <-tm.StreamResults()
5151
if res.Task == nil {
5252
t.Fatalf("Task result was not added to the results channel")
5353
}
@@ -62,7 +62,7 @@ func TestTaskManager_GetResults(t *testing.T) {
6262
}
6363
tm.RegisterTask(context.Background(), task)
6464

65-
results := <-tm.GetResults()
65+
results := <-tm.StreamResults()
6666
if results.Task == nil {
6767
t.Fatalf("results channel is nil")
6868
}

worker.go

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type TaskManager struct {
4444
MaxRetries int // MaxRetries is the maximum number of retries
4545
limiter *rate.Limiter // limiter is a rate limiter that limits the number of tasks that can be executed at once
4646
wg sync.WaitGroup // wg is a wait group that waits for all tasks to finish
47+
running sync.WaitGroup // running is a wait group that waits for all running tasks to finish
4748
mutex sync.RWMutex // mutex protects the task handling
4849
quit chan struct{} // quit is a channel to signal all goroutines to stop
4950
ctx context.Context // ctx is the context for the task manager
@@ -102,6 +103,7 @@ func NewTaskManager(maxWorkers int, maxTasks int, tasksPerSecond float64, timeou
102103
MaxRetries: maxRetries,
103104
limiter: rate.NewLimiter(rate.Limit(tasksPerSecond), maxTasks),
104105
wg: sync.WaitGroup{},
106+
running: sync.WaitGroup{},
105107
mutex: sync.RWMutex{},
106108
quit: make(chan struct{}),
107109
ctx: ctx,
@@ -282,37 +284,23 @@ func (tm *TaskManager) RegisterTasks(ctx context.Context, tasks ...Task) {
282284

283285
// Wait waits for all tasks to complete or for the timeout to elapse
284286
func (tm *TaskManager) Wait(timeout time.Duration) {
285-
timer := time.NewTimer(timeout)
286-
defer timer.Stop()
287-
288-
// flag to indicate if any tasks have been executed
289-
executed := false
287+
done := make(chan struct{})
288+
go func() {
289+
tm.wg.Wait() // Wait for all tasks to be started
290+
tm.running.Wait() // Wait for all running tasks to finish
291+
close(done)
292+
}()
290293

291-
for {
292-
select {
293-
case <-tm.quit:
294-
// task manager has been closed, cancel all tasks
295-
tm.CancelAll()
296-
close(tm.Results)
297-
close(tm.ctx.Value(ctxKeyCancelled{}).(chan Task))
298-
299-
case <-timer.C:
300-
// timeout reached, cancel all tasks
301-
tm.CancelAll()
302-
default:
303-
// check if any tasks have been executed
304-
if !executed {
305-
// no tasks have been executed, return immediately
306-
return
307-
}
308-
// wait for all tasks to finish
309-
// tm.scheduler.Wait()
310-
tm.wg.Wait()
311-
// close the results and cancelled channels
312-
close(tm.Results)
313-
close(tm.ctx.Value(ctxKeyCancelled{}).(chan Task))
314-
}
294+
select {
295+
case <-done:
296+
// All tasks have finished
297+
case <-time.After(timeout):
298+
// Timeout reached before all tasks finished
315299
}
300+
301+
// close the results and cancelled channels
302+
close(tm.Results)
303+
close(tm.ctx.Value(ctxKeyCancelled{}).(chan Task))
316304
}
317305

318306
// Close stops the task manager and waits for all tasks to finish
@@ -378,11 +366,35 @@ func (tm *TaskManager) GetActiveTasks() int {
378366
return int(tm.limiter.Limit()) - tm.limiter.Burst()
379367
}
380368

381-
// GetResults gets the results channel
382-
func (tm *TaskManager) GetResults() <-chan Result {
369+
// StreamResults streams the results channel
370+
func (tm *TaskManager) StreamResults() <-chan Result {
383371
return tm.Results
384372
}
385373

374+
// GetResults gets the results channel
375+
func (tm *TaskManager) GetResults() []Result {
376+
results := make([]Result, 0)
377+
378+
// Create a done channel to signal when all tasks have finished
379+
done := make(chan struct{})
380+
381+
// Start a goroutine to read from the Results channel
382+
go func() {
383+
for result := range tm.Results {
384+
results = append(results, result)
385+
}
386+
close(done)
387+
}()
388+
389+
// Wait for all tasks to finish
390+
tm.Wait(tm.Timeout)
391+
392+
// Wait for the results goroutine to finish
393+
<-done
394+
395+
return results
396+
}
397+
386398
// GetCancelled gets the cancelled tasks channel
387399
func (tm *TaskManager) GetCancelled() <-chan Task {
388400
return tm.ctx.Value(ctxKeyCancelled{}).(chan Task)
@@ -532,6 +544,9 @@ func (tm *TaskManager) ExecuteTask(id uuid.UUID, timeout time.Duration) (interfa
532544
return nil, ErrTaskCancelled
533545
default:
534546
task.setStarted()
547+
// Increment the running wait group when the task starts
548+
tm.running.Add(1)
549+
defer tm.running.Done()
535550

536551
// execute the task
537552
result, err := task.Fn()

0 commit comments

Comments
 (0)