Skip to content

Commit

Permalink
feat: add migration checker cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Feb 12, 2025
1 parent b56dd9f commit 310a5f8
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 0 deletions.
236 changes: 236 additions & 0 deletions cmd/migration-checker/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
package main

import (
"bytes"
"encoding/hex"
"flag"
"fmt"
"os"
"runtime"
"sync"
"sync/atomic"

"github.com/scroll-tech/go-ethereum/common"
"github.com/scroll-tech/go-ethereum/core/types"
"github.com/scroll-tech/go-ethereum/crypto"
"github.com/scroll-tech/go-ethereum/ethdb/leveldb"
"github.com/scroll-tech/go-ethereum/rlp"
"github.com/scroll-tech/go-ethereum/trie"
)

var accountsDone atomic.Uint64
var trieCheckers = make(chan struct{}, runtime.GOMAXPROCS(0)*4)

type dbs struct {
zkDb *leveldb.Database
mptDb *leveldb.Database
}

func main() {
var (
mptDbPath = flag.String("mpt-db", "", "path to the MPT node DB")
zkDbPath = flag.String("zk-db", "", "path to the ZK node DB")
mptRoot = flag.String("mpt-root", "", "root hash of the MPT node")
zkRoot = flag.String("zk-root", "", "root hash of the ZK node")
)
flag.Parse()

zkDb, err := leveldb.New(*zkDbPath, 1024, 128, "", true)
panicOnError(err, "", "failed to open zk db")
mptDb, err := leveldb.New(*mptDbPath, 1024, 128, "", true)
panicOnError(err, "", "failed to open mpt db")

zkRootHash := common.HexToHash(*zkRoot)
mptRootHash := common.HexToHash(*mptRoot)

for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ {
trieCheckers <- struct{}{}
}

checkTrieEquality(&dbs{
zkDb: zkDb,
mptDb: mptDb,
}, zkRootHash, mptRootHash, "", checkAccountEquality, true)
}

func panicOnError(err error, label, msg string) {
if err != nil {
panic(fmt.Sprint(label, " error: ", msg, " ", err))
}
}

func dup(s []byte) []byte {
return append([]byte{}, s...)
}
func checkTrieEquality(dbs *dbs, zkRoot, mptRoot common.Hash, label string, leafChecker func(string, *dbs, []byte, []byte), top bool) {
zkTrie, err := trie.NewZkTrie(zkRoot, trie.NewZktrieDatabaseFromTriedb(trie.NewDatabaseWithConfig(dbs.zkDb, &trie.Config{Preimages: true})))
panicOnError(err, label, "failed to create zk trie")
mptTrie, err := trie.NewSecureNoTracer(mptRoot, trie.NewDatabaseWithConfig(dbs.mptDb, &trie.Config{Preimages: true}))
panicOnError(err, label, "failed to create mpt trie")

mptLeafCh := loadMPT(mptTrie, top)
zkLeafCh := loadZkTrie(zkTrie, top)

mptLeafMap := <-mptLeafCh
zkLeafMap := <-zkLeafCh

if len(mptLeafMap) != len(zkLeafMap) {
panic(fmt.Sprintf("%s MPT and ZK trie leaf count mismatch: MPT: %d, ZK: %d", label, len(mptLeafMap), len(zkLeafMap)))
}

for preimageKey, zkValue := range zkLeafMap {
if top {
// ZkTrie pads preimages with 0s to make them 32 bytes.
// So we might need to clear those zeroes here since we need 20 byte addresses at top level (ie state trie)
if len(preimageKey) > 20 {
for b := range preimageKey[20:] {
if b != 0 {
panic(fmt.Sprintf("%s padded byte is not 0 (preimage %s)", label, hex.EncodeToString([]byte(preimageKey))))
}
}
preimageKey = preimageKey[:20]
}
}

mptKey := crypto.Keccak256([]byte(preimageKey))
mptVal, ok := mptLeafMap[string(mptKey)]
if !ok {
panic(fmt.Sprintf("%s key %s (preimage %s) not found in mpt", label, hex.EncodeToString([]byte(mptKey)), hex.EncodeToString([]byte(preimageKey))))
}

leafChecker(fmt.Sprintf("%s key: %s", label, hex.EncodeToString([]byte(preimageKey))), dbs, zkValue, mptVal)
}
}

