Skip to content

Commit

Permalink
naming convention for tree name
Browse files Browse the repository at this point in the history
  • Loading branch information
sontrinh16 committed Dec 13, 2023
1 parent c79bb47 commit 857a175
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 38 deletions.
10 changes: 9 additions & 1 deletion default_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@ package rsmt2d

import (
"crypto/sha256"
"fmt"

"github.com/celestiaorg/merkletree"
)

var (
DefaultTreeName = "default-tree"
)

func init() {
registerTree(Default, NewDefaultTree)
err := RegisterTree(DefaultTreeName, NewDefaultTree)
if err != nil {
panic(fmt.Sprintf("%s already registered", DefaultTreeName))
}
}

var _ Tree = &DefaultTree{}
Expand Down
13 changes: 7 additions & 6 deletions extendeddatacrossword_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestRepairExtendedDataSquare(t *testing.T) {
flattened[12], flattened[13] = nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, Default)
eds, err := ImportExtendedDataSquare(flattened, codec, DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand All @@ -67,7 +67,7 @@ func TestRepairExtendedDataSquare(t *testing.T) {
flattened[12], flattened[13], flattened[14] = nil, nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, Default)
eds, err := ImportExtendedDataSquare(flattened, codec, DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func BenchmarkRepair(b *testing.B) {

// Generate a new range original data square then extend it
square := genRandDS(originalDataWidth, shareSize)
eds, err := ComputeExtendedDataSquare(square, codec, Default)
eds, err := ComputeExtendedDataSquare(square, codec, DefaultTreeName)
if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -275,7 +275,7 @@ func BenchmarkRepair(b *testing.B) {
}

// Re-import the data square.
eds, _ = ImportExtendedDataSquare(flattened, codec, Default)
eds, _ = ImportExtendedDataSquare(flattened, codec, DefaultTreeName)

b.StartTimer()

Expand All @@ -301,7 +301,7 @@ func createTestEds(codec Codec, shareSize int) *ExtendedDataSquare {
eds, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, codec, Default)
}, codec, DefaultTreeName)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -392,7 +392,8 @@ func TestCorruptedEdsReturnsErrByzantineData_UnorderedShares(t *testing.T) {

edsWidth := 4 // number of shares per row/column in the extended data square
odsWidth := edsWidth / 2 // number of shares per row/column in the original data square
registerTree("Testing", newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize)))
err := RegisterTree("testing-tree", newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize)))
assert.NoError(t, err)

// create a DA header
eds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, 1, 2, 3, 4)
Expand Down
2 changes: 1 addition & 1 deletion extendeddatasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func ComputeExtendedDataSquare(
return nil, err
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: Default}
eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: DefaultTreeName}
err = eds.erasureExtendSquare(codec)
if err != nil {
return nil, err
Expand Down
34 changes: 17 additions & 17 deletions extendeddatasquare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,29 +59,29 @@ func TestComputeExtendedDataSquare(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := ComputeExtendedDataSquare(tc.data, codec, Default)
result, err := ComputeExtendedDataSquare(tc.data, codec, DefaultTreeName)
assert.NoError(t, err)
assert.Equal(t, tc.want, result.squareRow)
})
}

t.Run("returns an error if chunkSize is not a multiple of 64", func(t *testing.T) {
chunk := bytes.Repeat([]byte{1}, 65)
_, err := ComputeExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), Default)
_, err := ComputeExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), DefaultTreeName)
assert.Error(t, err)
})
}

func TestImportExtendedDataSquare(t *testing.T) {
t.Run("is able to import an EDS", func(t *testing.T) {
eds := createExampleEds(t, shareSize)
got, err := ImportExtendedDataSquare(eds.Flattened(), NewLeoRSCodec(), Default)
got, err := ImportExtendedDataSquare(eds.Flattened(), NewLeoRSCodec(), DefaultTreeName)
assert.NoError(t, err)
assert.Equal(t, eds.Flattened(), got.Flattened())
})
t.Run("returns an error if chunkSize is not a multiple of 64", func(t *testing.T) {
chunk := bytes.Repeat([]byte{1}, 65)
_, err := ImportExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), Default)
_, err := ImportExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), DefaultTreeName)
assert.Error(t, err)
})
}
Expand All @@ -91,7 +91,7 @@ func TestMarshalJSON(t *testing.T) {
result, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, codec, Default)
}, codec, DefaultTreeName)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -162,7 +162,7 @@ func TestImmutableRoots(t *testing.T) {
result, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, codec, Default)
}, codec, DefaultTreeName)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func TestEDSRowColImmutable(t *testing.T) {
result, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, codec, Default)
}, codec, DefaultTreeName)
if err != nil {
panic(err)
}
Expand All @@ -220,7 +220,7 @@ func TestRowRoots(t *testing.T) {
eds, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, NewLeoRSCodec(), Default)
}, NewLeoRSCodec(), DefaultTreeName)
require.NoError(t, err)

rowRoots, err := eds.RowRoots()
Expand All @@ -232,7 +232,7 @@ func TestRowRoots(t *testing.T) {
eds, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, NewLeoRSCodec(), Default)
}, NewLeoRSCodec(), DefaultTreeName)
require.NoError(t, err)

// set a cell to nil to make the EDS incomplete
Expand All @@ -248,7 +248,7 @@ func TestColRoots(t *testing.T) {
eds, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, NewLeoRSCodec(), Default)
}, NewLeoRSCodec(), DefaultTreeName)
require.NoError(t, err)

