Skip to content

Commit

Permalink
fix!: UnmarshalJSON is limited to the default Tree bug (#277)
Browse files Browse the repository at this point in the history
<!--
Please read and fill out this form before submitting your PR.

Please make sure you have reviewed our contributors guide before
submitting your
first PR.
-->

## Overview

Closed: #275 

<!-- 
Please provide an explanation of the PR, including the appropriate
context,
background, goal, and rationale. If there is an issue with this
information,
please provide a tl;dr and link the issue. 
-->

## Checklist

<!-- 
Please complete the checklist to ensure that the PR is ready to be
reviewed.

IMPORTANT:
PRs should be left in Draft until the below checklist is completed.
-->

- [x] New and updated code has appropriate documentation
- [x] New and updated code has new and/or updated testing
- [x] Required CI checks are passing
- [ ] Visual proof for any user facing features like CLI or
documentation updates
- [x] Linked issues closed with keywords

---------

Co-authored-by: Rootul P <[email protected]>
Co-authored-by: Sanaz Taheri <[email protected]>
  • Loading branch information
3 people authored Jan 30, 2024
1 parent 7fcdfd1 commit bb5e119
Show file tree
Hide file tree
Showing 6 changed files with 410 additions and 30 deletions.
48 changes: 48 additions & 0 deletions default_tree.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package rsmt2d

import (
"crypto/sha256"
"fmt"

"github.com/celestiaorg/merkletree"
)

var DefaultTreeName = "default-tree"

func init() {
err := RegisterTree(DefaultTreeName, NewDefaultTree)
if err != nil {
panic(fmt.Sprintf("%s already registered", DefaultTreeName))
}
}

var _ Tree = &DefaultTree{}

type DefaultTree struct {
*merkletree.Tree
leaves [][]byte
root []byte
}

func NewDefaultTree(_ Axis, _ uint) Tree {
return &DefaultTree{
Tree: merkletree.New(sha256.New()),
leaves: make([][]byte, 0, 128),
}
}

func (d *DefaultTree) Push(data []byte) error {
// ignore the idx, as this implementation doesn't need that info
d.leaves = append(d.leaves, data)
return nil
}

func (d *DefaultTree) Root() ([]byte, error) {
if d.root == nil {
for _, l := range d.leaves {
d.Tree.Push(l)
}
d.root = d.Tree.Root()
}
return d.root, nil
}
13 changes: 10 additions & 3 deletions extendeddatacrossword_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,14 @@ func TestCorruptedEdsReturnsErrByzantineData_UnorderedShares(t *testing.T) {

codec := NewLeoRSCodec()

edsWidth := 4 // number of shares per row/column in the extended data square
odsWidth := edsWidth / 2 // number of shares per row/column in the original data square
err := RegisterTree("testing-tree", newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize)))
assert.NoError(t, err)

// create a DA header
eds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, 1, 2, 3, 4)

assert.NotNil(t, eds)
dAHeaderRoots, err := eds.getRowRoots()
assert.NoError(t, err)
Expand Down Expand Up @@ -436,10 +442,11 @@ func createTestEdsWithNMT(t *testing.T, codec Codec, shareSize, namespaceSize in
for i, shareValue := range sharesValue {
shares[i] = bytes.Repeat([]byte{byte(shareValue)}, shareSize)
}
edsWidth := 4 // number of shares per row/column in the extended data square
odsWidth := edsWidth / 2 // number of shares per row/column in the original data square

eds, err := ComputeExtendedDataSquare(shares, codec, newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize)))
treeConstructorFn, err := TreeFn("testing-tree")
require.NoError(t, err)

eds, err := ComputeExtendedDataSquare(shares, codec, treeConstructorFn)
require.NoError(t, err)

return eds
Expand Down
36 changes: 32 additions & 4 deletions extendeddatasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,45 @@ import (
type ExtendedDataSquare struct {
*dataSquare
codec Codec
treeName string
originalDataWidth uint
}

func (eds *ExtendedDataSquare) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
DataSquare [][]byte `json:"data_square"`
Codec string `json:"codec"`
Tree string `json:"tree"`
}{
DataSquare: eds.dataSquare.Flattened(),
Codec: eds.codec.Name(),
Tree: eds.treeName,
})
}

func (eds *ExtendedDataSquare) UnmarshalJSON(b []byte) error {
var aux struct {
DataSquare [][]byte `json:"data_square"`
Codec string `json:"codec"`
Tree string `json:"tree"`
}

if err := json.Unmarshal(b, &aux); err != nil {
err := json.Unmarshal(b, &aux)
if err != nil {
return err
}

var treeConstructor TreeConstructorFn
if aux.Tree == "" {
aux.Tree = DefaultTreeName
}

treeConstructor, err = TreeFn(aux.Tree)
if err != nil {
return err
}
importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], NewDefaultTree)

importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], treeConstructor)
if err != nil {
return err
}
Expand All @@ -61,12 +77,18 @@ func ComputeExtendedDataSquare(
if err != nil {
return nil, err
}

ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize))
if err != nil {
return nil, err
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec}
treeName := getTreeNameFromConstructorFn(treeCreatorFn)
if treeName == "" {
return nil, errors.New("tree name not found")
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: treeName}
err = eds.erasureExtendSquare(codec)
if err != nil {
return nil, err
Expand All @@ -90,12 +112,18 @@ func ImportExtendedDataSquare(
if err != nil {
return nil, err
}

ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize))
if err != nil {
return nil, err
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec}
treeName := getTreeNameFromConstructorFn(treeCreatorFn)
if treeName == "" {
return nil, errors.New("tree name not found")
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: treeName}
err = validateEdsWidth(eds.width)
if err != nil {
return nil, err
Expand Down
63 changes: 63 additions & 0 deletions extendeddatasquare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,69 @@ func TestMarshalJSON(t *testing.T) {
}
}

// TestUnmarshalJSON test the UnmarshalJSON function.
func TestUnmarshalJSON(t *testing.T) {
treeName := "testing_unmarshalJSON_tree"
treeConstructorFn := sudoConstructorFn
err := RegisterTree(treeName, treeConstructorFn)
require.NoError(t, err)

codec := NewLeoRSCodec()
result, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, codec, treeConstructorFn)
if err != nil {
panic(err)
}

tests := []struct {
name string
malleate func()
expectedTreeName string
cleanUp func()
}{
{
"Tree field exists",
func() {},
treeName,
func() {
cleanUp(treeName)
},
},
{
"Tree field missing",
func() {
// clear the tree name value in the eds before marshal
result.treeName = ""
},
DefaultTreeName,
func() {},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test.malleate()
edsBytes, err := json.Marshal(result)
if err != nil {
t.Errorf("failed to marshal EDS: %v", err)
}

var eds ExtendedDataSquare
err = json.Unmarshal(edsBytes, &eds)
if err != nil {
t.Errorf("failed to unmarshal EDS: %v", err)
}
if !reflect.DeepEqual(result.squareRow, eds.squareRow) {
t.Errorf("eds not equal after json marshal/unmarshal")
}
require.Equal(t, test.expectedTreeName, eds.treeName)

test.cleanUp()
})
}
}

func TestNewExtendedDataSquare(t *testing.T) {
t.Run("returns an error if edsWidth is not even", func(t *testing.T) {
edsWidth := uint(1)
Expand Down
78 changes: 55 additions & 23 deletions tree.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package rsmt2d

import (
"crypto/sha256"

"github.com/celestiaorg/merkletree"
"fmt"
"reflect"
"sync"
)

// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle tree
Expand All @@ -22,33 +22,65 @@ type Tree interface {
Root() ([]byte, error)
}

var _ Tree = &DefaultTree{}
// treeFns is a global map used for keeping track of registered tree constructors for JSON serialization
// The keys of this map should be kebab cased. E.g. "default-tree"
var treeFns = sync.Map{}

// RegisterTree must be called in the init function
func RegisterTree(treeName string, treeConstructor TreeConstructorFn) error {
if _, ok := treeFns.Load(treeName); ok {
return fmt.Errorf("%s already registered", treeName)
}

treeFns.Store(treeName, treeConstructor)

type DefaultTree struct {
*merkletree.Tree
leaves [][]byte
root []byte
return nil
}

func NewDefaultTree(_ Axis, _ uint) Tree {
return &DefaultTree{
Tree: merkletree.New(sha256.New()),
leaves: make([][]byte, 0, 128),
// TreeFn get tree constructor function by tree name from the global map registry
func TreeFn(treeName string) (TreeConstructorFn, error) {
var treeFn TreeConstructorFn
v, ok := treeFns.Load(treeName)
if !ok {
return nil, fmt.Errorf("%s not registered yet", treeName)
}
treeFn, ok = v.(TreeConstructorFn)
if !ok {
return nil, fmt.Errorf("key %s has invalid interface", treeName)
}

return treeFn, nil
}

func (d *DefaultTree) Push(data []byte) error {
// ignore the idx, as this implementation doesn't need that info
d.leaves = append(d.leaves, data)
return nil
// removeTreeFn removes a treeConstructorFn by treeName.
// Only use for test cleanup. Proceed with caution.
func removeTreeFn(treeName string) {
treeFns.Delete(treeName)
}

func (d *DefaultTree) Root() ([]byte, error) {
if d.root == nil {
for _, l := range d.leaves {
d.Tree.Push(l)
// Get the tree name by the tree constructor function from the global map registry
// TODO: this code is temporary until all breaking changes is handle here: https://github.com/celestiaorg/rsmt2d/pull/278
func getTreeNameFromConstructorFn(treeConstructor TreeConstructorFn) string {
key := ""
treeFns.Range(func(k, v interface{}) bool {
keyString, ok := k.(string)
if !ok {
// continue checking other key, value
return true
}
d.root = d.Tree.Root()
}
return d.root, nil
treeFn, ok := v.(TreeConstructorFn)
if !ok {
// continue checking other key, value
return true
}

if reflect.DeepEqual(reflect.ValueOf(treeFn), reflect.ValueOf(treeConstructor)) {
key = keyString
return false
}

return true
})

return key
}
Loading

0 comments on commit bb5e119

Please sign in to comment.