From c5a42a0a89bc41e1486406d1fa15ddeebd6d7b4a Mon Sep 17 00:00:00 2001 From: Sammy Liu Date: Tue, 7 Nov 2023 13:48:52 -0500 Subject: [PATCH] add tests --- x/axelarnet/keeper/message_route.go | 2 +- x/axelarnet/keeper/message_route_test.go | 212 ++++++++++++++++++++- x/axelarnet/keeper/msg_server.go | 2 +- x/axelarnet/keeper/msg_server_test.go | 8 +- x/axelarnet/types/expected_keepers.go | 2 +- x/axelarnet/types/mock/expected_keepers.go | 110 ++++++++++- x/evm/keeper/message_route_test.go | 7 +- x/nexus/keeper/general_message.go | 9 +- x/nexus/types/message_router_test.go | 151 +++++++++++++++ 9 files changed, 476 insertions(+), 27 deletions(-) create mode 100644 x/nexus/types/message_router_test.go diff --git a/x/axelarnet/keeper/message_route.go b/x/axelarnet/keeper/message_route.go index 1666d14cb..d9bfbae25 100644 --- a/x/axelarnet/keeper/message_route.go +++ b/x/axelarnet/keeper/message_route.go @@ -61,7 +61,7 @@ func escrowAssetToMessageSender( asset := sdk.NewCoin(exported.NativeAsset, sdk.OneInt()) sender := routingCtx.Sender - if routingCtx.FeeGranter != nil { + if !routingCtx.FeeGranter.Empty() { req := types.RouteMessageRequest{ Sender: routingCtx.Sender, ID: msg.ID, diff --git a/x/axelarnet/keeper/message_route_test.go b/x/axelarnet/keeper/message_route_test.go index 3a681fd1a..fc4f2486d 100644 --- a/x/axelarnet/keeper/message_route_test.go +++ b/x/axelarnet/keeper/message_route_test.go @@ -1,16 +1,218 @@ package keeper_test import ( + "context" "testing" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/stretchr/testify/assert" + + "github.com/axelarnetwork/axelar-core/testutils/rand" + "github.com/axelarnetwork/axelar-core/x/axelarnet/exported" + "github.com/axelarnetwork/axelar-core/x/axelarnet/keeper" + "github.com/axelarnetwork/axelar-core/x/axelarnet/types" + "github.com/axelarnetwork/axelar-core/x/axelarnet/types/mock" + evmtestutils "github.com/axelarnetwork/axelar-core/x/evm/types/testutils" + nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" + nexustestutils "github.com/axelarnetwork/axelar-core/x/nexus/exported/testutils" + "github.com/axelarnetwork/utils/funcs" + "github.com/axelarnetwork/utils/slices" . "github.com/axelarnetwork/utils/test" ) +func randPayload() []byte { + bytesType := funcs.Must(abi.NewType("bytes", "bytes", nil)) + stringType := funcs.Must(abi.NewType("string", "string", nil)) + stringArrayType := funcs.Must(abi.NewType("string[]", "string[]", nil)) + + argNum := int(rand.I64Between(1, 10)) + + var args abi.Arguments + for i := 0; i < argNum; i += 1 { + args = append(args, abi.Argument{Type: stringType}) + } + + schema := abi.Arguments{{Type: stringType}, {Type: stringArrayType}, {Type: stringArrayType}, {Type: bytesType}} + payload := funcs.Must( + schema.Pack( + rand.StrBetween(5, 10), + slices.Expand2(func() string { return rand.Str(5) }, argNum), + slices.Expand2(func() string { return "string" }, argNum), + funcs.Must(args.Pack(slices.Expand2(func() interface{} { return "string" }, argNum)...)), + ), + ) + + return append(funcs.Must(hexutil.Decode(types.CosmWasmV1)), payload...) +} + +func randMsg(status nexus.GeneralMessage_Status, payload []byte, token ...*sdk.Coin) nexus.GeneralMessage { + var asset *sdk.Coin + if len(token) > 0 { + asset = token[0] + } + + return nexus.GeneralMessage{ + ID: rand.NormalizedStr(10), + Sender: nexus.CrossChainAddress{ + Chain: nexustestutils.RandomChain(), + Address: rand.NormalizedStr(42), + }, + Recipient: nexus.CrossChainAddress{ + Chain: nexustestutils.RandomChain(), + Address: rand.NormalizedStr(42), + }, + PayloadHash: evmtestutils.RandomHash().Bytes(), + Status: status, + Asset: asset, + SourceTxID: evmtestutils.RandomHash().Bytes(), + SourceTxIndex: uint64(rand.I64Between(0, 100)), + } +} + func TestNewMessageRoute(t *testing.T) { - // var ( - // route nexus.MessageRoute - // ) + var ( + ctx sdk.Context + routingCtx nexus.RoutingContext + msg nexus.GeneralMessage + route nexus.MessageRoute + + k keeper.Keeper + feegrantK *mock.FeegrantKeeperMock + ibcK *mock.IBCKeeperMock + bankK *mock.BankKeeperMock + nexusK *mock.NexusMock + accountK *mock.AccountKeeperMock + ) + + givenMessageRoute := Given("the message route", func() { + ctx, k, _, feegrantK = setup() + + ibcK = &mock.IBCKeeperMock{} + bankK = &mock.BankKeeperMock{} + nexusK = &mock.NexusMock{} + accountK = &mock.AccountKeeperMock{} + + route = keeper.NewMessageRoute(k, ibcK, feegrantK, bankK, nexusK, accountK) + }) + + givenMessageRoute. + When("payload is nil", func() { + routingCtx = nexus.RoutingContext{Payload: nil} + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, route(ctx, routingCtx, msg), "payload is required") + }). + Run(t) + + givenMessageRoute. + When("the message cannot be translated", func() { + routingCtx = nexus.RoutingContext{ + Sender: rand.AccAddr(), + FeeGranter: nil, + Payload: rand.Bytes(100), + } + msg = randMsg(nexus.Processing, routingCtx.Payload) + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, route(ctx, routingCtx, msg), "invalid payload") + }). + Run(t) + + whenTheMessageCanBeTranslated := When("the message can be translated", func() { + routingCtx = nexus.RoutingContext{ + Sender: rand.AccAddr(), + Payload: randPayload(), + } + }) + + givenMessageRoute. + When2(whenTheMessageCanBeTranslated). + When("the message has no token transfer", func() { + msg = randMsg(nexus.Processing, routingCtx.Payload) + }). + Branch( + When("the fee granter is not set", func() { + routingCtx.FeeGranter = nil + }). + Then("should deduct the fee from the sender", func(t *testing.T) { + bankK.SendCoinsFunc = func(_ sdk.Context, _, _ sdk.AccAddress, _ sdk.Coins) error { return nil } + ibcK.SendMessageFunc = func(_ context.Context, _ nexus.CrossChainAddress, _ sdk.Coin, _, _ string) error { + return nil + } + + assert.NoError(t, route(ctx, routingCtx, msg)) + + assert.Len(t, bankK.SendCoinsCalls(), 1) + assert.Equal(t, routingCtx.Sender, bankK.SendCoinsCalls()[0].FromAddr) + assert.Equal(t, types.AxelarGMPAccount, bankK.SendCoinsCalls()[0].ToAddr) + assert.Equal(t, sdk.NewCoins(sdk.NewCoin(exported.NativeAsset, sdk.OneInt())), bankK.SendCoinsCalls()[0].Amt) + + assert.Len(t, ibcK.SendMessageCalls(), 1) + assert.Equal(t, msg.Recipient, ibcK.SendMessageCalls()[0].Recipient) + assert.Equal(t, sdk.NewCoin(exported.NativeAsset, sdk.OneInt()), ibcK.SendMessageCalls()[0].Asset) + assert.Equal(t, msg.ID, ibcK.SendMessageCalls()[0].ID) + }), + + When("the fee granter is set", func() { + routingCtx.FeeGranter = rand.AccAddr() + }). + Then("should deduct the fee from the fee granter", func(t *testing.T) { + feegrantK.UseGrantedFeesFunc = func(_ sdk.Context, granter, _ sdk.AccAddress, _ sdk.Coins, _ []sdk.Msg) error { + return nil + } + bankK.SendCoinsFunc = func(_ sdk.Context, _, _ sdk.AccAddress, _ sdk.Coins) error { return nil } + ibcK.SendMessageFunc = func(_ context.Context, _ nexus.CrossChainAddress, _ sdk.Coin, _, _ string) error { + return nil + } + + assert.NoError(t, route(ctx, routingCtx, msg)) + + assert.Len(t, feegrantK.UseGrantedFeesCalls(), 1) + assert.Equal(t, routingCtx.FeeGranter, feegrantK.UseGrantedFeesCalls()[0].Granter) + assert.Equal(t, routingCtx.Sender, feegrantK.UseGrantedFeesCalls()[0].Grantee) + assert.Equal(t, sdk.NewCoins(sdk.NewCoin(exported.NativeAsset, sdk.OneInt())), feegrantK.UseGrantedFeesCalls()[0].Fee) + + assert.Len(t, bankK.SendCoinsCalls(), 1) + assert.Equal(t, routingCtx.FeeGranter, bankK.SendCoinsCalls()[0].FromAddr) + assert.Equal(t, types.AxelarGMPAccount, bankK.SendCoinsCalls()[0].ToAddr) + assert.Equal(t, sdk.NewCoins(sdk.NewCoin(exported.NativeAsset, sdk.OneInt())), bankK.SendCoinsCalls()[0].Amt) + + assert.Len(t, ibcK.SendMessageCalls(), 1) + assert.Equal(t, msg.Recipient, ibcK.SendMessageCalls()[0].Recipient) + assert.Equal(t, sdk.NewCoin(exported.NativeAsset, sdk.OneInt()), ibcK.SendMessageCalls()[0].Asset) + assert.Equal(t, msg.ID, ibcK.SendMessageCalls()[0].ID) + }), + ). + Run(t) + + givenMessageRoute. + When2(whenTheMessageCanBeTranslated). + When("the message has token transfer", func() { + coin := rand.Coin() + msg = randMsg(nexus.Processing, routingCtx.Payload, &coin) + }). + Then("should deduct from the corresponding account", func(t *testing.T) { + nexusK.GetChainByNativeAssetFunc = func(_ sdk.Context, _ string) (nexus.Chain, bool) { + return exported.Axelarnet, true + } + bankK.SendCoinsFunc = func(_ sdk.Context, _, _ sdk.AccAddress, _ sdk.Coins) error { return nil } + ibcK.SendMessageFunc = func(_ context.Context, _ nexus.CrossChainAddress, _ sdk.Coin, _, _ string) error { + return nil + } + + assert.NoError(t, route(ctx, routingCtx, msg)) + + assert.Len(t, bankK.SendCoinsCalls(), 1) + assert.Equal(t, types.GetEscrowAddress(msg.Asset.Denom), bankK.SendCoinsCalls()[0].FromAddr) + assert.Equal(t, types.AxelarGMPAccount, bankK.SendCoinsCalls()[0].ToAddr) + assert.Equal(t, sdk.NewCoins(*msg.Asset), bankK.SendCoinsCalls()[0].Amt) - // givenMessageRoute := Given("the message route", func() { - // }) + assert.Len(t, ibcK.SendMessageCalls(), 1) + assert.Equal(t, msg.Recipient, ibcK.SendMessageCalls()[0].Recipient) + assert.Equal(t, *msg.Asset, ibcK.SendMessageCalls()[0].Asset) + assert.Equal(t, msg.ID, ibcK.SendMessageCalls()[0].ID) + }). + Run(t) } diff --git a/x/axelarnet/keeper/msg_server.go b/x/axelarnet/keeper/msg_server.go index 700c5389f..99bb342b0 100644 --- a/x/axelarnet/keeper/msg_server.go +++ b/x/axelarnet/keeper/msg_server.go @@ -489,7 +489,7 @@ func (s msgServer) RouteMessage(c context.Context, req *types.RouteMessageReques FeeGranter: req.Feegranter, Payload: req.Payload, } - if err := s.nexus.RouteMessage(ctx, routingCtx, req.ID); err != nil { + if err := s.nexus.RouteMessage(ctx, req.ID, routingCtx); err != nil { return nil, err } diff --git a/x/axelarnet/keeper/msg_server_test.go b/x/axelarnet/keeper/msg_server_test.go index cbf7803c0..fec023e0f 100644 --- a/x/axelarnet/keeper/msg_server_test.go +++ b/x/axelarnet/keeper/msg_server_test.go @@ -989,16 +989,16 @@ func TestRouteMessage(t *testing.T) { givenMsgServer. When("route message successfully", func() { - nexusK.RouteMessageFunc = func(_ sdk.Context, _ nexus.RoutingContext, _ string) error { return nil } + nexusK.RouteMessageFunc = func(_ sdk.Context, _ string, _ ...nexus.RoutingContext) error { return nil } }). Then("should route the correct message", func(t *testing.T) { _, err := server.RouteMessage(sdk.WrapSDKContext(ctx), &req) assert.NoError(t, err) assert.Len(t, nexusK.RouteMessageCalls(), 1) - assert.Equal(t, nexusK.RouteMessageCalls()[0].RoutingCtx.Sender, req.Sender) - assert.Equal(t, nexusK.RouteMessageCalls()[0].RoutingCtx.FeeGranter, req.Feegranter) - assert.Equal(t, nexusK.RouteMessageCalls()[0].RoutingCtx.Payload, req.Payload) + assert.Equal(t, nexusK.RouteMessageCalls()[0].RoutingCtx[0].Sender, req.Sender) + assert.Equal(t, nexusK.RouteMessageCalls()[0].RoutingCtx[0].FeeGranter, req.Feegranter) + assert.Equal(t, nexusK.RouteMessageCalls()[0].RoutingCtx[0].Payload, req.Payload) assert.Equal(t, nexusK.RouteMessageCalls()[0].ID, req.ID) }). Run(t) diff --git a/x/axelarnet/types/expected_keepers.go b/x/axelarnet/types/expected_keepers.go index b6183de0f..18867fd71 100644 --- a/x/axelarnet/types/expected_keepers.go +++ b/x/axelarnet/types/expected_keepers.go @@ -62,7 +62,7 @@ type Nexus interface { SetMessageFailed(ctx sdk.Context, id string) error GenerateMessageID(ctx sdk.Context) (string, []byte, uint64) ValidateAddress(ctx sdk.Context, address nexus.CrossChainAddress) error - RouteMessage(ctx sdk.Context, routingCtx nexus.RoutingContext, id string) error + RouteMessage(ctx sdk.Context, id string, routingCtx ...nexus.RoutingContext) error } // BankKeeper defines the expected interface contract the vesting module requires diff --git a/x/axelarnet/types/mock/expected_keepers.go b/x/axelarnet/types/mock/expected_keepers.go index 280c07542..fffe7519b 100644 --- a/x/axelarnet/types/mock/expected_keepers.go +++ b/x/axelarnet/types/mock/expected_keepers.go @@ -611,7 +611,7 @@ var _ axelarnettypes.Nexus = &NexusMock{} // RegisterAssetFunc: func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain, asset github_com_axelarnetwork_axelar_core_x_nexus_exported.Asset, limit cosmossdktypes.Uint, window time.Duration) error { // panic("mock out the RegisterAsset method") // }, -// RouteMessageFunc: func(ctx cosmossdktypes.Context, routingCtx github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext, id string) error { +// RouteMessageFunc: func(ctx cosmossdktypes.Context, id string, routingCtx ...github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext) error { // panic("mock out the RouteMessage method") // }, // SetChainFunc: func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain) { @@ -691,7 +691,7 @@ type NexusMock struct { RegisterAssetFunc func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain, asset github_com_axelarnetwork_axelar_core_x_nexus_exported.Asset, limit cosmossdktypes.Uint, window time.Duration) error // RouteMessageFunc mocks the RouteMessage method. - RouteMessageFunc func(ctx cosmossdktypes.Context, routingCtx github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext, id string) error + RouteMessageFunc func(ctx cosmossdktypes.Context, id string, routingCtx ...github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext) error // SetChainFunc mocks the SetChain method. SetChainFunc func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain) @@ -852,10 +852,10 @@ type NexusMock struct { RouteMessage []struct { // Ctx is the ctx argument value. Ctx cosmossdktypes.Context - // RoutingCtx is the routingCtx argument value. - RoutingCtx github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext // ID is the id argument value. ID string + // RoutingCtx is the routingCtx argument value. + RoutingCtx []github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext } // SetChain holds details about calls to the SetChain method. SetChain []struct { @@ -1550,23 +1550,23 @@ func (mock *NexusMock) RegisterAssetCalls() []struct { } // RouteMessage calls RouteMessageFunc. -func (mock *NexusMock) RouteMessage(ctx cosmossdktypes.Context, routingCtx github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext, id string) error { +func (mock *NexusMock) RouteMessage(ctx cosmossdktypes.Context, id string, routingCtx ...github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext) error { if mock.RouteMessageFunc == nil { panic("NexusMock.RouteMessageFunc: method is nil but Nexus.RouteMessage was just called") } callInfo := struct { Ctx cosmossdktypes.Context - RoutingCtx github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext ID string + RoutingCtx []github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext }{ Ctx: ctx, - RoutingCtx: routingCtx, ID: id, + RoutingCtx: routingCtx, } mock.lockRouteMessage.Lock() mock.calls.RouteMessage = append(mock.calls.RouteMessage, callInfo) mock.lockRouteMessage.Unlock() - return mock.RouteMessageFunc(ctx, routingCtx, id) + return mock.RouteMessageFunc(ctx, id, routingCtx...) } // RouteMessageCalls gets all the calls that were made to RouteMessage. @@ -1575,13 +1575,13 @@ func (mock *NexusMock) RouteMessage(ctx cosmossdktypes.Context, routingCtx githu // len(mockedNexus.RouteMessageCalls()) func (mock *NexusMock) RouteMessageCalls() []struct { Ctx cosmossdktypes.Context - RoutingCtx github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext ID string + RoutingCtx []github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext } { var calls []struct { Ctx cosmossdktypes.Context - RoutingCtx github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext ID string + RoutingCtx []github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext } mock.lockRouteMessage.RLock() calls = mock.calls.RouteMessage @@ -3282,3 +3282,93 @@ func (mock *FeegrantKeeperMock) UseGrantedFeesCalls() []struct { mock.lockUseGrantedFees.RUnlock() return calls } + +// Ensure, that IBCKeeperMock does implement axelarnettypes.IBCKeeper. +// If this is not the case, regenerate this file with moq. +var _ axelarnettypes.IBCKeeper = &IBCKeeperMock{} + +// IBCKeeperMock is a mock implementation of axelarnettypes.IBCKeeper. +// +// func TestSomethingThatUsesIBCKeeper(t *testing.T) { +// +// // make and configure a mocked axelarnettypes.IBCKeeper +// mockedIBCKeeper := &IBCKeeperMock{ +// SendMessageFunc: func(c context.Context, recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress, asset cosmossdktypes.Coin, payload string, id string) error { +// panic("mock out the SendMessage method") +// }, +// } +// +// // use mockedIBCKeeper in code that requires axelarnettypes.IBCKeeper +// // and then make assertions. +// +// } +type IBCKeeperMock struct { + // SendMessageFunc mocks the SendMessage method. + SendMessageFunc func(c context.Context, recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress, asset cosmossdktypes.Coin, payload string, id string) error + + // calls tracks calls to the methods. + calls struct { + // SendMessage holds details about calls to the SendMessage method. + SendMessage []struct { + // C is the c argument value. + C context.Context + // Recipient is the recipient argument value. + Recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress + // Asset is the asset argument value. + Asset cosmossdktypes.Coin + // Payload is the payload argument value. + Payload string + // ID is the id argument value. + ID string + } + } + lockSendMessage sync.RWMutex +} + +// SendMessage calls SendMessageFunc. +func (mock *IBCKeeperMock) SendMessage(c context.Context, recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress, asset cosmossdktypes.Coin, payload string, id string) error { + if mock.SendMessageFunc == nil { + panic("IBCKeeperMock.SendMessageFunc: method is nil but IBCKeeper.SendMessage was just called") + } + callInfo := struct { + C context.Context + Recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress + Asset cosmossdktypes.Coin + Payload string + ID string + }{ + C: c, + Recipient: recipient, + Asset: asset, + Payload: payload, + ID: id, + } + mock.lockSendMessage.Lock() + mock.calls.SendMessage = append(mock.calls.SendMessage, callInfo) + mock.lockSendMessage.Unlock() + return mock.SendMessageFunc(c, recipient, asset, payload, id) +} + +// SendMessageCalls gets all the calls that were made to SendMessage. +// Check the length with: +// +// len(mockedIBCKeeper.SendMessageCalls()) +func (mock *IBCKeeperMock) SendMessageCalls() []struct { + C context.Context + Recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress + Asset cosmossdktypes.Coin + Payload string + ID string +} { + var calls []struct { + C context.Context + Recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress + Asset cosmossdktypes.Coin + Payload string + ID string + } + mock.lockSendMessage.RLock() + calls = mock.calls.SendMessage + mock.lockSendMessage.RUnlock() + return calls +} diff --git a/x/evm/keeper/message_route_test.go b/x/evm/keeper/message_route_test.go index f326e007a..74446d81e 100644 --- a/x/evm/keeper/message_route_test.go +++ b/x/evm/keeper/message_route_test.go @@ -3,13 +3,14 @@ package keeper_test import ( "testing" - "github.com/axelarnetwork/axelar-core/testutils/fake" - "github.com/axelarnetwork/axelar-core/x/evm/keeper" - nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/assert" "github.com/tendermint/tendermint/libs/log" tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + + "github.com/axelarnetwork/axelar-core/testutils/fake" + "github.com/axelarnetwork/axelar-core/x/evm/keeper" + nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" ) func TestNewMessageRoute(t *testing.T) { diff --git a/x/nexus/keeper/general_message.go b/x/nexus/keeper/general_message.go index a929f5be3..4fb486a83 100644 --- a/x/nexus/keeper/general_message.go +++ b/x/nexus/keeper/general_message.go @@ -231,7 +231,9 @@ func (k Keeper) validateAddressAndAsset(ctx sdk.Context, address exported.CrossC return k.validateAsset(ctx, address.Chain, asset.Denom) } -func (k Keeper) RouteMessage(ctx sdk.Context, routingCtx exported.RoutingContext, id string) error { +// RouteMessage routes the given general message to the corresponding module and +// set the message status to processing +func (k Keeper) RouteMessage(ctx sdk.Context, id string, routingCtx ...exported.RoutingContext) error { err := k.SetMessageProcessing(ctx, id) if err != nil { return err @@ -239,5 +241,8 @@ func (k Keeper) RouteMessage(ctx sdk.Context, routingCtx exported.RoutingContext k.Logger(ctx).Debug("set general message status to processing", "messageID", id) - return k.getMessageRouter().Route(ctx, routingCtx, funcs.MustOk(k.GetMessage(ctx, id))) + if len(routingCtx) == 0 { + routingCtx = []exported.RoutingContext{{}} + } + return k.getMessageRouter().Route(ctx, routingCtx[0], funcs.MustOk(k.GetMessage(ctx, id))) } diff --git a/x/nexus/types/message_router_test.go b/x/nexus/types/message_router_test.go new file mode 100644 index 000000000..fa9ef4664 --- /dev/null +++ b/x/nexus/types/message_router_test.go @@ -0,0 +1,151 @@ +package types_test + +import ( + "fmt" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/assert" + "github.com/tendermint/tendermint/libs/log" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + + "github.com/axelarnetwork/axelar-core/testutils/fake" + "github.com/axelarnetwork/axelar-core/testutils/rand" + "github.com/axelarnetwork/axelar-core/x/nexus/exported" + "github.com/axelarnetwork/axelar-core/x/nexus/types" + . "github.com/axelarnetwork/utils/test" +) + +func TestAddRoute(t *testing.T) { + var ( + router types.MessageRouter + module string + ) + + givenRouter := Given("a message router", func() { + router = types.NewMessageRouter() + }) + + givenRouter. + When("it is sealed", func() { + router.Seal() + module = "module" + }). + Then("it panics when adding a route", func(t *testing.T) { + assert.PanicsWithValue(t, "cannot add handler (router sealed)", func() { + router.AddRoute(module, nil) + }) + }). + Run(t) + + givenRouter. + When("module is empty", func() { + module = "" + }). + Then("it panics when adding a route", func(t *testing.T) { + assert.PanicsWithValue(t, "module name cannot be an empty string", func() { + router.AddRoute(module, nil) + }) + }). + Run(t) + + givenRouter. + When("module route is added already", func() { + module = "module" + router.AddRoute(module, func(_ sdk.Context, _ exported.RoutingContext, _ exported.GeneralMessage) error { + return nil + }) + }). + Then("it panics when adding a route again", func(t *testing.T) { + assert.PanicsWithValue(t, fmt.Sprintf("router for module %s has already been registered", module), func() { + router.AddRoute(module, func(_ sdk.Context, _ exported.RoutingContext, _ exported.GeneralMessage) error { + return nil + }) + }) + }). + Run(t) +} + +func TestRoute(t *testing.T) { + var ( + ctx sdk.Context + routingCtx exported.RoutingContext + msg exported.GeneralMessage + router types.MessageRouter + module string + routeCount uint + route exported.MessageRoute + ) + + givenRouter := Given("a message router", func() { + ctx = sdk.NewContext(fake.NewMultiStore(), tmproto.Header{}, false, log.TestingLogger()) + router = types.NewMessageRouter() + }) + + givenRouter. + When("it is not sealed", func() {}). + Then("it panics when routing a message", func(t *testing.T) { + assert.PanicsWithValue(t, "cannot route message (router not sealed)", func() { + router.Route(ctx, exported.RoutingContext{}, exported.GeneralMessage{}) + }) + }). + Run(t) + + whenIsSealed := When("it is sealed", func() { + router.Seal() + }) + + givenRouter. + When2(whenIsSealed). + When("module is not found", func() { + msg = exported.GeneralMessage{Recipient: exported.CrossChainAddress{Chain: exported.Chain{Module: "unknown"}}} + }). + Then("it should return error", func(t *testing.T) { + assert.ErrorContains(t, router.Route(ctx, routingCtx, msg), "no router found") + }). + Run(t) + + givenRouter. + When("route is added", func() { + module = "module" + routeCount = 0 + route = func(_ sdk.Context, _ exported.RoutingContext, msg exported.GeneralMessage) error { + routeCount++ + return nil + } + + router.AddRoute(module, route) + }). + When2(whenIsSealed). + Branch( + When("payload is provided but does not match the payload hash", func() { + routingCtx = exported.RoutingContext{Payload: []byte("payload")} + msg = exported.GeneralMessage{PayloadHash: rand.Bytes(common.HashLength), Recipient: exported.CrossChainAddress{Chain: exported.Chain{Module: module}}} + }). + Then("it should return error", func(t *testing.T) { + assert.ErrorContains(t, router.Route(ctx, routingCtx, msg), "payload hash does not match") + }), + + When("payload is provided and matches the payload hash", func() { + payload := rand.Bytes(100) + routingCtx = exported.RoutingContext{Payload: payload} + msg = exported.GeneralMessage{PayloadHash: crypto.Keccak256Hash(payload).Bytes(), Recipient: exported.CrossChainAddress{Chain: exported.Chain{Module: module}}} + }). + Then("it should succeed", func(t *testing.T) { + assert.NoError(t, router.Route(ctx, routingCtx, msg), "payload hash does not match") + assert.Equal(t, uint(1), routeCount) + }), + + When("payload is not provided", func() { + routingCtx = exported.RoutingContext{Payload: nil} + msg = exported.GeneralMessage{Recipient: exported.CrossChainAddress{Chain: exported.Chain{Module: module}}} + }). + Then("it should succeed", func(t *testing.T) { + assert.NoError(t, router.Route(ctx, routingCtx, msg), "payload hash does not match") + assert.Equal(t, uint(1), routeCount) + }), + ). + Run(t) +}