Skip to content

Commit e85e942

Browse files
authored
Merge pull request #1824 from lightninglabs/wip/fix-flake-queueSeedlingsInBatch
tapgarden: harden `MockKeyRing` and fix flaky `queueSeedlingsInBatch` test
2 parents 6d104fd + 7139de5 commit e85e942

File tree

2 files changed

+71
-32
lines changed

2 files changed

+71
-32
lines changed

tapgarden/mock.go

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,10 @@ type MockKeyRing struct {
857857
KeyIndex uint32
858858

859859
Keys map[keychain.KeyLocator]*btcec.PrivateKey
860+
861+
// deriveNextKeyCallCount is used to track the number of calls to
862+
// DeriveNextKey.
863+
deriveNextKeyCallCount atomic.Uint64
860864
}
861865

862866
var _ KeyRing = (*MockKeyRing)(nil)
@@ -869,8 +873,11 @@ func NewMockKeyRing() *MockKeyRing {
869873
keyRing.On(
870874
"DeriveNextKey", mock.Anything,
871875
keychain.KeyFamily(asset.TaprootAssetsKeyFamily),
872-
).Return(nil)
873-
keyRing.On("DeriveNextTaprootAssetKey", mock.Anything).Return(nil)
876+
).Return(keychain.KeyDescriptor{}, nil)
877+
878+
keyRing.On(
879+
"DeriveNextTaprootAssetKey", mock.Anything,
880+
).Return(keychain.KeyDescriptor{}, nil)
874881