func checkAccountEquality(label string, dbs *dbs, zkAccountBytes, mptAccountBytes []byte) {
mptAccount := &types.StateAccount{}
panicOnError(rlp.DecodeBytes(mptAccountBytes, mptAccount), label, "failed to decode mpt account")
zkAccount, err := types.UnmarshalStateAccount(zkAccountBytes)
panicOnError(err, label, "failed to decode zk account")

if mptAccount.Nonce != zkAccount.Nonce {
panic(fmt.Sprintf("%s nonce mismatch: zk: %d, mpt: %d", label, zkAccount.Nonce, mptAccount.Nonce))
}

if mptAccount.Balance.Cmp(zkAccount.Balance) != 0 {
panic(fmt.Sprintf("%s balance mismatch: zk: %s, mpt: %s", label, zkAccount.Balance.String(), mptAccount.Balance.String()))
}

if !bytes.Equal(mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash) {
panic(fmt.Sprintf("%s code hash mismatch: zk: %s, mpt: %s", label, hex.EncodeToString(zkAccount.KeccakCodeHash), hex.EncodeToString(mptAccount.KeccakCodeHash)))
}

if (zkAccount.Root == common.Hash{}) != (mptAccount.Root == types.EmptyRootHash) {
panic(fmt.Sprintf("%s empty account root mismatch", label))
} else if zkAccount.Root != (common.Hash{}) {
zkRoot := common.BytesToHash(zkAccount.Root[:])
mptRoot := common.BytesToHash(mptAccount.Root[:])
<-trieCheckers
go func() {
defer func() {
if p := recover(); p != nil {
fmt.Println(p)
os.Exit(1)
}
}()

checkTrieEquality(dbs, zkRoot, mptRoot, label, checkStorageEquality, false)
accountsDone.Add(1)
fmt.Println("Accounts done:", accountsDone.Load())
trieCheckers <- struct{}{}
}()
} else {
accountsDone.Add(1)
fmt.Println("Accounts done:", accountsDone.Load())
}
}

func checkStorageEquality(label string, _ *dbs, zkStorageBytes, mptStorageBytes []byte) {
zkValue := common.BytesToHash(zkStorageBytes)
_, content, _, err := rlp.Split(mptStorageBytes)
panicOnError(err, label, "failed to decode mpt storage")
mptValue := common.BytesToHash(content)
if !bytes.Equal(zkValue[:], mptValue[:]) {
panic(fmt.Sprintf("%s storage mismatch: zk: %s, mpt: %s", label, zkValue.Hex(), mptValue.Hex()))
}
}

func loadMPT(mptTrie *trie.SecureTrie, parallel bool) chan map[string][]byte {
startKey := make([]byte, 32)
workers := 1 << 5
if !parallel {
workers = 1
}
step := byte(0xFF) / byte(workers)

mptLeafMap := make(map[string][]byte, 1000)
var mptLeafMutex sync.Mutex

var mptWg sync.WaitGroup
for i := 0; i < workers; i++ {
startKey[0] = byte(i) * step
trieIt := trie.NewIterator(mptTrie.NodeIterator(startKey))

mptWg.Add(1)
go func() {
defer mptWg.Done()
for trieIt.Next() {
if parallel {
mptLeafMutex.Lock()
}

if _, ok := mptLeafMap[string(trieIt.Key)]; ok {
mptLeafMutex.Unlock()
break
}

mptLeafMap[string(dup(trieIt.Key))] = dup(trieIt.Value)

if parallel {
mptLeafMutex.Unlock()
}

if parallel && len(mptLeafMap)%10000 == 0 {
fmt.Println("MPT Accounts Loaded:", len(mptLeafMap))
}
}
}()
}

respChan := make(chan map[string][]byte)
go func() {
mptWg.Wait()
respChan <- mptLeafMap
}()
return respChan
}

