From f8a33d37f3f4e016f3f34d128b3e073431160315 Mon Sep 17 00:00:00 2001 From: infrmtcs Date: Fri, 3 Oct 2025 12:08:24 +0700 Subject: [PATCH] feat: Implement tracker to synchronize WaitGroup Add() before Wait() --- utils/tracker/service_test.go | 96 +++++++++++++++++++++++++++++++ utils/tracker/tracker.go | 37 ++++++++++++ utils/tracker/tracker_test.go | 104 ++++++++++++++++++++++++++++++++++ utils/tracker/wrapper_test.go | 43 ++++++++++++++ 4 files changed, 280 insertions(+) create mode 100644 utils/tracker/service_test.go create mode 100644 utils/tracker/tracker.go create mode 100644 utils/tracker/tracker_test.go create mode 100644 utils/tracker/wrapper_test.go diff --git a/utils/tracker/service_test.go b/utils/tracker/service_test.go new file mode 100644 index 0000000000..55a7e4ddd3 --- /dev/null +++ b/utils/tracker/service_test.go @@ -0,0 +1,96 @@ +package tracker_test + +import ( + "context" + "sync/atomic" + "time" +) + +type barrier struct { + checkCh chan struct{} + doneCh chan struct{} +} + +func newBarrier(size int) barrier { + checkCh := make(chan struct{}, size) + doneCh := make(chan struct{}) + + go func() { + for range size { + <-checkCh + } + close(doneCh) + }() + + return barrier{ + checkCh: checkCh, + doneCh: doneCh, + } +} + +func (b barrier) check() { + b.checkCh <- struct{}{} +} + +func (b barrier) done() { + <-b.doneCh +} + +type service struct { + tracker waitGroup + isShutdown *atomic.Bool + passedCtxCheck barrier + finishedNonDelayed barrier +} + +func NewService(tracker waitGroup, isShutdown *atomic.Bool, nonDelayed, delayed int) service { + service := service{ + tracker: tracker, + isShutdown: isShutdown, + passedCtxCheck: newBarrier(nonDelayed + delayed), + finishedNonDelayed: newBarrier(nonDelayed), + } + return service +} + +func (s *service) handle(ctx context.Context, isDelayed bool) bool { + select { + case <-ctx.Done(): + return false + default: + } + + // Communicate with the test that ctx.Done() check has passed + s.passedCtxCheck.check() + + // We track the completion of the non delayed requests, then wait until they are all completed + // so that the Wait() at call site completes. Then we can check if any acquisition is done + // after Wait() completes. + if isDelayed { + s.finishedNonDelayed.done() + } else { + defer s.finishedNonDelayed.check() + } + + isShutdown := s.isShutdown.Load() + + if !s.tracker.Add(1) { + return false + } + defer s.tracker.Done() + + // Because isShutdown is loaded before, if isShutdown is true while acquiring succeeds, + // this implies that the acquisition is done after ctx.Done() is closed + if isShutdown { + panic("acquired after shutdown") + } + + // Simulate a small delay to create async boundary to increase the chances of race condition + time.Sleep(asyncBoundaryDelay) + return true +} + +func (s *service) run(ctx context.Context) { + <-ctx.Done() + s.tracker.Wait() +} diff --git a/utils/tracker/tracker.go b/utils/tracker/tracker.go new file mode 100644 index 0000000000..2f64e47290 --- /dev/null +++ b/utils/tracker/tracker.go @@ -0,0 +1,37 @@ +package tracker + +import ( + "sync" +) + +// Inspired by net/http +// See https://cs.opensource.google/go/go/+/refs/tags/go1.25.1:src/net/http/server.go;l=3604 +type Tracker struct { + mu sync.Mutex + wg sync.WaitGroup + done bool +} + +func (g *Tracker) Add(delta int) bool { + g.mu.Lock() + defer g.mu.Unlock() + + if g.done { + return false + } + + g.wg.Add(delta) + return true +} + +func (g *Tracker) Done() { + g.wg.Done() +} + +func (g *Tracker) Wait() { + g.mu.Lock() + g.done = true + g.mu.Unlock() + + g.wg.Wait() +} diff --git a/utils/tracker/tracker_test.go b/utils/tracker/tracker_test.go new file mode 100644 index 0000000000..d9a87036cb --- /dev/null +++ b/utils/tracker/tracker_test.go @@ -0,0 +1,104 @@ +package tracker_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/NethermindEth/juno/utils/tracker" + "github.com/sourcegraph/conc" + "github.com/stretchr/testify/require" +) + +const ( + nonDelayed = 10 + delayed = 100 + asyncBoundaryDelay = 1 * time.Millisecond +) + +type waitGroup interface { + Add(delta int) bool + Done() + Wait() +} + +func TestTracker(t *testing.T) { + runTests(t, func(context.Context) waitGroup { + return &tracker.Tracker{} + }) +} + +func TestSemaphoreWrapper(t *testing.T) { + runTests(t, func(ctx context.Context) waitGroup { + return NewSemaphoreWrapper(ctx) + }) +} + +func TestWaitGroupFailure(t *testing.T) { + t.Skip("This test is skipped because it's supposed to fail to demonstrate the issue") + runTests(t, func(context.Context) waitGroup { + return &waitGroupWrapper{} + }) +} + +func runTests(t *testing.T, tracker func(context.Context) waitGroup) { + t.Helper() + + t.Run("normal case, all requests complete", func(t *testing.T) { + runTest(t, false, nonDelayed, 0, nonDelayed, tracker) + }) + + t.Run("early cancel, half of non-delayed requests complete", func(t *testing.T) { + runTest(t, true, nonDelayed, delayed, nonDelayed/2, tracker) + }) +} + +func runTest( + t *testing.T, + earlyCancel bool, + nonDelayed int, + delayed int, + expected int, + tracker func(context.Context) waitGroup, +) { + t.Helper() + + ctx, cancel := context.WithCancel(t.Context()) + isShutdown := atomic.Bool{} + success := atomic.Uint32{} + service := NewService(tracker(ctx), &isShutdown, nonDelayed, delayed) + + clientWg := conc.NewWaitGroup() + serverWg := conc.NewWaitGroup() + + serverWg.Go(func() { + service.run(ctx) + // Set isShutdown to true after finishing wait, so we can panic any subsequent acquisitions + isShutdown.Store(true) + }) + + if earlyCancel { + serverWg.Go(func() { + // Wait until all ctx.Done() checks have passed before canceling the context + service.passedCtxCheck.done() + cancel() + }) + } + + for isDelayed, count := range map[bool]int{false: nonDelayed, true: delayed} { + for range count { + clientWg.Go(func() { + if result := service.handle(ctx, isDelayed); result { + success.Add(1) + } + }) + } + } + + clientWg.Wait() + cancel() + serverWg.Wait() + + require.GreaterOrEqual(t, int(success.Load()), expected) +} diff --git a/utils/tracker/wrapper_test.go b/utils/tracker/wrapper_test.go new file mode 100644 index 0000000000..888c1d7f7e --- /dev/null +++ b/utils/tracker/wrapper_test.go @@ -0,0 +1,43 @@ +package tracker_test + +import ( + "context" + "sync" + + "golang.org/x/sync/semaphore" +) + +const highThreshold = 1000000 + +type waitGroupWrapper struct { + sync.WaitGroup +} + +func (w *waitGroupWrapper) Add(delta int) bool { + w.WaitGroup.Add(delta) + return true +} + +type semaphoreWrapper struct { + *semaphore.Weighted + ctx context.Context +} + +func NewSemaphoreWrapper(ctx context.Context) *semaphoreWrapper { + return &semaphoreWrapper{ + semaphore.NewWeighted(highThreshold), + ctx, + } +} + +func (s *semaphoreWrapper) Add(delta int) bool { + return s.Weighted.Acquire(s.ctx, int64(delta)) == nil +} + +func (s *semaphoreWrapper) Done() { + s.Weighted.Release(1) +} + +func (s *semaphoreWrapper) Wait() { + _ = s.Weighted.Acquire(s.ctx, highThreshold) +}