diff --git a/p2p/server_test.go b/p2p/server_test.go index 1e896b2e..db2ab575 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -17,6 +17,9 @@ func TestExchangeServer_handleRequestTimeout(t *testing.T) { peer := createMocknet(t, 1) s, err := store.NewStore[*headertest.DummyHeader](datastore.NewMapDatastore()) require.NoError(t, err) + head := headertest.RandDummyHeader(t) + head.HeightI %= 1000 // make it a bit lower + s.Init(context.Background(), head) server, err := NewExchangeServer[*headertest.DummyHeader]( peer[0], s, diff --git a/store/batch.go b/store/batch.go index 7785953f..4328da1c 100644 --- a/store/batch.go +++ b/store/batch.go @@ -1,22 +1,129 @@ package store import ( + "slices" "sync" "github.com/celestiaorg/go-header" ) -// batch keeps an adjacent range of headers and loosely mimics the Store -// interface. NOTE: Can fully implement Store for a use case. +type batches[H header.Header[H]] struct { + batchesMu sync.RWMutex + batches []*batch[H] + + batchLenLimit int +} + +func newEmptyBatches[H header.Header[H]]() *batches[H] { + return &batches[H]{batches: make([]*batch[H], 0, 8)} +} + +// Append must take adjacent range of Headers. +// Returns one of the internal batches once it reaches the length limit +// with true. +func (bs *batches[H]) Append(headers ...H) (*batch[H], bool) { + // TODO: Check if headers are adjacent? + if len(headers) == 0 { + return nil, false + } + bs.batchesMu.Lock() + defer bs.batchesMu.Unlock() + + // 1. Add headers as a new batch + newBatch := newBatch[H](len(headers)) + newBatch.Append(headers...) + bs.batches = append(bs.batches, newBatch) + + // 2. Ensure all the batches are sorted in descending order + slices.SortFunc(bs.batches, func(a, b *batch[H]) int { + return int(b.Head() - a.Head()) + }) + + // 3. Merge adjacent and overlapping batches + mergeIdx := 0 + for idx := 1; idx < len(bs.batches); idx++ { + curr := bs.batches[mergeIdx] + next := bs.batches[idx] + + if !next.IsReadOnly() && curr.Tail()-1 <= next.Head() { + curr.Append(next.GetAll()...) + } else { + mergeIdx++ + bs.batches[mergeIdx] = next + } + } + clear(bs.batches[mergeIdx+1:]) + bs.batches = bs.batches[:mergeIdx+1] + + // 4. Mark filled batches as read only and return if any + for i := len(bs.batches) - 1; i >= 0; i-- { + // Why in reverse? There might be several batches + // but only one is processed, so there needs to be prioritization + // which in this case is for lower heights. + b := bs.batches[i] + if b.Len() >= bs.batchLenLimit { + b.MarkReadOnly() + return b, true + } + } + + return nil, false +} + +func (bs *batches[H]) GetByHeight(height uint64) (H, error) { + bs.batchesMu.RLock() + defer bs.batchesMu.RUnlock() + + for _, b := range bs.batches { + if height >= b.Tail() && height <= b.Head() { + return b.GetByHeight(height) + } + } + + var zero H + return zero, header.ErrNotFound +} + +func (bs *batches[H]) Get(hash header.Hash) (H, error) { + bs.batchesMu.RLock() + defer bs.batchesMu.RUnlock() + + for _, b := range bs.batches { + h, err := b.Get(hash) + if err == nil { + return h, nil + } + } + + var zero H + return zero, header.ErrNotFound +} + +func (bs *batches[H]) Has(hash header.Hash) bool { + bs.batchesMu.RLock() + defer bs.batchesMu.RUnlock() + + for _, b := range bs.batches { + if b.Has(hash) { + return true + } + } + + return false +} + +// batch keeps a range of adjacent headers and loosely mimics the Store +// interface. // // It keeps a mapping 'height -> header' and 'hash -> height' // unlike the Store which keeps 'hash -> header' and 'height -> hash'. // The approach simplifies implementation for the batch and // makes it better optimized for the GetByHeight case which is what we need. type batch[H header.Header[H]] struct { - lk sync.RWMutex heights map[string]uint64 - headers []H + headers []H // in descending order + + readOnly bool } // newBatch creates the batch with the given pre-allocated size. @@ -27,80 +134,83 @@ func newBatch[H header.Header[H]](size int) *batch[H] { } } +func (b *batch[H]) MarkReadOnly() { + b.readOnly = true +} + +func (b *batch[H]) IsReadOnly() bool { + return b.readOnly +} + +func (b *batch[H]) Head() uint64 { + if len(b.headers) == 0 { + return 0 + } + return b.headers[0].Height() +} + +func (b *batch[H]) Tail() uint64 { + if len(b.headers) == 0 { + return 0 + } + return b.headers[len(b.headers)-1].Height() +} + // Len gives current length of the batch. func (b *batch[H]) Len() int { - b.lk.RLock() - defer b.lk.RUnlock() return len(b.headers) } // GetAll returns a slice of all the headers in the batch. func (b *batch[H]) GetAll() []H { - b.lk.RLock() - defer b.lk.RUnlock() return b.headers } // Get returns a header by its hash. -func (b *batch[H]) Get(hash header.Hash) H { - b.lk.RLock() - defer b.lk.RUnlock() +func (b *batch[H]) Get(hash header.Hash) (H, error) { height, ok := b.heights[hash.String()] if !ok { var zero H - return zero + return zero, header.ErrNotFound } - return b.getByHeight(height) + return b.GetByHeight(height) } // GetByHeight returns a header by its height. -func (b *batch[H]) GetByHeight(height uint64) H { - b.lk.RLock() - defer b.lk.RUnlock() - return b.getByHeight(height) -} - -func (b *batch[H]) getByHeight(height uint64) H { - var ( - ln = uint64(len(b.headers)) - zero H - ) - if ln == 0 { - return zero - } - - head := b.headers[ln-1].Height() - base := head - ln - if height > head || height <= base { - return zero +func (b *batch[H]) GetByHeight(height uint64) (H, error) { + h := b.headers[b.Head()-height] + if h.Height() != height { + var zero H + return zero, header.ErrNotFound } - return b.headers[height-base-1] + return h, nil } // Append appends new headers to the batch. func (b *batch[H]) Append(headers ...H) { - b.lk.Lock() - defer b.lk.Unlock() + head, tail := b.Head(), b.Tail() for _, h := range headers { - b.headers = append(b.headers, h) - b.heights[h.Hash().String()] = h.Height() + if h.Height() >= tail && h.Height() <= head { + // overwrite if exists already + b.headers[head-h.Height()] = h + } else { + // add new + b.headers = append(b.headers, h) + b.heights[h.Hash().String()] = h.Height() + } } } // Has checks whether header by the hash is present in the batch. func (b *batch[H]) Has(hash header.Hash) bool { - b.lk.RLock() - defer b.lk.RUnlock() _, ok := b.heights[hash.String()] return ok } // Reset cleans references to batched headers. func (b *batch[H]) Reset() { - b.lk.Lock() - defer b.lk.Unlock() b.headers = b.headers[:0] for k := range b.heights { delete(b.heights, k) diff --git a/store/batch_test.go b/store/batch_test.go new file mode 100644 index 00000000..6e6efe37 --- /dev/null +++ b/store/batch_test.go @@ -0,0 +1,165 @@ +package store + +import ( + "github.com/celestiaorg/go-header/headertest" + "github.com/stretchr/testify/assert" + "slices" + "testing" +) + +func TestBatches_GetByHeight(t *testing.T) { + headers := headertest.NewTestSuite(t).GenDummyHeaders(8) + // reverse the order to be descending + slices.SortFunc(headers, func(a, b *headertest.DummyHeader) int { + return int(b.Height() - a.Height()) + }) + + setup := [][]*headertest.DummyHeader{ + headers[:2], // Batch 8-7 + headers[4:], // Batch 4-1 + } + expected := headers[5] + + bs := newEmptyBatches[*headertest.DummyHeader]() + for _, headers := range setup { + b := newBatch[*headertest.DummyHeader](len(headers)) + b.Append(headers...) + bs.batches = append(bs.batches, b) + } + + actual, err := bs.GetByHeight(expected.Height()) + assert.NoError(t, err) + assert.Equal(t, expected, actual) +} + +func TestBatches_Append(t *testing.T) { + headers := headertest.NewTestSuite(t).GenDummyHeaders(8) // Pre-generate headers + // reverse the order to be descending + slices.SortFunc(headers, func(a, b *headertest.DummyHeader) int { + return int(b.Height() - a.Height()) + }) + + tests := []struct { + name string + setup func() [][]*headertest.DummyHeader + appendAndExpected func() ([]*headertest.DummyHeader, [][]*headertest.DummyHeader) + }{ + { + name: "Append fills gap between two batches", + setup: func() [][]*headertest.DummyHeader { + return [][]*headertest.DummyHeader{ + headers[:2], // Batch 8-7 + headers[4:], // Batch 4-1 + } + }, + appendAndExpected: func() ([]*headertest.DummyHeader, [][]*headertest.DummyHeader) { + toAppend := headers[2:4] // Headers 6,5 + expected := [][]*headertest.DummyHeader{ + headers, // Merged 8-1 + } + return toAppend, expected + }, + }, + { + name: "Append adjacent to a batch and merges", + setup: func() [][]*headertest.DummyHeader { + return [][]*headertest.DummyHeader{ + headers[:2], // Headers 8-7 + } + }, + appendAndExpected: func() ([]*headertest.DummyHeader, [][]*headertest.DummyHeader) { + toAppend := headers[2:4] // Headers 6,5 + expected := [][]*headertest.DummyHeader{ + headers[:4], // Merged 8-5 + } + return toAppend, expected + }, + }, + { + name: "Append creates a new batch in between existing batches", + setup: func() [][]*headertest.DummyHeader { + return [][]*headertest.DummyHeader{ + headers[:2], // Batch 8-7 + headers[6:], // Batch 2-1 + } + }, + appendAndExpected: func() ([]*headertest.DummyHeader, [][]*headertest.DummyHeader) { + toAppend := headers[3:5] // Headers 4,3 + expected := [][]*headertest.DummyHeader{ + headers[:2], // Batch 8-7 + headers[3:5], // Batch 4-3 + headers[6:], // Batch 2-1 + } + + return toAppend, expected + + }, + }, + { + name: "Append creates a new batch at the end", + setup: func() [][]*headertest.DummyHeader { + return [][]*headertest.DummyHeader{ + headers[:2], // Batch 8-7 + } + }, + appendAndExpected: func() ([]*headertest.DummyHeader, [][]*headertest.DummyHeader) { + toAppend := headers[4:] // Headers 5-1 + expected := [][]*headertest.DummyHeader{ + headers[:2], // Batch 8-7 + headers[4:], // Batch 5-1 + } + + return toAppend, expected + }, + }, + { + name: "Append overrides existing headers", + setup: func() [][]*headertest.DummyHeader { + return [][]*headertest.DummyHeader{ + headers, // Entire batch 8-1 + } + }, + appendAndExpected: func() ([]*headertest.DummyHeader, [][]*headertest.DummyHeader) { + differentHeaders := headertest.NewTestSuite(t).GenDummyHeaders(8) + // reverse the order to be descending + slices.SortFunc(differentHeaders, func(a, b *headertest.DummyHeader) int { + return int(b.Height() - a.Height()) + }) + + expectedHeaders := make([]*headertest.DummyHeader, len(headers)) + copy(expectedHeaders, headers) + expectedHeaders[6] = differentHeaders[6] + expectedHeaders[7] = differentHeaders[7] + + toAppend := expectedHeaders[6:] // Headers 1,2 + + return toAppend, [][]*headertest.DummyHeader{ + expectedHeaders, // Batch 1-8 with 1,2 replaced + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + setup := test.setup() + toAppend, expected := test.appendAndExpected() + + bs := newEmptyBatches[*headertest.DummyHeader]() + for _, headers := range setup { + b := newBatch[*headertest.DummyHeader](len(headers)) + b.Append(slices.Clone(headers)...) + bs.batches = append(bs.batches, b) + } + bs.Append(slices.Clone(toAppend)...) + + // Verify expected batch structure + var actualBatches [][]*headertest.DummyHeader + for _, b := range bs.batches { + actualBatches = append(actualBatches, b.GetAll()) + } + + assert.EqualValues(t, expected, actualBatches) + }) + } +} diff --git a/store/heightsub.go b/store/heightsub.go index 2335001d..1bed1d3f 100644 --- a/store/heightsub.go +++ b/store/heightsub.go @@ -5,126 +5,133 @@ import ( "errors" "sync" "sync/atomic" - - "github.com/celestiaorg/go-header" ) // errElapsedHeight is thrown when a requested height was already provided to heightSub. var errElapsedHeight = errors.New("elapsed height") // heightSub provides a minimalistic mechanism to wait till header for a height becomes available. -type heightSub[H header.Header[H]] struct { +type heightSub struct { // height refers to the latest locally available header height // that has been fully verified and inserted into the subjective chain height atomic.Uint64 - heightReqsLk sync.Mutex - heightReqs map[uint64]map[chan H]struct{} + heightSubsLk sync.Mutex + heightSubs map[uint64]*sub +} + +type sub struct { + signal chan struct{} + count int } // newHeightSub instantiates new heightSub. -func newHeightSub[H header.Header[H]]() *heightSub[H] { - return &heightSub[H]{ - heightReqs: make(map[uint64]map[chan H]struct{}), +func newHeightSub() *heightSub { + return &heightSub{ + heightSubs: make(map[uint64]*sub), + } +} + +// Init the heightSub with a given height. +// Notifies all awaiting [Wait] calls lower than height. +func (hs *heightSub) Init(height uint64) { + hs.height.Store(height) + + hs.heightSubsLk.Lock() + defer hs.heightSubsLk.Unlock() + + for h := range hs.heightSubs { + if h < height { + hs.notify(h, true) + } } } // Height reports current height. -func (hs *heightSub[H]) Height() uint64 { +func (hs *heightSub) Height() uint64 { return hs.height.Load() } // SetHeight sets the new head height for heightSub. -func (hs *heightSub[H]) SetHeight(height uint64) { - hs.height.Store(height) +// Notifies all awaiting [Wait] calls in range from [heightSub.Height] to height. +func (hs *heightSub) SetHeight(height uint64) { + for { + curr := hs.height.Load() + if curr >= height { + return + } + if !hs.height.CompareAndSwap(curr, height) { + continue + } + + hs.heightSubsLk.Lock() + defer hs.heightSubsLk.Unlock() //nolint:gocritic // we have a return below + + for ; curr <= height; curr++ { + hs.notify(curr, true) + } + return + } } -// Sub subscribes for a header of a given height. -// It can return errElapsedHeight, which means a requested header was already provided +// Wait for a given height to be published. +// It can return errElapsedHeight, which means a requested height was already seen // and caller should get it elsewhere. -func (hs *heightSub[H]) Sub(ctx context.Context, height uint64) (H, error) { - var zero H +func (hs *heightSub) Wait(ctx context.Context, height uint64) error { if hs.Height() >= height { - return zero, errElapsedHeight + return errElapsedHeight } - hs.heightReqsLk.Lock() + hs.heightSubsLk.Lock() if hs.Height() >= height { // This is a rare case we have to account for. // The lock above can park a goroutine long enough for hs.height to change for a requested height, // leaving the request never fulfilled and the goroutine deadlocked. - hs.heightReqsLk.Unlock() - return zero, errElapsedHeight + hs.heightSubsLk.Unlock() + return errElapsedHeight } - resp := make(chan H, 1) - reqs, ok := hs.heightReqs[height] + + sac, ok := hs.heightSubs[height] if !ok { - reqs = make(map[chan H]struct{}) - hs.heightReqs[height] = reqs + sac = &sub{ + signal: make(chan struct{}, 1), + } + hs.heightSubs[height] = sac } - reqs[resp] = struct{}{} - hs.heightReqsLk.Unlock() + sac.count++ + hs.heightSubsLk.Unlock() select { - case resp := <-resp: - return resp, nil + case <-sac.signal: + return nil case <-ctx.Done(): // no need to keep the request, if the op has canceled - hs.heightReqsLk.Lock() - delete(reqs, resp) - if len(reqs) == 0 { - delete(hs.heightReqs, height) - } - hs.heightReqsLk.Unlock() - return zero, ctx.Err() + hs.heightSubsLk.Lock() + hs.notify(height, false) + hs.heightSubsLk.Unlock() + return ctx.Err() } } -// Pub processes all the outstanding subscriptions matching the given headers. -// Pub is only safe when called from one goroutine. -// For Pub to work correctly, heightSub has to be initialized with SetHeight -// so that given headers are contiguous to the height on heightSub. -func (hs *heightSub[H]) Pub(headers ...H) { - ln := len(headers) - if ln == 0 { - return - } +// Notify and release the waiters in [Wait]. +// Note: do not advance heightSub's height. +func (hs *heightSub) Notify(heights ...uint64) { + hs.heightSubsLk.Lock() + defer hs.heightSubsLk.Unlock() - height := hs.Height() - from, to := headers[0].Height(), headers[ln-1].Height() - if height+1 != from && height != 0 { // height != 0 is needed to enable init from any height and not only 1 - log.Fatalf("PLEASE FILE A BUG REPORT: headers given to the heightSub are in the wrong order: expected %d, got %d", height+1, from) - return + for _, h := range heights { + hs.notify(h, true) } - hs.SetHeight(to) - - hs.heightReqsLk.Lock() - defer hs.heightReqsLk.Unlock() - - // there is a common case where we Pub only header - // in this case, we shouldn't loop over each heightReqs - // and instead read from the map directly - if ln == 1 { - reqs, ok := hs.heightReqs[from] - if ok { - for req := range reqs { - req <- headers[0] // reqs must always be buffered, so this won't block - } - delete(hs.heightReqs, from) - } +} + +func (hs *heightSub) notify(height uint64, all bool) { + sac, ok := hs.heightSubs[height] + if !ok { return } - // instead of looping over each header in 'headers', we can loop over each request - // which will drastically decrease idle iterations, as there will be less requests than headers - for height, reqs := range hs.heightReqs { - // then we look if any of the requests match the given range of headers - if height >= from && height <= to { - // and if so, calculate its position and fulfill requests - h := headers[height-from] - for req := range reqs { - req <- h // reqs must always be buffered, so this won't block - } - delete(hs.heightReqs, height) - } + sac.count-- + if all || sac.count == 0 { + close(sac.signal) + delete(hs.heightSubs, height) } } diff --git a/store/heightsub_test.go b/store/heightsub_test.go index f5958422..6ef64a64 100644 --- a/store/heightsub_test.go +++ b/store/heightsub_test.go @@ -14,18 +14,14 @@ func TestHeightSub(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - hs := newHeightSub[*headertest.DummyHeader]() + hs := newHeightSub() // assert subscription returns nil for past heights { - h := headertest.RandDummyHeader(t) - h.HeightI = 100 - hs.SetHeight(99) - hs.Pub(h) + hs.Init(99) - h, err := hs.Sub(ctx, 10) + err := hs.Wait(ctx, 10) assert.ErrorIs(t, err, errElapsedHeight) - assert.Nil(t, h) } // assert actual subscription works @@ -34,16 +30,11 @@ func TestHeightSub(t *testing.T) { // fixes flakiness on CI time.Sleep(time.Millisecond) - h1 := headertest.RandDummyHeader(t) - h1.HeightI = 101 - h2 := headertest.RandDummyHeader(t) - h2.HeightI = 102 - hs.Pub(h1, h2) + hs.SetHeight(102) }() - h, err := hs.Sub(ctx, 101) + err := hs.Wait(ctx, 101) assert.NoError(t, err) - assert.NotNil(t, h) } // assert multiple subscriptions work @@ -51,16 +42,14 @@ func TestHeightSub(t *testing.T) { ch := make(chan error, 10) for range cap(ch) { go func() { - _, err := hs.Sub(ctx, 103) + err := hs.Wait(ctx, 103) ch <- err }() } time.Sleep(time.Millisecond * 10) - h3 := headertest.RandDummyHeader(t) - h3.HeightI = 103 - hs.Pub(h3) + hs.SetHeight(103) for range cap(ch) { assert.NoError(t, <-ch) @@ -68,18 +57,98 @@ func TestHeightSub(t *testing.T) { } } +func TestHeightSub_withWaitCancelled(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + hs := newHeightSub() + hs.Init(10) + + const waiters = 5 + + cancelChs := make([]chan error, waiters) + blockedChs := make([]chan error, waiters) + for i := range waiters { + cancelChs[i] = make(chan error, 1) + blockedChs[i] = make(chan error, 1) + + go func() { + ctx, cancel := context.WithTimeout(ctx, time.Duration(i+1)*time.Millisecond) + defer cancel() + + err := hs.Wait(ctx, 100) + cancelChs[i] <- err + }() + + go func() { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + err := hs.Wait(ctx, 100) + blockedChs[i] <- err + }() + } + + for i := range cancelChs { + err := <-cancelChs[i] + assert.ErrorIs(t, err, context.DeadlineExceeded) + } + + for i := range blockedChs { + select { + case <-blockedChs[i]: + t.Error("channel should be blocked") + default: + } + } +} + +// Test heightSub can accept non-adj headers without an error. +func TestHeightSubNonAdjacency(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + hs := newHeightSub() + hs.Init(99) + + go func() { + // fixes flakiness on CI + time.Sleep(time.Millisecond) + + hs.SetHeight(300) + }() + + err := hs.Wait(ctx, 200) + assert.NoError(t, err) +} + +// Test heightSub's height cannot go down but only up. +func TestHeightSub_monotonicHeight(t *testing.T) { + hs := newHeightSub() + + hs.Init(99) + assert.Equal(t, int64(hs.height.Load()), int64(99)) + + hs.SetHeight(300) + assert.Equal(t, int64(hs.height.Load()), int64(300)) + + hs.SetHeight(120) + assert.Equal(t, int64(hs.height.Load()), int64(300)) +} + func TestHeightSubCancellation(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() h := headertest.RandDummyHeader(t) - hs := newHeightSub[*headertest.DummyHeader]() + h.HeightI %= 1000 // make it a bit lower + hs := newHeightSub() - sub := make(chan *headertest.DummyHeader) + sub := make(chan struct{}) go func() { // subscribe first time - h, _ := hs.Sub(ctx, h.HeightI) - sub <- h + hs.Wait(ctx, h.Height()) + sub <- struct{}{} }() // give a bit time for subscription to settle @@ -88,19 +157,18 @@ func TestHeightSubCancellation(t *testing.T) { // subscribe again but with failed canceled context canceledCtx, cancel := context.WithCancel(ctx) cancel() - _, err := hs.Sub(canceledCtx, h.HeightI) - assert.Error(t, err) + err := hs.Wait(canceledCtx, h.Height()) + assert.ErrorIs(t, err, context.Canceled) - // publish header - hs.Pub(h) + // update height + hs.SetHeight(h.Height()) // ensure we still get our header select { - case subH := <-sub: - assert.Equal(t, h.HeightI, subH.HeightI) + case <-sub: case <-ctx.Done(): t.Error(ctx.Err()) } // ensure we don't have any active subscriptions - assert.Len(t, hs.heightReqs, 0) + assert.Len(t, hs.heightSubs, 0) } diff --git a/store/store.go b/store/store.go index 83303fd1..d2c64b48 100644 --- a/store/store.go +++ b/store/store.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync" "sync/atomic" "time" @@ -41,18 +42,20 @@ type Store[H header.Header[H]] struct { heightIndex *heightIndexer[H] // manages current store read head height (1) and // allows callers to wait until header for a height is stored (2) - heightSub *heightSub[H] + heightSub *heightSub // writing to datastore // - // queue of headers to be written - writes chan []H + writesMu sync.Mutex + // writesPending keeps headers pending to be written in one batch + writesPending *batch[H] + // queue of batches to be written + writesCh chan *batch[H] // signals when writes are finished writesDn chan struct{} - // writeHead maintains the current write head - writeHead atomic.Pointer[H] - // pending keeps headers pending to be written in one batch - pending *batch[H] + + // contiguousHead is the highest contiguous header observed + contiguousHead atomic.Pointer[H] Params Parameters } @@ -99,15 +102,15 @@ func newStore[H header.Header[H]](ds datastore.Batching, opts ...Option) (*Store } return &Store[H]{ - ds: wrappedStore, - cache: cache, - metrics: metrics, - heightIndex: index, - heightSub: newHeightSub[H](), - writes: make(chan []H, 16), - writesDn: make(chan struct{}), - pending: newBatch[H](params.WriteBatchSize), - Params: params, + ds: wrappedStore, + cache: cache, + metrics: metrics, + heightIndex: index, + heightSub: newHeightSub(), + writesCh: make(chan *batch[H], 4), + writesDn: make(chan struct{}), + writesPending: newBatch[H](params.WriteBatchSize), + Params: params, }, nil } @@ -115,6 +118,11 @@ func (s *Store[H]) Init(ctx context.Context, initial H) error { if s.heightSub.Height() != 0 { return errors.New("store already initialized") } + + // initialize with the initial head before first flush. + s.contiguousHead.Store(&initial) + s.heightSub.Init(initial.Height()) + // trust the given header as the initial head err := s.flush(ctx, initial) if err != nil { @@ -122,35 +130,51 @@ func (s *Store[H]) Init(ctx context.Context, initial H) error { } log.Infow("initialized head", "height", initial.Height(), "hash", initial.Hash()) - s.heightSub.Pub(initial) return nil } -func (s *Store[H]) Start(context.Context) error { +// Start starts or restarts the Store. +func (s *Store[H]) Start(ctx context.Context) error { // closed s.writesDn means that store was stopped before, recreate chan. select { case <-s.writesDn: + s.writesCh = make(chan *batch[H], 4) s.writesDn = make(chan struct{}) + s.writesPending = newBatch[H](s.Params.WriteBatchSize) default: } + if err := s.loadContiguousHead(ctx); err != nil { + // we might start on an empty datastore, no key is okay. + if !errors.Is(err, datastore.ErrNotFound) { + return fmt.Errorf("header/store: cannot load headKey: %w", err) + } + } + go s.flushLoop() return nil } +// Stop stops the store and cleans up resources. +// Canceling context while stopping may leave the store in an inconsistent state. func (s *Store[H]) Stop(ctx context.Context) error { + s.writesMu.Lock() + defer s.writesMu.Unlock() + // check if store was already stopped select { case <-s.writesDn: return errStoppedStore default: } - // signal to prevent further writes to Store + // write the pending leftover select { - case s.writes <- nil: + case s.writesCh <- s.writesPending: + // signal closing to flushLoop + close(s.writesCh) case <-ctx.Done(): return ctx.Err() } - // wait till it is done writing + // wait till flushLoop is done writing select { case <-s.writesDn: case <-ctx.Done(): @@ -167,24 +191,13 @@ func (s *Store[H]) Height() uint64 { return s.heightSub.Height() } -func (s *Store[H]) Head(ctx context.Context, _ ...header.HeadOption[H]) (H, error) { - head, err := s.GetByHeight(ctx, s.heightSub.Height()) - if err == nil { - return head, nil +func (s *Store[H]) Head(_ context.Context, _ ...header.HeadOption[H]) (H, error) { + if head := s.contiguousHead.Load(); head != nil { + return *head, nil } var zero H - head, err = s.readHead(ctx) - switch { - default: - return zero, err - case errors.Is(err, datastore.ErrNotFound), errors.Is(err, header.ErrNotFound): - return zero, header.ErrNoHead - case err == nil: - s.heightSub.SetHeight(head.Height()) - log.Infow("loaded head", "height", head.Height(), "hash", head.Hash()) - return head, nil - } + return zero, header.ErrNoHead } func (s *Store[H]) Get(ctx context.Context, hash header.Hash) (H, error) { @@ -193,7 +206,7 @@ func (s *Store[H]) Get(ctx context.Context, hash header.Hash) (H, error) { return v, nil } // check if the requested header is not yet written on disk - if h := s.pending.Get(hash); !h.IsZero() { + if h, _ := s.writesPending.Get(hash); !h.IsZero() { return h, nil } @@ -217,22 +230,34 @@ func (s *Store[H]) GetByHeight(ctx context.Context, height uint64) (H, error) { if height == 0 { return zero, errors.New("header/store: height must be bigger than zero") } + + if h, err := s.getByHeight(ctx, height); err == nil { + return h, nil + } + // if the requested 'height' was not yet published // we subscribe to it - h, err := s.heightSub.Sub(ctx, height) - if !errors.Is(err, errElapsedHeight) { - return h, err + err := s.heightSub.Wait(ctx, height) + if err != nil && !errors.Is(err, errElapsedHeight) { + return zero, err } // otherwise, the errElapsedHeight is thrown, // which means the requested 'height' should be present // // check if the requested header is not yet written on disk - if h := s.pending.GetByHeight(height); !h.IsZero() { + + return s.getByHeight(ctx, height) +} + +func (s *Store[H]) getByHeight(ctx context.Context, height uint64) (H, error) { + // TODO: Synchronize with prepareWrite? + if h, _ := s.writesPending.GetByHeight(height); !h.IsZero() { return h, nil } hash, err := s.heightIndex.HashByHeight(ctx, height) if err != nil { + var zero H if errors.Is(err, datastore.ErrNotFound) { return zero, header.ErrNotFound } @@ -287,7 +312,7 @@ func (s *Store[H]) Has(ctx context.Context, hash header.Hash) (bool, error) { return ok, nil } // check if the requested header is not yet written on disk - if ok := s.pending.Has(hash); ok { + if ok := s.writesPending.Has(hash); ok { return ok, nil } @@ -304,23 +329,15 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { return nil } - var err error - // take current write head to verify headers against - var head H - headPtr := s.writeHead.Load() - if headPtr == nil { - head, err = s.Head(ctx) - if err != nil { - return err - } - } else { - head = *headPtr + // take current contiguous head to verify headers against + head, err := s.Head(ctx) + if err != nil { + return err } // collect valid headers verified := make([]H, 0, lh) for i, h := range headers { - err = head.Verify(h) if err != nil { var verErr *header.VerifyError @@ -344,27 +361,27 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { head = h } - onWrite := func() { - newHead := verified[len(verified)-1] - s.writeHead.Store(&newHead) - log.Infow("new head", "height", newHead.Height(), "hash", newHead.Hash()) - s.metrics.newHead(newHead.Height()) + // prepare headers to be written + toWrite, err := s.prepareWrite(ctx, verified) + switch { + case err != nil: + return err + case toWrite == nil: + return nil } // queue headers to be written on disk select { - case s.writes <- verified: + case s.writesCh <- toWrite: // we return an error here after writing, // as there might be an invalid header in between of a given range - onWrite() return err default: s.metrics.writesQueueBlocked(ctx) } - // if the writes queue is full, we block until it is not + // if the writesCh queue is full - we block anyway select { - case s.writes <- verified: - onWrite() + case s.writesCh <- toWrite: return err case <-s.writesDn: return errStoppedStore @@ -373,6 +390,32 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { } } +func (s *Store[H]) prepareWrite(ctx context.Context, headers []H) (*batch[H], error) { + s.writesMu.Lock() + defer s.writesMu.Unlock() + // check if store was stopped + select { + case <-s.writesDn: + return nil, errStoppedStore + default: + } + + // keep verified headers as pending writes and ensure they are accessible for reads + s.writesPending.Append(headers...) + // notify heightSub about new headers + // Notify after updating pending so unblocked heightSub waiters can get it + s.heightSub.Notify(getHeights(headers...)...) + // advance contiguousHead if we don't have gaps. + s.advanceContiguousHead(ctx, s.heightSub.Height()) // TODO: Ensure never IO blocking + + // don't flush and continue if pending write batch is not grown enough, + if s.writesPending.Len() < s.Params.WriteBatchSize { + return nil, nil + } + + return s.writesPending, nil +} + // flushLoop performs writing task to the underlying datastore in a separate routine // This way writes are controlled and manageable from one place allowing // (1) Appends not to be blocked on long disk IO writes and underlying DB compactions @@ -380,21 +423,10 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { func (s *Store[H]) flushLoop() { defer close(s.writesDn) ctx := context.Background() - for headers := range s.writes { - // add headers to the pending and ensure they are accessible - s.pending.Append(headers...) - // and notify waiters if any + increase current read head height - // it is important to do Pub after updating pending - // so pending is consistent with atomic Height counter on the heightSub - s.heightSub.Pub(headers...) - // don't flush and continue if pending batch is not grown enough, - // and Store is not stopping(headers == nil) - if s.pending.Len() < s.Params.WriteBatchSize && headers != nil { - continue - } + for headers := range s.writesCh { startTime := time.Now() - toFlush := s.pending.GetAll() + toFlush := headers.GetAll() for i := 0; ; i++ { err := s.flush(ctx, toFlush...) @@ -404,21 +436,15 @@ func (s *Store[H]) flushLoop() { from, to := toFlush[0].Height(), toFlush[len(toFlush)-1].Height() log.Errorw("writing header batch", "try", i+1, "from", from, "to", to, "err", err) - s.metrics.flush(ctx, time.Since(startTime), s.pending.Len(), true) + s.metrics.flush(ctx, time.Since(startTime), headers.Len(), true) const maxRetrySleep = time.Second sleep := min(10*time.Duration(i+1)*time.Millisecond, maxRetrySleep) time.Sleep(sleep) } - s.metrics.flush(ctx, time.Since(startTime), s.pending.Len(), false) - // reset pending - s.pending.Reset() - - if headers == nil { - // a signal to stop - return - } + s.metrics.flush(ctx, time.Since(startTime), headers.Len(), false) + headers.Reset() } } @@ -448,7 +474,8 @@ func (s *Store[H]) flush(ctx context.Context, headers ...H) error { } // marshal and add to batch reference to the new head - b, err := headers[ln-1].Hash().MarshalJSON() + head := *s.contiguousHead.Load() + b, err := head.Hash().MarshalJSON() if err != nil { return err } @@ -467,6 +494,18 @@ func (s *Store[H]) flush(ctx context.Context, headers ...H) error { return batch.Commit(ctx) } +// loadContiguousHead from the disk and sets contiguousHead and heightSub. +func (s *Store[H]) loadContiguousHead(ctx context.Context) error { + h, err := s.readHead(ctx) + if err != nil { + return err + } + + s.contiguousHead.Store(&h) + s.heightSub.SetHeight(h.Height()) + return nil +} + // readHead loads the head from the datastore. func (s *Store[H]) readHead(ctx context.Context) (H, error) { var zero H @@ -499,6 +538,35 @@ func (s *Store[H]) get(ctx context.Context, hash header.Hash) ([]byte, error) { return data, nil } +// advanceContiguousHead updates contiguousHead and heightSub if a higher +// contiguous header exists on a disk. +func (s *Store[H]) advanceContiguousHead(ctx context.Context, height uint64) { + newHead := s.nextContiguousHead(ctx, height) + if newHead.IsZero() || newHead.Height() <= height { + return + } + + s.contiguousHead.Store(&newHead) + s.heightSub.SetHeight(newHead.Height()) + log.Infow("new head", "height", newHead.Height(), "hash", newHead.Hash()) + s.metrics.newHead(newHead.Height()) +} + +// nextContiguousHead iterates up header by header until it finds a gap. +// if height+1 header not found returns a default header. +func (s *Store[H]) nextContiguousHead(ctx context.Context, height uint64) H { + var newHead H + for { + height++ + h, err := s.getByHeight(ctx, height) + if err != nil { + break + } + newHead = h + } + return newHead +} + // indexTo saves mapping between header Height and Hash to the given batch. func indexTo[H header.Header[H]](ctx context.Context, batch datastore.Batch, headers ...H) error { for _, h := range headers { @@ -509,3 +577,11 @@ func indexTo[H header.Header[H]](ctx context.Context, batch datastore.Batch, hea } return nil } + +func getHeights[H header.Header[H]](headers ...H) []uint64 { + heights := make([]uint64, len(headers)) + for i := range headers { + heights[i] = headers[i].Height() + } + return heights +} diff --git a/store/store_test.go b/store/store_test.go index 96f5ff25..a152a42b 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -1,7 +1,10 @@ package store import ( + "bytes" "context" + "math/rand" + stdsync "sync" "testing" "time" @@ -10,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/celestiaorg/go-header" "github.com/celestiaorg/go-header/headertest" ) @@ -145,6 +149,260 @@ func TestStore_Append_BadHeader(t *testing.T) { require.Error(t, err) } +func TestStore_Append(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + suite := headertest.NewTestSuite(t) + + ds := sync.MutexWrap(datastore.NewMapDatastore()) + store := NewTestStore(t, ctx, ds, suite.Head(), WithWriteBatchSize(4)) + + head, err := store.Head(ctx) + require.NoError(t, err) + assert.Equal(t, head.Hash(), suite.Head().Hash()) + + const workers = 10 + const chunk = 5 + headers := suite.GenDummyHeaders(workers * chunk) + + errCh := make(chan error, workers) + var wg stdsync.WaitGroup + wg.Add(workers) + + for i := range workers { + go func() { + defer wg.Done() + // make every append happened in random order. + time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) + + err := store.Append(ctx, headers[i*chunk:(i+1)*chunk]...) + errCh <- err + }() + } + + wg.Wait() + close(errCh) + for err := range errCh { + assert.NoError(t, err) + } + + // wait for batch to be written. + time.Sleep(100 * time.Millisecond) + + assert.Eventually(t, func() bool { + head, err = store.Head(ctx) + assert.NoError(t, err) + assert.Equal(t, int(head.Height()), int(headers[len(headers)-1].Height())) + + switch { + case int(head.Height()) != int(headers[len(headers)-1].Height()): + return false + case !bytes.Equal(head.Hash(), headers[len(headers)-1].Hash()): + return false + default: + return true + } + }, time.Second, time.Millisecond) +} + +func TestStore_Append_stableHeadWhenGaps(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + suite := headertest.NewTestSuite(t) + + ds := sync.MutexWrap(datastore.NewMapDatastore()) + store := NewTestStore(t, ctx, ds, suite.Head(), WithWriteBatchSize(4)) + + head, err := store.Head(ctx) + require.NoError(t, err) + assert.Equal(t, head.Hash(), suite.Head().Hash()) + + firstChunk := suite.GenDummyHeaders(5) + missedChunk := suite.GenDummyHeaders(5) + lastChunk := suite.GenDummyHeaders(5) + + wantHead := firstChunk[len(firstChunk)-1] + latestHead := lastChunk[len(lastChunk)-1] + + { + err := store.Append(ctx, firstChunk...) + require.NoError(t, err) + // wait for batch to be written. + time.Sleep(100 * time.Millisecond) + + // head is advanced to the last known header. + head, err := store.Head(ctx) + require.NoError(t, err) + assert.Equal(t, head.Height(), wantHead.Height()) + assert.Equal(t, head.Hash(), wantHead.Hash()) + + // check that store height is aligned with the head. + height := store.Height() + assert.Equal(t, height, head.Height()) + } + { + err := store.Append(ctx, lastChunk...) + require.NoError(t, err) + // wait for batch to be written. + time.Sleep(100 * time.Millisecond) + + // head is not advanced due to a gap. + head, err := store.Head(ctx) + require.NoError(t, err) + assert.Equal(t, head.Height(), wantHead.Height()) + assert.Equal(t, head.Hash(), wantHead.Hash()) + + // check that store height is aligned with the head. + height := store.Height() + assert.Equal(t, height, head.Height()) + } + { + err := store.Append(ctx, missedChunk...) + require.NoError(t, err) + // wait for batch to be written. + time.Sleep(time.Second) + + // after appending missing headers we're on the latest header. + head, err := store.Head(ctx) + require.NoError(t, err) + assert.Equal(t, head.Height(), latestHead.Height()) + assert.Equal(t, head.Hash(), latestHead.Hash()) + + // check that store height is aligned with the head. + height := store.Height() + assert.Equal(t, height, head.Height()) + } +} + +func TestStoreGetByHeight_whenGaps(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + suite := headertest.NewTestSuite(t) + + ds := sync.MutexWrap(datastore.NewMapDatastore()) + store := NewTestStore(t, ctx, ds, suite.Head(), WithWriteBatchSize(10)) + + head, err := store.Head(ctx) + require.NoError(t, err) + assert.Equal(t, head.Hash(), suite.Head().Hash()) + + { + firstChunk := suite.GenDummyHeaders(5) + latestHead := firstChunk[len(firstChunk)-1] + + err := store.Append(ctx, firstChunk...) + require.NoError(t, err) + // wait for batch to be written. + time.Sleep(100 * time.Millisecond) + + head, err := store.Head(ctx) + require.NoError(t, err) + assert.Equal(t, head.Height(), latestHead.Height()) + assert.Equal(t, head.Hash(), latestHead.Hash()) + } + + missedChunk := suite.GenDummyHeaders(5) + wantMissHead := missedChunk[len(missedChunk)-2] + + errChMiss := make(chan error, 1) + go func() { + shortCtx, shortCancel := context.WithTimeout(ctx, 3*time.Second) + defer shortCancel() + + _, err := store.GetByHeight(shortCtx, wantMissHead.Height()) + errChMiss <- err + }() + + lastChunk := suite.GenDummyHeaders(5) + wantLastHead := lastChunk[len(lastChunk)-1] + + errChLast := make(chan error, 1) + go func() { + shortCtx, shortCancel := context.WithTimeout(ctx, 3*time.Second) + defer shortCancel() + + _, err := store.GetByHeight(shortCtx, wantLastHead.Height()) + errChLast <- err + }() + + // wait for goroutines start + time.Sleep(100 * time.Millisecond) + + select { + case err := <-errChMiss: + t.Fatalf("store.GetByHeight on prelast height MUST be blocked, have error: %v", err) + case err := <-errChLast: + t.Fatalf("store.GetByHeight on last height MUST be blocked, have error: %v", err) + default: + // ok + } + + { + err := store.Append(ctx, lastChunk...) + require.NoError(t, err) + // wait for batch to be written. + time.Sleep(100 * time.Millisecond) + + select { + case err := <-errChMiss: + t.Fatalf("store.GetByHeight on prelast height MUST be blocked, have error: %v", err) + case err := <-errChLast: + require.NoError(t, err) + default: + t.Fatalf("store.GetByHeight on last height MUST NOT be blocked, have error: %v", err) + } + } + + { + err := store.Append(ctx, missedChunk...) + require.NoError(t, err) + // wait for batch to be written. + time.Sleep(100 * time.Millisecond) + + select { + case err := <-errChMiss: + require.NoError(t, err) + + head, err := store.GetByHeight(ctx, wantLastHead.Height()) + require.NoError(t, err) + require.Equal(t, head, wantLastHead) + default: + t.Fatal("store.GetByHeight on last height MUST NOT be blocked") + } + } +} + +func TestStoreGetByHeight_earlyAvailable(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + suite := headertest.NewTestSuite(t) + + ds := sync.MutexWrap(datastore.NewMapDatastore()) + store := NewTestStore(t, ctx, ds, suite.Head(), WithWriteBatchSize(10)) + + const skippedHeaders = 15 + suite.GenDummyHeaders(skippedHeaders) + lastChunk := suite.GenDummyHeaders(1) + + { + err := store.Append(ctx, lastChunk...) + require.NoError(t, err) + + // wait for batch to be written. + time.Sleep(100 * time.Millisecond) + } + + { + h, err := store.GetByHeight(ctx, lastChunk[0].Height()) + require.NoError(t, err) + require.Equal(t, h, lastChunk[0]) + } +} + // TestStore_GetRange tests possible combinations of requests and ensures that // the store can handle them adequately (even malformed requests) func TestStore_GetRange(t *testing.T) { @@ -253,6 +511,7 @@ func TestBatch_GetByHeightBeforeInit(t *testing.T) { t.Cleanup(cancel) suite := headertest.NewTestSuite(t) + suite.Head().HeightI = 1_000_000 ds := sync.MutexWrap(datastore.NewMapDatastore()) store, err := NewStore[*headertest.DummyHeader](ds) @@ -265,9 +524,8 @@ func TestBatch_GetByHeightBeforeInit(t *testing.T) { _ = store.Init(ctx, suite.Head()) }() - h, err := store.GetByHeight(ctx, 1) - require.NoError(t, err) - require.NotNil(t, h) + _, err = store.GetByHeight(ctx, 1) + require.ErrorIs(t, err, header.ErrNotFound) } func TestStoreInit(t *testing.T) { diff --git a/sync/sync_test.go b/sync/sync_test.go index b9acb2d3..b8d98b7d 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -47,19 +47,37 @@ func TestSyncSimpleRequestingHead(t *testing.T) { err = syncer.SyncWait(ctx) require.NoError(t, err) - exp, err := remoteStore.Head(ctx) - require.NoError(t, err) - - have, err := localStore.Head(ctx) - require.NoError(t, err) - assert.Equal(t, exp.Height(), have.Height()) - assert.Empty(t, syncer.pending.Head()) - - state := syncer.State() - assert.Equal(t, uint64(exp.Height()), state.Height) - assert.Equal(t, uint64(2), state.FromHeight) - assert.Equal(t, uint64(exp.Height()), state.ToHeight) - assert.True(t, state.Finished(), state) + // force sync to update underlying stores. + syncer.wantSync() + + // we need to wait for a flush + assert.Eventually(t, func() bool { + exp, err := remoteStore.Head(ctx) + require.NoError(t, err) + + have, err := localStore.Head(ctx) + require.NoError(t, err) + + state := syncer.State() + switch { + case exp.Height() != have.Height(): + return false + case syncer.pending.Head() != nil: + return false + + case uint64(exp.Height()) != state.Height: + return false + case uint64(2) != state.FromHeight: + return false + + case uint64(exp.Height()) != state.ToHeight: + return false + case !state.Finished(): + return false + default: + return true + } + }, 2*time.Second, 100*time.Millisecond) } func TestDoSyncFullRangeFromExternalPeer(t *testing.T) { @@ -206,11 +224,20 @@ func TestSyncPendingRangesWithMisses(t *testing.T) { exp, err := remoteStore.Head(ctx) require.NoError(t, err) - have, err := localStore.Head(ctx) - require.NoError(t, err) - - assert.Equal(t, exp.Height(), have.Height()) - assert.Empty(t, syncer.pending.Head()) // assert all cache from pending is used + // we need to wait for a flush + assert.Eventually(t, func() bool { + have, err := localStore.Head(ctx) + require.NoError(t, err) + + switch { + case exp.Height() != have.Height(): + return false + case !syncer.pending.Head().IsZero(): + return false + default: + return true + } + }, 2*time.Second, 100*time.Millisecond) } // TestSyncer_FindHeadersReturnsCorrectRange ensures that `findHeaders` returns