@@ -857,6 +857,10 @@ type MockKeyRing struct {
857
857
KeyIndex uint32
858
858
859
859
Keys map [keychain.KeyLocator ]* btcec.PrivateKey
860
+
861
+ // deriveNextKeyCallCount is used to track the number of calls to
862
+ // DeriveNextKey.
863
+ deriveNextKeyCallCount atomic.Uint64
860
864
}
861
865
862
866
var _ KeyRing = (* MockKeyRing )(nil )
@@ -869,8 +873,11 @@ func NewMockKeyRing() *MockKeyRing {
869
873
keyRing .On (
870
874
"DeriveNextKey" , mock .Anything ,
871
875
keychain .KeyFamily (asset .TaprootAssetsKeyFamily ),
872
- ).Return (nil )
873
- keyRing .On ("DeriveNextTaprootAssetKey" , mock .Anything ).Return (nil )
876
+ ).Return (keychain.KeyDescriptor {}, nil )
877
+
878
+ keyRing .On (
879
+ "DeriveNextTaprootAssetKey" , mock .Anything ,
880
+ ).Return (keychain.KeyDescriptor {}, nil )
874
881
875
882
return keyRing
876
883
}
@@ -880,6 +887,7 @@ func NewMockKeyRing() *MockKeyRing {
880
887
func (m * MockKeyRing ) DeriveNextTaprootAssetKey (
881
888
ctx context.Context ) (keychain.KeyDescriptor , error ) {
882
889
890
+ // No need to lock mutex here, DeriveNextKey does that for us.
883
891
m .Called (ctx )
884
892
885
893
return m .DeriveNextKey (ctx , asset .TaprootAssetsKeyFamily )
@@ -888,20 +896,21 @@ func (m *MockKeyRing) DeriveNextTaprootAssetKey(
888
896
func (m * MockKeyRing ) DeriveNextKey (ctx context.Context ,
889
897
keyFam keychain.KeyFamily ) (keychain.KeyDescriptor , error ) {
890
898
899
+ m .Lock ()
900
+ defer func () {
901
+ m .KeyIndex ++
902
+ m .Unlock ()
903
+ }()
904
+
891
905
m .Called (ctx , keyFam )
906
+ m .deriveNextKeyCallCount .Add (1 )
892
907
893
908
select {
894
909
case <- ctx .Done ():
895
910
return keychain.KeyDescriptor {}, fmt .Errorf ("shutting down" )
896
911
default :
897
912
}
898
913
899
- m .Lock ()
900
- defer func () {
901
- m .KeyIndex ++
902
- m .Unlock ()
903
- }()
904
-
905
914
priv , err := btcec .NewPrivateKey ()
906
915
if err != nil {
907
916
return keychain.KeyDescriptor {}, err
@@ -925,10 +934,10 @@ func (m *MockKeyRing) DeriveNextKey(ctx context.Context,
925
934
func (m * MockKeyRing ) IsLocalKey (ctx context.Context ,
926
935
d keychain.KeyDescriptor ) bool {
927
936
928
- m .Called (ctx , d )
937
+ m .Lock ()
938
+ defer m .Unlock ()
929
939
930
- m .RLock ()
931
- defer m .RUnlock ()
940
+ m .Called (ctx , d )
932
941
933
942
priv , ok := m .Keys [d .KeyLocator ]
934
943
if ok && priv .PubKey ().IsEqual (d .PubKey ) {
@@ -945,8 +954,8 @@ func (m *MockKeyRing) IsLocalKey(ctx context.Context,
945
954
}
946
955
947
956
func (m * MockKeyRing ) PubKeyAt (t * testing.T , idx uint32 ) * btcec.PublicKey {
948
- m .RLock ()
949
- defer m .RUnlock ()
957
+ m .Lock ()
958
+ defer m .Unlock ()
950
959
951
960
loc := keychain.KeyLocator {
952
961
Index : idx ,
@@ -962,8 +971,8 @@ func (m *MockKeyRing) PubKeyAt(t *testing.T, idx uint32) *btcec.PublicKey {
962
971
}
963
972
964
973
func (m * MockKeyRing ) ScriptKeyAt (t * testing.T , idx uint32 ) asset.ScriptKey {
965
- m .RLock ()
966
- defer m .RUnlock ()
974
+ m .Lock ()
975
+ defer m .Unlock ()
967
976
968
977
loc := keychain.KeyLocator {
969
978
Index : idx ,
@@ -984,13 +993,13 @@ func (m *MockKeyRing) ScriptKeyAt(t *testing.T, idx uint32) asset.ScriptKey {
984
993
func (m * MockKeyRing ) DeriveSharedKey (_ context.Context , key * btcec.PublicKey ,
985
994
locator * keychain.KeyLocator ) ([sha256 .Size ]byte , error ) {
986
995
996
+ m .Lock ()
997
+ defer m .Unlock ()
998
+
987
999
if locator == nil {
988
1000
return [32 ]byte {}, fmt .Errorf ("locator is nil" )
989
1001
}
990
1002
991
- m .RLock ()
992
- defer m .RUnlock ()
993
-
994
1003
priv , ok := m .Keys [* locator ]
995
1004
if ! ok {
996
1005
return [32 ]byte {}, fmt .Errorf ("script key not found at index " +
@@ -1003,6 +1012,19 @@ func (m *MockKeyRing) DeriveSharedKey(_ context.Context, key *btcec.PublicKey,
1003
1012
return ecdh .ECDH (key )
1004
1013
}
1005
1014
1015
+ // DeriveNextKeyCallCount returns the number of calls to DeriveNextKey. This is
1016
+ // useful in tests to assert that the key ring was used as expected in
1017
+ // concurrent scenarios.
1018
+ func (m * MockKeyRing ) DeriveNextKeyCallCount () int {
1019
+ return int (m .deriveNextKeyCallCount .Load ())
1020
+ }
1021
+
1022
+ // ResetDeriveNextKeyCallCount resets the call counter for DeriveNextKey to
1023
+ // zero. This is useful in tests to ensure a clean state for assertions.
1024
+ func (m * MockKeyRing ) ResetDeriveNextKeyCallCount () {
1025
+ m .deriveNextKeyCallCount .Store (0 )
1026
+ }
1027
+
1006
1028
type MockGenSigner struct {
1007
1029
KeyRing * MockKeyRing
1008
1030
failSigning atomic.Bool
0 commit comments