875882
return keyRing
876883
}
@@ -880,6 +887,7 @@ func NewMockKeyRing() *MockKeyRing {
880887
func (m *MockKeyRing) DeriveNextTaprootAssetKey(
881888
ctx context.Context) (keychain.KeyDescriptor, error) {
882889

890+
// No need to lock mutex here, DeriveNextKey does that for us.
883891
m.Called(ctx)
884892

885893
return m.DeriveNextKey(ctx, asset.TaprootAssetsKeyFamily)
@@ -888,20 +896,21 @@ func (m *MockKeyRing) DeriveNextTaprootAssetKey(
888896
func (m *MockKeyRing) DeriveNextKey(ctx context.Context,
889897
keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
890898

899+
m.Lock()
900+
defer func() {
901+
m.KeyIndex++
902+
m.Unlock()
903+
}()
904+
891905
m.Called(ctx, keyFam)
906+
m.deriveNextKeyCallCount.Add(1)
892907

893908
select {
894909
case <-ctx.Done():
895910
return keychain.KeyDescriptor{}, fmt.Errorf("shutting down")
896911
default:
897912
}
898913

899-
m.Lock()
900-
defer func() {
901-
m.KeyIndex++
902-
m.Unlock()
903-
}()
904-
905914
priv, err := btcec.NewPrivateKey()
906915
if err != nil {
907916
return keychain.KeyDescriptor{}, err
@@ -925,10 +934,10 @@ func (m *MockKeyRing) DeriveNextKey(ctx context.Context,
925934
func (m *MockKeyRing) IsLocalKey(ctx context.Context,
926935
d keychain.KeyDescriptor) bool {
927936

928-
m.Called(ctx, d)
937+
m.Lock()
938+
defer m.Unlock()
929939

930-
m.RLock()
931-
defer m.RUnlock()
940+
m.Called(ctx, d)
932941

933942
priv, ok := m.Keys[d.KeyLocator]
934943
if ok && priv.PubKey().IsEqual(d.PubKey) {
@@ -945,8 +954,8 @@ func (m *MockKeyRing) IsLocalKey(ctx context.Context,
945954
}
946955

947956
func (m *MockKeyRing) PubKeyAt(t *testing.T, idx uint32) *btcec.PublicKey {
948-
m.RLock()
949-
defer m.RUnlock()
957+
m.Lock()
958+
defer m.Unlock()
950959

951960
loc := keychain.KeyLocator{
952961
Index: idx,
@@ -962,8 +971,8 @@ func (m *MockKeyRing) PubKeyAt(t *testing.T, idx uint32) *btcec.PublicKey {
962971
}
963972

964973
func (m *MockKeyRing) ScriptKeyAt(t *testing.T, idx uint32) asset.ScriptKey {
965-
m.RLock()
966-
defer m.RUnlock()
974+
m.Lock()
975+
defer m.Unlock()
967976

968977
loc := keychain.KeyLocator{
969978
Index: idx,
@@ -984,13 +993,13 @@ func (m *MockKeyRing) ScriptKeyAt(t *testing.T, idx uint32) asset.ScriptKey {
984993
func (m *MockKeyRing) DeriveSharedKey(_ context.Context, key *btcec.PublicKey,
985994
locator *keychain.KeyLocator) ([sha256.Size]byte, error) {
986995

996+
m.Lock()
997+
defer m.Unlock()
998+
987999
if locator == nil {
9881000
return [32]byte{}, fmt.Errorf("locator is nil")
9891001
}
9901002

991-
m.RLock()
992-
defer m.RUnlock()
993-
9941003
priv, ok := m.Keys[*locator]
9951004
if !ok {
9961005
return [32]byte{}, fmt.Errorf("script key not found at index "+
@@ -1003,6 +1012,19 @@ func (m *MockKeyRing) DeriveSharedKey(_ context.Context, key *btcec.PublicKey,
10031012
return ecdh.ECDH(key)
10041013
}
10051014

1015+
// DeriveNextKeyCallCount returns the number of calls to DeriveNextKey. This is
1016+
// useful in tests to assert that the key ring was used as expected in
1017+
// concurrent scenarios.
1018+
func (m *MockKeyRing) DeriveNextKeyCallCount() int {
1019+
return int(m.deriveNextKeyCallCount.Load())
1020+
}
1021+
1022+
// ResetDeriveNextKeyCallCount resets the call counter for DeriveNextKey to
1023+
// zero. This is useful in tests to ensure a clean state for assertions.
1024+
func (m *MockKeyRing) ResetDeriveNextKeyCallCount() {
1025+
m.deriveNextKeyCallCount.Store(0)
1026+
}
1027+
10061028
type MockGenSigner struct {
10071029
KeyRing *MockKeyRing
10081030
failSigning atomic.Bool

tapgarden/planter_test.go

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,9 @@ func (t *mintingTestHarness) queueSeedlingsInBatch(isFunded bool,
272272

273273
for i, seedling := range seedlings {
274274
seedling := seedling
275+
276+
t.keyRing.ResetDeriveNextKeyCallCount()
275277
keyCount := 0
276-
t.keyRing.Calls = nil
277278

278279
// For the first seedling sent, we should get a new request,
279280
// representing the batch internal key.
@@ -310,18 +311,21 @@ func (t *mintingTestHarness) queueSeedlingsInBatch(isFunded bool,
310311
// The received update should be a state of MintingStateSeed.
311312
require.Equal(t, tapgarden.MintingStateSeed, update.NewState)
312313

313-
require.Eventually(t, func() bool {
314+
err = wait.NoError(func() error {
314315
// Assert that the key ring method DeriveNextKey was
315316
// called the expected number of times.
316-
count := 0
317-
for _, call := range t.keyRing.Calls {
318-
if call.Method == "DeriveNextKey" {
319-
count++
320-
}
317+
expectedCount := keyCount
318+
actualCount := t.keyRing.DeriveNextKeyCallCount()
319+
320+
if actualCount < expectedCount {
321+
return fmt.Errorf("expected %d calls to key "+
322+
"derivation, got %d", expectedCount,
323+
actualCount)
321324
}
322325

323-
return count == keyCount
324-
}, defaultTimeout, wait.PollInterval)
326+
return nil
327+
}, defaultTimeout)
328+
require.NoError(t, err)
325329
}
326330
}
327331

@@ -332,13 +336,26 @@ func (t *mintingTestHarness) assertPendingBatchExists(numSeedlings int) {
332336

333337
// The planter is a state machine, so we need to wait until it has
334338
// reached the expected state.
335-
require.Eventually(t, func() bool {
339+
err := wait.NoError(func() error {
336340
batch, err := t.planter.PendingBatch()
337-
require.NoError(t, err)
341+
if err != nil {
342+
return fmt.Errorf("unable to fetch pending batch: %w",
343+
err)
344+
}
345+
346+
if batch == nil {
347+
return fmt.Errorf("expected pending batch to be " +
348+
"non-nil")
349+
}
350+
351+
if len(batch.Seedlings) < numSeedlings {
352+
return fmt.Errorf("expected %d seedlings, got %d",
353+
numSeedlings, len(batch.Seedlings))
354+
}
338355

339-
require.NotNil(t, batch)
340-
return len(batch.Seedlings) == numSeedlings
341-
}, defaultTimeout, wait.PollInterval)
356+
return nil
357+
}, defaultTimeout)
358+
require.NoError(t, err)
342359
}
343360

344361
// assertNoActiveBatch asserts that no pending batch exists.

0 commit comments

Comments
 (0)