Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions utils/tracker/service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package tracker_test

import (
"context"
"sync/atomic"
"time"
)

type barrier struct {
checkCh chan struct{}
doneCh chan struct{}
}

func newBarrier(size int) barrier {
checkCh := make(chan struct{}, size)
doneCh := make(chan struct{})

go func() {
for range size {
<-checkCh
}
close(doneCh)
}()

return barrier{
checkCh: checkCh,
doneCh: doneCh,
}
}

func (b barrier) check() {
b.checkCh <- struct{}{}
}

func (b barrier) done() {
<-b.doneCh
}

type service struct {
tracker waitGroup
isShutdown *atomic.Bool
passedCtxCheck barrier
finishedNonDelayed barrier
}

func NewService(tracker waitGroup, isShutdown *atomic.Bool, nonDelayed, delayed int) service {
service := service{
tracker: tracker,
isShutdown: isShutdown,
passedCtxCheck: newBarrier(nonDelayed + delayed),
finishedNonDelayed: newBarrier(nonDelayed),
}
return service
}

func (s *service) handle(ctx context.Context, isDelayed bool) bool {
select {
case <-ctx.Done():
return false
default:
}

// Communicate with the test that ctx.Done() check has passed
s.passedCtxCheck.check()

// We track the completion of the non delayed requests, then wait until they are all completed
// so that the Wait() at call site completes. Then we can check if any acquisition is done
// after Wait() completes.
if isDelayed {
s.finishedNonDelayed.done()
} else {
defer s.finishedNonDelayed.check()
}

isShutdown := s.isShutdown.Load()

if !s.tracker.Add(1) {
return false
}
defer s.tracker.Done()

// Because isShutdown is loaded before, if isShutdown is true while acquiring succeeds,
// this implies that the acquisition is done after ctx.Done() is closed
if isShutdown {
panic("acquired after shutdown")
}

// Simulate a small delay to create async boundary to increase the chances of race condition
time.Sleep(asyncBoundaryDelay)
return true
}

func (s *service) run(ctx context.Context) {
<-ctx.Done()
s.tracker.Wait()
}
37 changes: 37 additions & 0 deletions utils/tracker/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package tracker

import (
"sync"
)

// Inspired by net/http
// See https://cs.opensource.google/go/go/+/refs/tags/go1.25.1:src/net/http/server.go;l=3604
type Tracker struct {
mu sync.Mutex
wg sync.WaitGroup
done bool
}

func (g *Tracker) Add(delta int) bool {
g.mu.Lock()
defer g.mu.Unlock()

if g.done {
return false
}

g.wg.Add(delta)
return true
}

func (g *Tracker) Done() {
g.wg.Done()
}

func (g *Tracker) Wait() {
g.mu.Lock()
g.done = true
g.mu.Unlock()

g.wg.Wait()
}
104 changes: 104 additions & 0 deletions utils/tracker/tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package tracker_test

import (
"context"
"sync/atomic"
"testing"
"time"

"github.com/NethermindEth/juno/utils/tracker"
"github.com/sourcegraph/conc"
"github.com/stretchr/testify/require"
)

const (
nonDelayed = 10
delayed = 100
asyncBoundaryDelay = 1 * time.Millisecond
)

type waitGroup interface {
Add(delta int) bool
Done()
Wait()
}

func TestTracker(t *testing.T) {
runTests(t, func(context.Context) waitGroup {
return &tracker.Tracker{}
})
}

func TestSemaphoreWrapper(t *testing.T) {
runTests(t, func(ctx context.Context) waitGroup {
return NewSemaphoreWrapper(ctx)
})
}

func TestWaitGroupFailure(t *testing.T) {
t.Skip("This test is skipped because it's supposed to fail to demonstrate the issue")
runTests(t, func(context.Context) waitGroup {
return &waitGroupWrapper{}
})
}

func runTests(t *testing.T, tracker func(context.Context) waitGroup) {
t.Helper()

t.Run("normal case, all requests complete", func(t *testing.T) {
runTest(t, false, nonDelayed, 0, nonDelayed, tracker)
})

t.Run("early cancel, half of non-delayed requests complete", func(t *testing.T) {
runTest(t, true, nonDelayed, delayed, nonDelayed/2, tracker)
})
}

func runTest(
t *testing.T,
earlyCancel bool,
nonDelayed int,
delayed int,
expected int,
tracker func(context.Context) waitGroup,
) {
t.Helper()

ctx, cancel := context.WithCancel(t.Context())
isShutdown := atomic.Bool{}
success := atomic.Uint32{}
service := NewService(tracker(ctx), &isShutdown, nonDelayed, delayed)

clientWg := conc.NewWaitGroup()
serverWg := conc.NewWaitGroup()

serverWg.Go(func() {
service.run(ctx)
// Set isShutdown to true after finishing wait, so we can panic any subsequent acquisitions
isShutdown.Store(true)
})

if earlyCancel {
serverWg.Go(func() {
// Wait until all ctx.Done() checks have passed before canceling the context
service.passedCtxCheck.done()
cancel()
})
}

for isDelayed, count := range map[bool]int{false: nonDelayed, true: delayed} {
for range count {
clientWg.Go(func() {
if result := service.handle(ctx, isDelayed); result {
success.Add(1)
}
})
}
}

clientWg.Wait()
cancel()
serverWg.Wait()

require.GreaterOrEqual(t, int(success.Load()), expected)
}
43 changes: 43 additions & 0 deletions utils/tracker/wrapper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package tracker_test

import (
"context"
"sync"

"golang.org/x/sync/semaphore"
)

const highThreshold = 1000000

type waitGroupWrapper struct {
sync.WaitGroup
}

func (w *waitGroupWrapper) Add(delta int) bool {
w.WaitGroup.Add(delta)
return true
}

type semaphoreWrapper struct {
*semaphore.Weighted
ctx context.Context
}

func NewSemaphoreWrapper(ctx context.Context) *semaphoreWrapper {
return &semaphoreWrapper{
semaphore.NewWeighted(highThreshold),
ctx,
}
}

func (s *semaphoreWrapper) Add(delta int) bool {
return s.Weighted.Acquire(s.ctx, int64(delta)) == nil
}

func (s *semaphoreWrapper) Done() {
s.Weighted.Release(1)
}

func (s *semaphoreWrapper) Wait() {
_ = s.Weighted.Acquire(s.ctx, highThreshold)
}
Loading