Skip to content

Commit fcb3261

Browse files
committed
call capture handlers in goroutines and introduce capture timeout
alert timeout fixes fix race condition
1 parent 92f65be commit fcb3261

File tree

2 files changed

+99
-8
lines changed

2 files changed

+99
-8
lines changed

capture.go

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@ import (
55
"log"
66
"runtime"
77
"strings"
8+
"sync"
89
"time"
910

1011
pkgerrors "github.com/pkg/errors"
1112
)
1213

14+
// CaptureTimeout limits how long to wait for a capture ID to be returned from a capture handler.
15+
var CaptureTimeout = 500 * time.Millisecond
16+
1317
type CaptureProvider string // i.e. "sentry"
1418

1519
type CaptureID string // may be a URL or any string that allows a captured error to be looked up
@@ -117,11 +121,12 @@ func alert(exception error) error {
117121
// infinite recursion. Here, we try to prevent that. This is relatively expensive, but we're alerting, which
118122
// shouldn't happen often.
119123
pc := make([]uintptr, 42)
120-
runtime.Callers(1, pc) // skip 1 (runtime.Callers)
124+
runtime.Callers(1, pc) // skip 1 (the one skipped is runtime.Callers)
121125
cf := runtime.CallersFrames(pc)
122126
us, _ := cf.Next()
123127
for them, ok := cf.Next(); ok; them, ok = cf.Next() {
124-
if us.Func.Name() == them.Func.Name() {
128+
// use HasPrefix here, not simple equality, because handlers are called from goroutine (below)
129+
if strings.HasPrefix(them.Func.Name(), us.Func.Name()) {
125130
log.Printf("cannot alert, recursion detected (%s): %+v", us.Func.Name(), exception)
126131
return exception // don't recurse again
127132
}
@@ -149,18 +154,56 @@ func alert(exception error) error {
149154
return true
150155
})
151156

157+
// Run handlers in goroutines, so that if one handler is deadlocked
158+
// it does not prevent others from running, or us from returning.
159+
160+
timer := time.NewTimer(CaptureTimeout)
161+
defer timer.Stop()
162+
163+
done := make(chan struct{})
164+
finish := sync.OnceFunc(func() {close(done)})
165+
var mu sync.Mutex
166+
167+
// start a goroutine for each handler
152168
for provider, handler := range capture {
153-
defer func() {
154-
if r := recover(); r != nil {
155-
log.Printf("failed to capture exception (%q): %+v", provider, r)
169+
provider := provider
170+
handler := handler
171+
go func() {
172+
defer func() {
173+
if r := recover(); r != nil {
174+
log.Printf("failed to capture exception (%q): %+v", provider, r)
175+
}
176+
}()
177+
178+
id := handler(exception, arg...)
179+
180+
mu.Lock()
181+
defer mu.Unlock()
182+
select {
183+
case <-done:
184+
// we are too late
185+
default:
186+
e.id[provider] = id
187+
if len(e.id) == len(capture) {
188+
finish()
189+
}
156190
}
157191
}()
192+
}
158193

159-
id := handler(exception, arg...)
160-
if id != "" {
161-
e.id[provider] = id
194+
// wait until done or timed out
195+
waitLoop:
196+
for {
197+
select {
198+
case <- timer.C:
199+
mu.Lock()
200+
defer mu.Unlock()
201+
finish()
202+
case <- done:
203+
break waitLoop
162204
}
163205
}
206+
164207
return e
165208
}
166209

capture_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package errors_test
33
import (
44
"fmt"
55
"strings"
6+
"sync/atomic"
67
"testing"
78
"time"
89

@@ -102,3 +103,50 @@ func TestCaptureRecurse(t *testing.T) {
102103
t.Errorf("alert did not capture")
103104
}
104105
}
106+
107+
func TestCaptureTimeout(t *testing.T) {
108+
var called atomic.Uint64 // how many handlers have been called
109+
var returned atomic.Uint64 // how many returned
110+
n := 5 // how many slow handlers we will register
111+
slow := errors.CaptureTimeout/time.Duration(n) // fastest duration of a slow handler
112+
113+
slowHandler := func(ex error, arg ...any) errors.CaptureID {
114+
c := called.Add(1)
115+
defer returned.Add(1)
116+
117+
// slow so that if multiple handlers are registered, capture will timeout
118+
time.Sleep(time.Duration(c+1) * slow) // use count to make each handler slower than the one before
119+
return errors.CaptureID(fmt.Sprintf("slowHandler %d", c))
120+
}
121+
122+
for i := 0; i < n; i++ {
123+
name := errors.CaptureProvider(fmt.Sprintf("slowHandler %d", i+1))
124+
errors.RegisterCapture(name, slowHandler)
125+
defer errors.UnregisterCapture(name)
126+
}
127+
128+
beforeAlert := time.Now()
129+
err := errors.Alertf(t.Name())
130+
howLong := time.Since(beforeAlert)
131+
132+
// make sure we didn't wait much longer than CaptureTimeout
133+
if howLong > errors.CaptureTimeout + (10 * time.Millisecond) {
134+
t.Errorf("alert to %d handlers took longer than timeout by %s", n, howLong - errors.CaptureTimeout)
135+
}
136+
137+
if int(called.Load()) != n {
138+
t.Errorf("expected to call %d handlers, called %d", n, called.Load())
139+
}
140+
141+
// we don't expect the alert to wait for all handlers
142+
if returned.Load() >= called.Load() {
143+
t.Error("alert waited for all slow handlers to return")
144+
}
145+
146+
// some handlers should be fast enough that alert waits for them
147+
if returned.Load() == 0 {
148+
t.Errorf("alert did not wait for any handlers")
149+
}
150+
151+
t.Log(err) // should show capture IDs returned from faster handlers, but not slower handlers
152+
}

0 commit comments

Comments
 (0)