colRoots, err := eds.ColRoots()
Expand All @@ -260,7 +260,7 @@ func TestColRoots(t *testing.T) {
eds, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, NewLeoRSCodec(), Default)
}, NewLeoRSCodec(), DefaultTreeName)
require.NoError(t, err)

// set a cell to nil to make the EDS incomplete
Expand Down Expand Up @@ -290,7 +290,7 @@ func BenchmarkExtensionEncoding(b *testing.B) {
fmt.Sprintf("%s %dx%dx%d ODS", codecName, i, i, len(square[0])),
func(b *testing.B) {
for n := 0; n < b.N; n++ {
eds, err := ComputeExtendedDataSquare(square, codec, Default)
eds, err := ComputeExtendedDataSquare(square, codec, DefaultTreeName)
if err != nil {
b.Error(err)
}
Expand All @@ -317,7 +317,7 @@ func BenchmarkExtensionWithRoots(b *testing.B) {
fmt.Sprintf("%s %dx%dx%d ODS", codecName, i, i, len(square[0])),
func(b *testing.B) {
for n := 0; n < b.N; n++ {
eds, err := ComputeExtendedDataSquare(square, codec, Default)
eds, err := ComputeExtendedDataSquare(square, codec, DefaultTreeName)
if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -396,7 +396,7 @@ func TestEquals(t *testing.T) {

unequalChunkSize := createExampleEds(t, shareSize*2)

unequalEds, err := ComputeExtendedDataSquare([][]byte{ones}, NewLeoRSCodec(), Default)
unequalEds, err := ComputeExtendedDataSquare([][]byte{ones}, NewLeoRSCodec(), DefaultTreeName)
require.NoError(t, err)

testCases := []testCase{
Expand Down Expand Up @@ -431,7 +431,7 @@ func TestRoots(t *testing.T) {
eds, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, NewLeoRSCodec(), Default)
}, NewLeoRSCodec(), DefaultTreeName)
require.NoError(t, err)

roots, err := eds.Roots()
Expand All @@ -458,7 +458,7 @@ func TestRoots(t *testing.T) {
eds, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, NewLeoRSCodec(), Default)
}, NewLeoRSCodec(), DefaultTreeName)
require.NoError(t, err)

// set a cell to nil to make the EDS incomplete
Expand All @@ -479,7 +479,7 @@ func createExampleEds(t *testing.T, chunkSize int) (eds *ExtendedDataSquare) {
threes, fours,
}

eds, err := ComputeExtendedDataSquare(ods, NewLeoRSCodec(), Default)
eds, err := ComputeExtendedDataSquare(ods, NewLeoRSCodec(), DefaultTreeName)
require.NoError(t, err)
return eds
}
12 changes: 6 additions & 6 deletions rsmt2d_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestEdsRepairRoundtripSimple(t *testing.T) {
threes, fours,
},
tt.codec,
rsmt2d.Default,
rsmt2d.DefaultTreeName,
)
if err != nil {
t.Errorf("ComputeExtendedDataSquare failed: %v", err)
Expand All @@ -56,7 +56,7 @@ func TestEdsRepairRoundtripSimple(t *testing.T) {
flattened[12], flattened[13] = nil, nil

// Re-import the data square.
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.Default)
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestEdsRepairTwice(t *testing.T) {
threes, fours,
},
tt.codec,
rsmt2d.Default,
rsmt2d.DefaultTreeName,
)
if err != nil {
t.Errorf("ComputeExtendedDataSquare failed: %v", err)
Expand All @@ -120,7 +120,7 @@ func TestEdsRepairTwice(t *testing.T) {
flattened[12], flattened[13] = nil, nil

// Re-import the data square.
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.Default)
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand All @@ -139,7 +139,7 @@ func TestEdsRepairTwice(t *testing.T) {
copy(flattened[1], missing)

// Re-import the data square.
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.Default)
eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.DefaultTreeName)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand Down Expand Up @@ -205,7 +205,7 @@ func createExampleEds(t *testing.T, chunkSize int) (eds *rsmt2d.ExtendedDataSqua
threes, fours,
}

eds, err := rsmt2d.ComputeExtendedDataSquare(ods, rsmt2d.NewLeoRSCodec(), rsmt2d.Default)
eds, err := rsmt2d.ComputeExtendedDataSquare(ods, rsmt2d.NewLeoRSCodec(), rsmt2d.DefaultTreeName)
require.NoError(t, err)
return eds
}
13 changes: 6 additions & 7 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ package rsmt2d

import "fmt"

const (
Default = "Default-tree"
)

// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle tree
// inside of rsmt2d.
type TreeConstructorFn = func(axis Axis, index uint) Tree
Expand All @@ -22,12 +18,15 @@ type Tree interface {
Root() ([]byte, error)
}

// trees is a global map used for keeping track of registered tree constructors for testing and JSON unmarshalling
// trees is a global map used for keeping track of registered tree constructors for JSON serialization
// The keys of this map should be kebab cased. E.g. "default-tree"
var trees = make(map[string]TreeConstructorFn)

func registerTree(treeName string, treeConstructor TreeConstructorFn) {
func RegisterTree(treeName string, treeConstructor TreeConstructorFn) error {
if trees[treeName] != nil {
panic(fmt.Sprintf("%s already registered", treeName))
return fmt.Errorf("%s already registered", treeName)
}
trees[treeName] = treeConstructor

return nil
}

0 comments on commit 857a175

Please sign in to comment.