Skip to content

Commit

Permalink
Node/EVM: Verify EVM chain ID
Browse files Browse the repository at this point in the history
  • Loading branch information
bruce-riley committed Jan 28, 2025
1 parent c888252 commit 4f0037f
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 2 deletions.
51 changes: 49 additions & 2 deletions node/pkg/watchers/evm/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"math"
"math/big"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand All @@ -28,6 +30,7 @@ import (
"github.com/certusone/wormhole/node/pkg/query"
"github.com/certusone/wormhole/node/pkg/readiness"
"github.com/certusone/wormhole/node/pkg/supervisor"
"github.com/wormhole-foundation/wormhole/sdk"
"github.com/wormhole-foundation/wormhole/sdk/vaa"
)

Expand Down Expand Up @@ -211,14 +214,18 @@ func (w *Watcher) Run(parentCtx context.Context) error {
ContractAddress: w.contract.Hex(),
})

timeout, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
if err := w.verifyEvmChainID(ctx, logger); err != nil {
return fmt.Errorf("failed to verify evm chain id: %w", err)
}

finalizedPollingSupported, safePollingSupported, err := w.getFinality(ctx)
if err != nil {
return fmt.Errorf("failed to determine finality: %w", err)
}

timeout, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()

if finalizedPollingSupported {
if safePollingSupported {
logger.Info("polling for finalized and safe blocks")
Expand Down Expand Up @@ -794,6 +801,46 @@ func (w *Watcher) getFinality(ctx context.Context) (bool, bool, error) {
return finalized, safe, nil
}

// verifyEvmChainID reads the EVM chain ID from the node and verifies that it matches the expected value (making sure we aren't connected to the wrong chain).
func (w *Watcher) verifyEvmChainID(ctx context.Context, logger *zap.Logger) error {
timeout, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()

c, err := rpc.DialContext(timeout, w.url)
if err != nil {
return fmt.Errorf("failed to connect to endpoint: %w", err)
}

var str string
err = c.CallContext(ctx, &str, "eth_chainId")
if err != nil {
return fmt.Errorf("failed to read evm chain id: %w", err)
}

evmChainID, err := strconv.ParseUint(strings.TrimPrefix(str, "0x"), 16, 64)
if err != nil {
return fmt.Errorf(`eth_chainId returned an invalid int: "%s"`, str)
}

logger.Info("queried evm chain id", zap.Uint64("evmChainID", evmChainID))

if w.unsafeDevMode {
// In devnet we log the result but don't enforce it.
return nil
}

expectedEvmChainID, err := sdk.GetEvmChainID(string(w.env), w.chainID)
if err != nil {
return fmt.Errorf("failed to look up evm chain id: %w", err)
}

if evmChainID != uint64(expectedEvmChainID) {
return fmt.Errorf("evm chain ID miss match, expected %d, received %d", expectedEvmChainID, evmChainID)
}

return nil
}

// SetL1Finalizer is used to set the layer one finalizer.
func (w *Watcher) SetL1Finalizer(l1Finalizer interfaces.L1Finalizer) {
w.l1Finalizer = l1Finalizer
Expand Down
45 changes: 45 additions & 0 deletions sdk/evm_chain_ids.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package sdk

import (
"errors"
"strings"

"github.com/wormhole-foundation/wormhole/sdk/vaa"
)

var ErrInvalidEnv = errors.New("invalid environment")
var ErrNotFound = errors.New("not found")

// IsEvmChainID if the specified chain is defined as an EVM chain ID in the specified environment.
func IsEvmChainID(env string, chainID vaa.ChainID) (bool, error) {
var m *map[vaa.ChainID]int
if env == "prod" || env == "mainnet" {
m = &MainnetEvmChainIDs
} else if env == "test" || env == "testnet" {
m = &TestnetEvmChainIDs
} else {
return false, ErrInvalidEnv
}
_, exists := (*m)[chainID]
return exists, nil
}

// GetEvmChainID returns the expected EVM chain ID associated with the given Wormhole chain ID and environment passed it.
func GetEvmChainID(env string, chainID vaa.ChainID) (int, error) {
env = strings.ToLower(env)
if env == "prod" || env == "mainnet" {
return getEvmChainID(MainnetEvmChainIDs, chainID)
}
if env == "test" || env == "testnet" {
return getEvmChainID(TestnetEvmChainIDs, chainID)
}
return 0, ErrInvalidEnv
}

func getEvmChainID(evmChains map[vaa.ChainID]int, chainID vaa.ChainID) (int, error) {
id, exists := evmChains[chainID]
if !exists {
return 0, ErrNotFound
}
return id, nil
}
74 changes: 74 additions & 0 deletions sdk/evm_chain_ids_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package sdk

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wormhole-foundation/wormhole/sdk/vaa"
)

func TestGetEvmChainID(t *testing.T) {
type test struct {
env string
input vaa.ChainID
output int
err error
}

// Note: Don't intend to list every chain here, just enough to verify `GetEvmChainID`.
tests := []test{
{env: "mainnet", input: vaa.ChainIDUnset, output: 0, err: ErrNotFound},
{env: "mainnet", input: vaa.ChainIDSepolia, output: 0, err: ErrNotFound},
{env: "mainnet", input: vaa.ChainIDEthereum, output: 1},
{env: "mainnet", input: vaa.ChainIDArbitrum, output: 42161},
{env: "testnet", input: vaa.ChainIDSepolia, output: 11155111},
{env: "testnet", input: vaa.ChainIDEthereum, output: 17000},
{env: "junk", input: vaa.ChainIDEthereum, output: 17000, err: ErrInvalidEnv},
}

for _, tc := range tests {
t.Run(tc.env+"-"+tc.input.String(), func(t *testing.T) {
evmChainID, err := GetEvmChainID(tc.env, tc.input)
if tc.err != nil {
assert.ErrorIs(t, tc.err, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.output, evmChainID)
}
})
}
}
func TestIsEvmChainID(t *testing.T) {
type test struct {
env string
input vaa.ChainID
output bool
err error
}

// Note: Don't intend to list every chain here, just enough to verify `GetEvmChainID`.
tests := []test{
{env: "mainnet", input: vaa.ChainIDUnset, output: false},
{env: "mainnet", input: vaa.ChainIDSepolia, output: false},
{env: "mainnet", input: vaa.ChainIDEthereum, output: true},
{env: "mainnet", input: vaa.ChainIDArbitrum, output: true},
{env: "mainnet", input: vaa.ChainIDSolana, output: false},
{env: "testnet", input: vaa.ChainIDSepolia, output: true},
{env: "testnet", input: vaa.ChainIDEthereum, output: true},
{env: "testnet", input: vaa.ChainIDTerra, output: false},
{env: "junk", input: vaa.ChainIDEthereum, output: true, err: ErrInvalidEnv},
}

for _, tc := range tests {
t.Run(tc.env+"-"+tc.input.String(), func(t *testing.T) {
result, err := IsEvmChainID(tc.env, tc.input)
if tc.err != nil {
assert.ErrorIs(t, tc.err, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.output, result)
}
})
}
}

0 comments on commit 4f0037f

Please sign in to comment.