func loadZkTrie(zkTrie *trie.ZkTrie, parallel bool) chan map[string][]byte {
zkLeafMap := make(map[string][]byte, 1000)
var zkLeafMutex sync.Mutex
zkDone := make(chan map[string][]byte)
go func() {
zkTrie.CountLeaves(func(key, value []byte) {
preimageKey := zkTrie.GetKey(key)
if len(preimageKey) == 0 {
panic(fmt.Sprintf("preimage not found zk trie %s", hex.EncodeToString(key)))
}

if parallel {
zkLeafMutex.Lock()
}

zkLeafMap[string(dup(preimageKey))] = value

if parallel {
zkLeafMutex.Unlock()
}

if parallel && len(zkLeafMap)%10000 == 0 {
fmt.Println("ZK Accounts Loaded:", len(zkLeafMap))
}
}, parallel)
zkDone <- zkLeafMap
}()
return zkDone
}
10 changes: 10 additions & 0 deletions trie/secure_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) {
return &SecureTrie{trie: *trie, preimages: db.preimages}, nil
}

func NewSecureNoTracer(root common.Hash, db *Database) (*SecureTrie, error) {
t, err := NewSecure(root, db)
if err != nil {
return nil, err
}

t.trie.tracer = nil
return t, nil
}

// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *SecureTrie) Get(key []byte) []byte {
Expand Down
24 changes: 24 additions & 0 deletions trie/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,21 @@ func newTracer() *tracer {
// blob internally. Don't change the value outside of function since
// it's not deep-copied.
func (t *tracer) onRead(path []byte, val []byte) {
if t == nil {
return
}

t.accessList[string(path)] = val
}

// onInsert tracks the newly inserted trie node. If it's already
// in the deletion set (resurrected node), then just wipe it from
// the deletion set as it's "untouched".
func (t *tracer) onInsert(path []byte) {
if t == nil {
return
}

if _, present := t.deletes[string(path)]; present {
delete(t.deletes, string(path))
return
Expand All @@ -78,6 +86,10 @@ func (t *tracer) onInsert(path []byte) {
// in the addition set, then just wipe it from the addition set
// as it's untouched.
func (t *tracer) onDelete(path []byte) {
if t == nil {
return
}

if _, present := t.inserts[string(path)]; present {
delete(t.inserts, string(path))
return
Expand All @@ -87,13 +99,21 @@ func (t *tracer) onDelete(path []byte) {

// reset clears the content tracked by tracer.
func (t *tracer) reset() {
if t == nil {
return
}

t.inserts = make(map[string]struct{})
t.deletes = make(map[string]struct{})
t.accessList = make(map[string][]byte)
}

// copy returns a deep copied tracer instance.
func (t *tracer) copy() *tracer {
if t == nil {
return nil
}

accessList := make(map[string][]byte, len(t.accessList))
for path, blob := range t.accessList {
accessList[path] = common.CopyBytes(blob)
Expand All @@ -107,6 +127,10 @@ func (t *tracer) copy() *tracer {

// deletedNodes returns a list of node paths which are deleted from the trie.
func (t *tracer) deletedNodes() []string {
if t == nil {
return nil
}

var paths []string
for path := range t.deletes {
// It's possible a few deleted nodes were embedded
Expand Down
39 changes: 39 additions & 0 deletions trie/zk_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,42 @@ func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueRead
func (t *ZkTrie) Witness() map[string]struct{} {
panic("not implemented")
}

func (t *ZkTrie) CountLeaves(cb func(key, value []byte), parallel bool) uint64 {
root, err := t.ZkTrie.Tree().Root()
if err != nil {
panic("CountLeaves cannot get root")
}
return t.countLeaves(root, cb, 0, parallel)
}

func (t *ZkTrie) countLeaves(root *zkt.Hash, cb func(key, value []byte), depth int, parallel bool) uint64 {
if root == nil {
return 0
}

rootNode, err := t.ZkTrie.Tree().GetNode(root)
if err != nil {
panic("countLeaves cannot get rootNode")
}

if rootNode.Type == zktrie.NodeTypeLeaf_New {
cb(append([]byte{}, rootNode.NodeKey.Bytes()...), append([]byte{}, rootNode.Data()...))
return 1
} else {
if parallel && depth < 5 {
count := make(chan uint64)
leftT := t.Copy()
rightT := t.Copy()
go func() {
count <- leftT.countLeaves(rootNode.ChildL, cb, depth+1, parallel)
}()
go func() {
count <- rightT.countLeaves(rootNode.ChildR, cb, depth+1, parallel)
}()
return <-count + <-count
} else {
return t.countLeaves(rootNode.ChildL, cb, depth+1, parallel) + t.countLeaves(rootNode.ChildR, cb, depth+1, parallel)
}
}
}

0 comments on commit 310a5f8

Please sign in to comment.