diff --git a/plugin/evm/atomic/sync/atomic_sync_extender.go b/plugin/evm/atomic/sync/atomic_sync_extender.go index 56735eb509..8174548da4 100644 --- a/plugin/evm/atomic/sync/atomic_sync_extender.go +++ b/plugin/evm/atomic/sync/atomic_sync_extender.go @@ -30,7 +30,7 @@ func NewAtomicSyncExtender(backend interfaces.AtomicBackend, stateSyncRequestSiz } func (a *AtomicSyncExtender) Sync(ctx context.Context, client syncclient.LeafClient, verDB *versiondb.Database, syncSummary message.Syncable) error { - atomicSyncSummary, ok := syncSummary.(*AtomicBlockSyncSummary) + atomicSyncSummary, ok := syncSummary.(*AtomicSyncSummary) if !ok { return fmt.Errorf("expected *AtomicBlockSyncSummary, got %T", syncSummary) } diff --git a/plugin/evm/atomic/sync/syncable.go b/plugin/evm/atomic/sync/atomic_sync_summary.go similarity index 63% rename from plugin/evm/atomic/sync/syncable.go rename to plugin/evm/atomic/sync/atomic_sync_summary.go index 8a6e574983..fe5fb72454 100644 --- a/plugin/evm/atomic/sync/syncable.go +++ b/plugin/evm/atomic/sync/atomic_sync_summary.go @@ -7,8 +7,8 @@ import ( "context" "fmt" + "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/coreth/plugin/evm/atomic" "github.com/ava-labs/coreth/plugin/evm/message" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" @@ -17,13 +17,26 @@ import ( ) var ( - _ message.Syncable = (*AtomicBlockSyncSummary)(nil) + _ message.Syncable = (*AtomicSyncSummary)(nil) _ message.SyncableParser = (*AtomicSyncSummaryParser)(nil) ) -// AtomicBlockSyncSummary provides the information necessary to sync a node starting +// codecWithAtomicSync is the codec manager that contains the codec for AtomicBlockSyncSummary and +// other message types that are used in the network protocol. This is to ensure that the codec +// version is consistent across all message types and includes the codec for AtomicBlockSyncSummary. +var codecWithAtomicSync codec.Manager + +func init() { + var err error + codecWithAtomicSync, err = message.NewCodec(AtomicSyncSummary{}) + if err != nil { + panic(fmt.Errorf("failed to create codec manager: %w", err)) + } +} + +// AtomicSyncSummary provides the information necessary to sync a node starting // at the given block. -type AtomicBlockSyncSummary struct { +type AtomicSyncSummary struct { BlockNumber uint64 `serialize:"true"` BlockHash common.Hash `serialize:"true"` BlockRoot common.Hash `serialize:"true"` @@ -34,10 +47,6 @@ type AtomicBlockSyncSummary struct { acceptImpl message.AcceptImplFn } -func init() { - message.SyncSummaryType = &AtomicBlockSyncSummary{} -} - type AtomicSyncSummaryParser struct{} func NewAtomicSyncSummaryParser() *AtomicSyncSummaryParser { @@ -45,8 +54,8 @@ func NewAtomicSyncSummaryParser() *AtomicSyncSummaryParser { } func (a *AtomicSyncSummaryParser) ParseFromBytes(summaryBytes []byte, acceptImpl message.AcceptImplFn) (message.Syncable, error) { - summary := AtomicBlockSyncSummary{} - if codecVersion, err := atomic.Codec.Unmarshal(summaryBytes, &summary); err != nil { + summary := AtomicSyncSummary{} + if codecVersion, err := codecWithAtomicSync.Unmarshal(summaryBytes, &summary); err != nil { return nil, fmt.Errorf("failed to parse syncable summary: %w", err) } else if codecVersion != message.Version { return nil, fmt.Errorf("failed to parse syncable summary due to unexpected codec version (got %d, expected %d)", codecVersion, message.Version) @@ -62,14 +71,14 @@ func (a *AtomicSyncSummaryParser) ParseFromBytes(summaryBytes []byte, acceptImpl return &summary, nil } -func NewAtomicSyncSummary(blockHash common.Hash, blockNumber uint64, blockRoot common.Hash, atomicRoot common.Hash) (*AtomicBlockSyncSummary, error) { - summary := AtomicBlockSyncSummary{ +func NewAtomicSyncSummary(blockHash common.Hash, blockNumber uint64, blockRoot common.Hash, atomicRoot common.Hash) (*AtomicSyncSummary, error) { + summary := AtomicSyncSummary{ BlockNumber: blockNumber, BlockHash: blockHash, BlockRoot: blockRoot, AtomicRoot: atomicRoot, } - bytes, err := atomic.Codec.Marshal(message.Version, &summary) + bytes, err := codecWithAtomicSync.Marshal(message.Version, &summary) if err != nil { return nil, fmt.Errorf("failed to marshal syncable summary: %w", err) } @@ -84,35 +93,31 @@ func NewAtomicSyncSummary(blockHash common.Hash, blockNumber uint64, blockRoot c return &summary, nil } -func (a *AtomicBlockSyncSummary) GetBlockNumber() uint64 { - return a.BlockNumber -} - -func (a *AtomicBlockSyncSummary) GetBlockHash() common.Hash { +func (a *AtomicSyncSummary) GetBlockHash() common.Hash { return a.BlockHash } -func (a *AtomicBlockSyncSummary) GetBlockRoot() common.Hash { +func (a *AtomicSyncSummary) GetBlockRoot() common.Hash { return a.BlockRoot } -func (a *AtomicBlockSyncSummary) Bytes() []byte { +func (a *AtomicSyncSummary) Bytes() []byte { return a.bytes } -func (a *AtomicBlockSyncSummary) Height() uint64 { +func (a *AtomicSyncSummary) Height() uint64 { return a.BlockNumber } -func (a *AtomicBlockSyncSummary) ID() ids.ID { +func (a *AtomicSyncSummary) ID() ids.ID { return a.summaryID } -func (a *AtomicBlockSyncSummary) String() string { +func (a *AtomicSyncSummary) String() string { return fmt.Sprintf("AtomicBlockSyncSummary(BlockHash=%s, BlockNumber=%d, BlockRoot=%s, AtomicRoot=%s)", a.BlockHash, a.BlockNumber, a.BlockRoot, a.AtomicRoot) } -func (a *AtomicBlockSyncSummary) Accept(context.Context) (block.StateSyncMode, error) { +func (a *AtomicSyncSummary) Accept(context.Context) (block.StateSyncMode, error) { if a.acceptImpl == nil { return block.StateSyncSkipped, fmt.Errorf("accept implementation not specified for summary: %s", a) } diff --git a/plugin/evm/atomic/sync/atomic_sync_summary_test.go b/plugin/evm/atomic/sync/atomic_sync_summary_test.go new file mode 100644 index 0000000000..9b534c75ed --- /dev/null +++ b/plugin/evm/atomic/sync/atomic_sync_summary_test.go @@ -0,0 +1,46 @@ +// (c) 2021-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package sync + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/ava-labs/avalanchego/snow/engine/snowman/block" + "github.com/ava-labs/coreth/plugin/evm/message" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" +) + +func TestMarshalAtomicSyncSummary(t *testing.T) { + atomicSyncSummary, err := NewAtomicSyncSummary(common.Hash{1}, 2, common.Hash{3}, common.Hash{4}) + require.NoError(t, err) + + require.Equal(t, common.Hash{1}, atomicSyncSummary.GetBlockHash()) + require.Equal(t, uint64(2), atomicSyncSummary.Height()) + require.Equal(t, common.Hash{3}, atomicSyncSummary.GetBlockRoot()) + + expectedBase64Bytes := "AAAAAAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==" + require.Equal(t, expectedBase64Bytes, base64.StdEncoding.EncodeToString(atomicSyncSummary.Bytes())) + + parser := NewAtomicSyncSummaryParser() + called := false + acceptImplTest := func(message.Syncable) (block.StateSyncMode, error) { + called = true + return block.StateSyncSkipped, nil + } + s, err := parser.ParseFromBytes(atomicSyncSummary.Bytes(), acceptImplTest) + require.NoError(t, err) + require.Equal(t, atomicSyncSummary.GetBlockHash(), s.GetBlockHash()) + require.Equal(t, atomicSyncSummary.Height(), s.Height()) + require.Equal(t, atomicSyncSummary.GetBlockRoot(), s.GetBlockRoot()) + require.Equal(t, atomicSyncSummary.AtomicRoot, s.(*AtomicSyncSummary).AtomicRoot) + require.Equal(t, atomicSyncSummary.Bytes(), s.Bytes()) + + mode, err := s.Accept(context.TODO()) + require.NoError(t, err) + require.Equal(t, block.StateSyncSkipped, mode) + require.True(t, called) +} diff --git a/plugin/evm/atomic/sync/atomic_syncer_test.go b/plugin/evm/atomic/sync/atomic_syncer_test.go index fa47e52279..41984a7c2f 100644 --- a/plugin/evm/atomic/sync/atomic_syncer_test.go +++ b/plugin/evm/atomic/sync/atomic_syncer_test.go @@ -47,14 +47,14 @@ func testAtomicSyncer(t *testing.T, serverTrieDB *triedb.Database, targetHeight numLeaves := 0 mockClient := syncclient.NewMockClient( - message.Codec, - handlers.NewLeafsRequestHandler(serverTrieDB, state.AtomicTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()), + codecWithAtomicSync, + handlers.NewLeafsRequestHandler(serverTrieDB, state.AtomicTrieKeyLength, nil, codecWithAtomicSync, handlerstats.NewNoopHandlerStats()), nil, nil, ) clientDB := versiondb.New(memdb.New()) - repo, err := state.NewAtomicTxRepository(clientDB, message.Codec, 0) + repo, err := state.NewAtomicTxRepository(clientDB, codecWithAtomicSync, 0) if err != nil { t.Fatal("could not initialize atomix tx repository", err) } diff --git a/plugin/evm/message/block_request.go b/plugin/evm/message/block_request.go index f1f353f2f7..9bb053bf07 100644 --- a/plugin/evm/message/block_request.go +++ b/plugin/evm/message/block_request.go @@ -12,9 +12,7 @@ import ( "github.com/ethereum/go-ethereum/common" ) -var ( - _ Request = BlockRequest{} -) +var _ Request = BlockRequest{} // BlockRequest is a request to retrieve Parents number of blocks starting from Hash from newest-oldest manner type BlockRequest struct { diff --git a/plugin/evm/message/block_request_test.go b/plugin/evm/message/block_request_test.go index cd9070117d..9e519676d9 100644 --- a/plugin/evm/message/block_request_test.go +++ b/plugin/evm/message/block_request_test.go @@ -23,12 +23,12 @@ func TestMarshalBlockRequest(t *testing.T) { base64BlockRequest := "AAAAAAAAAAAAAAAAAABzb21lIGhhc2ggaXMgaGVyZSB5bwAAAAAAAAU5AEA=" - blockRequestBytes, err := Codec.Marshal(Version, blockRequest) + blockRequestBytes, err := codecWithBlockSync.Marshal(Version, blockRequest) assert.NoError(t, err) assert.Equal(t, base64BlockRequest, base64.StdEncoding.EncodeToString(blockRequestBytes)) var b BlockRequest - _, err = Codec.Unmarshal(blockRequestBytes, &b) + _, err = codecWithBlockSync.Unmarshal(blockRequestBytes, &b) assert.NoError(t, err) assert.Equal(t, blockRequest.Hash, b.Hash) assert.Equal(t, blockRequest.Height, b.Height) @@ -54,12 +54,12 @@ func TestMarshalBlockResponse(t *testing.T) { base64BlockResponse := "AAAAAAAgAAAAIU8WP18PmmIdcpVmx00QA3xNe7sEB9HixkmBhVrYaB0NhgAAADnR6ZTSxCKs0gigByk5SH9pmeudGKRHhARdh/PGfPInRumVr1olNnlRuqL/bNRxxIPxX7kLrbN8WCEAAAA6tmgLTnyLdjobHUnUlVyEhiFjJSU/7HON16nii/khEZwWDwcCRIYVu9oIMT9qjrZo0gv1BZh1kh5migAAACtb3yx/xIRo0tbFL1BU4tCDa/hMcXTLdHY2TMPb2Wiw9xcu2FeUuzWLDDtSAAAAO12heG+f69ehnQ97usvgJVqlt9RL7ED4TIkrm//UNimwIjvupfT3Q5H0RdFa/UKUBAN09pJLmMv4cT+NAAAAMpYtJOLK/Mrjph+1hrFDI6a8j5598dkpMz/5k5M76m9bOvbeA3Q2bEcZ5DobBn2JvH8BAAAAOfHxekxyFaO1OeseWEnGB327VyL1cXoomiZvl2R5gZmOvqicC0s3OXARXoLtb0ElyPpzEeTX3vqSLQAAACc2zU8kq/ffhmuqVgODZ61hRd4e6PSosJk+vfiIOgrYvpw5eLBIg+UAAAAkahVqnexqQOmh0AfwM8KCMGG90Oqln45NpkMBBSINCyloi3NLAAAAKI6gENd8luqAp6Zl9gb2pjt/Pf0lZ8GJeeTWDyZobZvy+ybJAf81TN4AAAA8FgfuKbpk+Eq0PKDG5rkcH9O+iZBDQXnTr0SRo2kBLbktGE/DnRc0/1cWQolTu2hl/PkrDDoXyQKL6ZFOAAAAMwl50YMDVvKlTD3qsqS0R11jr76PtWmHx39YGFJvGBS+gjNQ6rE5NfMdhEhFF+kkrveK4QAAADhRwAdVkgww7CmjcDk0v1CijaECl13tp351hXnqPf5BNqv3UrO4Jx0D6USzyds2a3UEX479adIq5QAAADpBGUfLVbzqQGsy1hCL1oWE9X43yqxuM/6qMmOjmUNwJLqcmxRniidPAakQrilfbvv+X1q/RMzeJjtWAAAAKAZjPn05Bp8BojnENlhUw69/a0HWMfkrmo0S9BJXMl//My91drBiBVYAAAAqMEo+Pq6QGlJyDahcoeSzjq8/RMbG74Ni8vVPwA4J1vwlZAhUwV38rKqKAAAAOyzszlo6lLTTOKUUPmNAjYcksM8/rhej95vhBy+2PDXWBCxBYPOO6eKp8/tP+wAZtFTVIrX/oXYEGT+4AAAAMpZnz1PD9SDIibeb9QTPtXx2ASMtWJuszqnW4mPiXCd0HT9sYsu7FdmvvL9/faQasECOAAAALzk4vxd0rOdwmk8JHpqD/erg7FXrIzqbU5TLPHhWtUbTE8ijtMHA4FRH9Lo3DrNtAAAAPLz97PUi4qbx7Qr+wfjiD6q+32sWLnF9OnSKWGd6DFY0j4khomaxHQ8zTGL+UrpTrxl3nLKUi2Vw/6C3cwAAADqWPBMK15dRJSEPDvHDFAkPB8eab1ccJG8+msC3QT7xEL1YsAznO/9wb3/0tvRAkKMnEfMgjk5LictRAAAAJ2XOZAA98kaJKNWiO5ynQPgMk4LZxgNK0pYMeWUD4c4iFyX1DK8fvwAAADtcR6U9v459yvyeE4ZHpLRO1LzpZO1H90qllEaM7TI8t28NP6xHbJ+wP8kij7roj9WAZjoEVLaDEiB/CgAAADc7WExi1QJ84VpPClglDY+1Dnfyv08BUuXUlDWAf51Ll75vt3lwRmpWJv4zQIz56I4seXQIoy0pAAAAKkFrryBqmDIJgsharXA4SFnAWksTodWy9b/vWm7ZLaSCyqlWjltv6dip3QAAAC7Z6wkne1AJRMvoAKCxUn6mRymoYdL2SXoyNcN/QZJ3nsHZazscVCT84LcnsDByAAAAI+ZAq8lEj93rIZHZRcBHZ6+Eev0O212IV7eZrLGOSv+r4wN/AAAAL/7MQW5zTTc8Xr68nNzFlbzOPHvT2N+T+rfhJd3rr+ZaMb1dQeLSzpwrF4kvD+oZAAAAMTGikNy/poQG6HcHP/CINOGXpANKpIr6P4W4picIyuu6yIC1uJuT2lOBAWRAIQTmSLYAAAA1ImobDzE6id38RUxfj3KsibOLGfU3hMGem+rAPIdaJ9sCneN643pCMYgTSHaFkpNZyoxeuU4AAAA9FS3Br0LquOKSXG2u5N5e+fnc8I38vQK4CAk5hYWSig995QvhptwdV2joU3mI/dzlYum5SMkYu6PpM+XEAAAAAC3Nrne6HSWbGIpLIchvvCPXKLRTR+raZQryTFbQgAqGkTMgiKgFvVXERuJesHU=" - blockResponseBytes, err := Codec.Marshal(Version, blockResponse) + blockResponseBytes, err := codecWithBlockSync.Marshal(Version, blockResponse) assert.NoError(t, err) assert.Equal(t, base64BlockResponse, base64.StdEncoding.EncodeToString(blockResponseBytes)) var b BlockResponse - _, err = Codec.Unmarshal(blockResponseBytes, &b) + _, err = codecWithBlockSync.Unmarshal(blockResponseBytes, &b) assert.NoError(t, err) assert.Equal(t, blockResponse.Blocks, b.Blocks) } diff --git a/plugin/evm/message/block_sync_summary.go b/plugin/evm/message/block_sync_summary.go new file mode 100644 index 0000000000..04c74bb682 --- /dev/null +++ b/plugin/evm/message/block_sync_summary.go @@ -0,0 +1,119 @@ +// (c) 2021-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package message + +import ( + "context" + "fmt" + + "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/ids" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + + "github.com/ava-labs/avalanchego/snow/engine/snowman/block" +) + +var _ Syncable = (*BlockSyncSummary)(nil) + +// codecWithBlockSync is the codec manager that contains the codec for BlockSyncSummary and +// other message types that are used in the network protocol. This is to ensure that the codec +// version is consistent across all message types and includes the codec for BlockSyncSummary. +var codecWithBlockSync codec.Manager + +func init() { + var err error + codecWithBlockSync, err = NewCodec(BlockSyncSummary{}) + if err != nil { + panic(fmt.Errorf("failed to create codec manager: %w", err)) + } +} + +// BlockSyncSummary provides the information necessary to sync a node starting +// at the given block. +type BlockSyncSummary struct { + BlockNumber uint64 `serialize:"true"` + BlockHash common.Hash `serialize:"true"` + BlockRoot common.Hash `serialize:"true"` + + summaryID ids.ID + bytes []byte + acceptImpl AcceptImplFn +} + +type BlockSyncSummaryParser struct{} + +func NewBlockSyncSummaryParser() *BlockSyncSummaryParser { + return &BlockSyncSummaryParser{} +} + +func (b *BlockSyncSummaryParser) ParseFromBytes(summaryBytes []byte, acceptImpl AcceptImplFn) (Syncable, error) { + summary := BlockSyncSummary{} + if codecVersion, err := codecWithBlockSync.Unmarshal(summaryBytes, &summary); err != nil { + return nil, fmt.Errorf("failed to parse syncable summary: %w", err) + } else if codecVersion != Version { + return nil, fmt.Errorf("failed to parse syncable summary due to unexpected codec version (%d != %d)", codecVersion, Version) + } + + summary.bytes = summaryBytes + summaryID, err := ids.ToID(crypto.Keccak256(summaryBytes)) + if err != nil { + return nil, fmt.Errorf("failed to compute summary ID: %w", err) + } + summary.summaryID = summaryID + summary.acceptImpl = acceptImpl + return &summary, nil +} + +func NewBlockSyncSummary(blockHash common.Hash, blockNumber uint64, blockRoot common.Hash) (*BlockSyncSummary, error) { + summary := BlockSyncSummary{ + BlockNumber: blockNumber, + BlockHash: blockHash, + BlockRoot: blockRoot, + } + bytes, err := codecWithBlockSync.Marshal(Version, &summary) + if err != nil { + return nil, fmt.Errorf("failed to marshal syncable summary: %w", err) + } + + summary.bytes = bytes + summaryID, err := ids.ToID(crypto.Keccak256(bytes)) + if err != nil { + return nil, fmt.Errorf("failed to compute summary ID: %w", err) + } + summary.summaryID = summaryID + + return &summary, nil +} + +func (s *BlockSyncSummary) GetBlockHash() common.Hash { + return s.BlockHash +} + +func (s *BlockSyncSummary) GetBlockRoot() common.Hash { + return s.BlockRoot +} + +func (s *BlockSyncSummary) Bytes() []byte { + return s.bytes +} + +func (s *BlockSyncSummary) Height() uint64 { + return s.BlockNumber +} + +func (s *BlockSyncSummary) ID() ids.ID { + return s.summaryID +} + +func (s *BlockSyncSummary) String() string { + return fmt.Sprintf("BlockSyncSummary(BlockHash=%s, BlockNumber=%d, BlockRoot=%s)", s.BlockHash, s.BlockNumber, s.BlockRoot) +} + +func (s *BlockSyncSummary) Accept(context.Context) (block.StateSyncMode, error) { + if s.acceptImpl == nil { + return block.StateSyncSkipped, fmt.Errorf("accept implementation not specified for summary: %s", s) + } + return s.acceptImpl(s) +} diff --git a/plugin/evm/message/block_sync_summary_test.go b/plugin/evm/message/block_sync_summary_test.go new file mode 100644 index 0000000000..f7a4c19975 --- /dev/null +++ b/plugin/evm/message/block_sync_summary_test.go @@ -0,0 +1,44 @@ +// (c) 2021-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package message + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/ava-labs/avalanchego/snow/engine/snowman/block" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" +) + +func TestMarshalBlockSyncSummary(t *testing.T) { + blockSyncSummary, err := NewBlockSyncSummary(common.Hash{1}, 2, common.Hash{3}) + require.NoError(t, err) + + require.Equal(t, common.Hash{1}, blockSyncSummary.GetBlockHash()) + require.Equal(t, uint64(2), blockSyncSummary.Height()) + require.Equal(t, common.Hash{3}, blockSyncSummary.GetBlockRoot()) + + expectedBase64Bytes := "AAAAAAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" + require.Equal(t, expectedBase64Bytes, base64.StdEncoding.EncodeToString(blockSyncSummary.Bytes())) + + parser := NewBlockSyncSummaryParser() + called := false + acceptImplTest := func(Syncable) (block.StateSyncMode, error) { + called = true + return block.StateSyncSkipped, nil + } + s, err := parser.ParseFromBytes(blockSyncSummary.Bytes(), acceptImplTest) + require.NoError(t, err) + require.Equal(t, blockSyncSummary.GetBlockHash(), s.GetBlockHash()) + require.Equal(t, blockSyncSummary.Height(), s.Height()) + require.Equal(t, blockSyncSummary.GetBlockRoot(), s.GetBlockRoot()) + require.Equal(t, blockSyncSummary.Bytes(), s.Bytes()) + + mode, err := s.Accept(context.TODO()) + require.NoError(t, err) + require.Equal(t, block.StateSyncSkipped, mode) + require.True(t, called) +} diff --git a/plugin/evm/message/code_request_test.go b/plugin/evm/message/code_request_test.go index 88cedb54d4..10321a389d 100644 --- a/plugin/evm/message/code_request_test.go +++ b/plugin/evm/message/code_request_test.go @@ -21,12 +21,12 @@ func TestMarshalCodeRequest(t *testing.T) { base64CodeRequest := "AAAAAAABAAAAAAAAAAAAAAAAAAAAAAAAAHNvbWUgY29kZSBwbHM=" - codeRequestBytes, err := Codec.Marshal(Version, codeRequest) + codeRequestBytes, err := codecWithBlockSync.Marshal(Version, codeRequest) assert.NoError(t, err) assert.Equal(t, base64CodeRequest, base64.StdEncoding.EncodeToString(codeRequestBytes)) var c CodeRequest - _, err = Codec.Unmarshal(codeRequestBytes, &c) + _, err = codecWithBlockSync.Unmarshal(codeRequestBytes, &c) assert.NoError(t, err) assert.Equal(t, codeRequest.Hashes, c.Hashes) } @@ -47,12 +47,12 @@ func TestMarshalCodeResponse(t *testing.T) { base64CodeResponse := "AAAAAAABAAAAMlL9/AchgmVPFj9fD5piHXKVZsdNEAN8TXu7BAfR4sZJgYVa2GgdDYbR6R4AFnk5y2aU" - codeResponseBytes, err := Codec.Marshal(Version, codeResponse) + codeResponseBytes, err := codecWithBlockSync.Marshal(Version, codeResponse) assert.NoError(t, err) assert.Equal(t, base64CodeResponse, base64.StdEncoding.EncodeToString(codeResponseBytes)) var c CodeResponse - _, err = Codec.Unmarshal(codeResponseBytes, &c) + _, err = codecWithBlockSync.Unmarshal(codeResponseBytes, &c) assert.NoError(t, err) assert.Equal(t, codeResponse.Data, c.Data) } diff --git a/plugin/evm/message/codec.go b/plugin/evm/message/codec.go index d7de1820c6..d2599647a3 100644 --- a/plugin/evm/message/codec.go +++ b/plugin/evm/message/codec.go @@ -15,14 +15,11 @@ const ( maxMessageSize = 2*units.MiB - 64*units.KiB // Subtract 64 KiB from p2p network cap to leave room for encoding overhead from AvalancheGo ) -var ( - Codec codec.Manager - // TODO: Remove this once we have a better way to register types (i.e use a different codec version or use build flags) - SyncSummaryType interface{} = BlockSyncSummary{} -) - -func init() { - Codec = codec.NewManager(maxMessageSize) +// NewCodec returns a codec manager that can be used to marshal and unmarshal +// messages, including the provided syncSummaryType. syncSummaryType can be used +// to register a type for sync summaries. +func NewCodec(syncSummaryType interface{}) (codec.Manager, error) { + codec := codec.NewManager(maxMessageSize) c := linearcodec.NewDefault() errs := wrappers.Errs{} @@ -30,7 +27,7 @@ func init() { c.SkipRegistrations(2) errs.Add( // Types for state sync frontier consensus - c.RegisterType(SyncSummaryType), + c.RegisterType(syncSummaryType), // state sync types c.RegisterType(BlockRequest{}), @@ -45,10 +42,12 @@ func init() { c.RegisterType(BlockSignatureRequest{}), c.RegisterType(SignatureResponse{}), - Codec.RegisterCodec(Version, c), + codec.RegisterCodec(Version, c), ) if errs.Errored() { - panic(errs.Err) + return nil, errs.Err } + + return codec, nil } diff --git a/plugin/evm/message/leafs_request_test.go b/plugin/evm/message/leafs_request_test.go index f70aad7bba..9b98c45aba 100644 --- a/plugin/evm/message/leafs_request_test.go +++ b/plugin/evm/message/leafs_request_test.go @@ -38,12 +38,12 @@ func TestMarshalLeafsRequest(t *testing.T) { base64LeafsRequest := "AAAAAAAAAAAAAAAAAAAAAABpbSBST09UaW5nIGZvciB5YQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIFL9/AchgmVPFj9fD5piHXKVZsdNEAN8TXu7BAfR4sZJAAAAIIGFWthoHQ2G0ekeABZ5OctmlNLEIqzSCKAHKTlIf2mZBAAB" - leafsRequestBytes, err := Codec.Marshal(Version, leafsRequest) + leafsRequestBytes, err := codecWithBlockSync.Marshal(Version, leafsRequest) assert.NoError(t, err) assert.Equal(t, base64LeafsRequest, base64.StdEncoding.EncodeToString(leafsRequestBytes)) var l LeafsRequest - _, err = Codec.Unmarshal(leafsRequestBytes, &l) + _, err = codecWithBlockSync.Unmarshal(leafsRequestBytes, &l) assert.NoError(t, err) assert.Equal(t, leafsRequest.Root, l.Root) assert.Equal(t, leafsRequest.Start, l.Start) @@ -92,12 +92,12 @@ func TestMarshalLeafsResponse(t *testing.T) { base64LeafsResponse := "AAAAAAAQAAAAIE8WP18PmmIdcpVmx00QA3xNe7sEB9HixkmBhVrYaB0NAAAAIGagByk5SH9pmeudGKRHhARdh/PGfPInRumVr1olNnlRAAAAIK2zfFghtmgLTnyLdjobHUnUlVyEhiFjJSU/7HON16niAAAAIIYVu9oIMfUFmHWSHmaKW98sf8SERZLSVyvNBmjS1sUvAAAAIHHb2Wiw9xcu2FeUuzWLDDtSXaF4b5//CUJ52xlE69ehAAAAIPhMiSs77qX090OR9EXRWv1ClAQDdPaSS5jL+HE/jZYtAAAAIMr8yuOmvI+effHZKTM/+ZOTO+pvWzr23gN0NmxHGeQ6AAAAIBZZpE856x5YScYHfbtXIvVxeiiaJm+XZHmBmY6+qJwLAAAAIHOq53hmZ/fpNs1PJKv334ZrqlYDg2etYUXeHuj0qLCZAAAAIHiN5WOvpGfUnexqQOmh0AfwM8KCMGG90Oqln45NpkMBAAAAIKAQ13yW6oCnpmX2BvamO389/SVnwYl55NYPJmhtm/L7AAAAIAfuKbpk+Eq0PKDG5rkcH9O+iZBDQXnTr0SRo2kBLbktAAAAILsXyQKL6ZFOt2ScbJNHgAl50YMDVvKlTD3qsqS0R11jAAAAIOqxOTXzHYRIRRfpJK73iuFRwAdVklg2twdYhWUMMOwpAAAAIHnqPf5BNqv3UrO4Jx0D6USzyds2a3UEX479adIq5UEZAAAAIDLWEMqsbjP+qjJjo5lDcCS6nJsUZ4onTwGpEK4pX277AAAAEAAAAAmG0ekeABZ5OcsAAAAMuqL/bNRxxIPxX7kLAAAACov5IRGcFg8HAkQAAAAIUFTi0INr+EwAAAAOnQ97usvgJVqlt9RL7EAAAAAJfI0BkZLCQiTiAAAACxsGfYm8fwHx9XOYAAAADUs3OXARXoLtb0ElyPoAAAAKPr34iDoK2L6cOQAAAAoFIg0LKWiLc0uOAAAACCbJAf81TN4WAAAADBhPw50XNP9XFkKJUwAAAAuvvo+1aYfHf1gYUgAAAAqjcDk0v1CijaECAAAADkfLVT12lCZ670686kBrAAAADf5fWr9EzN4mO1YGYz4AAAAEAAAADlcyXwVWMEo+Pq4Uwo0MAAAADeo50qHks46vP0TGxu8AAAAOg2Ly9WQIVMFd/KyqiiwAAAAL7M5aOpS00zilFD4=" - leafsResponseBytes, err := Codec.Marshal(Version, leafsResponse) + leafsResponseBytes, err := codecWithBlockSync.Marshal(Version, leafsResponse) assert.NoError(t, err) assert.Equal(t, base64LeafsResponse, base64.StdEncoding.EncodeToString(leafsResponseBytes)) var l LeafsResponse - _, err = Codec.Unmarshal(leafsResponseBytes, &l) + _, err = codecWithBlockSync.Unmarshal(leafsResponseBytes, &l) assert.NoError(t, err) assert.Equal(t, leafsResponse.Keys, l.Keys) assert.Equal(t, leafsResponse.Vals, l.Vals) diff --git a/plugin/evm/message/signature_request_test.go b/plugin/evm/message/signature_request_test.go index 59614fbb2e..a9fb5470c0 100644 --- a/plugin/evm/message/signature_request_test.go +++ b/plugin/evm/message/signature_request_test.go @@ -21,12 +21,12 @@ func TestMarshalMessageSignatureRequest(t *testing.T) { } base64MessageSignatureRequest := "AABET0ZBSElAawAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==" - signatureRequestBytes, err := Codec.Marshal(Version, signatureRequest) + signatureRequestBytes, err := codecWithBlockSync.Marshal(Version, signatureRequest) require.NoError(t, err) require.Equal(t, base64MessageSignatureRequest, base64.StdEncoding.EncodeToString(signatureRequestBytes)) var s MessageSignatureRequest - _, err = Codec.Unmarshal(signatureRequestBytes, &s) + _, err = codecWithBlockSync.Unmarshal(signatureRequestBytes, &s) require.NoError(t, err) require.Equal(t, signatureRequest.MessageID, s.MessageID) } @@ -39,12 +39,12 @@ func TestMarshalBlockSignatureRequest(t *testing.T) { } base64BlockSignatureRequest := "AABET0ZBSElAawAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==" - signatureRequestBytes, err := Codec.Marshal(Version, signatureRequest) + signatureRequestBytes, err := codecWithBlockSync.Marshal(Version, signatureRequest) require.NoError(t, err) require.Equal(t, base64BlockSignatureRequest, base64.StdEncoding.EncodeToString(signatureRequestBytes)) var s BlockSignatureRequest - _, err = Codec.Unmarshal(signatureRequestBytes, &s) + _, err = codecWithBlockSync.Unmarshal(signatureRequestBytes, &s) require.NoError(t, err) require.Equal(t, signatureRequest.BlockID, s.BlockID) } @@ -62,12 +62,12 @@ func TestMarshalSignatureResponse(t *testing.T) { } base64SignatureResponse := "AAABI0VniavN7wEjRWeJq83vASNFZ4mrze8BI0VniavN7wEjRWeJq83vASNFZ4mrze8BI0VniavN7wEjRWeJq83vASNFZ4mrze8BI0VniavN7wEjRWeJq83vASNFZ4mrze8=" - signatureResponseBytes, err := Codec.Marshal(Version, signatureResponse) + signatureResponseBytes, err := codecWithBlockSync.Marshal(Version, signatureResponse) require.NoError(t, err) require.Equal(t, base64SignatureResponse, base64.StdEncoding.EncodeToString(signatureResponseBytes)) var s SignatureResponse - _, err = Codec.Unmarshal(signatureResponseBytes, &s) + _, err = codecWithBlockSync.Unmarshal(signatureResponseBytes, &s) require.NoError(t, err) require.Equal(t, signatureResponse.Signature, s.Signature) } diff --git a/plugin/evm/message/syncable.go b/plugin/evm/message/syncable.go index 8a8918b891..c83ae2d8cb 100644 --- a/plugin/evm/message/syncable.go +++ b/plugin/evm/message/syncable.go @@ -4,21 +4,13 @@ package message import ( - "context" - "fmt" - - "github.com/ava-labs/avalanchego/ids" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" "github.com/ava-labs/avalanchego/snow/engine/snowman/block" ) -var _ Syncable = (*BlockSyncSummary)(nil) - type Syncable interface { block.StateSummary - GetBlockNumber() uint64 GetBlockHash() common.Hash GetBlockRoot() common.Hash } @@ -28,95 +20,3 @@ type SyncableParser interface { } type AcceptImplFn func(Syncable) (block.StateSyncMode, error) - -// BlockSyncSummary provides the information necessary to sync a node starting -// at the given block. -type BlockSyncSummary struct { - BlockNumber uint64 `serialize:"true"` - BlockHash common.Hash `serialize:"true"` - BlockRoot common.Hash `serialize:"true"` - - summaryID ids.ID - bytes []byte - acceptImpl AcceptImplFn -} - -type BlockSyncSummaryParser struct{} - -func NewBlockSyncSummaryParser() *BlockSyncSummaryParser { - return &BlockSyncSummaryParser{} -} - -func (b *BlockSyncSummaryParser) ParseFromBytes(summaryBytes []byte, acceptImpl AcceptImplFn) (Syncable, error) { - summary := BlockSyncSummary{} - if codecVersion, err := Codec.Unmarshal(summaryBytes, &summary); err != nil { - return nil, fmt.Errorf("failed to parse syncable summary: %w", err) - } else if codecVersion != Version { - return nil, fmt.Errorf("failed to parse syncable summary due to unexpected codec version (%d != %d)", codecVersion, Version) - } - - summary.bytes = summaryBytes - summaryID, err := ids.ToID(crypto.Keccak256(summaryBytes)) - if err != nil { - return nil, fmt.Errorf("failed to compute summary ID: %w", err) - } - summary.summaryID = summaryID - summary.acceptImpl = acceptImpl - return &summary, nil -} - -func NewBlockSyncSummary(blockHash common.Hash, blockNumber uint64, blockRoot common.Hash) (*BlockSyncSummary, error) { - summary := BlockSyncSummary{ - BlockNumber: blockNumber, - BlockHash: blockHash, - BlockRoot: blockRoot, - } - bytes, err := Codec.Marshal(Version, &summary) - if err != nil { - return nil, fmt.Errorf("failed to marshal syncable summary: %w", err) - } - - summary.bytes = bytes - summaryID, err := ids.ToID(crypto.Keccak256(bytes)) - if err != nil { - return nil, fmt.Errorf("failed to compute summary ID: %w", err) - } - summary.summaryID = summaryID - - return &summary, nil -} - -func (s *BlockSyncSummary) GetBlockNumber() uint64 { - return s.BlockNumber -} - -func (s *BlockSyncSummary) GetBlockHash() common.Hash { - return s.BlockHash -} - -func (s *BlockSyncSummary) GetBlockRoot() common.Hash { - return s.BlockRoot -} - -func (s *BlockSyncSummary) Bytes() []byte { - return s.bytes -} - -func (s *BlockSyncSummary) Height() uint64 { - return s.BlockNumber -} - -func (s *BlockSyncSummary) ID() ids.ID { - return s.summaryID -} - -func (s *BlockSyncSummary) String() string { - return fmt.Sprintf("BlockSyncSummary(BlockHash=%s, BlockNumber=%d, BlockRoot=%s)", s.BlockHash, s.BlockNumber, s.BlockRoot) -} - -func (s *BlockSyncSummary) Accept(context.Context) (block.StateSyncMode, error) { - if s.acceptImpl == nil { - return block.StateSyncSkipped, fmt.Errorf("accept implementation not specified for summary: %s", s) - } - return s.acceptImpl(s) -} diff --git a/plugin/evm/sync/syncervm_client.go b/plugin/evm/sync/syncervm_client.go index f2e8dac7d9..b315895127 100644 --- a/plugin/evm/sync/syncervm_client.go +++ b/plugin/evm/sync/syncervm_client.go @@ -161,7 +161,7 @@ func (client *stateSyncerClient) ParseStateSummary(_ context.Context, summaryByt // stateSync blockingly performs the state sync for the EVM state and the atomic state // to [client.syncSummary]. returns an error if one occurred. func (client *stateSyncerClient) stateSync(ctx context.Context) error { - if err := client.syncBlocks(ctx, client.syncSummary.GetBlockHash(), client.syncSummary.GetBlockNumber(), ParentsToFetch); err != nil { + if err := client.syncBlocks(ctx, client.syncSummary.GetBlockHash(), client.syncSummary.Height(), ParentsToFetch); err != nil { return err } @@ -345,8 +345,8 @@ func (client *stateSyncerClient) finishSync() error { if block.Hash() != client.syncSummary.GetBlockHash() { return fmt.Errorf("attempted to set last summary block to unexpected block hash: (%s != %s)", block.Hash(), client.syncSummary.GetBlockHash()) } - if block.NumberU64() != client.syncSummary.GetBlockNumber() { - return fmt.Errorf("attempted to set last summary block to unexpected block number: (%d != %d)", block.NumberU64(), client.syncSummary.GetBlockNumber()) + if block.NumberU64() != client.syncSummary.Height() { + return fmt.Errorf("attempted to set last summary block to unexpected block number: (%d != %d)", block.NumberU64(), client.syncSummary.Height()) } // BloomIndexer needs to know that some parts of the chain are not available diff --git a/plugin/evm/vm.go b/plugin/evm/vm.go index 85c1530d00..ef6c535d40 100644 --- a/plugin/evm/vm.go +++ b/plugin/evm/vm.go @@ -159,6 +159,8 @@ var ( metadataPrefix = []byte("metadata") warpPrefix = []byte("warp") ethDBPrefix = []byte("ethdb") + + networkCodec codec.Manager ) var ( @@ -200,6 +202,13 @@ func init() { // Preserving the log level allows us to update the root handler while writing to the original // [os.Stderr] that is being piped through to the logger via the rpcchainvm. originalStderr = os.Stderr + + // Register the codec for the atomic block sync summary + var err error + networkCodec, err = message.NewCodec(atomicsync.AtomicSyncSummary{}) + if err != nil { + panic(fmt.Errorf("failed to create codec manager: %w", err)) + } } // VM implements the snowman.ChainVM interface @@ -271,8 +280,7 @@ type VM struct { profiler profiler.ContinuousProfiler peer.Network - client peer.NetworkClient - networkCodec codec.Manager + client peer.NetworkClient p2pValidators *p2p.Validators @@ -541,8 +549,7 @@ func (vm *VM) Initialize( return fmt.Errorf("failed to initialize p2p network: %w", err) } vm.p2pValidators = p2p.NewValidators(p2pNetwork.Peers, vm.ctx.Log, vm.ctx.SubnetID, vm.ctx.ValidatorState, maxValidatorSetStaleness) - vm.networkCodec = message.Codec - vm.Network = peer.NewNetwork(p2pNetwork, appSender, vm.networkCodec, chainCtx.NodeID, vm.config.MaxOutboundActiveRequests) + vm.Network = peer.NewNetwork(p2pNetwork, appSender, networkCodec, chainCtx.NodeID, vm.config.MaxOutboundActiveRequests) vm.client = peer.NewNetworkClient(vm.Network) // Initialize warp backend @@ -710,7 +717,7 @@ func (vm *VM) initializeStateSyncClient(lastAcceptedHeight uint64) error { Client: statesyncclient.NewClient( &statesyncclient.ClientConfig{ NetworkClient: vm.client, - Codec: vm.networkCodec, + Codec: networkCodec, Stats: stats.NewClientSyncerStats(leafMetricsNames), StateSyncNodeIDs: stateSyncIDs, BlockParser: vm, @@ -1239,7 +1246,7 @@ func (vm *VM) setAppRequestHandlers() error { vm.blockChain, vm.chaindb, vm.warpBackend, - vm.networkCodec, + networkCodec, vm.leafRequestTypeConfigs, ) vm.Network.SetRequestHandler(networkHandler) @@ -1493,7 +1500,8 @@ func (vm *VM) CreateHandlers(context.Context) (map[string]http.Handler, error) { } if vm.config.WarpAPIEnabled { - if err := handler.RegisterName("warp", warp.NewAPI(vm.ctx.NetworkID, vm.ctx.SubnetID, vm.ctx.ChainID, vm.ctx.ValidatorState, vm.warpBackend, vm.client, vm.requirePrimaryNetworkSigners)); err != nil { + warpAPI := warp.NewAPI(vm.ctx, networkCodec, vm.warpBackend, vm.client, vm.requirePrimaryNetworkSigners) + if err := handler.RegisterName("warp", warpAPI); err != nil { return nil, err } enabledAPIs = append(enabledAPIs, "warp") diff --git a/plugin/evm/vm_warp_test.go b/plugin/evm/vm_warp_test.go index e44f5ba606..b2d7dd46d8 100644 --- a/plugin/evm/vm_warp_test.go +++ b/plugin/evm/vm_warp_test.go @@ -753,7 +753,7 @@ func TestMessageSignatureRequestsToVM(t *testing.T) { appSender.SendAppResponseF = func(ctx context.Context, nodeID ids.NodeID, requestID uint32, responseBytes []byte) error { calledSendAppResponseFn = true var response message.SignatureResponse - _, err := message.Codec.Unmarshal(responseBytes, &response) + _, err := networkCodec.Unmarshal(responseBytes, &response) require.NoError(t, err) require.Equal(t, test.expectedResponse, response.Signature) @@ -764,7 +764,7 @@ func TestMessageSignatureRequestsToVM(t *testing.T) { MessageID: test.messageID, } - requestBytes, err := message.Codec.Marshal(message.Version, &signatureRequest) + requestBytes, err := networkCodec.Marshal(message.Version, &signatureRequest) require.NoError(t, err) // Send the app request and make sure we called SendAppResponseFn @@ -811,7 +811,7 @@ func TestBlockSignatureRequestsToVM(t *testing.T) { appSender.SendAppResponseF = func(ctx context.Context, nodeID ids.NodeID, requestID uint32, responseBytes []byte) error { calledSendAppResponseFn = true var response message.SignatureResponse - _, err := message.Codec.Unmarshal(responseBytes, &response) + _, err := networkCodec.Unmarshal(responseBytes, &response) require.NoError(t, err) require.Equal(t, test.expectedResponse, response.Signature) @@ -822,7 +822,7 @@ func TestBlockSignatureRequestsToVM(t *testing.T) { BlockID: test.blockID, } - requestBytes, err := message.Codec.Marshal(message.Version, &signatureRequest) + requestBytes, err := networkCodec.Marshal(message.Version, &signatureRequest) require.NoError(t, err) // Send the app request and make sure we called SendAppResponseFn diff --git a/sync/client/client_test.go b/sync/client/client_test.go index 167b5ce120..31658dafb4 100644 --- a/sync/client/client_test.go +++ b/sync/client/client_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/coreth/consensus/dummy" @@ -30,6 +31,16 @@ import ( "github.com/ethereum/go-ethereum/crypto" ) +var networkCodec codec.Manager + +func init() { + var err error + networkCodec, err = message.NewCodec(message.BlockSyncSummary{}) + if err != nil { + panic(err) + } +} + func TestGetCode(t *testing.T) { mockNetClient := &mockNetwork{} @@ -86,7 +97,7 @@ func TestGetCode(t *testing.T) { stateSyncClient := NewClient(&ClientConfig{ NetworkClient: mockNetClient, - Codec: message.Codec, + Codec: networkCodec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, BlockParser: mockBlockParser, @@ -98,7 +109,7 @@ func TestGetCode(t *testing.T) { defer cancel() codeHashes, res, expectedCode := test.setupRequest() - responseBytes, err := message.Codec.Marshal(message.Version, res) + responseBytes, err := networkCodec.Marshal(message.Version, res) if err != nil { t.Fatal(err) } @@ -157,13 +168,13 @@ func TestGetBlocks(t *testing.T) { mockNetClient := &mockNetwork{} stateSyncClient := NewClient(&ClientConfig{ NetworkClient: mockNetClient, - Codec: message.Codec, + Codec: networkCodec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, BlockParser: mockBlockParser, }) - blocksRequestHandler := handlers.NewBlockRequestHandler(buildGetter(blocks), message.Codec, handlerstats.NewNoopHandlerStats()) + blocksRequestHandler := handlers.NewBlockRequestHandler(buildGetter(blocks), networkCodec, handlerstats.NewNoopHandlerStats()) // encodeBlockSlice takes a slice of blocks that are ordered in increasing height order // and returns a slice of byte slices with those blocks encoded in reverse order @@ -254,12 +265,12 @@ func TestGetBlocks(t *testing.T) { t.Fatalf("failed to get block response: %s", err) } var blockResponse message.BlockResponse - if _, err = message.Codec.Unmarshal(response, &blockResponse); err != nil { + if _, err = networkCodec.Unmarshal(response, &blockResponse); err != nil { t.Fatalf("failed to marshal block response: %s", err) } // Replace middle value with garbage data blockResponse.Blocks[10] = []byte("invalid value replacing block bytes") - responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) + responseBytes, err := networkCodec.Marshal(message.Version, blockResponse) if err != nil { t.Fatalf("failed to marshal block response: %s", err) } @@ -308,7 +319,7 @@ func TestGetBlocks(t *testing.T) { blockResponse := message.BlockResponse{ Blocks: blockBytes, } - responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) + responseBytes, err := networkCodec.Marshal(message.Version, blockResponse) if err != nil { t.Fatalf("failed to marshal block response: %s", err) } @@ -327,7 +338,7 @@ func TestGetBlocks(t *testing.T) { blockResponse := message.BlockResponse{ Blocks: nil, } - responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) + responseBytes, err := networkCodec.Marshal(message.Version, blockResponse) if err != nil { t.Fatalf("failed to marshal block response: %s", err) } @@ -348,7 +359,7 @@ func TestGetBlocks(t *testing.T) { blockResponse := message.BlockResponse{ Blocks: blockBytes, } - responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) + responseBytes, err := networkCodec.Marshal(message.Version, blockResponse) if err != nil { t.Fatalf("failed to marshal block response: %s", err) } @@ -415,10 +426,10 @@ func TestGetLeafs(t *testing.T) { largeTrieRoot, largeTrieKeys, _ := syncutils.GenerateTrie(t, trieDB, 100_000, common.HashLength) smallTrieRoot, _, _ := syncutils.GenerateTrie(t, trieDB, leafsLimit, common.HashLength) - handler := handlers.NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()) + handler := handlers.NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, nil, networkCodec, handlerstats.NewNoopHandlerStats()) client := NewClient(&ClientConfig{ NetworkClient: &mockNetwork{}, - Codec: message.Codec, + Codec: networkCodec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, BlockParser: mockBlockParser, @@ -594,13 +605,13 @@ func TestGetLeafs(t *testing.T) { t.Fatal("Failed to create valid response") } var leafResponse message.LeafsResponse - if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { + if _, err := networkCodec.Unmarshal(response, &leafResponse); err != nil { t.Fatal(err) } leafResponse.Keys = leafResponse.Keys[1:] leafResponse.Vals = leafResponse.Vals[1:] - modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) + modifiedResponse, err := networkCodec.Marshal(message.Version, leafResponse) if err != nil { t.Fatal(err) } @@ -625,7 +636,7 @@ func TestGetLeafs(t *testing.T) { t.Fatal("Failed to create valid response") } var leafResponse message.LeafsResponse - if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { + if _, err := networkCodec.Unmarshal(response, &leafResponse); err != nil { t.Fatal(err) } modifiedRequest := request @@ -655,13 +666,13 @@ func TestGetLeafs(t *testing.T) { t.Fatal("Failed to create valid response") } var leafResponse message.LeafsResponse - if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { + if _, err := networkCodec.Unmarshal(response, &leafResponse); err != nil { t.Fatal(err) } leafResponse.Keys = leafResponse.Keys[:len(leafResponse.Keys)-2] leafResponse.Vals = leafResponse.Vals[:len(leafResponse.Vals)-2] - modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) + modifiedResponse, err := networkCodec.Marshal(message.Version, leafResponse) if err != nil { t.Fatal(err) } @@ -686,14 +697,14 @@ func TestGetLeafs(t *testing.T) { t.Fatal("Failed to create valid response") } var leafResponse message.LeafsResponse - if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { + if _, err := networkCodec.Unmarshal(response, &leafResponse); err != nil { t.Fatal(err) } // Remove middle key-value pair response leafResponse.Keys = append(leafResponse.Keys[:100], leafResponse.Keys[101:]...) leafResponse.Vals = append(leafResponse.Vals[:100], leafResponse.Vals[101:]...) - modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) + modifiedResponse, err := networkCodec.Marshal(message.Version, leafResponse) if err != nil { t.Fatal(err) } @@ -718,13 +729,13 @@ func TestGetLeafs(t *testing.T) { t.Fatal("Failed to create valid response") } var leafResponse message.LeafsResponse - if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { + if _, err := networkCodec.Unmarshal(response, &leafResponse); err != nil { t.Fatal(err) } // Remove middle key-value pair response leafResponse.Vals[100] = []byte("garbage value data") - modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) + modifiedResponse, err := networkCodec.Marshal(message.Version, leafResponse) if err != nil { t.Fatal(err) } @@ -750,13 +761,13 @@ func TestGetLeafs(t *testing.T) { } var leafResponse message.LeafsResponse - if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { + if _, err := networkCodec.Unmarshal(response, &leafResponse); err != nil { t.Fatal(err) } // Remove the proof leafResponse.ProofVals = nil - modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) + modifiedResponse, err := networkCodec.Marshal(message.Version, leafResponse) if err != nil { t.Fatal(err) } @@ -797,13 +808,13 @@ func TestGetLeafsRetries(t *testing.T) { trieDB := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil) root, _, _ := syncutils.GenerateTrie(t, trieDB, 100_000, common.HashLength) - handler := handlers.NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()) + handler := handlers.NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, nil, networkCodec, handlerstats.NewNoopHandlerStats()) mockNetClient := &mockNetwork{} const maxAttempts = 8 client := NewClient(&ClientConfig{ NetworkClient: mockNetClient, - Codec: message.Codec, + Codec: networkCodec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, BlockParser: mockBlockParser, @@ -865,7 +876,7 @@ func TestStateSyncNodes(t *testing.T) { } client := NewClient(&ClientConfig{ NetworkClient: mockNetClient, - Codec: message.Codec, + Codec: networkCodec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: stateSyncNodes, BlockParser: mockBlockParser, diff --git a/sync/handlers/block_request_test.go b/sync/handlers/block_request_test.go index 7b0124d8f9..aed723a4a7 100644 --- a/sync/handlers/block_request_test.go +++ b/sync/handlers/block_request_test.go @@ -55,7 +55,7 @@ func executeBlockRequestTest(t testing.TB, test blockRequestTest, blocks []*type return blk }, } - blockRequestHandler := NewBlockRequestHandler(blockProvider, message.Codec, mockHandlerStats) + blockRequestHandler := NewBlockRequestHandler(blockProvider, networkCodec, mockHandlerStats) var blockRequest message.BlockRequest if test.startBlockHash != (common.Hash{}) { @@ -84,7 +84,7 @@ func executeBlockRequestTest(t testing.TB, test blockRequestTest, blocks []*type assert.NotEmpty(t, responseBytes) var response message.BlockResponse - if _, err = message.Codec.Unmarshal(responseBytes, &response); err != nil { + if _, err = networkCodec.Unmarshal(responseBytes, &response); err != nil { t.Fatal("error unmarshalling", err) } assert.Len(t, response.Blocks, test.expectedBlocks) @@ -102,7 +102,7 @@ func executeBlockRequestTest(t testing.TB, test blockRequestTest, blocks []*type } func TestBlockRequestHandler(t *testing.T) { - var gspec = &core.Genesis{ + gspec := &core.Genesis{ Config: params.TestChainConfig, } memdb := rawdb.NewMemoryDatabase() @@ -214,7 +214,7 @@ func TestBlockRequestHandlerLargeBlocks(t *testing.T) { } func TestBlockRequestHandlerCtxExpires(t *testing.T) { - var gspec = &core.Genesis{ + gspec := &core.Genesis{ Config: params.TestChainConfig, } memdb := rawdb.NewMemoryDatabase() @@ -252,7 +252,7 @@ func TestBlockRequestHandlerCtxExpires(t *testing.T) { return blk }, } - blockRequestHandler := NewBlockRequestHandler(blockProvider, message.Codec, stats.NewNoopHandlerStats()) + blockRequestHandler := NewBlockRequestHandler(blockProvider, networkCodec, stats.NewNoopHandlerStats()) responseBytes, err := blockRequestHandler.OnBlockRequest(ctx, ids.GenerateTestNodeID(), 1, message.BlockRequest{ Hash: blocks[10].Hash(), @@ -265,7 +265,7 @@ func TestBlockRequestHandlerCtxExpires(t *testing.T) { assert.NotEmpty(t, responseBytes) var response message.BlockResponse - if _, err = message.Codec.Unmarshal(responseBytes, &response); err != nil { + if _, err = networkCodec.Unmarshal(responseBytes, &response); err != nil { t.Fatal("error unmarshalling", err) } // requested 8 blocks, received cancelAfterNumRequests because of timeout diff --git a/sync/handlers/code_request_test.go b/sync/handlers/code_request_test.go index 1bf5bd5223..797c191311 100644 --- a/sync/handlers/code_request_test.go +++ b/sync/handlers/code_request_test.go @@ -35,7 +35,7 @@ func TestCodeRequestHandler(t *testing.T) { rawdb.WriteCode(database, maxSizeCodeHash, maxSizeCodeBytes) mockHandlerStats := &stats.MockHandlerStats{} - codeRequestHandler := NewCodeRequestHandler(database, message.Codec, mockHandlerStats) + codeRequestHandler := NewCodeRequestHandler(database, networkCodec, mockHandlerStats) tests := map[string]struct { setup func() (request message.CodeRequest, expectedCodeResponse [][]byte) @@ -100,7 +100,7 @@ func TestCodeRequestHandler(t *testing.T) { return } var response message.CodeResponse - if _, err = message.Codec.Unmarshal(responseBytes, &response); err != nil { + if _, err = networkCodec.Unmarshal(responseBytes, &response); err != nil { t.Fatal("error unmarshalling CodeResponse", err) } if len(expectedResponse) != len(response.Data) { diff --git a/sync/handlers/leafs_request_test.go b/sync/handlers/leafs_request_test.go index 296e87371f..eb57362e01 100644 --- a/sync/handlers/leafs_request_test.go +++ b/sync/handlers/leafs_request_test.go @@ -9,6 +9,7 @@ import ( "math/rand" "testing" + "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/coreth/core/rawdb" "github.com/ava-labs/coreth/core/state/snapshot" @@ -24,6 +25,16 @@ import ( "github.com/stretchr/testify/assert" ) +var networkCodec codec.Manager + +func init() { + var err error + networkCodec, err = message.NewCodec(message.BlockSyncSummary{}) + if err != nil { + panic(err) + } +} + func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { rand.Seed(1) mockHandlerStats := &stats.MockHandlerStats{} @@ -74,7 +85,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { } } snapshotProvider := &TestSnapshotProvider{} - leafsHandler := NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, snapshotProvider, message.Codec, mockHandlerStats) + leafsHandler := NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, snapshotProvider, networkCodec, mockHandlerStats) snapConfig := snapshot.Config{ CacheSize: 64, AsyncBuild: false, @@ -228,7 +239,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.Len(t, leafsResponse.Keys, 500) assert.Len(t, leafsResponse.Vals, 500) @@ -248,7 +259,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.Len(t, leafsResponse.Keys, 500) assert.Len(t, leafsResponse.Vals, 500) @@ -302,7 +313,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, len(leafsResponse.Keys), maxLeavesLimit) assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit) @@ -323,7 +334,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, len(leafsResponse.Keys), maxLeavesLimit) assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit) @@ -345,7 +356,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, len(leafsResponse.Keys), maxLeavesLimit) assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit) @@ -370,7 +381,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, 40, len(leafsResponse.Keys)) assert.EqualValues(t, 40, len(leafsResponse.Vals)) @@ -392,7 +403,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, 600, len(leafsResponse.Keys)) assert.EqualValues(t, 600, len(leafsResponse.Vals)) @@ -414,7 +425,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, len(leafsResponse.Keys), 0) assert.EqualValues(t, len(leafsResponse.Vals), 0) @@ -437,7 +448,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assert.NotEmpty(t, response) var leafsResponse message.LeafsResponse - if _, err = message.Codec.Unmarshal(response, &leafsResponse); err != nil { + if _, err = networkCodec.Unmarshal(response, &leafsResponse); err != nil { t.Fatalf("unexpected error when unmarshalling LeafsResponse: %v", err) } @@ -465,7 +476,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys)) assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals)) @@ -513,7 +524,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys)) assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals)) @@ -546,7 +557,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys)) assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals)) @@ -592,7 +603,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys)) assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals)) @@ -633,7 +644,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, 500, len(leafsResponse.Keys)) assert.EqualValues(t, 500, len(leafsResponse.Vals)) @@ -670,7 +681,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { assert.NoError(t, err) var leafsResponse message.LeafsResponse - _, err = message.Codec.Unmarshal(response, &leafsResponse) + _, err = networkCodec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) assert.EqualValues(t, 1, len(leafsResponse.Keys)) assert.EqualValues(t, 1, len(leafsResponse.Vals)) diff --git a/sync/statesync/code_syncer_test.go b/sync/statesync/code_syncer_test.go index 574290e286..1fd20ba643 100644 --- a/sync/statesync/code_syncer_test.go +++ b/sync/statesync/code_syncer_test.go @@ -10,7 +10,6 @@ import ( "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/coreth/core/rawdb" - "github.com/ava-labs/coreth/plugin/evm/message" statesyncclient "github.com/ava-labs/coreth/sync/client" "github.com/ava-labs/coreth/sync/handlers" handlerstats "github.com/ava-labs/coreth/sync/handlers/stats" @@ -40,8 +39,8 @@ func testCodeSyncer(t *testing.T, test codeSyncerTest) { } // Set up mockClient - codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, message.Codec, handlerstats.NewNoopHandlerStats()) - mockClient := statesyncclient.NewMockClient(message.Codec, nil, codeRequestHandler, nil) + codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, networkCodec, handlerstats.NewNoopHandlerStats()) + mockClient := statesyncclient.NewMockClient(networkCodec, nil, codeRequestHandler, nil) mockClient.GetCodeIntercept = test.getCodeIntercept clientDB := rawdb.NewMemoryDatabase() diff --git a/sync/statesync/sync_test.go b/sync/statesync/sync_test.go index c64d36faa5..99c5d17b0e 100644 --- a/sync/statesync/sync_test.go +++ b/sync/statesync/sync_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/coreth/core/rawdb" "github.com/ava-labs/coreth/core/state/snapshot" "github.com/ava-labs/coreth/core/types" @@ -34,6 +35,16 @@ const testSyncTimeout = 30 * time.Second var errInterrupted = errors.New("interrupted sync") +var networkCodec codec.Manager + +func init() { + var err error + networkCodec, err = message.NewCodec(message.BlockSyncSummary{}) + if err != nil { + panic(err) + } +} + type syncTest struct { ctx context.Context prepareForTest func(t *testing.T) (clientDB ethdb.Database, serverDB ethdb.Database, serverTrieDB *triedb.Database, syncRoot common.Hash) @@ -49,9 +60,9 @@ func testSync(t *testing.T, test syncTest) { ctx = test.ctx } clientDB, serverDB, serverTrieDB, root := test.prepareForTest(t) - leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, message.StateTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()) - codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, message.Codec, handlerstats.NewNoopHandlerStats()) - mockClient := statesyncclient.NewMockClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) + leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, message.StateTrieKeyLength, nil, networkCodec, handlerstats.NewNoopHandlerStats()) + codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, networkCodec, handlerstats.NewNoopHandlerStats()) + mockClient := statesyncclient.NewMockClient(networkCodec, leafsRequestHandler, codeRequestHandler, nil) // Set intercept functions for the mock client mockClient.GetLeafsIntercept = test.GetLeafsIntercept mockClient.GetCodeIntercept = test.GetCodeIntercept diff --git a/warp/aggregator/signature_getter.go b/warp/aggregator/signature_getter.go index 8bdb60fea1..d5e36ac972 100644 --- a/warp/aggregator/signature_getter.go +++ b/warp/aggregator/signature_getter.go @@ -8,6 +8,7 @@ import ( "fmt" "time" + "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/crypto/bls" avalancheWarp "github.com/ava-labs/avalanchego/vms/platformvm/warp" @@ -36,12 +37,14 @@ type NetworkClient interface { // NetworkSignatureGetter fetches warp signatures on behalf of the // aggregator using VM App-Specific Messaging type NetworkSignatureGetter struct { - Client NetworkClient + Client NetworkClient + networkCodec codec.Manager } -func NewSignatureGetter(client NetworkClient) *NetworkSignatureGetter { +func NewSignatureGetter(client NetworkClient, networkCodec codec.Manager) *NetworkSignatureGetter { return &NetworkSignatureGetter{ - Client: client, + Client: client, + networkCodec: networkCodec, } } @@ -60,7 +63,7 @@ func (s *NetworkSignatureGetter) GetSignature(ctx context.Context, nodeID ids.No signatureReq := message.MessageSignatureRequest{ MessageID: unsignedWarpMessage.ID(), } - signatureReqBytes, err = message.RequestToBytes(message.Codec, signatureReq) + signatureReqBytes, err = message.RequestToBytes(s.networkCodec, signatureReq) if err != nil { return nil, fmt.Errorf("failed to marshal signature request: %w", err) } @@ -68,7 +71,7 @@ func (s *NetworkSignatureGetter) GetSignature(ctx context.Context, nodeID ids.No signatureReq := message.BlockSignatureRequest{ BlockID: p.Hash, } - signatureReqBytes, err = message.RequestToBytes(message.Codec, signatureReq) + signatureReqBytes, err = message.RequestToBytes(s.networkCodec, signatureReq) if err != nil { return nil, fmt.Errorf("failed to marshal signature request: %w", err) } @@ -102,7 +105,7 @@ func (s *NetworkSignatureGetter) GetSignature(ctx context.Context, nodeID ids.No continue } var response message.SignatureResponse - if _, err := message.Codec.Unmarshal(signatureRes, &response); err != nil { + if _, err := s.networkCodec.Unmarshal(signatureRes, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal signature res: %w", err) } if response.Signature == [bls.SignatureLen]byte{} { diff --git a/warp/handlers/signature_request_test.go b/warp/handlers/signature_request_test.go index 77a4af087e..528bf87bfd 100644 --- a/warp/handlers/signature_request_test.go +++ b/warp/handlers/signature_request_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/ava-labs/avalanchego/cache" + "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/crypto/bls" @@ -20,6 +21,16 @@ import ( "github.com/stretchr/testify/require" ) +var networkCodec codec.Manager + +func init() { + var err error + networkCodec, err = message.NewCodec(message.BlockSyncSummary{}) + if err != nil { + panic(err) + } +} + func TestMessageSignatureHandler(t *testing.T) { database := memdb.New() snowCtx := utils.TestSnowContext() @@ -102,7 +113,7 @@ func TestMessageSignatureHandler(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - handler := NewSignatureRequestHandler(backend, message.Codec) + handler := NewSignatureRequestHandler(backend, networkCodec) request, expectedResponse := test.setup() responseBytes, err := handler.OnMessageSignatureRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) @@ -116,7 +127,7 @@ func TestMessageSignatureHandler(t *testing.T) { return } var response message.SignatureResponse - _, err = message.Codec.Unmarshal(responseBytes, &response) + _, err = networkCodec.Unmarshal(responseBytes, &response) require.NoError(t, err, "error unmarshalling SignatureResponse") require.Equal(t, expectedResponse, response.Signature[:]) @@ -189,7 +200,7 @@ func TestBlockSignatureHandler(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - handler := NewSignatureRequestHandler(backend, message.Codec) + handler := NewSignatureRequestHandler(backend, networkCodec) request, expectedResponse := test.setup() responseBytes, err := handler.OnBlockSignatureRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) @@ -203,7 +214,7 @@ func TestBlockSignatureHandler(t *testing.T) { return } var response message.SignatureResponse - _, err = message.Codec.Unmarshal(responseBytes, &response) + _, err = networkCodec.Unmarshal(responseBytes, &response) require.NoError(t, err, "error unmarshalling SignatureResponse") require.Equal(t, expectedResponse, response.Signature[:]) diff --git a/warp/service.go b/warp/service.go index 610fc85a91..d160fe6812 100644 --- a/warp/service.go +++ b/warp/service.go @@ -8,8 +8,9 @@ import ( "errors" "fmt" + "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/vms/platformvm/warp" "github.com/ava-labs/avalanchego/vms/platformvm/warp/payload" "github.com/ava-labs/coreth/peer" @@ -23,22 +24,18 @@ var errNoValidators = errors.New("cannot aggregate signatures from subnet with n // API introduces snowman specific functionality to the evm type API struct { - networkID uint32 - sourceSubnetID, sourceChainID ids.ID - backend Backend - state validators.State - client peer.NetworkClient - requirePrimaryNetworkSigners func() bool + chainContext *snow.Context + backend Backend + signatureGetter aggregator.SignatureGetter + requirePrimaryNetworkSigners func() bool } -func NewAPI(networkID uint32, sourceSubnetID ids.ID, sourceChainID ids.ID, state validators.State, backend Backend, client peer.NetworkClient, requirePrimaryNetworkSigners func() bool) *API { +func NewAPI(chainCtx *snow.Context, networkCodec codec.Manager, backend Backend, client peer.NetworkClient, requirePrimaryNetworkSigners func() bool) *API { + signatureGetter := aggregator.NewSignatureGetter(client, networkCodec) return &API{ - networkID: networkID, - sourceSubnetID: sourceSubnetID, - sourceChainID: sourceChainID, backend: backend, - state: state, - client: client, + chainContext: chainCtx, + signatureGetter: signatureGetter, requirePrimaryNetworkSigners: requirePrimaryNetworkSigners, } } @@ -89,7 +86,7 @@ func (a *API) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, qu if err != nil { return nil, err } - unsignedMessage, err := warp.NewUnsignedMessage(a.networkID, a.sourceChainID, blockHashPayload.Bytes()) + unsignedMessage, err := warp.NewUnsignedMessage(a.chainContext.NetworkID, a.chainContext.ChainID, blockHashPayload.Bytes()) if err != nil { return nil, err } @@ -98,7 +95,7 @@ func (a *API) GetBlockAggregateSignature(ctx context.Context, blockID ids.ID, qu } func (a *API) aggregateSignatures(ctx context.Context, unsignedMessage *warp.UnsignedMessage, quorumNum uint64, subnetIDStr string) (hexutil.Bytes, error) { - subnetID := a.sourceSubnetID + subnetID := a.chainContext.SubnetID if len(subnetIDStr) > 0 { sid, err := ids.FromString(subnetIDStr) if err != nil { @@ -106,12 +103,13 @@ func (a *API) aggregateSignatures(ctx context.Context, unsignedMessage *warp.Uns } subnetID = sid } - pChainHeight, err := a.state.GetCurrentHeight(ctx) + validatorState := a.chainContext.ValidatorState + pChainHeight, err := validatorState.GetCurrentHeight(ctx) if err != nil { return nil, err } - state := warpValidators.NewState(a.state, a.sourceSubnetID, a.sourceChainID, a.requirePrimaryNetworkSigners()) + state := warpValidators.NewState(validatorState, subnetID, a.chainContext.ChainID, a.requirePrimaryNetworkSigners()) validators, totalWeight, err := warp.GetCanonicalValidatorSet(ctx, state, pChainHeight, subnetID) if err != nil { return nil, fmt.Errorf("failed to get validator set: %w", err) @@ -127,7 +125,7 @@ func (a *API) aggregateSignatures(ctx context.Context, unsignedMessage *warp.Uns "totalWeight", totalWeight, ) - agg := aggregator.New(aggregator.NewSignatureGetter(a.client), validators, totalWeight) + agg := aggregator.New(a.signatureGetter, validators, totalWeight) signatureResult, err := agg.AggregateSignatures(ctx, unsignedMessage, quorumNum) if err != nil { return nil, err