diff --git a/x/ante/batch.go b/x/ante/batch.go index 28d47d45db..c3c350dfee 100644 --- a/x/ante/batch.go +++ b/x/ante/batch.go @@ -42,33 +42,42 @@ func NewBatchDecorator(cdc codec.Codec) BatchDecorator { // AnteHandle record qualified refund for the multiSig and vote transactions func (b BatchDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { - msgs := tx.GetMsgs() + unwrappedMsgs, err := unpackMsgs2(tx.GetMsgs()) + if err != nil { + return ctx, err + } + + feeTx, ok := tx.(sdk.FeeTx) + if !ok { + return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "tx must be a FeeTx") + } + + return next(ctx, txWithUnwrappedMsgs{feeTx, unwrappedMsgs}, simulate) +} + +func unpackMsgs(msgs []sdk.Msg) ([]sdk.Msg, error) { + var unpackedMsgs []sdk.Msg + idx := 0 - var unwrappedMsgs []sdk.Msg - start := 0 for i, msg := range msgs { if batchReq, ok := msg.(*batchtypes.BatchRequest); ok { // Bulk append messages, including the current batch request - unwrappedMsgs = append(unwrappedMsgs, msgs[start:i+1]...) + unpackedMsgs = append(unpackedMsgs, msgs[idx:i+1]...) innerMsgs := batchReq.UnwrapMessages() if batchtypes.AnyBatch(innerMsgs) { - return ctx, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "nested batch requests are not allowed") + return []sdk.Msg{}, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "nested batch requests are not allowed") } - unwrappedMsgs = append(unwrappedMsgs, innerMsgs...) + unpackedMsgs = append(unpackedMsgs, innerMsgs...) - start = i + 1 + idx = i + 1 } } - if len(unwrappedMsgs) == 0 { - return next(ctx, tx, simulate) - } - - feeTx, ok := tx.(sdk.FeeTx) - if !ok { - return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "tx must be a FeeTx") + // avoid copying the slice if there are no batch requests + if len(unpackedMsgs) == 0 { + return msgs, nil } - return next(ctx, txWithUnwrappedMsgs{feeTx, unwrappedMsgs}, simulate) + return append(unpackedMsgs, msgs[idx:]...), nil } diff --git a/x/ante/batch_test.go b/x/ante/batch_test.go index b1969a63f4..bdb0667a50 100644 --- a/x/ante/batch_test.go +++ b/x/ante/batch_test.go @@ -53,6 +53,30 @@ func TestBatch(t *testing.T) { }). Run(t) + givenBatchAnteHandler. + When("messages do not contain batch", func() { + tx = &mock.FeeTxMock{ + GetMsgsFunc: func() []sdk.Msg { + return []sdk.Msg{ + votetypes.NewVoteRequest(sender, vote.PollID(rand.PosI64()), evmTypes.NewVoteEvents(nexus.ChainName(rand.NormalizedStr(3)))), + votetypes.NewVoteRequest(sender, vote.PollID(rand.PosI64()), evmTypes.NewVoteEvents(nexus.ChainName(rand.NormalizedStr(3)))), + votetypes.NewVoteRequest(sender, vote.PollID(rand.PosI64()), evmTypes.NewVoteEvents(nexus.ChainName(rand.NormalizedStr(3)))), + } + }, + } + }). + Then("should pass messages as it", func(t *testing.T) { + _, err := handler.AnteHandle(sdk.Context{}, tx, false, + func(_ sdk.Context, tx sdk.Tx, _ bool) (sdk.Context, error) { + unwrappedMsgs = tx.GetMsgs() + return sdk.Context{}, nil + }) + + assert.NoError(t, err) + assert.Equal(t, 3, len(unwrappedMsgs)) + }). + Run(t) + givenBatchAnteHandler. When("a Batch Request is valid", func() { batchMsg = batchtypes.NewBatchRequest(sender, []sdk.Msg{ @@ -63,7 +87,13 @@ func TestBatch(t *testing.T) { Then("should unwrap inner message", func(t *testing.T) { tx = &mock.FeeTxMock{ GetMsgsFunc: func() []sdk.Msg { - return []sdk.Msg{batchMsg, batchMsg} + return []sdk.Msg{ + votetypes.NewVoteRequest(sender, vote.PollID(rand.PosI64()), evmTypes.NewVoteEvents(nexus.ChainName(rand.NormalizedStr(3)))), + batchMsg, + votetypes.NewVoteRequest(sender, vote.PollID(rand.PosI64()), evmTypes.NewVoteEvents(nexus.ChainName(rand.NormalizedStr(3)))), + batchMsg, + votetypes.NewVoteRequest(sender, vote.PollID(rand.PosI64()), evmTypes.NewVoteEvents(nexus.ChainName(rand.NormalizedStr(3)))), + } }, } @@ -74,7 +104,7 @@ func TestBatch(t *testing.T) { }) assert.NoError(t, err) - assert.Equal(t, 6, len(unwrappedMsgs)) + assert.Equal(t, 9, len(unwrappedMsgs)) }). Run(t) }