Skip to content

Commit

Permalink
refactor(ipld): use Set/GetCell API from rstm2d
Browse files Browse the repository at this point in the history
The latest rsmt2d version allows us to set cells/share directly into the imported square.
Hence, we can avoid using flattened slices and the complexity it introduces when calculating position or index to set a share.
Also, we now avoid reimporting the square on each repair retry and overall optimizing the whole process.
  • Loading branch information
Wondertan committed Sep 30, 2022
1 parent 0a32167 commit 7051382
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 95 deletions.
6 changes: 3 additions & 3 deletions ipld/get_namespaced_shares.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ type fetchedBounds struct {
highest int64
}

// update checks if the passed index is outside the current bounds,
// update checks if the passed pos is outside the current bounds,
// and updates the bounds atomically if it extends them.
func (b *fetchedBounds) update(index int64) {
lowest := atomic.LoadInt64(&b.lowest)
// try to write index to the lower bound if appropriate, and retry until the atomic op is successful
// try to write pos to the lower bound if appropriate, and retry until the atomic op is successful
// CAS ensures that we don't overwrite if the bound has been updated in another goroutine after the comparison here
for index < lowest && !atomic.CompareAndSwapInt64(&b.lowest, lowest, index) {
lowest = atomic.LoadInt64(&b.lowest)
Expand Down Expand Up @@ -193,7 +193,7 @@ func getLeavesByNamespace(
for i, lnk := range links {
newJob := &job{
id: lnk.Cid,
// position represents the index in a flattened binary tree,
// position represents the pos in a flattened binary tree,
// so we can return a slice of leaves in order
pos: j.pos*2 + i,
// we pass the context to job so that spans are tracked in a tree
Expand Down
101 changes: 43 additions & 58 deletions ipld/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,18 @@ func (r *Retriever) Retrieve(ctx context.Context, dah *da.DataAvailabilityHeader
// quadrant request retries. Also, provides an API
// to reconstruct the block once enough shares are fetched.
type retrievalSession struct {
dah *da.DataAvailabilityHeader
bget blockservice.BlockGetter
adder *NmtNodeAdder

treeFn rsmt2d.TreeConstructorFn
codec rsmt2d.Codec

dah *da.DataAvailabilityHeader
squareImported *rsmt2d.ExtendedDataSquare

quadrants []*quadrant
sharesLks []sync.Mutex
sharesCount uint32

squareLk sync.RWMutex
square [][]byte
squareSig chan struct{}
squareDn chan struct{}
// TODO(@Wondertan): Extract into a separate data structure
squareQuadrants []*quadrant
squareCellsLks [][]sync.Mutex
squareCellsCount uint32
squareSig chan struct{}
squareDn chan struct{}
squareLk sync.RWMutex
square *rsmt2d.ExtendedDataSquare

span trace.Span
}
Expand All @@ -131,29 +126,28 @@ func (r *Retriever) newSession(ctx context.Context, dah *da.DataAvailabilityHead
r.bServ,
format.MaxSizeBatchOption(batchSize(size)),
)
ses := &retrievalSession{
bget: blockservice.NewSession(ctx, r.bServ),
adder: adder,
treeFn: func() rsmt2d.Tree {
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, nmt.NodeVisitor(adder.Visit))
return &tree
},
codec: DefaultRSMT2DCodec(),
dah: dah,
quadrants: newQuadrants(dah),
sharesLks: make([]sync.Mutex, size*size),
square: make([][]byte, size*size),
squareSig: make(chan struct{}, 1),
squareDn: make(chan struct{}),
span: trace.SpanFromContext(ctx),
}

square, err := rsmt2d.ImportExtendedDataSquare(ses.square, ses.codec, ses.treeFn)
square, err := rsmt2d.ImportExtendedDataSquare(make([][]byte, size*size), DefaultRSMT2DCodec(), func() rsmt2d.Tree {
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, nmt.NodeVisitor(adder.Visit))
return &tree
})
if err != nil {
return nil, err
}

ses.squareImported = square
ses := &retrievalSession{
dah: dah,
bget: blockservice.NewSession(ctx, r.bServ),
adder: adder,
squareQuadrants: newQuadrants(dah),
squareCellsLks: make([][]sync.Mutex, size),
squareSig: make(chan struct{}, 1),
squareDn: make(chan struct{}),
square: square,
span: trace.SpanFromContext(ctx),
}
for i := range ses.squareCellsLks {
ses.squareCellsLks[i] = make([]sync.Mutex, size)
}
go ses.request(ctx)
return ses, nil
}
Expand All @@ -168,36 +162,24 @@ func (rs *retrievalSession) Done() <-chan struct{} {
// Reconstruct tries to reconstruct the data square and returns it on success.
func (rs *retrievalSession) Reconstruct(ctx context.Context) (*rsmt2d.ExtendedDataSquare, error) {
if rs.isReconstructed() {
return rs.squareImported, nil
return rs.square, nil
}
// prevent further writes to the square
rs.squareLk.Lock()
defer rs.squareLk.Unlock()

// TODO(@Wondertan): This is bad!
// * We should not reimport the square multiple times
// * We should set shares into imported square via SetShare(https://github.com/celestiaorg/rsmt2d/issues/83)
// to accomplish the above point.
{
squareImported, err := rsmt2d.ImportExtendedDataSquare(rs.square, rs.codec, rs.treeFn)
if err != nil {
return nil, err
}
rs.squareImported = squareImported
}

_, span := tracer.Start(ctx, "reconstruct-square")
defer span.End()

// and try to repair with what we have
err := rs.squareImported.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots)
err := rs.square.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots)
if err != nil {
span.RecordError(err)
return nil, err
}
log.Infow("data square reconstructed", "data_hash", hex.EncodeToString(rs.dah.Hash()), "size", len(rs.dah.RowsRoots))
close(rs.squareDn)
return rs.squareImported, nil
return rs.square, nil
}

// isReconstructed report true whether the square attached to the session
Expand Down Expand Up @@ -230,16 +212,16 @@ func (rs *retrievalSession) Close() error {
func (rs *retrievalSession) request(ctx context.Context) {
t := time.NewTicker(RetrieveQuadrantTimeout)
defer t.Stop()
for retry := 0; retry < len(rs.quadrants); retry++ {
q := rs.quadrants[retry]
for retry := 0; retry < len(rs.squareQuadrants); retry++ {
q := rs.squareQuadrants[retry]
log.Debugw("requesting quadrant",
"axis", q.source,
"x", q.x,
"y", q.y,
"size", len(q.roots),
)
rs.span.AddEvent("requesting quadrant", trace.WithAttributes(
attribute.Int("axis", q.source),
attribute.Int("axis", int(q.source)),
attribute.Int("x", q.x),
attribute.Int("y", q.y),
attribute.Int("size", len(q.roots)),
Expand All @@ -258,7 +240,7 @@ func (rs *retrievalSession) request(ctx context.Context) {
"size", len(q.roots),
)
rs.span.AddEvent("quadrant request timeout", trace.WithAttributes(
attribute.Int("axis", q.source),
attribute.Int("axis", int(q.source)),
attribute.Int("x", q.x),
attribute.Int("y", q.y),
attribute.Int("size", len(q.roots)),
Expand All @@ -283,17 +265,17 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) {
}
// and go get shares of left or the right side of the whole col/row axis
// the left or the right side of the tree represent some portion of the quadrant
// which we put into the rs.square share-by-share by calculating shares' indexes using q.index
// which we put into the rs.square share-by-share by calculating shares' indexes using q.pos
GetShares(ctx, rs.bget, nd.Links()[q.x].Cid, size, func(j int, share Share) {
// NOTE: Each share can appear twice here, for a Row and Col, respectively.
// These shares are always equal, and we allow only the first one to be written
// in the square.
// NOTE-2: We never actually fetch shares from the network *twice*.
// Once a share is downloaded from the network it is cached on the IPLD(blockservice) level.
// calc index of the share
idx := q.index(i, j)
// calc position of the share
x, y := q.pos(i, j)
// try to lock the share
ok := rs.sharesLks[idx].TryLock()
ok := rs.squareCellsLks[x][y].TryLock()
if !ok {
// if already locked and written - do nothing
return
Expand All @@ -310,14 +292,17 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) {
if rs.isReconstructed() {
return
}
rs.square[idx] = share
if rs.square.GetCell(uint(x), uint(y)) != nil {
return
}
rs.square.SetCell(uint(x), uint(y), share)
// if we have >= 1/4 of the square we can start trying to Reconstruct
// TODO(@Wondertan): This is not an ideal way to know when to start
// reconstruction and can cause idle reconstruction tries in some cases,
// but it is totally fine for the happy case and for now.
// The earlier we correctly know that we have the full square - the earlier
// we cancel ongoing requests - the less data is being wastedly transferred.
if atomic.AddUint32(&rs.sharesCount, 1) >= uint32(size*size) {
if atomic.AddUint32(&rs.squareCellsCount, 1) >= uint32(size*size) {
select {
case rs.squareSig <- struct{}{}:
default:
Expand Down
51 changes: 17 additions & 34 deletions ipld/retriever_quadrant.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package ipld

import (
"math"
"math/rand"
"time"

"github.com/ipfs/go-cid"
"github.com/tendermint/tendermint/pkg/da"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/ipld/plugin"
)

Expand Down Expand Up @@ -44,7 +45,7 @@ type quadrant struct {
// source defines the axis for quadrant
// it can be either 1 or 0 similar to x and y
// where 0 is Row source and 1 is Col respectively
source int
source rsmt2d.Axis
}

// newQuadrants constructs a slice of quadrants from DAHeader.
Expand All @@ -69,17 +70,13 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant {
}

for i := range quadrants {
// convert quadrant index into coordinates
// convert quadrant 1D into into 2D coordinates
x, y := i%2, i/2
if source == 1 { // swap coordinates for column
x, y = y, x
}

quadrants[i] = &quadrant{
roots: roots[qsize*y : qsize*(y+1)],
x: x,
y: y,
source: source,
source: rsmt2d.Axis(source),
}
}
}
Expand All @@ -92,30 +89,16 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant {
return quadrants
}

// index calculates index for a share in a data square slice flattened by rows.
//
// NOTE: The complexity of the formula below comes from:
// - Goal to avoid share copying
// - Goal to make formula generic for both rows and cols
// - While data square is flattened by rows only
//
// TODO(@Wondertan): This can be simplified by making rsmt2d working over 3D byte slice(not flattened)
func (q *quadrant) index(rootIdx, cellIdx int) int {
size := len(q.roots)
// half square offsets, e.g. share is from Q3,
// so we add to index Q1+Q2
halfSquareOffsetCol := pow(size*2, q.source)
halfSquareOffsetRow := pow(size*2, q.source^1)
// offsets for the axis, e.g. share is from Q4.
// so we add to index Q3
offsetX := q.x * halfSquareOffsetCol * size
offsetY := q.y * halfSquareOffsetRow * size

rootIdx *= halfSquareOffsetRow
cellIdx *= halfSquareOffsetCol
return rootIdx + cellIdx + offsetX + offsetY
}

func pow(x, y int) int {
return int(math.Pow(float64(x), float64(y)))
// pos calculates position of a share in a data square.
func (q *quadrant) pos(rootIdx, cellIdx int) (int, int) {
cellIdx += len(q.roots) * q.x
rootIdx += len(q.roots) * q.y
switch q.source {
case rsmt2d.Row:
return rootIdx, cellIdx
case rsmt2d.Col:
return cellIdx, rootIdx
default:
panic("unknown axis")
}
}

0 comments on commit 7051382

Please sign in to comment.