diff --git a/default_tree.go b/default_tree.go index d07328e..c9bb040 100644 --- a/default_tree.go +++ b/default_tree.go @@ -2,12 +2,20 @@ package rsmt2d import ( "crypto/sha256" + "fmt" "github.com/celestiaorg/merkletree" ) +var ( + DefaultTreeName = "default-tree" +) + func init() { - registerTree(Default, NewDefaultTree) + err := RegisterTree(DefaultTreeName, NewDefaultTree) + if err != nil { + panic(fmt.Sprintf("%s already registered", DefaultTreeName)) + } } var _ Tree = &DefaultTree{} diff --git a/extendeddatacrossword_test.go b/extendeddatacrossword_test.go index 79c7877..6888012 100644 --- a/extendeddatacrossword_test.go +++ b/extendeddatacrossword_test.go @@ -42,7 +42,7 @@ func TestRepairExtendedDataSquare(t *testing.T) { flattened[12], flattened[13] = nil, nil // Re-import the data square. - eds, err := ImportExtendedDataSquare(flattened, codec, Default) + eds, err := ImportExtendedDataSquare(flattened, codec, DefaultTreeName) if err != nil { t.Errorf("ImportExtendedDataSquare failed: %v", err) } @@ -67,7 +67,7 @@ func TestRepairExtendedDataSquare(t *testing.T) { flattened[12], flattened[13], flattened[14] = nil, nil, nil // Re-import the data square. - eds, err := ImportExtendedDataSquare(flattened, codec, Default) + eds, err := ImportExtendedDataSquare(flattened, codec, DefaultTreeName) if err != nil { t.Errorf("ImportExtendedDataSquare failed: %v", err) } @@ -237,7 +237,7 @@ func BenchmarkRepair(b *testing.B) { // Generate a new range original data square then extend it square := genRandDS(originalDataWidth, shareSize) - eds, err := ComputeExtendedDataSquare(square, codec, Default) + eds, err := ComputeExtendedDataSquare(square, codec, DefaultTreeName) if err != nil { b.Error(err) } @@ -275,7 +275,7 @@ func BenchmarkRepair(b *testing.B) { } // Re-import the data square. - eds, _ = ImportExtendedDataSquare(flattened, codec, Default) + eds, _ = ImportExtendedDataSquare(flattened, codec, DefaultTreeName) b.StartTimer() @@ -301,7 +301,7 @@ func createTestEds(codec Codec, shareSize int) *ExtendedDataSquare { eds, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, codec, Default) + }, codec, DefaultTreeName) if err != nil { panic(err) } @@ -392,7 +392,8 @@ func TestCorruptedEdsReturnsErrByzantineData_UnorderedShares(t *testing.T) { edsWidth := 4 // number of shares per row/column in the extended data square odsWidth := edsWidth / 2 // number of shares per row/column in the original data square - registerTree("Testing", newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize))) + err := RegisterTree("testing-tree", newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize))) + assert.NoError(t, err) // create a DA header eds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, 1, 2, 3, 4) diff --git a/extendeddatasquare.go b/extendeddatasquare.go index 3156af9..4080d6f 100644 --- a/extendeddatasquare.go +++ b/extendeddatasquare.go @@ -77,7 +77,7 @@ func ComputeExtendedDataSquare( return nil, err } - eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: Default} + eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: DefaultTreeName} err = eds.erasureExtendSquare(codec) if err != nil { return nil, err diff --git a/extendeddatasquare_test.go b/extendeddatasquare_test.go index b109726..b852e1c 100644 --- a/extendeddatasquare_test.go +++ b/extendeddatasquare_test.go @@ -59,7 +59,7 @@ func TestComputeExtendedDataSquare(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - result, err := ComputeExtendedDataSquare(tc.data, codec, Default) + result, err := ComputeExtendedDataSquare(tc.data, codec, DefaultTreeName) assert.NoError(t, err) assert.Equal(t, tc.want, result.squareRow) }) @@ -67,7 +67,7 @@ func TestComputeExtendedDataSquare(t *testing.T) { t.Run("returns an error if chunkSize is not a multiple of 64", func(t *testing.T) { chunk := bytes.Repeat([]byte{1}, 65) - _, err := ComputeExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), Default) + _, err := ComputeExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), DefaultTreeName) assert.Error(t, err) }) } @@ -75,13 +75,13 @@ func TestComputeExtendedDataSquare(t *testing.T) { func TestImportExtendedDataSquare(t *testing.T) { t.Run("is able to import an EDS", func(t *testing.T) { eds := createExampleEds(t, shareSize) - got, err := ImportExtendedDataSquare(eds.Flattened(), NewLeoRSCodec(), Default) + got, err := ImportExtendedDataSquare(eds.Flattened(), NewLeoRSCodec(), DefaultTreeName) assert.NoError(t, err) assert.Equal(t, eds.Flattened(), got.Flattened()) }) t.Run("returns an error if chunkSize is not a multiple of 64", func(t *testing.T) { chunk := bytes.Repeat([]byte{1}, 65) - _, err := ImportExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), Default) + _, err := ImportExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), DefaultTreeName) assert.Error(t, err) }) } @@ -91,7 +91,7 @@ func TestMarshalJSON(t *testing.T) { result, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, codec, Default) + }, codec, DefaultTreeName) if err != nil { panic(err) } @@ -162,7 +162,7 @@ func TestImmutableRoots(t *testing.T) { result, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, codec, Default) + }, codec, DefaultTreeName) if err != nil { panic(err) } @@ -197,7 +197,7 @@ func TestEDSRowColImmutable(t *testing.T) { result, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, codec, Default) + }, codec, DefaultTreeName) if err != nil { panic(err) } @@ -220,7 +220,7 @@ func TestRowRoots(t *testing.T) { eds, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, NewLeoRSCodec(), Default) + }, NewLeoRSCodec(), DefaultTreeName) require.NoError(t, err) rowRoots, err := eds.RowRoots() @@ -232,7 +232,7 @@ func TestRowRoots(t *testing.T) { eds, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, NewLeoRSCodec(), Default) + }, NewLeoRSCodec(), DefaultTreeName) require.NoError(t, err) // set a cell to nil to make the EDS incomplete @@ -248,7 +248,7 @@ func TestColRoots(t *testing.T) { eds, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, NewLeoRSCodec(), Default) + }, NewLeoRSCodec(), DefaultTreeName) require.NoError(t, err) colRoots, err := eds.ColRoots() @@ -260,7 +260,7 @@ func TestColRoots(t *testing.T) { eds, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, NewLeoRSCodec(), Default) + }, NewLeoRSCodec(), DefaultTreeName) require.NoError(t, err) // set a cell to nil to make the EDS incomplete @@ -290,7 +290,7 @@ func BenchmarkExtensionEncoding(b *testing.B) { fmt.Sprintf("%s %dx%dx%d ODS", codecName, i, i, len(square[0])), func(b *testing.B) { for n := 0; n < b.N; n++ { - eds, err := ComputeExtendedDataSquare(square, codec, Default) + eds, err := ComputeExtendedDataSquare(square, codec, DefaultTreeName) if err != nil { b.Error(err) } @@ -317,7 +317,7 @@ func BenchmarkExtensionWithRoots(b *testing.B) { fmt.Sprintf("%s %dx%dx%d ODS", codecName, i, i, len(square[0])), func(b *testing.B) { for n := 0; n < b.N; n++ { - eds, err := ComputeExtendedDataSquare(square, codec, Default) + eds, err := ComputeExtendedDataSquare(square, codec, DefaultTreeName) if err != nil { b.Error(err) } @@ -396,7 +396,7 @@ func TestEquals(t *testing.T) { unequalChunkSize := createExampleEds(t, shareSize*2) - unequalEds, err := ComputeExtendedDataSquare([][]byte{ones}, NewLeoRSCodec(), Default) + unequalEds, err := ComputeExtendedDataSquare([][]byte{ones}, NewLeoRSCodec(), DefaultTreeName) require.NoError(t, err) testCases := []testCase{ @@ -431,7 +431,7 @@ func TestRoots(t *testing.T) { eds, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, NewLeoRSCodec(), Default) + }, NewLeoRSCodec(), DefaultTreeName) require.NoError(t, err) roots, err := eds.Roots() @@ -458,7 +458,7 @@ func TestRoots(t *testing.T) { eds, err := ComputeExtendedDataSquare([][]byte{ ones, twos, threes, fours, - }, NewLeoRSCodec(), Default) + }, NewLeoRSCodec(), DefaultTreeName) require.NoError(t, err) // set a cell to nil to make the EDS incomplete @@ -479,7 +479,7 @@ func createExampleEds(t *testing.T, chunkSize int) (eds *ExtendedDataSquare) { threes, fours, } - eds, err := ComputeExtendedDataSquare(ods, NewLeoRSCodec(), Default) + eds, err := ComputeExtendedDataSquare(ods, NewLeoRSCodec(), DefaultTreeName) require.NoError(t, err) return eds } diff --git a/rsmt2d_test.go b/rsmt2d_test.go index 9bbdca0..2561c7e 100644 --- a/rsmt2d_test.go +++ b/rsmt2d_test.go @@ -35,7 +35,7 @@ func TestEdsRepairRoundtripSimple(t *testing.T) { threes, fours, }, tt.codec, - rsmt2d.Default, + rsmt2d.DefaultTreeName, ) if err != nil { t.Errorf("ComputeExtendedDataSquare failed: %v", err) @@ -56,7 +56,7 @@ func TestEdsRepairRoundtripSimple(t *testing.T) { flattened[12], flattened[13] = nil, nil // Re-import the data square. - eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.Default) + eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.DefaultTreeName) if err != nil { t.Errorf("ImportExtendedDataSquare failed: %v", err) } @@ -97,7 +97,7 @@ func TestEdsRepairTwice(t *testing.T) { threes, fours, }, tt.codec, - rsmt2d.Default, + rsmt2d.DefaultTreeName, ) if err != nil { t.Errorf("ComputeExtendedDataSquare failed: %v", err) @@ -120,7 +120,7 @@ func TestEdsRepairTwice(t *testing.T) { flattened[12], flattened[13] = nil, nil // Re-import the data square. - eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.Default) + eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.DefaultTreeName) if err != nil { t.Errorf("ImportExtendedDataSquare failed: %v", err) } @@ -139,7 +139,7 @@ func TestEdsRepairTwice(t *testing.T) { copy(flattened[1], missing) // Re-import the data square. - eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.Default) + eds, err = rsmt2d.ImportExtendedDataSquare(flattened, tt.codec, rsmt2d.DefaultTreeName) if err != nil { t.Errorf("ImportExtendedDataSquare failed: %v", err) } @@ -205,7 +205,7 @@ func createExampleEds(t *testing.T, chunkSize int) (eds *rsmt2d.ExtendedDataSqua threes, fours, } - eds, err := rsmt2d.ComputeExtendedDataSquare(ods, rsmt2d.NewLeoRSCodec(), rsmt2d.Default) + eds, err := rsmt2d.ComputeExtendedDataSquare(ods, rsmt2d.NewLeoRSCodec(), rsmt2d.DefaultTreeName) require.NoError(t, err) return eds } diff --git a/tree.go b/tree.go index dde7d70..0972f7c 100644 --- a/tree.go +++ b/tree.go @@ -2,10 +2,6 @@ package rsmt2d import "fmt" -const ( - Default = "Default-tree" -) - // TreeConstructorFn creates a fresh Tree instance to be used as the Merkle tree // inside of rsmt2d. type TreeConstructorFn = func(axis Axis, index uint) Tree @@ -22,12 +18,15 @@ type Tree interface { Root() ([]byte, error) } -// trees is a global map used for keeping track of registered tree constructors for testing and JSON unmarshalling +// trees is a global map used for keeping track of registered tree constructors for JSON serialization +// The keys of this map should be kebab cased. E.g. "default-tree" var trees = make(map[string]TreeConstructorFn) -func registerTree(treeName string, treeConstructor TreeConstructorFn) { +func RegisterTree(treeName string, treeConstructor TreeConstructorFn) error { if trees[treeName] != nil { - panic(fmt.Sprintf("%s already registered", treeName)) + return fmt.Errorf("%s already registered", treeName) } trees[treeName] = treeConstructor + + return nil }