Skip to content

Commit bcce5d7

Browse files
author
Francesco Cosentino
committed
thread safety
1 parent 3b083a5 commit bcce5d7

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

examples/manual/main.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,12 @@ func main() {
3636
} else {
3737
fmt.Println(res)
3838
}
39+
40+
tm.RegisterTask(task)
41+
res, err = tm.ExecuteTask(task.ID)
42+
if err != nil {
43+
fmt.Println(err)
44+
} else {
45+
fmt.Println(res)
46+
}
3947
}

task.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package worker
33
import (
44
"context"
55
"errors"
6-
"fmt"
76
"sync/atomic"
87
"time"
98

@@ -53,8 +52,7 @@ type Task struct {
5352
}
5453

5554
// IsValid returns an error if the task is invalid
56-
func (t Task) IsValid() (err error) {
57-
fmt.Println("validating task")
55+
func (t *Task) IsValid() (err error) {
5856
if t.ID == uuid.Nil {
5957
err = ErrInvalidTaskID
6058
t.Error.Store(err.Error())

worker.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ type TaskManager struct {
1616
Registry sync.Map // Registry is a map of registered tasks
1717
Results chan interface{} // Results is the channel of results
1818
taskHeap taskHeap // heap of tasks
19-
wg sync.WaitGroup // wg is a wait group that waits for all tasks to finish
2019
limiter *rate.Limiter // limiter is a rate limiter that limits the number of tasks that can be executed at once
20+
wg sync.WaitGroup // wg is a wait group that waits for all tasks to finish
21+
mutex sync.RWMutex // mutex protects the task handling
2122
}
2223

2324
// NewTaskManager creates a new task manager
@@ -47,6 +48,8 @@ func NewTaskManager(maxTasks int, tasksPerSecond float64) Service {
4748
// RegisterTask registers a new task to the task manager
4849
func (tm *TaskManager) RegisterTask(tasks ...Task) {
4950
for _, task := range tasks {
51+
tm.mutex.RLock()
52+
defer tm.mutex.RUnlock()
5053
if task.IsValid() != nil {
5154
tm.Results <- task
5255
continue
@@ -101,6 +104,8 @@ func (tm *TaskManager) GetResults() <-chan interface{} {
101104

102105
// GetTask gets a task by its ID
103106
func (tm *TaskManager) GetTask(id uuid.UUID) (task Task, ok bool) {
107+
tm.mutex.RLock()
108+
defer tm.mutex.RUnlock()
104109
t, ok := tm.Registry.Load(id)
105110
if !ok {
106111
return
@@ -155,7 +160,8 @@ func (tm *TaskManager) ExecuteTask(id uuid.UUID) (interface{}, error) {
155160
// task not found
156161
return nil, ErrTaskNotFound
157162
}
158-
163+
tm.mutex.RLock()
164+
defer tm.mutex.RUnlock()
159165
// execute the task
160166
tm.executeTask(&task)
161167
// drain the results channel and return the result
@@ -192,6 +198,8 @@ func (tm *TaskManager) worker(workerID int) {
192198

193199
// executeTask executes a task
194200
func (tm *TaskManager) executeTask(task *Task) {
201+
tm.mutex.RLock()
202+
defer tm.mutex.RUnlock()
195203
defer tm.wg.Done()
196204

197205
// reserve a token from the limiter
@@ -232,6 +240,8 @@ func (tm *TaskManager) cancelTask(task *Task, reason CancelReason, notifyWG bool
232240
if notifyWG {
233241
defer tm.wg.Done()
234242
}
243+
tm.mutex.RLock()
244+
defer tm.mutex.RUnlock()
235245
task.Cancel()
236246
// set the cancelled time
237247
task.setCancelled()

0 commit comments

Comments
 (0)