diff --git a/plugin/evm/vm_warp_test.go b/plugin/evm/vm_warp_test.go index 6fc43dab5d..c4aa7fdc39 100644 --- a/plugin/evm/vm_warp_test.go +++ b/plugin/evm/vm_warp_test.go @@ -343,22 +343,29 @@ func testWarpVMTransaction(t *testing.T, scheme string, unsignedMessage *avalanc GetSubnetIDF: func(context.Context, ids.ID) (ids.ID, error) { return ids.Empty, nil }, - GetValidatorSetF: func(_ context.Context, height uint64, _ ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + GetWarpValidatorSetF: func(_ context.Context, height uint64, _ ids.ID) (validators.WarpSet, error) { if height < minimumValidPChainHeight { - return nil, getValidatorSetTestErr + return validators.WarpSet{}, getValidatorSetTestErr } - return map[ids.NodeID]*validators.GetValidatorOutput{ - nodeID1: { - NodeID: nodeID1, - PublicKey: blsPublicKey1, - Weight: 50, + vdrs := validators.WarpSet{ + Validators: []*validators.Warp{ + { + PublicKey: blsPublicKey1, + PublicKeyBytes: bls.PublicKeyToUncompressedBytes(blsPublicKey1), + Weight: 50, + NodeIDs: []ids.NodeID{nodeID1}, + }, + { + PublicKey: blsPublicKey2, + PublicKeyBytes: bls.PublicKeyToUncompressedBytes(blsPublicKey2), + Weight: 50, + NodeIDs: []ids.NodeID{nodeID2}, + }, }, - nodeID2: { - NodeID: nodeID2, - PublicKey: blsPublicKey2, - Weight: 50, - }, - }, nil + TotalWeight: 100, + } + avagoUtils.Sort(vdrs.Validators) + return vdrs, nil }, } @@ -645,24 +652,28 @@ func testReceiveWarpMessage( } return vm.ctx.SubnetID, nil }, - GetValidatorSetF: func(_ context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + GetWarpValidatorSetF: func(_ context.Context, height uint64, subnetID ids.ID) (validators.WarpSet, error) { if height < minimumValidPChainHeight { - return nil, getValidatorSetTestErr + return validators.WarpSet{}, getValidatorSetTestErr } signers := subnetSigners if subnetID == constants.PrimaryNetworkID { signers = primarySigners } - vdrOutput := make(map[ids.NodeID]*validators.GetValidatorOutput) + vdrs := validators.WarpSet{} for _, s := range signers { - vdrOutput[s.nodeID] = &validators.GetValidatorOutput{ - NodeID: s.nodeID, - PublicKey: s.secret.PublicKey(), - Weight: s.weight, - } + pk := s.secret.PublicKey() + vdrs.Validators = append(vdrs.Validators, &validators.Warp{ + PublicKey: pk, + PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk), + Weight: s.weight, + NodeIDs: []ids.NodeID{s.nodeID}, + }) + vdrs.TotalWeight += s.weight } - return vdrOutput, nil + avagoUtils.Sort(vdrs.Validators) + return vdrs, nil }, } diff --git a/precompile/contracts/warp/config.go b/precompile/contracts/warp/config.go index 843a3b7dcc..66d25a8a94 100644 --- a/precompile/contracts/warp/config.go +++ b/precompile/contracts/warp/config.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/vms/evm/predicate" "github.com/ava-labs/avalanchego/vms/platformvm/warp" "github.com/ava-labs/avalanchego/vms/platformvm/warp/payload" @@ -16,8 +17,6 @@ import ( "github.com/ava-labs/libevm/log" "github.com/ava-labs/coreth/precompile/precompileconfig" - - warpValidators "github.com/ava-labs/coreth/warp/validators" ) const ( @@ -202,24 +201,48 @@ func (c *Config) VerifyPredicate(predicateContext *precompileconfig.PredicateCon quorumNumerator = c.QuorumNumerator } - log.Debug("verifying warp message", "warpMsg", warpMsg, "quorumNum", quorumNumerator, "quorumDenom", WarpQuorumDenominator) + log.Debug("verifying warp message", + "warpMsg", warpMsg, + "quorumNum", quorumNumerator, + "quorumDenom", WarpQuorumDenominator, + ) - // Wrap validators.State on the chain snow context to special case the Primary Network - state := warpValidators.NewState( - predicateContext.SnowCtx.ValidatorState, - predicateContext.SnowCtx.SubnetID, + sourceSubnetID, err := predicateContext.SnowCtx.ValidatorState.GetSubnetID( + context.TODO(), warpMsg.SourceChainID, - c.RequirePrimaryNetworkSigners, ) + if err != nil { + log.Debug("failed to retrieve subnetID for chain", + "msgID", warpMsg.ID(), + "chainID", warpMsg.SourceChainID, + "err", err, + ) + return fmt.Errorf("%w: %w", errCannotRetrieveValidatorSet, err) + } + + if sourceSubnetID == constants.PrimaryNetworkID { + // For the X-chain and the C-chain, chains can be configured not to + // require the primary network validators to have signed the warp + // message and to use the, likely smaller, local subnet's validator set. + // + // The primary network validator set is never required when verifying + // messages from the P-chain because the P-chain is always synced. + if !c.RequirePrimaryNetworkSigners || warpMsg.SourceChainID == constants.PlatformChainID { + sourceSubnetID = predicateContext.SnowCtx.SubnetID + } + } - validatorSet, err := warp.GetCanonicalValidatorSetFromChainID( - context.Background(), - state, + validatorSet, err := predicateContext.SnowCtx.ValidatorState.GetWarpValidatorSet( + context.TODO(), predicateContext.ProposerVMBlockCtx.PChainHeight, - warpMsg.UnsignedMessage.SourceChainID, + sourceSubnetID, ) if err != nil { - log.Debug("failed to retrieve canonical validator set", "msgID", warpMsg.ID(), "err", err) + log.Debug("failed to retrieve canonical validator set", + "msgID", warpMsg.ID(), + "subnetID", sourceSubnetID, + "err", err, + ) return fmt.Errorf("%w: %w", errCannotRetrieveValidatorSet, err) } @@ -231,7 +254,10 @@ func (c *Config) VerifyPredicate(predicateContext *precompileconfig.PredicateCon WarpQuorumDenominator, ) if err != nil { - log.Debug("failed to verify warp signature", "msgID", warpMsg.ID(), "err", err) + log.Debug("failed to verify warp signature", + "msgID", warpMsg.ID(), + "err", err, + ) return fmt.Errorf("%w: %w", errFailedVerification, err) } diff --git a/precompile/contracts/warp/predicate_test.go b/precompile/contracts/warp/predicate_test.go index 9620ebb059..6bb0ccc3f6 100644 --- a/precompile/contracts/warp/predicate_test.go +++ b/precompile/contracts/warp/predicate_test.go @@ -188,8 +188,7 @@ type validatorRange struct { // createSnowCtx creates a snow.Context instance with a validator state specified by the given validatorRanges func createSnowCtx(tb testing.TB, validatorRanges []validatorRange) *snow.Context { - getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput) - + validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput) for _, validatorRange := range validatorRanges { for i := validatorRange.start; i < validatorRange.end; i++ { validatorOutput := &validators.GetValidatorOutput{ @@ -199,20 +198,19 @@ func createSnowCtx(tb testing.TB, validatorRanges []validatorRange) *snow.Contex if validatorRange.publicKey { validatorOutput.PublicKey = testVdrs[i].vdr.PublicKey } - getValidatorsOutput[testVdrs[i].nodeID] = validatorOutput + validatorSet[testVdrs[i].nodeID] = validatorOutput } } snowCtx := snowtest.Context(tb, snowtest.CChainID) - state := &validatorstest.State{ + snowCtx.ValidatorState = &validatorstest.State{ GetSubnetIDF: func(context.Context, ids.ID) (ids.ID, error) { return sourceSubnetID, nil }, - GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - return getValidatorsOutput, nil + GetWarpValidatorSetF: func(context.Context, uint64, ids.ID) (validators.WarpSet, error) { + return validators.FlattenValidatorSet(validatorSet) }, } - snowCtx.ValidatorState = state return snowCtx } @@ -247,20 +245,29 @@ func testWarpMessageFromPrimaryNetwork(t *testing.T, requirePrimaryNetworkSigner unsignedMsg, err := avalancheWarp.NewUnsignedMessage(constants.UnitTestID, cChainID, addressedCall.Bytes()) require.NoError(err) - getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput) - blsSignatures := make([]*bls.Signature, 0, numKeys) + var ( + warpValidators = validators.WarpSet{ + Validators: make([]*validators.Warp, 0, numKeys), + TotalWeight: 20 * uint64(numKeys), + } + blsSignatures = make([]*bls.Signature, 0, numKeys) + ) for i := 0; i < numKeys; i++ { - sig, err := testVdrs[i].sk.Sign(unsignedMsg.Bytes()) + vdr := testVdrs[i] + sig, err := vdr.sk.Sign(unsignedMsg.Bytes()) require.NoError(err) - - validatorOutput := &validators.GetValidatorOutput{ - NodeID: testVdrs[i].nodeID, - Weight: 20, - PublicKey: testVdrs[i].vdr.PublicKey, - } - getValidatorsOutput[testVdrs[i].nodeID] = validatorOutput blsSignatures = append(blsSignatures, sig) + + pk := vdr.sk.PublicKey() + warpValidators.Validators = append(warpValidators.Validators, &validators.Warp{ + PublicKey: pk, + PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk), + Weight: 20, + NodeIDs: []ids.NodeID{vdr.nodeID}, + }) } + agoUtils.Sort(warpValidators.Validators) + aggregateSignature, err := bls.AggregateSignatures(blsSignatures) require.NoError(err) bitSet := set.NewBits() @@ -284,13 +291,13 @@ func testWarpMessageFromPrimaryNetwork(t *testing.T, requirePrimaryNetworkSigner require.Equal(chainID, cChainID) return constants.PrimaryNetworkID, nil // Return Primary Network SubnetID }, - GetValidatorSetF: func(_ context.Context, _ uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + GetWarpValidatorSetF: func(_ context.Context, _ uint64, subnetID ids.ID) (validators.WarpSet, error) { expectedSubnetID := snowCtx.SubnetID if requirePrimaryNetworkSigners { expectedSubnetID = constants.PrimaryNetworkID } require.Equal(expectedSubnetID, subnetID) - return getValidatorsOutput, nil + return warpValidators, nil }, } @@ -717,25 +724,26 @@ func makeWarpPredicateTests(tb testing.TB) map[string]precompiletest.PredicateTe testName := fmt.Sprintf("%d validators w/ %d signers/repeated PublicKeys", totalNodes, numSigners) pred := createPredicate(numSigners) - getValidatorsOutput := make(map[ids.NodeID]*validators.GetValidatorOutput, totalNodes) + validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput, totalNodes) for i := 0; i < totalNodes; i++ { - getValidatorsOutput[testVdrs[i].nodeID] = &validators.GetValidatorOutput{ + validatorSet[testVdrs[i].nodeID] = &validators.GetValidatorOutput{ NodeID: testVdrs[i].nodeID, Weight: 20, PublicKey: testVdrs[i%numSigners].vdr.PublicKey, } } + warpValidators, err := validators.FlattenValidatorSet(validatorSet) + require.NoError(tb, err) snowCtx := snowtest.Context(tb, snowtest.CChainID) - state := &validatorstest.State{ + snowCtx.ValidatorState = &validatorstest.State{ GetSubnetIDF: func(context.Context, ids.ID) (ids.ID, error) { return sourceSubnetID, nil }, - GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - return getValidatorsOutput, nil + GetWarpValidatorSetF: func(context.Context, uint64, ids.ID) (validators.WarpSet, error) { + return warpValidators, nil }, } - snowCtx.ValidatorState = state predicateTests[testName] = createValidPredicateTest(snowCtx, uint64(numSigners), pred) } diff --git a/warp/service.go b/warp/service.go index ed2f438c94..a56f451c88 100644 --- a/warp/service.go +++ b/warp/service.go @@ -105,7 +105,7 @@ func (a *API) aggregateSignatures(ctx context.Context, unsignedMessage *warp.Uns return nil, err } - validatorSet, err := warp.GetCanonicalValidatorSetFromSubnetID(ctx, validatorState, pChainHeight, subnetID) + validatorSet, err := validatorState.GetWarpValidatorSet(ctx, pChainHeight, subnetID) if err != nil { return nil, fmt.Errorf("failed to get validator set: %w", err) } diff --git a/warp/validators/state.go b/warp/validators/state.go deleted file mode 100644 index 5853075069..0000000000 --- a/warp/validators/state.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package validators - -import ( - "context" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/constants" -) - -var _ validators.State = (*State)(nil) - -// State provides a special case used to handle Avalanche Warp Message verification for messages sent -// from the Primary Network. Subnets have strictly fewer validators than the Primary Network, so we require -// signatures from a threshold of the RECEIVING subnet validator set rather than the full Primary Network -// since the receiving subnet already relies on a majority of its validators being correct. -type State struct { - validators.State - mySubnetID ids.ID - sourceChainID ids.ID - requirePrimaryNetworkSigners bool -} - -// NewState returns a wrapper of [validators.State] which special cases the handling of the Primary Network. -// -// The wrapped state will return the [mySubnetID's] validator set instead of the Primary Network when -// the Primary Network SubnetID is passed in. -func NewState(state validators.State, mySubnetID ids.ID, sourceChainID ids.ID, requirePrimaryNetworkSigners bool) *State { - return &State{ - State: state, - mySubnetID: mySubnetID, - sourceChainID: sourceChainID, - requirePrimaryNetworkSigners: requirePrimaryNetworkSigners, - } -} - -func (s *State) GetValidatorSet( - ctx context.Context, - height uint64, - subnetID ids.ID, -) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - // If the subnetID is anything other than the Primary Network, or Primary - // Network signers are required (except P-Chain), this is a direct passthrough. - usePrimary := s.requirePrimaryNetworkSigners && s.sourceChainID != constants.PlatformChainID - if usePrimary || subnetID != constants.PrimaryNetworkID { - return s.State.GetValidatorSet(ctx, height, subnetID) - } - - // If the requested subnet is the primary network, then we return the validator - // set for the Subnet that is receiving the message instead. - return s.State.GetValidatorSet(ctx, height, s.mySubnetID) -} diff --git a/warp/validators/state_test.go b/warp/validators/state_test.go deleted file mode 100644 index 4ed8fe9610..0000000000 --- a/warp/validators/state_test.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package validators - -import ( - "context" - "testing" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/snowtest" - "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/snow/validators/validatorsmock" - "github.com/ava-labs/avalanchego/utils/constants" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" -) - -func TestGetValidatorSetPrimaryNetwork(t *testing.T) { - require := require.New(t) - ctrl := gomock.NewController(t) - - mySubnetID := ids.GenerateTestID() - otherSubnetID := ids.GenerateTestID() - - mockState := validatorsmock.NewState(ctrl) - snowCtx := snowtest.Context(t, snowtest.CChainID) - snowCtx.SubnetID = mySubnetID - snowCtx.ValidatorState = mockState - state := NewState(snowCtx.ValidatorState, snowCtx.SubnetID, snowCtx.ChainID, false) - // Expect that requesting my validator set returns my validator set - mockState.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), mySubnetID).Return(make(map[ids.NodeID]*validators.GetValidatorOutput), nil) - output, err := state.GetValidatorSet(context.Background(), 10, mySubnetID) - require.NoError(err) - require.Empty(output) - - // Expect that requesting the Primary Network validator set overrides and returns my validator set - mockState.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), mySubnetID).Return(make(map[ids.NodeID]*validators.GetValidatorOutput), nil) - output, err = state.GetValidatorSet(context.Background(), 10, constants.PrimaryNetworkID) - require.NoError(err) - require.Empty(output) - - // Expect that requesting other validator set returns that validator set - mockState.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), otherSubnetID).Return(make(map[ids.NodeID]*validators.GetValidatorOutput), nil) - output, err = state.GetValidatorSet(context.Background(), 10, otherSubnetID) - require.NoError(err) - require.Empty(output) -}