From d0237d21a4ac4276509616d9bd5bfd6424ce93fc Mon Sep 17 00:00:00 2001 From: Sammy Liu Date: Thu, 2 Nov 2023 14:22:51 -0400 Subject: [PATCH 1/3] feat(evm): route contract call via the nexus module --- x/evm/keeper/chainKeeper.go | 36 +++++++--- x/evm/keeper/vote_handler.go | 65 ++++++++++++++--- x/evm/types/expected_keepers.go | 2 + x/evm/types/mock/expected_keepers.go | 100 +++++++++++++++++++++++++++ 4 files changed, 185 insertions(+), 18 deletions(-) diff --git a/x/evm/keeper/chainKeeper.go b/x/evm/keeper/chainKeeper.go index f38380fc5..2e86fa393 100644 --- a/x/evm/keeper/chainKeeper.go +++ b/x/evm/keeper/chainKeeper.go @@ -857,7 +857,6 @@ func (k chainKeeper) GetEvent(ctx sdk.Context, eventID types.EventID) (event typ return event, event.Status != types.EventNonExistent } -// SetConfirmedEvent sets the event as confirmed func (k chainKeeper) SetConfirmedEvent(ctx sdk.Context, event types.Event) error { eventID := event.GetID() if _, ok := k.GetEvent(ctx, eventID); ok { @@ -865,14 +864,7 @@ func (k chainKeeper) SetConfirmedEvent(ctx sdk.Context, event types.Event) error } event.Status = types.EventConfirmed - - switch event.GetEvent().(type) { - case *types.Event_ContractCall, *types.Event_ContractCallWithToken, *types.Event_TokenSent, - *types.Event_Transfer, *types.Event_TokenDeployed, *types.Event_MultisigOperatorshipTransferred: - k.GetConfirmedEventQueue(ctx).Enqueue(getEventKey(eventID), &event) - default: - return fmt.Errorf("unsupported event type %T", event) - } + k.setEvent(ctx, event) events.Emit(ctx, &types.EVMEventConfirmed{ Chain: event.Chain, @@ -883,6 +875,32 @@ func (k chainKeeper) SetConfirmedEvent(ctx sdk.Context, event types.Event) error return nil } +// EnqueueConfirmedEvent enqueues the confirmed event +func (k chainKeeper) EnqueueConfirmedEvent(ctx sdk.Context, id types.EventID) error { + event, ok := k.GetEvent(ctx, id) + if !ok { + return fmt.Errorf("event %s does not exist", id) + } + if event.Status != types.EventConfirmed { + return fmt.Errorf("event %s is not confirmed", id) + } + + switch event.GetEvent().(type) { + // we no longer allow Event_ContractCall to be enqueued in the EVM module, but + // to be enqueued in the nexus module as a general message instead + case *types.Event_ContractCallWithToken, + *types.Event_TokenSent, + *types.Event_Transfer, + *types.Event_TokenDeployed, + *types.Event_MultisigOperatorshipTransferred: + k.GetConfirmedEventQueue(ctx).Enqueue(getEventKey(id), &event) + default: + return fmt.Errorf("unsupported event type %T", event) + } + + return nil +} + // SetEventCompleted sets the event as completed func (k chainKeeper) SetEventCompleted(ctx sdk.Context, eventID types.EventID) error { event, ok := k.GetEvent(ctx, eventID) diff --git a/x/evm/keeper/vote_handler.go b/x/evm/keeper/vote_handler.go index e9c612bbe..cc01d9867 100644 --- a/x/evm/keeper/vote_handler.go +++ b/x/evm/keeper/vote_handler.go @@ -3,12 +3,15 @@ package keeper import ( "fmt" + "github.com/CosmWasm/wasmd/x/wasm" "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/axelarnetwork/axelar-core/utils" "github.com/axelarnetwork/axelar-core/utils/events" "github.com/axelarnetwork/axelar-core/x/evm/types" nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" + tss "github.com/axelarnetwork/axelar-core/x/tss/exported" vote "github.com/axelarnetwork/axelar-core/x/vote/exported" "github.com/axelarnetwork/utils/funcs" ) @@ -177,7 +180,7 @@ func (v voteHandler) HandleResult(ctx sdk.Context, result codec.ProtoMarshaler) } for _, event := range voteEvents.Events { - if err := handleEvent(ctx, ck, event, chain); err != nil { + if err := v.handleEvent(ctx, ck, event, chain); err != nil { return err } } @@ -185,16 +188,21 @@ func (v voteHandler) HandleResult(ctx sdk.Context, result codec.ProtoMarshaler) return nil } -func handleEvent(ctx sdk.Context, ck types.ChainKeeper, event types.Event, chain nexus.Chain) error { - // check if event confirmed before - eventID := event.GetID() - if _, ok := ck.GetEvent(ctx, eventID); ok { - return fmt.Errorf("event %s is already confirmed", eventID) - } +func (v voteHandler) handleEvent(ctx sdk.Context, ck types.ChainKeeper, event types.Event, chain nexus.Chain) error { if err := ck.SetConfirmedEvent(ctx, event); err != nil { - panic(err) + return err } - ck.Logger(ctx).Info(fmt.Sprintf("confirmed %s event %s in transaction %s", chain.Name, eventID, event.TxID.Hex())) + + switch event.GetEvent().(type) { + case *types.Event_ContractCall: + if err := v.handleContractCall(ctx, event); err != nil { + return err + } + default: + funcs.MustNoErr(ck.EnqueueConfirmedEvent(ctx, event.GetID())) + } + + ck.Logger(ctx).Info(fmt.Sprintf("confirmed %s event %s in transaction %s", chain.Name, event.GetID(), event.TxID.Hex())) // Deprecated ctx.EventManager().EmitEvent( @@ -210,6 +218,45 @@ func handleEvent(ctx sdk.Context, ck types.ChainKeeper, event types.Event, chain return nil } +func (v voteHandler) handleContractCall(ctx sdk.Context, event types.Event) error { + msg := mustToGeneralMessage(ctx, v.nexus, event) + + if err := v.nexus.SetNewMessage(ctx, msg); err != nil { + return err + } + + if !msg.Recipient.Chain.IsFrom(types.ModuleName) { + return nil + } + + // if the message is sent to an EVM chain, try setting the message processing + // so that the end blocker can pick it up + _ = utils.RunCached(ctx, v.keeper, func(ctx sdk.Context) (bool, error) { + err := v.nexus.SetMessageProcessing(ctx, msg.ID) + + return err == nil, err + }) + + return nil +} + +func mustToGeneralMessage(ctx sdk.Context, n types.Nexus, event types.Event) nexus.GeneralMessage { + id := string(event.GetID()) + contractCall := event.GetEvent().(*types.Event_ContractCall).ContractCall + + sourceChain := funcs.MustOk(n.GetChain(ctx, event.Chain)) + sender := nexus.CrossChainAddress{Chain: sourceChain, Address: contractCall.Sender.Hex()} + + destinationChain, ok := n.GetChain(ctx, contractCall.DestinationChain) + if !ok { + // try forwarding it to wasm router if destination chain is not registered + destinationChain = nexus.Chain{Name: contractCall.DestinationChain, SupportsForeignAssets: false, KeyType: tss.None, Module: wasm.ModuleName} + } + recipient := nexus.CrossChainAddress{Chain: destinationChain, Address: contractCall.ContractAddress} + + return nexus.NewGeneralMessage(id, sender, recipient, contractCall.PayloadHash.Bytes(), event.TxID.Bytes(), event.Index, nil) +} + func mustGetMetadata(poll vote.Poll) types.PollMetadata { md := funcs.MustOk(poll.GetMetaData()) metadata, ok := md.(*types.PollMetadata) diff --git a/x/evm/types/expected_keepers.go b/x/evm/types/expected_keepers.go index c089fdd14..4d9230a91 100644 --- a/x/evm/types/expected_keepers.go +++ b/x/evm/types/expected_keepers.go @@ -76,6 +76,7 @@ type ChainKeeper interface { GetConfirmedEventQueue(ctx sdk.Context) utils.KVQueue GetEvent(ctx sdk.Context, eventID EventID) (Event, bool) SetConfirmedEvent(ctx sdk.Context, event Event) error + EnqueueConfirmedEvent(ctx sdk.Context, eventID EventID) error SetEventCompleted(ctx sdk.Context, eventID EventID) error SetEventFailed(ctx sdk.Context, eventID EventID) error @@ -116,6 +117,7 @@ type Nexus interface { RateLimitTransfer(ctx sdk.Context, chain nexus.ChainName, asset sdk.Coin, direction nexus.TransferDirection) error SetNewMessage(ctx sdk.Context, m nexus.GeneralMessage) error GetProcessingMessages(ctx sdk.Context, chain nexus.ChainName, limit int64) []nexus.GeneralMessage + SetMessageProcessing(ctx sdk.Context, id string) error SetMessageFailed(ctx sdk.Context, id string) error SetMessageExecuted(ctx sdk.Context, id string) error } diff --git a/x/evm/types/mock/expected_keepers.go b/x/evm/types/mock/expected_keepers.go index 03c7ace4f..897e27d9c 100644 --- a/x/evm/types/mock/expected_keepers.go +++ b/x/evm/types/mock/expected_keepers.go @@ -167,6 +167,9 @@ var _ types.Nexus = &NexusMock{} // SetMessageFailedFunc: func(ctx github_com_cosmos_cosmos_sdk_types.Context, id string) error { // panic("mock out the SetMessageFailed method") // }, +// SetMessageProcessingFunc: func(ctx github_com_cosmos_cosmos_sdk_types.Context, id string) error { +// panic("mock out the SetMessageProcessing method") +// }, // SetNewMessageFunc: func(ctx github_com_cosmos_cosmos_sdk_types.Context, m github_com_axelarnetwork_axelar_core_x_nexus_exported.GeneralMessage) error { // panic("mock out the SetNewMessage method") // }, @@ -243,6 +246,9 @@ type NexusMock struct { // SetMessageFailedFunc mocks the SetMessageFailed method. SetMessageFailedFunc func(ctx github_com_cosmos_cosmos_sdk_types.Context, id string) error + // SetMessageProcessingFunc mocks the SetMessageProcessing method. + SetMessageProcessingFunc func(ctx github_com_cosmos_cosmos_sdk_types.Context, id string) error + // SetNewMessageFunc mocks the SetNewMessage method. SetNewMessageFunc func(ctx github_com_cosmos_cosmos_sdk_types.Context, m github_com_axelarnetwork_axelar_core_x_nexus_exported.GeneralMessage) error @@ -432,6 +438,13 @@ type NexusMock struct { // ID is the id argument value. ID string } + // SetMessageProcessing holds details about calls to the SetMessageProcessing method. + SetMessageProcessing []struct { + // Ctx is the ctx argument value. + Ctx github_com_cosmos_cosmos_sdk_types.Context + // ID is the id argument value. + ID string + } // SetNewMessage holds details about calls to the SetNewMessage method. SetNewMessage []struct { // Ctx is the ctx argument value. @@ -462,6 +475,7 @@ type NexusMock struct { lockSetChainMaintainerState sync.RWMutex lockSetMessageExecuted sync.RWMutex lockSetMessageFailed sync.RWMutex + lockSetMessageProcessing sync.RWMutex lockSetNewMessage sync.RWMutex } @@ -1317,6 +1331,42 @@ func (mock *NexusMock) SetMessageFailedCalls() []struct { return calls } +// SetMessageProcessing calls SetMessageProcessingFunc. +func (mock *NexusMock) SetMessageProcessing(ctx github_com_cosmos_cosmos_sdk_types.Context, id string) error { + if mock.SetMessageProcessingFunc == nil { + panic("NexusMock.SetMessageProcessingFunc: method is nil but Nexus.SetMessageProcessing was just called") + } + callInfo := struct { + Ctx github_com_cosmos_cosmos_sdk_types.Context + ID string + }{ + Ctx: ctx, + ID: id, + } + mock.lockSetMessageProcessing.Lock() + mock.calls.SetMessageProcessing = append(mock.calls.SetMessageProcessing, callInfo) + mock.lockSetMessageProcessing.Unlock() + return mock.SetMessageProcessingFunc(ctx, id) +} + +// SetMessageProcessingCalls gets all the calls that were made to SetMessageProcessing. +// Check the length with: +// +// len(mockedNexus.SetMessageProcessingCalls()) +func (mock *NexusMock) SetMessageProcessingCalls() []struct { + Ctx github_com_cosmos_cosmos_sdk_types.Context + ID string +} { + var calls []struct { + Ctx github_com_cosmos_cosmos_sdk_types.Context + ID string + } + mock.lockSetMessageProcessing.RLock() + calls = mock.calls.SetMessageProcessing + mock.lockSetMessageProcessing.RUnlock() + return calls +} + // SetNewMessage calls SetNewMessageFunc. func (mock *NexusMock) SetNewMessage(ctx github_com_cosmos_cosmos_sdk_types.Context, m github_com_axelarnetwork_axelar_core_x_nexus_exported.GeneralMessage) error { if mock.SetNewMessageFunc == nil { @@ -1684,6 +1734,9 @@ var _ types.ChainKeeper = &ChainKeeperMock{} // EnqueueCommandFunc: func(ctx github_com_cosmos_cosmos_sdk_types.Context, cmd types.Command) error { // panic("mock out the EnqueueCommand method") // }, +// EnqueueConfirmedEventFunc: func(ctx github_com_cosmos_cosmos_sdk_types.Context, eventID types.EventID) error { +// panic("mock out the EnqueueConfirmedEvent method") +// }, // GenerateSaltFunc: func(ctx github_com_cosmos_cosmos_sdk_types.Context, recipient string) types.Hash { // panic("mock out the GenerateSalt method") // }, @@ -1820,6 +1873,9 @@ type ChainKeeperMock struct { // EnqueueCommandFunc mocks the EnqueueCommand method. EnqueueCommandFunc func(ctx github_com_cosmos_cosmos_sdk_types.Context, cmd types.Command) error + // EnqueueConfirmedEventFunc mocks the EnqueueConfirmedEvent method. + EnqueueConfirmedEventFunc func(ctx github_com_cosmos_cosmos_sdk_types.Context, eventID types.EventID) error + // GenerateSaltFunc mocks the GenerateSalt method. GenerateSaltFunc func(ctx github_com_cosmos_cosmos_sdk_types.Context, recipient string) types.Hash @@ -1971,6 +2027,13 @@ type ChainKeeperMock struct { // Cmd is the cmd argument value. Cmd types.Command } + // EnqueueConfirmedEvent holds details about calls to the EnqueueConfirmedEvent method. + EnqueueConfirmedEvent []struct { + // Ctx is the ctx argument value. + Ctx github_com_cosmos_cosmos_sdk_types.Context + // EventID is the eventID argument value. + EventID types.EventID + } // GenerateSalt holds details about calls to the GenerateSalt method. GenerateSalt []struct { // Ctx is the ctx argument value. @@ -2221,6 +2284,7 @@ type ChainKeeperMock struct { lockDeleteDeposit sync.RWMutex lockDeleteUnsignedCommandBatchID sync.RWMutex lockEnqueueCommand sync.RWMutex + lockEnqueueConfirmedEvent sync.RWMutex lockGenerateSalt sync.RWMutex lockGetBatchByID sync.RWMutex lockGetBurnerAddress sync.RWMutex @@ -2441,6 +2505,42 @@ func (mock *ChainKeeperMock) EnqueueCommandCalls() []struct { return calls } +// EnqueueConfirmedEvent calls EnqueueConfirmedEventFunc. +func (mock *ChainKeeperMock) EnqueueConfirmedEvent(ctx github_com_cosmos_cosmos_sdk_types.Context, eventID types.EventID) error { + if mock.EnqueueConfirmedEventFunc == nil { + panic("ChainKeeperMock.EnqueueConfirmedEventFunc: method is nil but ChainKeeper.EnqueueConfirmedEvent was just called") + } + callInfo := struct { + Ctx github_com_cosmos_cosmos_sdk_types.Context + EventID types.EventID + }{ + Ctx: ctx, + EventID: eventID, + } + mock.lockEnqueueConfirmedEvent.Lock() + mock.calls.EnqueueConfirmedEvent = append(mock.calls.EnqueueConfirmedEvent, callInfo) + mock.lockEnqueueConfirmedEvent.Unlock() + return mock.EnqueueConfirmedEventFunc(ctx, eventID) +} + +// EnqueueConfirmedEventCalls gets all the calls that were made to EnqueueConfirmedEvent. +// Check the length with: +// +// len(mockedChainKeeper.EnqueueConfirmedEventCalls()) +func (mock *ChainKeeperMock) EnqueueConfirmedEventCalls() []struct { + Ctx github_com_cosmos_cosmos_sdk_types.Context + EventID types.EventID +} { + var calls []struct { + Ctx github_com_cosmos_cosmos_sdk_types.Context + EventID types.EventID + } + mock.lockEnqueueConfirmedEvent.RLock() + calls = mock.calls.EnqueueConfirmedEvent + mock.lockEnqueueConfirmedEvent.RUnlock() + return calls +} + // GenerateSalt calls GenerateSaltFunc. func (mock *ChainKeeperMock) GenerateSalt(ctx github_com_cosmos_cosmos_sdk_types.Context, recipient string) types.Hash { if mock.GenerateSaltFunc == nil { From 05be83af5f1067ef163c26a252d2f2517119c0e3 Mon Sep 17 00:00:00 2001 From: Sammy Liu Date: Thu, 2 Nov 2023 16:58:29 -0400 Subject: [PATCH 2/3] add tests --- x/evm/keeper/vote_handler_test.go | 355 +++++++++++++++++------------- 1 file changed, 196 insertions(+), 159 deletions(-) diff --git a/x/evm/keeper/vote_handler_test.go b/x/evm/keeper/vote_handler_test.go index 4691da252..29f6a4fea 100644 --- a/x/evm/keeper/vote_handler_test.go +++ b/x/evm/keeper/vote_handler_test.go @@ -1,11 +1,13 @@ package keeper_test import ( - "errors" + "fmt" mathRand "math/rand" "testing" + "github.com/CosmWasm/wasmd/x/wasm" "github.com/cosmos/cosmos-sdk/codec" + sdkstore "github.com/cosmos/cosmos-sdk/store/types" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/assert" @@ -13,20 +15,19 @@ import ( tmproto "github.com/tendermint/tendermint/proto/tendermint/types" "github.com/axelarnetwork/axelar-core/app/params" - "github.com/axelarnetwork/axelar-core/testutils" - "github.com/axelarnetwork/axelar-core/testutils/fake" - fakeMock "github.com/axelarnetwork/axelar-core/testutils/fake/interfaces/mock" + fakemock "github.com/axelarnetwork/axelar-core/testutils/fake/interfaces/mock" "github.com/axelarnetwork/axelar-core/testutils/rand" + axelarnet "github.com/axelarnetwork/axelar-core/x/axelarnet/exported" "github.com/axelarnetwork/axelar-core/x/evm/exported" "github.com/axelarnetwork/axelar-core/x/evm/keeper" "github.com/axelarnetwork/axelar-core/x/evm/types" "github.com/axelarnetwork/axelar-core/x/evm/types/mock" nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" - mock4 "github.com/axelarnetwork/axelar-core/x/nexus/exported/mock" + nexusmock "github.com/axelarnetwork/axelar-core/x/nexus/exported/mock" reward "github.com/axelarnetwork/axelar-core/x/reward/exported" - mock3 "github.com/axelarnetwork/axelar-core/x/reward/exported/mock" + rewardmock "github.com/axelarnetwork/axelar-core/x/reward/exported/mock" vote "github.com/axelarnetwork/axelar-core/x/vote/exported" - mock2 "github.com/axelarnetwork/axelar-core/x/vote/exported/mock" + votemock "github.com/axelarnetwork/axelar-core/x/vote/exported/mock" "github.com/axelarnetwork/utils/slices" . "github.com/axelarnetwork/utils/test" ) @@ -36,14 +37,14 @@ func TestHandleExpiredPoll(t *testing.T) { var ( ctx sdk.Context n *mock.NexusMock - rewardPool *mock3.RewardPoolMock - poll *mock2.PollMock - maintainerState *mock4.MaintainerStateMock + rewardPool *rewardmock.RewardPoolMock + poll *votemock.PollMock + maintainerState *nexusmock.MaintainerStateMock handler vote.VoteHandler ) givenVoteHandler := Given("the vote handler", func() { - ctx = sdk.NewContext(&fakeMock.MultiStoreMock{}, tmproto.Header{}, false, log.TestingLogger()) + ctx = sdk.NewContext(&fakemock.MultiStoreMock{}, tmproto.Header{}, false, log.TestingLogger()) encCfg := params.MakeEncodingConfig() k := &mock.BaseKeeperMock{ @@ -58,7 +59,7 @@ func TestHandleExpiredPoll(t *testing.T) { }, true }, } - rewardPool = &mock3.RewardPoolMock{} + rewardPool = &rewardmock.RewardPoolMock{} r := &mock.RewarderMock{ GetPoolFunc: func(sdk.Context, string) reward.RewardPool { return rewardPool }, } @@ -67,7 +68,7 @@ func TestHandleExpiredPoll(t *testing.T) { givenVoteHandler. When("some voter failed to vote for poll", func() { - poll = &mock2.PollMock{ + poll = &votemock.PollMock{ GetIDFunc: func() vote.PollID { return vote.PollID(rand.I64Between(10, 100)) }, GetRewardPoolNameFunc: func() (string, bool) { return rand.NormalizedStr(3), true }, GetMetaDataFunc: func() (codec.ProtoMarshaler, bool) { return &types.PollMetadata{Chain: exported.Ethereum.Name}, true }, @@ -78,7 +79,7 @@ func TestHandleExpiredPoll(t *testing.T) { } }). When("maintainer state can be found", func() { - maintainerState = &mock4.MaintainerStateMock{} + maintainerState = &nexusmock.MaintainerStateMock{} n.GetChainMaintainerStateFunc = func(sdk.Context, nexus.Chain, sdk.ValAddress) (nexus.MaintainerState, bool) { return maintainerState, true } @@ -103,7 +104,7 @@ func TestHandleExpiredPoll(t *testing.T) { givenVoteHandler. When("some voter failed to vote for poll", func() { - poll = &mock2.PollMock{ + poll = &votemock.PollMock{ GetIDFunc: func() vote.PollID { return vote.PollID(rand.I64Between(10, 100)) }, GetRewardPoolNameFunc: func() (string, bool) { return rand.NormalizedStr(3), true }, GetMetaDataFunc: func() (codec.ProtoMarshaler, bool) { return &types.PollMetadata{Chain: exported.Ethereum.Name}, true }, @@ -114,7 +115,7 @@ func TestHandleExpiredPoll(t *testing.T) { } }). When("maintainer state can not be found", func() { - maintainerState = &mock4.MaintainerStateMock{} + maintainerState = &nexusmock.MaintainerStateMock{} n.GetChainMaintainerStateFunc = func(sdk.Context, nexus.Chain, sdk.ValAddress) (nexus.MaintainerState, bool) { return nil, false } @@ -132,7 +133,7 @@ func TestHandleExpiredPoll(t *testing.T) { givenVoteHandler. When("no voter failed to vote for poll", func() { - poll = &mock2.PollMock{ + poll = &votemock.PollMock{ GetIDFunc: func() vote.PollID { return vote.PollID(rand.I64Between(10, 100)) }, GetRewardPoolNameFunc: func() (string, bool) { return rand.NormalizedStr(3), true }, GetMetaDataFunc: func() (codec.ProtoMarshaler, bool) { return &types.PollMetadata{Chain: exported.Ethereum.Name}, true }, @@ -143,7 +144,7 @@ func TestHandleExpiredPoll(t *testing.T) { } }). When("maintainer state can be found", func() { - maintainerState = &mock4.MaintainerStateMock{} + maintainerState = &nexusmock.MaintainerStateMock{} n.GetChainMaintainerStateFunc = func(sdk.Context, nexus.Chain, sdk.ValAddress) (nexus.MaintainerState, bool) { return maintainerState, true } @@ -169,181 +170,195 @@ func TestHandleResult(t *testing.T) { ctx sdk.Context basek *mock.BaseKeeperMock chaink *mock.ChainKeeperMock - n *mock.NexusMock - r *mock.RewarderMock + nexusK *mock.NexusMock result codec.ProtoMarshaler handler vote.VoteHandler ) setup := func() { - ctx = sdk.NewContext(&fakeMock.MultiStoreMock{}, tmproto.Header{}, false, log.TestingLogger()) + multiStore := fakemock.MultiStoreMock{} + multiStore.CacheMultiStoreFunc = func() sdkstore.CacheMultiStore { return &fakemock.CacheMultiStoreMock{} } + ctx = sdk.NewContext(&multiStore, tmproto.Header{}, false, log.TestingLogger()) basek = &mock.BaseKeeperMock{ - ForChainFunc: func(_ sdk.Context, chain nexus.ChainName) (types.ChainKeeper, error) { - if chain.Equals(evmChain) { - return chaink, nil - } - return nil, errors.New("unknown chain") - }, LoggerFunc: func(ctx sdk.Context) log.Logger { return log.TestingLogger() }, } chaink = &mock.ChainKeeperMock{ - GetEventFunc: func(sdk.Context, types.EventID) (types.Event, bool) { - return types.Event{}, false - }, - SetConfirmedEventFunc: func(sdk.Context, types.Event) error { - return nil - }, LoggerFunc: func(ctx sdk.Context) log.Logger { return log.TestingLogger() }, } + nexusK = &mock.NexusMock{} - chains := map[nexus.ChainName]nexus.Chain{ - exported.Ethereum.Name: exported.Ethereum, - } - n = &mock.NexusMock{ - IsChainActivatedFunc: func(ctx sdk.Context, chain nexus.Chain) bool { return true }, - GetChainFunc: func(ctx sdk.Context, chain nexus.ChainName) (nexus.Chain, bool) { - c, ok := chains[chain] - return c, ok - }, - } - r = &mock.RewarderMock{} - encCfg := params.MakeEncodingConfig() - handler = keeper.NewVoteHandler(encCfg.Codec, basek, n, r) + handler = keeper.NewVoteHandler(params.MakeEncodingConfig().Codec, basek, nexusK, &mock.RewarderMock{}) } - repeats := 20 - - t.Run("Given vote When events are not from the same source chain THEN return error", testutils.Func(func(t *testing.T) { - setup() - - result = &types.VoteEvents{ - Chain: nexus.ChainName(rand.Str(5)), - Events: randTransferEvents(int(rand.I64Between(5, 10))), - } - err := handler.HandleResult(ctx, result) - - assert.Error(t, err) - }).Repeat(repeats)) - - t.Run("Given vote When events empty THEN should nothing and return nil", testutils.Func(func(t *testing.T) { - setup() - - result = &types.VoteEvents{ - Chain: evmChain, - Events: []types.Event{}, - } - err := handler.HandleResult(ctx, result) - - assert.NoError(t, err) - }).Repeat(repeats)) - - t.Run("GIVEN vote WHEN chain is not registered THEN return error", testutils.Func(func(t *testing.T) { - setup() - n.GetChainFunc = func(ctx sdk.Context, chain nexus.ChainName) (nexus.Chain, bool) { - return nexus.Chain{}, false - } - result = &types.VoteEvents{ - Chain: evmChain, - Events: randTransferEvents(int(rand.I64Between(5, 10))), - } - err := handler.HandleResult(ctx, result) - - assert.Error(t, err) - }).Repeat(repeats)) - - t.Run("GIVEN vote WHEN chain is not activated THEN still confirm the event", testutils.Func(func(t *testing.T) { - setup() - - n.IsChainActivatedFunc = func(sdk.Context, nexus.Chain) bool { return false } - - result = &types.VoteEvents{ - Chain: evmChain, - Events: randTransferEvents(int(rand.I64Between(5, 10))), - } - err := handler.HandleResult(ctx, result) - - assert.NoError(t, err) - }).Repeat(repeats)) - - t.Run("GIVEN vote WHEN result is invalid THEN panic", testutils.Func(func(t *testing.T) { - setup() + givenHandler := Given("the vote handler", setup) - result = types.NewConfirmGatewayTxRequest(rand.AccAddr(), rand.Str(5), types.Hash(common.BytesToHash(rand.Bytes(common.HashLength)))) - assert.Panics(t, func() { - _ = handler.HandleResult(ctx, result) - }) - }).Repeat(repeats)) + givenHandler. + When("result is falsy", func() { + chain := nexus.ChainName(rand.Str(5)) + result = &types.VoteEvents{ + Chain: chain, + Events: nil, + } + }). + Then("should return nil and do nothing", func(t *testing.T) { + assert.NoError(t, handler.HandleResult(ctx, result)) + }). + Run(t) - t.Run("GIVEN already confirmed event WHEN handle deposit THEN return error", testutils.Func(func(t *testing.T) { - setup() + givenHandler. + When("source chain is not registered", func() { + chain := nexus.ChainName(rand.Str(5)) + result = &types.VoteEvents{ + Chain: chain, + Events: randTransferEvents(chain, rand.I64Between(5, 10)), + } - chaink.GetEventFunc = func(sdk.Context, types.EventID) (types.Event, bool) { - return types.Event{}, true - } + nexusK.GetChainFunc = func(_ sdk.Context, _ nexus.ChainName) (nexus.Chain, bool) { return nexus.Chain{}, false } + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, handler.HandleResult(ctx, result), "is not a registered chain") + }). + Run(t) - result = &types.VoteEvents{ - Chain: evmChain, - Events: randTransferEvents(int(rand.I64Between(5, 10))), - } - err := handler.HandleResult(ctx, result) + givenHandler. + When("source chain is not an evm chain", func() { + chain := nexus.ChainName(rand.Str(5)) + result = &types.VoteEvents{ + Chain: chain, + Events: randTransferEvents(chain, rand.I64Between(5, 10)), + } - assert.Error(t, err) - }).Repeat(repeats)) + nexusK.GetChainFunc = func(_ sdk.Context, _ nexus.ChainName) (nexus.Chain, bool) { return nexus.Chain{}, true } + basek.ForChainFunc = func(_ sdk.Context, _ nexus.ChainName) (types.ChainKeeper, error) { + return nil, fmt.Errorf("not an evm chain") + } + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, handler.HandleResult(ctx, result), "is not an evm chain") + }). + Run(t) - var ( - poll *mock2.PollMock - nexusMock *mock.NexusMock - ) + givenHandler. + When("source chain is an evm chain", func() { + result = &types.VoteEvents{ + Chain: exported.Ethereum.Name, + } - Given("a vote handler", func() { - encCfg := params.MakeEncodingConfig() - nexusMock = &mock.NexusMock{ - GetChainFunc: func(_ sdk.Context, chain nexus.ChainName) (nexus.Chain, bool) { - return nexus.Chain{ - Name: chain, - SupportsForeignAssets: true, - Module: types.ModuleName, - }, true - }, - } - rewarder := &mock.RewarderMock{ - GetPoolFunc: func(sdk.Context, string) reward.RewardPool { return &mock3.RewardPoolMock{} }, - } - handler = keeper.NewVoteHandler(encCfg.Codec, &mock.BaseKeeperMock{}, nexusMock, rewarder) - }). - Given("a completed poll", func() { - poll = &mock2.PollMock{ - GetStateFunc: func() vote.PollState { return vote.Completed }, - GetResultFunc: func() codec.ProtoMarshaler { return &types.VoteEvents{Chain: "ethereum", Events: nil} }, - GetRewardPoolNameFunc: func() (string, bool) { return "rewards", true }, - GetIDFunc: func() vote.PollID { return vote.PollID(rand.PosI64()) }, - GetMetaDataFunc: func() (codec.ProtoMarshaler, bool) { return &types.PollMetadata{Chain: "ethereum"}, true }, - GetVotersFunc: func() []sdk.ValAddress { return slices.Expand(func(int) sdk.ValAddress { return rand.ValAddr() }, 10) }, + nexusK.GetChainFunc = func(_ sdk.Context, chainName nexus.ChainName) (nexus.Chain, bool) { + switch chainName { + case exported.Ethereum.Name: + return exported.Ethereum, true + case axelarnet.Axelarnet.Name: + return axelarnet.Axelarnet, true + default: + return nexus.Chain{}, false + } } - }). - When("a voter is not a chain maintainer", func() { - nexusMock.GetChainMaintainerStateFunc = func(sdk.Context, nexus.Chain, sdk.ValAddress) (nexus.MaintainerState, bool) { - return nil, false + basek.ForChainFunc = func(_ sdk.Context, _ nexus.ChainName) (types.ChainKeeper, error) { + return chaink, nil } }). - Then("ignore that voter", func(t *testing.T) { - ctx := sdk.NewContext(fake.NewMultiStore(), tmproto.Header{}, false, log.TestingLogger()) - assert.NoError(t, handler.HandleCompletedPoll(ctx, poll)) - assert.Len(t, nexusMock.SetChainMaintainerStateCalls(), 0) - }).Run(t) + Branch( + When("failed to set the confirmed event", func() { + result.(*types.VoteEvents).Events = randTransferEvents(exported.Ethereum.Name, 1) + + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return fmt.Errorf("failed to set confirmed event") } + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, handler.HandleResult(ctx, result), "failed to set confirmed event") + }), + + When("event is not contract call", func() { + result.(*types.VoteEvents).Events = randTransferEvents(exported.Ethereum.Name, 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + Then("should enqueue the confirmed event", func(t *testing.T) { + chaink.EnqueueConfirmedEventFunc = func(_ sdk.Context, _ types.EventID) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, chaink.EnqueueConfirmedEventCalls(), 5) + }), + + When("event is contract call and is sent to an evm chain", func() { + result.(*types.VoteEvents).Events = randContractCallEvents(exported.Ethereum.Name, exported.Ethereum.Name, 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + When("succeeded to set as processing general messages", func() { + nexusK.SetMessageProcessingFunc = func(_ sdk.Context, _ string) error { return nil } + }). + Then("should set as processing general messages", func(t *testing.T) { + nexusK.SetNewMessageFunc = func(_ sdk.Context, _ nexus.GeneralMessage) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, nexusK.SetNewMessageCalls(), 5) + assert.Len(t, nexusK.SetMessageProcessingCalls(), 5) + }), + + When("event is contract call and is sent to an evm chain", func() { + result.(*types.VoteEvents).Events = randContractCallEvents(exported.Ethereum.Name, exported.Ethereum.Name, 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + When("failed to set as processing general messages", func() { + nexusK.SetMessageProcessingFunc = func(_ sdk.Context, _ string) error { return fmt.Errorf("failed") } + }). + Then("should set as approved general messages", func(t *testing.T) { + nexusK.SetNewMessageFunc = func(_ sdk.Context, _ nexus.GeneralMessage) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, nexusK.SetNewMessageCalls(), 5) + assert.Len(t, nexusK.SetMessageProcessingCalls(), 5) + }), + + When("event is contract call and is sent to an non-evm chain", func() { + result.(*types.VoteEvents).Events = randContractCallEvents(exported.Ethereum.Name, axelarnet.Axelarnet.Name, 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + Then("should set as approved general messages", func(t *testing.T) { + nexusK.SetNewMessageFunc = func(_ sdk.Context, _ nexus.GeneralMessage) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, nexusK.SetNewMessageCalls(), 5) + }), + + When("event is contract call and is sent to an unknown chain", func() { + result.(*types.VoteEvents).Events = randContractCallEvents(exported.Ethereum.Name, nexus.ChainName(rand.Str(5)), 5) + }). + When("succeeded to set the confirmed event", func() { + chaink.SetConfirmedEventFunc = func(_ sdk.Context, _ types.Event) error { return nil } + }). + Then("should set as approved general messages", func(t *testing.T) { + nexusK.SetNewMessageFunc = func(_ sdk.Context, _ nexus.GeneralMessage) error { return nil } + + assert.NoError(t, handler.HandleResult(ctx, result)) + assert.Len(t, nexusK.SetNewMessageCalls(), 5) + + for _, call := range nexusK.SetNewMessageCalls() { + assert.Equal(t, wasm.ModuleName, call.M.Recipient.Chain.Module) + } + }), + ). + Run(t) } -func randTransferEvents(n int) []types.Event { +func randTransferEvents(chain nexus.ChainName, n int64) []types.Event { events := make([]types.Event, n) burnerAddress := types.Address(common.BytesToAddress(rand.Bytes(common.AddressLength))) - for i := 0; i < n; i++ { + for i := int64(0); i < n; i++ { transfer := types.EventTransfer{ To: burnerAddress, Amount: sdk.NewUint(mathRand.Uint64()), } events[i] = types.Event{ - Chain: evmChain, + Chain: chain, TxID: types.Hash(common.BytesToHash(rand.Bytes(common.HashLength))), Index: uint64(rand.I64Between(1, 50)), Event: &types.Event_Transfer{ @@ -354,3 +369,25 @@ func randTransferEvents(n int) []types.Event { return events } + +func randContractCallEvents(chain nexus.ChainName, destinationChain nexus.ChainName, n int64) []types.Event { + events := make([]types.Event, n) + for i := int64(0); i < n; i++ { + contractCall := types.EventContractCall{ + Sender: types.Address(common.BytesToAddress(rand.Bytes(common.AddressLength))), + DestinationChain: destinationChain, + ContractAddress: common.BytesToAddress(rand.Bytes(common.AddressLength)).Hex(), + PayloadHash: types.Hash(common.BytesToHash(rand.Bytes(common.HashLength))), + } + events[i] = types.Event{ + Chain: chain, + TxID: types.Hash(common.BytesToHash(rand.Bytes(common.HashLength))), + Index: uint64(rand.I64Between(1, 50)), + Event: &types.Event_ContractCall{ + ContractCall: &contractCall, + }, + } + } + + return events +} From f77cd828ca3a873faa64ffc01d427739047f033d Mon Sep 17 00:00:00 2001 From: Sammy Liu Date: Thu, 2 Nov 2023 17:23:20 -0400 Subject: [PATCH 3/3] improve comments --- x/evm/keeper/chainKeeper.go | 4 ++-- x/evm/keeper/vote_handler.go | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/x/evm/keeper/chainKeeper.go b/x/evm/keeper/chainKeeper.go index 2e86fa393..b99d1023a 100644 --- a/x/evm/keeper/chainKeeper.go +++ b/x/evm/keeper/chainKeeper.go @@ -886,8 +886,8 @@ func (k chainKeeper) EnqueueConfirmedEvent(ctx sdk.Context, id types.EventID) er } switch event.GetEvent().(type) { - // we no longer allow Event_ContractCall to be enqueued in the EVM module, but - // to be enqueued in the nexus module as a general message instead + // the missing Event_ContractCall is no longer allowed to be enqueued in the + // EVM module, it must be routed through the nexus module instead case *types.Event_ContractCallWithToken, *types.Event_TokenSent, *types.Event_Transfer, diff --git a/x/evm/keeper/vote_handler.go b/x/evm/keeper/vote_handler.go index cc01d9867..cc4c11c79 100644 --- a/x/evm/keeper/vote_handler.go +++ b/x/evm/keeper/vote_handler.go @@ -193,6 +193,8 @@ func (v voteHandler) handleEvent(ctx sdk.Context, ck types.ChainKeeper, event ty return err } + // Event_ContractCall is no longer directly handled by the EVM module, + // which bypassed nexus routing switch event.GetEvent().(type) { case *types.Event_ContractCall: if err := v.handleContractCall(ctx, event); err != nil {