Skip to content

Commit

Permalink
Support multiple previous ATXs (#6024)
Browse files Browse the repository at this point in the history
## Motivation

Add support for multiple previous ATXs as a merged ATX might reference multiple ATXs.
  • Loading branch information
poszu committed Aug 9, 2024
1 parent fcce5d8 commit 16d8ac9
Show file tree
Hide file tree
Showing 25 changed files with 433 additions and 218 deletions.
1 change: 1 addition & 0 deletions activation/activation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ func publishAtxV1(
return codec.Decode(got, &watx)
})
require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx), watx.Blob()))
require.NoError(tb, atxs.SetPost(tab.db, watx.ID(), watx.PrevATXID, 0, watx.SmesherID, watx.NumUnits))
tab.atxsdata.AddFromAtx(toAtx(tb, &watx), false)
return &watx
}
Expand Down
5 changes: 4 additions & 1 deletion activation/e2e/atx_merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ func Test_MarryAndMerge(t *testing.T) {
require.Equal(t, units[i], atxFromDb.NumUnits)
require.Equal(t, signer.NodeID(), atxFromDb.SmesherID)
require.Equal(t, publish, atxFromDb.PublishEpoch)
require.Equal(t, mergedATX2.ID(), atxFromDb.PrevATXID)
prev, err := atxs.Previous(db, atxFromDb.ID())
require.NoError(t, err)
require.Len(t, prev, 1)
require.Equal(t, mergedATX2.ID(), prev[0])
}
}
2 changes: 1 addition & 1 deletion activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ func (h *HandlerV1) storeAtx(
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
err = atxs.SetUnits(tx, atx.ID(), atx.SmesherID, watx.NumUnits)
err = atxs.SetPost(tx, atx.ID(), watx.PrevATXID, 0, atx.SmesherID, watx.NumUnits)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("set atx units: %w", err)
}
Expand Down
107 changes: 49 additions & 58 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,25 @@ func (h *HandlerV2) processATX(
return fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err)
}

parts, err := h.syntacticallyValidateDeps(ctx, watx)
atxData, err := h.syntacticallyValidateDeps(ctx, watx)
if err != nil {
return fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err)
}
atxData.marriages = marrying

atx := &types.ActivationTx{
PublishEpoch: watx.PublishEpoch,
MarriageATX: watx.MarriageATX,
Coinbase: watx.Coinbase,
BaseTickHeight: baseTickHeight,
NumUnits: parts.effectiveUnits,
TickCount: parts.ticks,
Weight: parts.weight,
NumUnits: atxData.effectiveUnits,
TickCount: atxData.ticks,
Weight: atxData.weight,
VRFNonce: types.VRFPostIndex(watx.VRFNonce),
SmesherID: watx.SmesherID,
}

if watx.Initial == nil {
// FIXME: update to keep many previous ATXs to support merged ATXs
atx.PrevATXID = watx.PreviousATXs[0]
} else {
if watx.Initial != nil {
atx.CommitmentATX = &watx.Initial.CommitmentATX
}

Expand All @@ -141,12 +139,12 @@ func (h *HandlerV2) processATX(
atx.SetID(watx.ID())
atx.SetReceived(received)

if err := h.storeAtx(ctx, atx, watx, marrying, parts.units); err != nil {
if err := h.storeAtx(ctx, atx, atxData); err != nil {
return fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err)
}

events.ReportNewActivation(atx)
h.logger.Info("new atx", log.ZContext(ctx), zap.Inline(atx))
h.logger.Debug("new atx", log.ZContext(ctx), zap.Inline(atx))
return err
}

Expand Down Expand Up @@ -434,11 +432,19 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e
return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX)
}

