diff --git a/x/evm/keeper/vote_handler_test.go b/x/evm/keeper/vote_handler_test.go index 4691da2524..29f6a4fea5 100644 --- a/x/evm/keeper/vote_handler_test.go +++ b/x/evm/keeper/vote_handler_test.go @@ -1,11 +1,13 @@ package keeper_test import ( - "errors" + "fmt" mathRand "math/rand" "testing" + "github.com/CosmWasm/wasmd/x/wasm" "github.com/cosmos/cosmos-sdk/codec" + sdkstore "github.com/cosmos/cosmos-sdk/store/types" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/assert" @@ -13,20 +15,19 @@ import ( tmproto "github.com/tendermint/tendermint/proto/tendermint/types" "github.com/axelarnetwork/axelar-core/app/params" - "github.com/axelarnetwork/axelar-core/testutils" - "github.com/axelarnetwork/axelar-core/testutils/fake" - fakeMock "github.com/axelarnetwork/axelar-core/testutils/fake/interfaces/mock" + fakemock "github.com/axelarnetwork/axelar-core/testutils/fake/interfaces/mock" "github.com/axelarnetwork/axelar-core/testutils/rand" + axelarnet "github.com/axelarnetwork/axelar-core/x/axelarnet/exported" "github.com/axelarnetwork/axelar-core/x/evm/exported" "github.com/axelarnetwork/axelar-core/x/evm/keeper" "github.com/axelarnetwork/axelar-core/x/evm/types" "github.com/axelarnetwork/axelar-core/x/evm/types/mock" nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" - mock4 "github.com/axelarnetwork/axelar-core/x/nexus/exported/mock" + nexusmock "github.com/axelarnetwork/axelar-core/x/nexus/exported/mock" reward "github.com/axelarnetwork/axelar-core/x/reward/exported" - mock3 "github.com/axelarnetwork/axelar-core/x/reward/exported/mock" + rewardmock "github.com/axelarnetwork/axelar-core/x/reward/exported/mock" vote "github.com/axelarnetwork/axelar-core/x/vote/exported" - mock2 "github.com/axelarnetwork/axelar-core/x/vote/exported/mock" + votemock "github.com/axelarnetwork/axelar-core/x/vote/exported/mock" "github.com/axelarnetwork/utils/slices" . "github.com/axelarnetwork/utils/test" ) @@ -36,14 +37,14 @@ func TestHandleExpiredPoll(t *testing.T) { var ( ctx sdk.Context n *mock.NexusMock - rewardPool *mock3.RewardPoolMock - poll *mock2.PollMock - maintainerState *mock4.MaintainerStateMock + rewardPool *rewardmock.RewardPoolMock + poll *votemock.PollMock + maintainerState *nexusmock.MaintainerStateMock handler vote.VoteHandler ) givenVoteHandler := Given("the vote handler", func() { - ctx = sdk.NewContext(&fakeMock.MultiStoreMock{}, tmproto.Header{}, false, log.TestingLogger()) + ctx = sdk.NewContext(&fakemock.MultiStoreMock{}, tmproto.Header{}, false, log.TestingLogger()) encCfg := params.MakeEncodingConfig() k := &mock.BaseKeeperMock{ @@ -58,7 +59,7 @@ func TestHandleExpiredPoll(t *testing.T) { }, true }, } - rewardPool = &mock3.RewardPoolMock{} + rewardPool = &rewardmock.RewardPoolMock{} r := &mock.RewarderMock{ GetPoolFunc: func(sdk.Context, string) reward.RewardPool { return rewardPool }, } @@ -67,7 +68,7 @@ func TestHandleExpiredPoll(t *testing.T) { givenVoteHandler. When("some voter failed to vote for poll", func() { - poll = &mock2.PollMock{ + poll = &votemock.PollMock{ GetIDFunc: func() vote.PollID { return vote.PollID(rand.I64Between(10, 100)) }, GetRewardPoolNameFunc: func() (string, bool) { return rand.NormalizedStr(3), true }, GetMetaDataFunc: func() (codec.ProtoMarshaler, bool) { return &types.PollMetadata{Chain: exported.Ethereum.Name}, true }, @@ -78,7 +79,7 @@ func TestHandleExpiredPoll(t *testing.T) { } }). When("maintainer state can be found", func() { - maintainerState = &mock4.MaintainerStateMock{} + maintainerState = &nexusmock.MaintainerStateMock{} n.GetChainMaintainerStateFunc = func(sdk.Context, nexus.Chain, sdk.ValAddress) (nexus.MaintainerState, bool) { return maintainerState, true } @@ -103,7 +104,7 @@ func TestHandleExpiredPoll(t *testing.T) { givenVoteHandler. When("some voter failed to vote for poll", func() { - poll = &mock2.PollMock{ + poll = &votemock.PollMock{ GetIDFunc: func() vote.PollID { return vote.PollID(rand.I64Between(10, 100)) }, GetRewardPoolNameFunc: func() (string, bool) { return rand.NormalizedStr(3), true }, GetMetaDataFunc: func() (codec.ProtoMarshaler, bool) { return &types.PollMetadata{Chain: exported.Ethereum.Name}, true }, @@ -114,7 +115,7 @@ func TestHandleExpiredPoll(t *testing.T) { } }). When("maintainer state can not be found", func() { - maintainerState = &mock4.MaintainerStateMock{} + maintainerState = &nexusmock.MaintainerStateMock{} n.GetChainMaintainerStateFunc = func(sdk.Context, nexus.Chain, sdk.ValAddress) (nexus.MaintainerState, bool) { return nil, false } @@ -132,7 +133,7 @@ func TestHandleExpiredPoll(t *testing.T) { givenVoteHandler. When("no voter failed to vote for poll", func() { - poll = &mock2.PollMock{ + poll = &votemock.PollMock{ GetIDFunc: func() vote.PollID { return vote.PollID(rand.I64Between(10, 100)) }, GetRewardPoolNameFunc: func() (string, bool) { return rand.NormalizedStr(3), true }, GetMetaDataFunc: func() (codec.ProtoMarshaler, bool) { return &types.PollMetadata{Chain: exported.Ethereum.Name}, true }, @@ -143,7 +144,7 @@ func TestHandleExpiredPoll(t *testing.T) { } }). When("maintainer state can be found", func() { - maintainerState = &mock4.MaintainerStateMock{} + maintainerState = &nexusmock.MaintainerStateMock{} n.GetChainMaintainerStateFunc = func(sdk.Context, nexus.Chain, sdk.ValAddress) (nexus.MaintainerState, bool) { return maintainerState, true } @@ -169,181 +170,195 @@ func TestHandleResult(t *testing.T) { ctx sdk.Context basek *mock.BaseKeeperMock chaink *mock.ChainKeeperMock - n *mock.NexusMock - r *mock.RewarderMock + nexusK *mock.NexusMock result codec.ProtoMarshaler handler vote.VoteHandler ) setup := func() { - ctx = sdk.NewContext(&fakeMock.MultiStoreMock{}, tmproto.Header{}, false, log.TestingLogger()) + multiStore := fakemock.MultiStoreMock{} + multiStore.CacheMultiStoreFunc = func() sdkstore.CacheMultiStore { return &fakemock.CacheMultiStoreMock{} } + ctx = sdk.NewContext(&multiStore, tmproto.Header{}, false, log.TestingLogger()) basek = &mock.BaseKeeperMock{ - ForChainFunc: func(_ sdk.Context, chain nexus.ChainName) (types.ChainKeeper, error) { - if chain.Equals(evmChain) { - return chaink, nil - } - return nil, errors.New("unknown chain") - }, LoggerFunc: func(ctx sdk.Context) log.Logger { return log.TestingLogger() }, } chaink = &mock.ChainKeeperMock{ - GetEventFunc: func(sdk.Context, types.EventID) (types.Event, bool) { - return types.Event{}, false - }, - SetConfirmedEventFunc: func(sdk.Context, types.Event) error { - return nil - }, LoggerFunc: func(ctx sdk.Context) log.Logger { return log.TestingLogger() }, } + nexusK = &mock.NexusMock{} - chains := map[nexus.ChainName]nexus.Chain{ - exported.Ethereum.Name: exported.Ethereum, - } - n = &mock.NexusMock{ - IsChainActivatedFunc: func(ctx sdk.Context, chain nexus.Chain) bool { return true }, - GetChainFunc: func(ctx sdk.Context, chain nexus.ChainName) (nexus.Chain, bool) { - c, ok := chains[chain] - return c, ok - }, - } - r = &mock.RewarderMock{} - encCfg := params.MakeEncodingConfig() - handler = keeper.NewVoteHandler(encCfg.Codec, basek, n, r) + handler = keeper.NewVoteHandler(params.MakeEncodingConfig().Codec, basek, nexusK, &mock.RewarderMock{}) } - repeats := 20 - - t.Run("Given vote When events are not from the same source chain THEN return error", testutils.Func(func(t *testing.T) { - setup() - - result = &types.VoteEvents{ - Chain: nexus.ChainName(rand.Str(5)), - Events: randTransferEvents(int(rand.I64Between(5, 10))), - } - err := handler.HandleResult(ctx, result) - - assert.Error(t, err) - }).Repeat(repeats)) - - t.Run("Given vote When events empty THEN should nothing and return nil", testutils.Func(func(t *testing.T) { - setup() - - result = &types.VoteEvents{ - Chain: evmChain, - Events: []types.Event{}, - } - err := handler.HandleResult(ctx, result) - - assert.NoError(t, err) - }).Repeat(repeats)) - - t.Run("GIVEN vote WHEN chain is not registered THEN return error", testutils.Func(func(t *testing.T) { - setup() - n.GetChainFunc = func(ctx sdk.Context, chain nexus.ChainName) (nexus.Chain, bool) { - return nexus.Chain{}, false - } - result = &types.VoteEvents{ - Chain: evmChain, - Events: randTransferEvents(int(rand.I64Between(5, 10))), - } - err := handler.HandleResult(ctx, result) - - assert.Error(t, err) - }).Repeat(repeats)) - - t.Run("GIVEN vote WHEN chain is not activated THEN still confirm the event", testutils.Func(func(t *testing.T) { - setup() - - n.IsChainActivatedFunc = func(sdk.Context, nexus.Chain) bool { return false } - - result = &types.VoteEvents{ - Chain: evmChain, - Events: randTransferEvents(int(rand.I64Between(5, 10))), - } - err := handler.HandleResult(ctx, result) - - assert.NoError(t, err) - }).Repeat(repeats)) - - t.Run("GIVEN vote WHEN result is invalid THEN panic", testutils.Func(func(t *testing.T) { - setup() + givenHandler := Given("the vote handler", setup) - result = types.NewConfirmGatewayTxRequest(rand.AccAddr(), rand.Str(5), types.Hash(common.BytesToHash(rand.Bytes(common.HashLength)))) - assert.Panics(t, func() { - _ = handler.HandleResult(ctx, result) - }) - }).Repeat(repeats)) + givenHandler. + When("result is falsy", func() { + chain := nexus.ChainName(rand.Str(5)) + result = &types.VoteEvents{ + Chain: chain, + Events: nil, + } + }). + Then("should return nil and do nothing", func(t *testing.T) { + assert.NoError(t, handler.HandleResult(ctx, result)) + }). + Run(t) - t.Run("GIVEN already confirmed event WHEN handle deposit THEN return error", testutils.Func(func(t *testing.T) { - setup() + givenHandler. + When("source chain is not registered", func() { + chain := nexus.ChainName(rand.Str(5)) + result = &types.VoteEvents{ + Chain: chain, + Events: randTransferEvents(chain, rand.I64Between(5, 10)), + } - chaink.GetEventFunc = func(sdk.Context, types.EventID) (types.Event, bool) { - return types.Event{}, true - } + nexusK.GetChainFunc = func(_ sdk.Context, _ nexus.ChainName) (nexus.Chain, bool) { return nexus.Chain{}, false } + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, handler.HandleResult(ctx, result), "is not a registered chain") + }). + Run(t) - result = &types.VoteEvents{ - Chain: evmChain, - Events: randTransferEvents(int(rand.I64Between(5, 10))), - } - err := handler.HandleResult(ctx, result) + givenHandler. + When("source chain is not an evm chain", func() { + chain := nexus.ChainName(rand.Str(5)) + result = &types.VoteEvents{ + Chain: chain, + Events: randTransferEvents(chain, rand.I64Between(5, 10)), + } - assert.Error(t, err) - }).Repeat(repeats)) + nexusK.GetChainFunc = func(_ sdk.Context, _ nexus.ChainName) (nexus.Chain, bool) { return nexus.Chain{}, true } + basek.ForChainFunc = func(_ sdk.Context, _ nexus.ChainName) (types.ChainKeeper, error) { + return nil, fmt.Errorf("not an evm chain") + } + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, handler.HandleResult(ctx, result), "is not an evm chain") + }). + Run(t) - var ( - poll *mock2.PollMock - nexusMock *mock.NexusMock - ) + givenHandler. + When("source chain is an evm chain", func() { + result = &types.VoteEvents{ + Chain: exported.Ethereum.Name, + } - Given("a vote handler", func() { - encCfg := params.MakeEncodingConfig() - nexusMock = &mock.NexusMock{ - GetChainFunc: func(_ sdk.Context, chain nexus.ChainName) (nexus.Chain, bool) { - return nexus.Chain{ - Name: chain, - SupportsForeignAssets: true, - Module: types.ModuleName, - }, true - }, - } - rewarder := &mock.RewarderMock{ - GetPoolFunc: func(sdk.Context, string) reward.RewardPool { return &mock3.RewardPoolMock{} }, - } - handler = keeper.NewVoteHandler(encCfg.Codec, &mock.BaseKeeperMock{}, nexusMock, rewarder) - }). - Given("a completed poll", func() { - poll = &mock2.PollMock{ - GetStateFunc: func() vote.PollState { return vote.Completed }, - GetResultFunc: func() codec.ProtoMarshaler { return &types.VoteEvents{Chain: "ethereum", Events: nil} }, - GetRewardPoolNameFunc: func() (string, bool) { return "rewards", true }, - GetIDFunc: func() vote.PollID { return vote.PollID(rand.PosI64()) }, - GetMetaDataFunc: func() (codec.ProtoMarshaler, bool) { return &types.PollMetadata{Chain: "ethereum"}, true }, - GetVotersFunc: func() []sdk.ValAddress { return slices.Expand(func(int) sdk.ValAddress { return rand.ValAddr() }, 10) }, + nexusK.GetChainFunc = func(_ sdk.Context, chainName nexus.ChainName) (nexus.Chain, bool) { + switch chainName { + case exported.Ethereum.Name: + return exported.Ethereum, true + case axelarnet.Axelarnet.Name: + return axelarnet.Axelarnet, true + default: + return nexus.Chain{}, false + } } - }). - When("a voter is not a chain maintainer", func() { - nexusMock.GetChainMaintainerStateFunc = func(sdk.Context, nexus.Chain, sdk.ValAddress) (nexus.MaintainerState, bool) { - return nil, false + basek.ForChainFunc = func(_ sdk.Context, _ nexus.ChainName) (types.ChainKeeper, error) { + return chaink, nil } }). - Then("ignore that voter", func(t *testing.T) { - ctx := sdk.NewContext(fake.NewMultiStore(), tmproto.Header{}, false, log.TestingLogger()) - assert.NoError(t, handler.HandleCompletedPoll(ctx, poll)) - assert.Len(t, nexusMock.SetChainMaintainerStateCalls(), 0) - }).Run(t) + Branch( + When("failed to set the confirmed event", func() { + result.(*types.VoteEvents).Events = randTransferEvents(exported.Ethereum.Name, 1) + + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return fmt.Errorf("failed to set confirmed event") } + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, handler.HandleResult(ctx, result), "failed to set confirmed event") + }), + + When("event is not contract call", func() { + result.(*types.VoteEvents).Events = randTransferEvents(exported.Ethereum.Name, 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + Then("should enqueue the confirmed event", func(t *testing.T) { + chaink.EnqueueConfirmedEventFunc = func(_ sdk.Context, _ types.EventID) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, chaink.EnqueueConfirmedEventCalls(), 5) + }), + + When("event is contract call and is sent to an evm chain", func() { + result.(*types.VoteEvents).Events = randContractCallEvents(exported.Ethereum.Name, exported.Ethereum.Name, 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + When("succeeded to set as processing general messages", func() { + nexusK.SetMessageProcessingFunc = func(_ sdk.Context, _ string) error { return nil } + }). + Then("should set as processing general messages", func(t *testing.T) { + nexusK.SetNewMessageFunc = func(_ sdk.Context, _ nexus.GeneralMessage) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, nexusK.SetNewMessageCalls(), 5) + assert.Len(t, nexusK.SetMessageProcessingCalls(), 5) + }), + + When("event is contract call and is sent to an evm chain", func() { + result.(*types.VoteEvents).Events = randContractCallEvents(exported.Ethereum.Name, exported.Ethereum.Name, 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + When("failed to set as processing general messages", func() { + nexusK.SetMessageProcessingFunc = func(_ sdk.Context, _ string) error { return fmt.Errorf("failed") } + }). + Then("should set as approved general messages", func(t *testing.T) { + nexusK.SetNewMessageFunc = func(_ sdk.Context, _ nexus.GeneralMessage) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, nexusK.SetNewMessageCalls(), 5) + assert.Len(t, nexusK.SetMessageProcessingCalls(), 5) + }), + + When("event is contract call and is sent to an non-evm chain", func() { + result.(*types.VoteEvents).Events = randContractCallEvents(exported.Ethereum.Name, axelarnet.Axelarnet.Name, 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + Then("should set as approved general messages", func(t *testing.T) { + nexusK.SetNewMessageFunc = func(_ sdk.Context, _ nexus.GeneralMessage) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, nexusK.SetNewMessageCalls(), 5) + }), + + When("event is contract call and is sent to an unknown chain", func() { + result.(*types.VoteEvents).Events = randContractCallEvents(exported.Ethereum.Name, nexus.ChainName(rand.Str(5)), 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + Then("should set as approved general messages", func(t *testing.T) { + nexusK.SetNewMessageFunc = func(_ sdk.Context, _ nexus.GeneralMessage) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, nexusK.SetNewMessageCalls(), 5) + + for _, call := range nexusK.SetNewMessageCalls() { + assert.Equal(t, wasm.ModuleName, call.M.Recipient.Chain.Module) + } + }), + ). + Run(t) } -func randTransferEvents(n int) []types.Event { +func randTransferEvents(chain nexus.ChainName, n int64) []types.Event { events := make([]types.Event, n) burnerAddress := types.Address(common.BytesToAddress(rand.Bytes(common.AddressLength))) - for i := 0; i < n; i++ { + for i := int64(0); i < n; i++ { transfer := types.EventTransfer{ To: burnerAddress, Amount: sdk.NewUint(mathRand.Uint64()), } events[i] = types.Event{ - Chain: evmChain, + Chain: chain, TxID: types.Hash(common.BytesToHash(rand.Bytes(common.HashLength))), Index: uint64(rand.I64Between(1, 50)), Event: &types.Event_Transfer{ @@ -354,3 +369,25 @@ func randTransferEvents(n int) []types.Event { return events } + +func randContractCallEvents(chain nexus.ChainName, destinationChain nexus.ChainName, n int64) []types.Event { + events := make([]types.Event, n) + for i := int64(0); i < n; i++ { + contractCall := types.EventContractCall{ + Sender: types.Address(common.BytesToAddress(rand.Bytes(common.AddressLength))), + DestinationChain: destinationChain, + ContractAddress: common.BytesToAddress(rand.Bytes(common.AddressLength)).Hex(), + PayloadHash: types.Hash(common.BytesToHash(rand.Bytes(common.HashLength))), + } + events[i] = types.Event{ + Chain: chain, + TxID: types.Hash(common.BytesToHash(rand.Bytes(common.HashLength))), + Index: uint64(rand.I64Between(1, 50)), + Event: &types.Event_ContractCall{ + ContractCall: &contractCall, + }, + } + } + + return events +}