type atxParts struct {
type idData struct {
previous types.ATXID
previousIndex int
units uint32
}

type activationTx struct {
*wire.ActivationTxV2
ticks uint64
weight uint64
effectiveUnits uint32
units map[types.NodeID]uint32
ids map[types.NodeID]idData
marriages []marriage
}

type nipostSize struct {
Expand Down Expand Up @@ -496,9 +502,10 @@ func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error
func (h *HandlerV2) syntacticallyValidateDeps(
ctx context.Context,
atx *wire.ActivationTxV2,
) (*atxParts, error) {
parts := atxParts{
units: make(map[types.NodeID]uint32),
) (*activationTx, error) {
result := activationTx{
ActivationTxV2: atx,
ids: make(map[types.NodeID]idData),
}
if atx.Initial != nil {
if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil {
Expand Down Expand Up @@ -586,7 +593,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
nipostSizes[i].ticks = leaves / h.tickSize
}

parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp()
result.effectiveUnits, result.weight, err = nipostSizes.sumUp()
if err != nil {
return nil, err
}
Expand All @@ -597,6 +604,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
for _, post := range niposts.Posts {
id := equivocationSet[post.MarriageIndex]
var commitment types.ATXID
var previous types.ATXID
if atx.Initial != nil {
commitment = atx.Initial.CommitmentATX
} else {
Expand All @@ -608,6 +616,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if id == atx.SmesherID {
smesherCommitment = &commitment
}
previous = previousAtxs[post.PrevATXIndex].ID()
}

err := h.nipostValidator.PostV2(
Expand Down Expand Up @@ -635,7 +644,11 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if err != nil {
return nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err)
}
parts.units[id] = post.NumUnits
result.ids[id] = idData{
previous: previous,
previousIndex: int(post.PrevATXIndex),
units: post.NumUnits,
}
}
}

Expand All @@ -649,42 +662,36 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}
}

parts.ticks = nipostSizes.minTicks()
return &parts, nil
result.ticks = nipostSizes.minTicks()
return &result, nil
}

func (h *HandlerV2) checkMalicious(
ctx context.Context,
tx *sql.Tx,
watx *wire.ActivationTxV2,
marrying []marriage,
ids []types.NodeID,
) error {
malicious, err := identities.IsMalicious(tx, watx.SmesherID)
func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activationTx) error {
malicious, err := identities.IsMalicious(tx, atx.SmesherID)
if err != nil {
return fmt.Errorf("checking if node is malicious: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoubleMarry(ctx, tx, watx, marrying)
malicious, err = h.checkDoubleMarry(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double marry: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoublePost(ctx, tx, watx, ids)
malicious, err = h.checkDoublePost(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double post: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoubleMerge(ctx, tx, watx)
malicious, err = h.checkDoubleMerge(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double merge: %w", err)
}
Expand All @@ -700,13 +707,8 @@ func (h *HandlerV2) checkMalicious(
return nil
}

func (h *HandlerV2) checkDoubleMarry(
ctx context.Context,
tx *sql.Tx,
atx *wire.ActivationTxV2,
marrying []marriage,
) (bool, error) {
for _, m := range marrying {
func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for _, m := range atx.marriages {
mATX, err := identities.MarriageATX(tx, m.id)
if err != nil {
return false, fmt.Errorf("checking if ID is married: %w", err)
Expand All @@ -725,7 +727,7 @@ func (h *HandlerV2) checkDoubleMarry(
var otherAtx wire.ActivationTxV2
codec.MustDecode(blob.Bytes, &otherAtx)

proof, err := wire.NewDoubleMarryProof(tx, atx, &otherAtx, m.id)
proof, err := wire.NewDoubleMarryProof(tx, atx.ActivationTxV2, &otherAtx, m.id)
if err != nil {
return true, fmt.Errorf("creating double marry proof: %w", err)
}
Expand All @@ -735,13 +737,8 @@ func (h *HandlerV2) checkDoubleMarry(
return false, nil
}

func (h *HandlerV2) checkDoublePost(
ctx context.Context,
tx *sql.Tx,
atx *wire.ActivationTxV2,
ids []types.NodeID,
) (bool, error) {
for _, id := range ids {
func (h *HandlerV2) checkDoublePost(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for id := range atx.ids {
atxids, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch)
switch {
case errors.Is(err, sql.ErrNotFound):
Expand All @@ -765,7 +762,7 @@ func (h *HandlerV2) checkDoublePost(
return false, nil
}

func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2) (bool, error) {
func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *activationTx) (bool, error) {
if watx.MarriageATX == nil {
return false, nil
}
Expand All @@ -791,20 +788,14 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire
}

// Store an ATX in the DB.
func (h *HandlerV2) storeAtx(
ctx context.Context,
atx *types.ActivationTx,
watx *wire.ActivationTxV2,
marrying []marriage,
units map[types.NodeID]uint32,
) error {
func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx *activationTx) error {
if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error {
if len(marrying) != 0 {
if len(watx.marriages) != 0 {
marriageData := identities.MarriageData{
ATX: atx.ID(),
Target: atx.SmesherID,
}
for i, m := range marrying {
for i, m := range watx.marriages {
marriageData.Signature = m.signature
marriageData.Index = i
if err := identities.SetMarriage(tx, m.id, &marriageData); err != nil {
Expand All @@ -817,8 +808,8 @@ func (h *HandlerV2) storeAtx(
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
for id, units := range units {
err = atxs.SetUnits(tx, atx.ID(), id, units)
for id, post := range watx.ids {
err = atxs.SetPost(tx, atx.ID(), post.previous, post.previousIndex, id, post.units)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("setting atx units for ID %s: %w", id, err)
}
Expand All @@ -837,7 +828,7 @@ func (h *HandlerV2) storeAtx(
// TODO(mafa): don't store own ATX if it would mark the node as malicious
// this probably needs to be done by validating and storing own ATXs eagerly and skipping validation in
// the gossip handler (not sync!)
err := h.checkMalicious(ctx, tx, watx, marrying, maps.Keys(units))
err := h.checkMalicious(ctx, tx, watx)
if err != nil {
return fmt.Errorf("check malicious: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions activation/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,7 @@ func Test_ValidatePreviousATX(t *testing.T) {
t.Parallel()
prev := &types.ActivationTx{}
prev.SetID(types.RandomATXID())
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), types.RandomNodeID(), 13))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, types.RandomNodeID(), 13))

_, err := atxHandler.validatePreviousAtx(types.RandomNodeID(), &wire.SubPostV2{}, []*types.ActivationTx{prev})
require.Error(t, err)
Expand All @@ -1446,8 +1446,8 @@ func Test_ValidatePreviousATX(t *testing.T) {
other := types.RandomNodeID()
prev := &types.ActivationTx{}
prev.SetID(types.RandomATXID())
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), id, 7))
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), other, 13))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, id, 7))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, other, 13))

units, err := atxHandler.validatePreviousAtx(id, &wire.SubPostV2{NumUnits: 100}, []*types.ActivationTx{prev})
require.NoError(t, err)
Expand All @@ -1467,7 +1467,7 @@ func Test_ValidatePreviousATX(t *testing.T) {
other := types.RandomNodeID()
prev := &types.ActivationTx{}
prev.SetID(types.RandomATXID())
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), other, 13))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, other, 13))

_, err := atxHandler.validatePreviousAtx(id, &wire.SubPostV2{NumUnits: 100}, []*types.ActivationTx{prev})
require.Error(t, err)
Expand Down
Loading

0 comments on commit 16d8ac9

Please sign in to comment.