diff --git a/activation/activation_test.go b/activation/activation_test.go index ba9acfa0b7..e073d2a334 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -137,6 +137,7 @@ func publishAtxV1( return codec.Decode(got, &watx) }) require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx), watx.Blob())) + require.NoError(tb, atxs.SetPost(tab.db, watx.ID(), watx.PrevATXID, 0, watx.SmesherID, watx.NumUnits)) tab.atxsdata.AddFromAtx(toAtx(tb, &watx), false) return &watx } diff --git a/activation/e2e/atx_merge_test.go b/activation/e2e/atx_merge_test.go index 49f3f6d0a0..e0dfcca0ff 100644 --- a/activation/e2e/atx_merge_test.go +++ b/activation/e2e/atx_merge_test.go @@ -513,6 +513,9 @@ func Test_MarryAndMerge(t *testing.T) { require.Equal(t, units[i], atxFromDb.NumUnits) require.Equal(t, signer.NodeID(), atxFromDb.SmesherID) require.Equal(t, publish, atxFromDb.PublishEpoch) - require.Equal(t, mergedATX2.ID(), atxFromDb.PrevATXID) + prev, err := atxs.Previous(db, atxFromDb.ID()) + require.NoError(t, err) + require.Len(t, prev, 1) + require.Equal(t, mergedATX2.ID(), prev[0]) } } diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 4cd988bd9d..481c90df24 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -524,7 +524,7 @@ func (h *HandlerV1) storeAtx( if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("add atx to db: %w", err) } - err = atxs.SetUnits(tx, atx.ID(), atx.SmesherID, watx.NumUnits) + err = atxs.SetPost(tx, atx.ID(), watx.PrevATXID, 0, atx.SmesherID, watx.NumUnits) if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("set atx units: %w", err) } diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 0f0b22a853..42d139161e 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -111,27 +111,25 @@ func (h *HandlerV2) processATX( return fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err) } - parts, err := h.syntacticallyValidateDeps(ctx, watx) + atxData, err := h.syntacticallyValidateDeps(ctx, watx) if err != nil { return fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err) } + atxData.marriages = marrying atx := &types.ActivationTx{ PublishEpoch: watx.PublishEpoch, MarriageATX: watx.MarriageATX, Coinbase: watx.Coinbase, BaseTickHeight: baseTickHeight, - NumUnits: parts.effectiveUnits, - TickCount: parts.ticks, - Weight: parts.weight, + NumUnits: atxData.effectiveUnits, + TickCount: atxData.ticks, + Weight: atxData.weight, VRFNonce: types.VRFPostIndex(watx.VRFNonce), SmesherID: watx.SmesherID, } - if watx.Initial == nil { - // FIXME: update to keep many previous ATXs to support merged ATXs - atx.PrevATXID = watx.PreviousATXs[0] - } else { + if watx.Initial != nil { atx.CommitmentATX = &watx.Initial.CommitmentATX } @@ -141,12 +139,12 @@ func (h *HandlerV2) processATX( atx.SetID(watx.ID()) atx.SetReceived(received) - if err := h.storeAtx(ctx, atx, watx, marrying, parts.units); err != nil { + if err := h.storeAtx(ctx, atx, atxData); err != nil { return fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err) } events.ReportNewActivation(atx) - h.logger.Info("new atx", log.ZContext(ctx), zap.Inline(atx)) + h.logger.Debug("new atx", log.ZContext(ctx), zap.Inline(atx)) return err } @@ -434,11 +432,19 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX) } -type atxParts struct { +type idData struct { + previous types.ATXID + previousIndex int + units uint32 +} + +type activationTx struct { + *wire.ActivationTxV2 ticks uint64 weight uint64 effectiveUnits uint32 - units map[types.NodeID]uint32 + ids map[types.NodeID]idData + marriages []marriage } type nipostSize struct { @@ -496,9 +502,10 @@ func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error func (h *HandlerV2) syntacticallyValidateDeps( ctx context.Context, atx *wire.ActivationTxV2, -) (*atxParts, error) { - parts := atxParts{ - units: make(map[types.NodeID]uint32), +) (*activationTx, error) { + result := activationTx{ + ActivationTxV2: atx, + ids: make(map[types.NodeID]idData), } if atx.Initial != nil { if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil { @@ -586,7 +593,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( nipostSizes[i].ticks = leaves / h.tickSize } - parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp() + result.effectiveUnits, result.weight, err = nipostSizes.sumUp() if err != nil { return nil, err } @@ -597,6 +604,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( for _, post := range niposts.Posts { id := equivocationSet[post.MarriageIndex] var commitment types.ATXID + var previous types.ATXID if atx.Initial != nil { commitment = atx.Initial.CommitmentATX } else { @@ -608,6 +616,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( if id == atx.SmesherID { smesherCommitment = &commitment } + previous = previousAtxs[post.PrevATXIndex].ID() } err := h.nipostValidator.PostV2( @@ -635,7 +644,11 @@ func (h *HandlerV2) syntacticallyValidateDeps( if err != nil { return nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err) } - parts.units[id] = post.NumUnits + result.ids[id] = idData{ + previous: previous, + previousIndex: int(post.PrevATXIndex), + units: post.NumUnits, + } } } @@ -649,18 +662,12 @@ func (h *HandlerV2) syntacticallyValidateDeps( } } - parts.ticks = nipostSizes.minTicks() - return &parts, nil + result.ticks = nipostSizes.minTicks() + return &result, nil } -func (h *HandlerV2) checkMalicious( - ctx context.Context, - tx *sql.Tx, - watx *wire.ActivationTxV2, - marrying []marriage, - ids []types.NodeID, -) error { - malicious, err := identities.IsMalicious(tx, watx.SmesherID) +func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activationTx) error { + malicious, err := identities.IsMalicious(tx, atx.SmesherID) if err != nil { return fmt.Errorf("checking if node is malicious: %w", err) } @@ -668,7 +675,7 @@ func (h *HandlerV2) checkMalicious( return nil } - malicious, err = h.checkDoubleMarry(ctx, tx, watx, marrying) + malicious, err = h.checkDoubleMarry(ctx, tx, atx) if err != nil { return fmt.Errorf("checking double marry: %w", err) } @@ -676,7 +683,7 @@ func (h *HandlerV2) checkMalicious( return nil } - malicious, err = h.checkDoublePost(ctx, tx, watx, ids) + malicious, err = h.checkDoublePost(ctx, tx, atx) if err != nil { return fmt.Errorf("checking double post: %w", err) } @@ -684,7 +691,7 @@ func (h *HandlerV2) checkMalicious( return nil } - malicious, err = h.checkDoubleMerge(ctx, tx, watx) + malicious, err = h.checkDoubleMerge(ctx, tx, atx) if err != nil { return fmt.Errorf("checking double merge: %w", err) } @@ -700,13 +707,8 @@ func (h *HandlerV2) checkMalicious( return nil } -func (h *HandlerV2) checkDoubleMarry( - ctx context.Context, - tx *sql.Tx, - atx *wire.ActivationTxV2, - marrying []marriage, -) (bool, error) { - for _, m := range marrying { +func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { + for _, m := range atx.marriages { mATX, err := identities.MarriageATX(tx, m.id) if err != nil { return false, fmt.Errorf("checking if ID is married: %w", err) @@ -725,7 +727,7 @@ func (h *HandlerV2) checkDoubleMarry( var otherAtx wire.ActivationTxV2 codec.MustDecode(blob.Bytes, &otherAtx) - proof, err := wire.NewDoubleMarryProof(tx, atx, &otherAtx, m.id) + proof, err := wire.NewDoubleMarryProof(tx, atx.ActivationTxV2, &otherAtx, m.id) if err != nil { return true, fmt.Errorf("creating double marry proof: %w", err) } @@ -735,13 +737,8 @@ func (h *HandlerV2) checkDoubleMarry( return false, nil } -func (h *HandlerV2) checkDoublePost( - ctx context.Context, - tx *sql.Tx, - atx *wire.ActivationTxV2, - ids []types.NodeID, -) (bool, error) { - for _, id := range ids { +func (h *HandlerV2) checkDoublePost(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { + for id := range atx.ids { atxids, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch) switch { case errors.Is(err, sql.ErrNotFound): @@ -765,7 +762,7 @@ func (h *HandlerV2) checkDoublePost( return false, nil } -func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2) (bool, error) { +func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *activationTx) (bool, error) { if watx.MarriageATX == nil { return false, nil } @@ -791,20 +788,14 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire } // Store an ATX in the DB. -func (h *HandlerV2) storeAtx( - ctx context.Context, - atx *types.ActivationTx, - watx *wire.ActivationTxV2, - marrying []marriage, - units map[types.NodeID]uint32, -) error { +func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx *activationTx) error { if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { - if len(marrying) != 0 { + if len(watx.marriages) != 0 { marriageData := identities.MarriageData{ ATX: atx.ID(), Target: atx.SmesherID, } - for i, m := range marrying { + for i, m := range watx.marriages { marriageData.Signature = m.signature marriageData.Index = i if err := identities.SetMarriage(tx, m.id, &marriageData); err != nil { @@ -817,8 +808,8 @@ func (h *HandlerV2) storeAtx( if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("add atx to db: %w", err) } - for id, units := range units { - err = atxs.SetUnits(tx, atx.ID(), id, units) + for id, post := range watx.ids { + err = atxs.SetPost(tx, atx.ID(), post.previous, post.previousIndex, id, post.units) if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("setting atx units for ID %s: %w", id, err) } @@ -837,7 +828,7 @@ func (h *HandlerV2) storeAtx( // TODO(mafa): don't store own ATX if it would mark the node as malicious // this probably needs to be done by validating and storing own ATXs eagerly and skipping validation in // the gossip handler (not sync!) - err := h.checkMalicious(ctx, tx, watx, marrying, maps.Keys(units)) + err := h.checkMalicious(ctx, tx, watx) if err != nil { return fmt.Errorf("check malicious: %w", err) } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 3397b709ac..46f6e76834 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -1435,7 +1435,7 @@ func Test_ValidatePreviousATX(t *testing.T) { t.Parallel() prev := &types.ActivationTx{} prev.SetID(types.RandomATXID()) - require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), types.RandomNodeID(), 13)) + require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, types.RandomNodeID(), 13)) _, err := atxHandler.validatePreviousAtx(types.RandomNodeID(), &wire.SubPostV2{}, []*types.ActivationTx{prev}) require.Error(t, err) @@ -1446,8 +1446,8 @@ func Test_ValidatePreviousATX(t *testing.T) { other := types.RandomNodeID() prev := &types.ActivationTx{} prev.SetID(types.RandomATXID()) - require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), id, 7)) - require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), other, 13)) + require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, id, 7)) + require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, other, 13)) units, err := atxHandler.validatePreviousAtx(id, &wire.SubPostV2{NumUnits: 100}, []*types.ActivationTx{prev}) require.NoError(t, err) @@ -1467,7 +1467,7 @@ func Test_ValidatePreviousATX(t *testing.T) { other := types.RandomNodeID() prev := &types.ActivationTx{} prev.SetID(types.RandomATXID()) - require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), other, 13)) + require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, 0, other, 13)) _, err := atxHandler.validatePreviousAtx(id, &wire.SubPostV2{NumUnits: 100}, []*types.ActivationTx{prev}) require.Error(t, err) diff --git a/activation/validation.go b/activation/validation.go index 6050ff3d45..d6d070d895 100644 --- a/activation/validation.go +++ b/activation/validation.go @@ -424,9 +424,9 @@ func (v *Validator) VerifyChain(ctx context.Context, id, goldenATXID types.ATXID } type atxDeps struct { - nipost types.NIPost + niposts []types.NIPost positioning types.ATXID - previous types.ATXID + previous []types.ATXID commitment types.ATXID } @@ -455,14 +455,16 @@ func (v *Validator) getAtxDeps(ctx context.Context, id types.ATXID) (*atxDeps, e } deps := &atxDeps{ - nipost: *wire.NiPostFromWireV1(atx.NIPost), + niposts: []types.NIPost{*wire.NiPostFromWireV1(atx.NIPost)}, positioning: atx.PositioningATXID, - previous: atx.PrevATXID, commitment: commitment, } + if atx.PrevATXID != types.EmptyATXID { + deps.previous = []types.ATXID{atx.PrevATXID} + } + return deps, nil case types.AtxV2: - // TODO: support merged ATXs var atx wire.ActivationTxV2 if err := codec.Decode(blob.Bytes, &atx); err != nil { return nil, fmt.Errorf("decoding ATX blob: %w", err) @@ -478,23 +480,23 @@ func (v *Validator) getAtxDeps(ctx context.Context, id types.ATXID) (*atxDeps, e } commitment = catx } - var previous types.ATXID - if len(atx.PreviousATXs) != 0 { - previous = atx.PreviousATXs[0] - } deps := &atxDeps{ - nipost: types.NIPost{ - Post: wire.PostFromWireV1(&atx.NiPosts[0].Posts[0].Post), - PostMetadata: &types.PostMetadata{ - Challenge: atx.NiPosts[0].Challenge[:], - LabelsPerUnit: v.cfg.LabelsPerUnit, - }, - }, positioning: atx.PositioningATX, - previous: previous, + previous: atx.PreviousATXs, commitment: commitment, } + for _, nipost := range atx.NiPosts { + for _, post := range nipost.Posts { + deps.niposts = append(deps.niposts, types.NIPost{ + Post: wire.PostFromWireV1(&post.Post), + PostMetadata: &types.PostMetadata{ + Challenge: nipost.Challenge[:], + LabelsPerUnit: v.cfg.LabelsPerUnit, + }, + }) + } + } return deps, nil } @@ -511,12 +513,11 @@ func (v *Validator) verifyChainWithOpts( if err != nil { return fmt.Errorf("get atx: %w", err) } - if atx.Golden() { - log.Debug("not verifying ATX chain", zap.Stringer("atx_id", id), zap.String("reason", "golden")) - return nil - } switch { + case atx.Golden(): + log.Debug("not verifying ATX chain", zap.Stringer("atx_id", id), zap.String("reason", "golden")) + return nil case atx.Validity() == types.Valid: log.Debug("not verifying ATX chain", zap.Stringer("atx_id", id), zap.String("reason", "already verified")) return nil @@ -542,20 +543,21 @@ func (v *Validator) verifyChainWithOpts( if err != nil { return fmt.Errorf("getting ATX dependencies: %w", err) } - - if err := v.Post( - ctx, - atx.SmesherID, - deps.commitment, - deps.nipost.Post, - deps.nipost.PostMetadata, - atx.NumUnits, - []validatorOption{PrioritizeCall()}..., - ); err != nil { - if err := atxs.SetValidity(v.db, id, types.Invalid); err != nil { - log.Warn("failed to persist atx validity", zap.Error(err), zap.Stringer("atx_id", id)) + for _, nipost := range deps.niposts { + if err := v.Post( + ctx, + atx.SmesherID, + deps.commitment, + nipost.Post, + nipost.PostMetadata, + atx.NumUnits, + []validatorOption{PrioritizeCall()}..., + ); err != nil { + if err := atxs.SetValidity(v.db, id, types.Invalid); err != nil { + log.Warn("failed to persist atx validity", zap.Error(err), zap.Stringer("atx_id", id)) + } + return &InvalidChainError{ID: id, src: err} } - return &InvalidChainError{ID: id, src: err} } err = v.verifyChainDeps(ctx, deps, goldenATXID, opts) @@ -579,9 +581,9 @@ func (v *Validator) verifyChainDeps( goldenATXID types.ATXID, opts verifyChainOpts, ) error { - if deps.previous != types.EmptyATXID { - if err := v.verifyChainWithOpts(ctx, deps.previous, goldenATXID, opts); err != nil { - return fmt.Errorf("validating previous ATX %s chain: %w", deps.previous.ShortString(), err) + for _, prev := range deps.previous { + if err := v.verifyChainWithOpts(ctx, prev, goldenATXID, opts); err != nil { + return fmt.Errorf("validating previous ATX %s chain: %w", prev.ShortString(), err) } } if deps.positioning != goldenATXID { @@ -591,7 +593,7 @@ func (v *Validator) verifyChainDeps( } // verify commitment only if arrived at the first ATX in the chain // to avoid verifying the same commitment ATX multiple times. - if deps.previous == types.EmptyATXID && deps.commitment != goldenATXID { + if len(deps.previous) == 0 && deps.commitment != goldenATXID { if err := v.verifyChainWithOpts(ctx, deps.commitment, goldenATXID, opts); err != nil { return fmt.Errorf("validating commitment ATX %s chain: %w", deps.commitment.ShortString(), err) } diff --git a/activation/validation_test.go b/activation/validation_test.go index a96c977f6e..6ab64fd5c7 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -616,6 +616,49 @@ func TestVerifyChainDeps(t *testing.T) { err = validator.VerifyChain(ctx, watx.ID(), goldenATXID) require.NoError(t, err) }) + t.Run("merged ATX", func(t *testing.T) { + initialAtx := newInitialATXv1(t, goldenATXID) + initialAtx.Sign(signer) + require.NoError(t, atxs.Add(db, toAtx(t, initialAtx), initialAtx.Blob())) + + // second ID for the merged ATX + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + initialAtx2 := newInitialATXv1(t, goldenATXID) + initialAtx2.Sign(otherSig) + require.NoError(t, atxs.Add(db, toAtx(t, initialAtx2), initialAtx2.Blob())) + + watx := newSoloATXv2(t, initialAtx.PublishEpoch+1, initialAtx.ID(), initialAtx.ID()) + watx.NiPosts[0].Posts = append(watx.NiPosts[0].Posts, wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + Post: wire.PostV1{ + Nonce: 99, + Pow: 55, + Indices: types.RandomBytes(33), + }, + NumUnits: 77, + }) + watx.PreviousATXs = append(watx.PreviousATXs, initialAtx2.ID()) + watx.Sign(signer) + atx := &types.ActivationTx{ + PublishEpoch: watx.PublishEpoch, + SmesherID: watx.SmesherID, + } + atx.SetID(watx.ID()) + require.NoError(t, atxs.Add(db, atx, watx.Blob())) + + v := NewMockPostVerifier(gomock.NewController(t)) + expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post)) + expectedPost2 := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[1].Post)) + v.EXPECT().Verify(ctx, (*shared.Proof)(initialAtx.NIPost.Post), gomock.Any(), gomock.Any()) + v.EXPECT().Verify(ctx, (*shared.Proof)(initialAtx2.NIPost.Post), gomock.Any(), gomock.Any()) + v.EXPECT().Verify(ctx, expectedPost, gomock.Any(), gomock.Any()) + v.EXPECT().Verify(ctx, expectedPost2, gomock.Any(), gomock.Any()) + validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v) + err = validator.VerifyChain(ctx, watx.ID(), goldenATXID) + require.NoError(t, err) + }) } func TestVerifyChainDepsAfterCheckpoint(t *testing.T) { diff --git a/activation/wire/wire_v1.go b/activation/wire/wire_v1.go index d76e343ab0..79daf37f66 100644 --- a/activation/wire/wire_v1.go +++ b/activation/wire/wire_v1.go @@ -196,7 +196,6 @@ func ActivationTxFromWireV1(atx *ActivationTxV1) *types.ActivationTx { result := &types.ActivationTx{ PublishEpoch: atx.PublishEpoch, Sequence: atx.Sequence, - PrevATXID: atx.PrevATXID, CommitmentATX: atx.CommitmentATXID, Coinbase: atx.Coinbase, NumUnits: atx.NumUnits, diff --git a/api/grpcserver/activation_service.go b/api/grpcserver/activation_service.go index 5e7bf3c303..4ba005a970 100644 --- a/api/grpcserver/activation_service.go +++ b/api/grpcserver/activation_service.go @@ -63,6 +63,15 @@ func (s *activationService) Get(ctx context.Context, request *pb.GetRequest) (*p ) return nil, status.Error(codes.NotFound, "id was not found") } + prev, err := s.atxProvider.Previous(atxId) + if err != nil { + ctxzap.Error(ctx, "failed to get previous ATX", + zap.Stringer("id", atxId), + zap.Error(err), + ) + return nil, status.Error(codes.Internal, "couldn't get previous ATXs") + } + proof, err := s.atxProvider.GetMalfeasanceProof(atx.SmesherID) if err != nil && !errors.Is(err, sql.ErrNotFound) { ctxzap.Error(ctx, "failed to get malfeasance proof", @@ -74,7 +83,7 @@ func (s *activationService) Get(ctx context.Context, request *pb.GetRequest) (*p return nil, status.Error(codes.NotFound, "id was not found") } resp := &pb.GetResponse{ - Atx: convertActivation(atx), + Atx: convertActivation(atx, prev), } if proof != nil { resp.MalfeasanceProof = events.ToMalfeasancePB(atx.SmesherID, proof, false) @@ -95,7 +104,16 @@ func (s *activationService) Highest(ctx context.Context, req *emptypb.Empty) (*p if err != nil || atx == nil { return nil, status.Error(codes.NotFound, fmt.Sprintf("atx id %v not found: %v", highest, err.Error())) } + prev, err := s.atxProvider.Previous(highest) + if err != nil { + ctxzap.Error(ctx, "failed to get previous ATX", + zap.Stringer("id", highest), + zap.Error(err), + ) + return nil, status.Error(codes.Internal, "couldn't get previous ATXs") + } + return &pb.HighestResponse{ - Atx: convertActivation(atx), + Atx: convertActivation(atx, prev), }, nil } diff --git a/api/grpcserver/activation_service_test.go b/api/grpcserver/activation_service_test.go index 2cf5ad0a3c..eac99561f6 100644 --- a/api/grpcserver/activation_service_test.go +++ b/api/grpcserver/activation_service_test.go @@ -32,7 +32,7 @@ func Test_Highest_ReturnsGoldenAtxOnError(t *testing.T) { require.Nil(t, response.Atx.Layer) require.Nil(t, response.Atx.SmesherId) require.Nil(t, response.Atx.Coinbase) - require.Nil(t, response.Atx.PrevAtx) + require.Nil(t, response.Atx.PrevAtx) // nolint:staticcheck // SA1019 (deprecated) require.EqualValues(t, 0, response.Atx.NumUnits) require.EqualValues(t, 0, response.Atx.Sequence) } @@ -43,9 +43,9 @@ func Test_Highest_ReturnsMaxTickHeight(t *testing.T) { goldenAtx := types.ATXID{2, 3, 4} activationService := grpcserver.NewActivationService(atxProvider, goldenAtx) + previous := types.RandomATXID() atx := types.ActivationTx{ Sequence: rand.Uint64(), - PrevATXID: types.RandomATXID(), PublishEpoch: 0, Coinbase: types.GenerateAddress(types.RandomBytes(32)), NumUnits: rand.Uint32(), @@ -54,6 +54,7 @@ func Test_Highest_ReturnsMaxTickHeight(t *testing.T) { atx.SetID(id) atxProvider.EXPECT().MaxHeightAtx().Return(id, nil) atxProvider.EXPECT().GetAtx(id).Return(&atx, nil) + atxProvider.EXPECT().Previous(id).Return([]types.ATXID{previous}, nil) response, err := activationService.Highest(context.Background(), &emptypb.Empty{}) require.NoError(t, err) @@ -61,7 +62,7 @@ func Test_Highest_ReturnsMaxTickHeight(t *testing.T) { require.Equal(t, atx.PublishEpoch.Uint32(), response.Atx.Layer.Number) require.Equal(t, atx.SmesherID.Bytes(), response.Atx.SmesherId.Id) require.Equal(t, atx.Coinbase.String(), response.Atx.Coinbase.Address) - require.Equal(t, atx.PrevATXID.Bytes(), response.Atx.PrevAtx.Id) + require.Equal(t, previous.Bytes(), response.Atx.PrevAtx.Id) // nolint:staticcheck // SA1019 (deprecated) require.Equal(t, atx.NumUnits, response.Atx.NumUnits) require.Equal(t, atx.Sequence, response.Atx.Sequence) } @@ -102,15 +103,29 @@ func TestGet_AtxProviderReturnsFailure(t *testing.T) { require.Equal(t, codes.NotFound, status.Code(err)) } +func TestGet_AtxProviderFailsObtainPreviousAtxs(t *testing.T) { + ctrl := gomock.NewController(t) + atxProvider := grpcserver.NewMockatxProvider(ctrl) + activationService := grpcserver.NewActivationService(atxProvider, types.ATXID{1}) + + id := types.RandomATXID() + atxProvider.EXPECT().GetAtx(id).Return(&types.ActivationTx{}, nil) + atxProvider.EXPECT().Previous(id).Return(nil, errors.New("")) + + _, err := activationService.Get(context.Background(), &pb.GetRequest{Id: id.Bytes()}) + require.Error(t, err) + require.Equal(t, codes.Internal, status.Code(err)) +} + func TestGet_HappyPath(t *testing.T) { ctrl := gomock.NewController(t) atxProvider := grpcserver.NewMockatxProvider(ctrl) activationService := grpcserver.NewActivationService(atxProvider, types.ATXID{1}) + previous := []types.ATXID{types.RandomATXID(), types.RandomATXID()} id := types.RandomATXID() atx := types.ActivationTx{ Sequence: rand.Uint64(), - PrevATXID: types.RandomATXID(), PublishEpoch: 0, Coinbase: types.GenerateAddress(types.RandomBytes(32)), NumUnits: rand.Uint32(), @@ -118,6 +133,7 @@ func TestGet_HappyPath(t *testing.T) { atx.SetID(id) atxProvider.EXPECT().GetAtx(id).Return(&atx, nil) atxProvider.EXPECT().GetMalfeasanceProof(gomock.Any()).Return(nil, sql.ErrNotFound) + atxProvider.EXPECT().Previous(id).Return(previous, nil) response, err := activationService.Get(context.Background(), &pb.GetRequest{Id: id.Bytes()}) require.NoError(t, err) @@ -126,7 +142,10 @@ func TestGet_HappyPath(t *testing.T) { require.Equal(t, atx.PublishEpoch.Uint32(), response.Atx.Layer.Number) require.Equal(t, atx.SmesherID.Bytes(), response.Atx.SmesherId.Id) require.Equal(t, atx.Coinbase.String(), response.Atx.Coinbase.Address) - require.Equal(t, atx.PrevATXID.Bytes(), response.Atx.PrevAtx.Id) + require.Equal(t, previous[0].Bytes(), response.Atx.PrevAtx.Id) // nolint:staticcheck // SA1019 (deprecated) + require.Len(t, response.Atx.PreviousAtxs, 2) + require.Equal(t, previous[0].Bytes(), response.Atx.PreviousAtxs[0].Id) + require.Equal(t, previous[1].Bytes(), response.Atx.PreviousAtxs[1].Id) require.Equal(t, atx.NumUnits, response.Atx.NumUnits) require.Equal(t, atx.Sequence, response.Atx.Sequence) require.Nil(t, response.MalfeasanceProof) @@ -138,10 +157,10 @@ func TestGet_IdentityCanceled(t *testing.T) { activationService := grpcserver.NewActivationService(atxProvider, types.ATXID{1}) smesher, proof := grpcserver.BallotMalfeasance(t, sql.InMemory()) + previous := types.RandomATXID() id := types.RandomATXID() atx := types.ActivationTx{ Sequence: rand.Uint64(), - PrevATXID: types.RandomATXID(), PublishEpoch: 0, Coinbase: types.GenerateAddress(types.RandomBytes(32)), NumUnits: rand.Uint32(), @@ -150,6 +169,7 @@ func TestGet_IdentityCanceled(t *testing.T) { atx.SetID(id) atxProvider.EXPECT().GetAtx(id).Return(&atx, nil) atxProvider.EXPECT().GetMalfeasanceProof(smesher).Return(proof, nil) + atxProvider.EXPECT().Previous(id).Return([]types.ATXID{previous}, nil) response, err := activationService.Get(context.Background(), &pb.GetRequest{Id: id.Bytes()}) require.NoError(t, err) @@ -158,7 +178,9 @@ func TestGet_IdentityCanceled(t *testing.T) { require.Equal(t, atx.PublishEpoch.Uint32(), response.Atx.Layer.Number) require.Equal(t, atx.SmesherID.Bytes(), response.Atx.SmesherId.Id) require.Equal(t, atx.Coinbase.String(), response.Atx.Coinbase.Address) - require.Equal(t, atx.PrevATXID.Bytes(), response.Atx.PrevAtx.Id) + require.Equal(t, previous.Bytes(), response.Atx.PrevAtx.Id) // nolint:staticcheck // SA1019 (deprecated) + require.Len(t, response.Atx.PreviousAtxs, 1) + require.Equal(t, previous.Bytes(), response.Atx.PreviousAtxs[0].Id) require.Equal(t, atx.NumUnits, response.Atx.NumUnits) require.Equal(t, atx.Sequence, response.Atx.Sequence) require.Equal(t, events.ToMalfeasancePB(smesher, proof, false), response.MalfeasanceProof) diff --git a/api/grpcserver/admin_service_test.go b/api/grpcserver/admin_service_test.go index 6e3ae789d8..cdd087c303 100644 --- a/api/grpcserver/admin_service_test.go +++ b/api/grpcserver/admin_service_test.go @@ -38,7 +38,7 @@ func newAtx(tb testing.TB, db *sql.Database) { atx.SmesherID = types.BytesToNodeID(types.RandomBytes(20)) atx.SetReceived(time.Now().Local()) require.NoError(tb, atxs.Add(db, atx, types.AtxBlob{})) - require.NoError(tb, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) + require.NoError(tb, atxs.SetPost(db, atx.ID(), types.EmptyATXID, 0, atx.SmesherID, atx.NumUnits)) } func createMesh(tb testing.TB, db *sql.Database) { diff --git a/api/grpcserver/grpcserver_test.go b/api/grpcserver/grpcserver_test.go index 08a282ad2a..b8862f1bd2 100644 --- a/api/grpcserver/grpcserver_test.go +++ b/api/grpcserver/grpcserver_test.go @@ -160,7 +160,6 @@ func TestMain(m *testing.M) { globalAtx = &types.ActivationTx{ PublishEpoch: postGenesisEpoch, Sequence: 1, - PrevATXID: types.ATXID{4, 4, 4, 4}, Coinbase: addr1, NumUnits: numUnits, Weight: numUnits, @@ -172,7 +171,6 @@ func TestMain(m *testing.M) { globalAtx2 = &types.ActivationTx{ PublishEpoch: postGenesisEpoch, Sequence: 1, - PrevATXID: types.ATXID{5, 5, 5, 5}, Coinbase: addr2, NumUnits: numUnits, Weight: numUnits, diff --git a/api/grpcserver/interface.go b/api/grpcserver/interface.go index 47a009a47c..bfca33b78b 100644 --- a/api/grpcserver/interface.go +++ b/api/grpcserver/interface.go @@ -56,6 +56,7 @@ type txValidator interface { // atxProvider is used by ActivationService to get ATXes. type atxProvider interface { GetAtx(id types.ATXID) (*types.ActivationTx, error) + Previous(id types.ATXID) ([]types.ATXID, error) MaxHeightAtx() (types.ATXID, error) GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof, error) } diff --git a/api/grpcserver/mesh_service.go b/api/grpcserver/mesh_service.go index 9e95a8b82f..6d8b23ffd2 100644 --- a/api/grpcserver/mesh_service.go +++ b/api/grpcserver/mesh_service.go @@ -252,16 +252,25 @@ func castTransaction(t *types.Transaction) *pb.Transaction { return tx } -func convertActivation(a *types.ActivationTx) *pb.Activation { - return &pb.Activation{ +func convertActivation(a *types.ActivationTx, previous []types.ATXID) *pb.Activation { + atx := &pb.Activation{ Id: &pb.ActivationId{Id: a.ID().Bytes()}, Layer: &pb.LayerNumber{Number: a.PublishEpoch.Uint32()}, SmesherId: &pb.SmesherId{Id: a.SmesherID.Bytes()}, Coinbase: &pb.AccountId{Address: a.Coinbase.String()}, - PrevAtx: &pb.ActivationId{Id: a.PrevATXID.Bytes()}, NumUnits: uint32(a.NumUnits), Sequence: a.Sequence, } + + if len(previous) == 0 { + previous = []types.ATXID{types.EmptyATXID} + } + // nolint:staticcheck // SA1019 (deprecated) + atx.PrevAtx = &pb.ActivationId{Id: previous[0].Bytes()} + for _, prev := range previous { + atx.PreviousAtxs = append(atx.PreviousAtxs, &pb.ActivationId{Id: prev.Bytes()}) + } + return atx } func (s *MeshService) readLayer( @@ -440,10 +449,14 @@ func (s *MeshService) AccountMeshDataStream( activation := activationEvent.ActivationTx // Apply address filter if activation.Coinbase == addr { + previous, err := s.cdb.Previous(activation.ID()) + if err != nil { + return status.Error(codes.Internal, "getting previous ATXs failed") + } resp := &pb.AccountMeshDataStreamResponse{ Datum: &pb.AccountMeshData{ Datum: &pb.AccountMeshData_Activation{ - Activation: convertActivation(activation), + Activation: convertActivation(activation, previous), }, }, } diff --git a/api/grpcserver/mocks.go b/api/grpcserver/mocks.go index ab849fa95c..1dab284f2b 100644 --- a/api/grpcserver/mocks.go +++ b/api/grpcserver/mocks.go @@ -990,6 +990,45 @@ func (c *MockatxProviderMaxHeightAtxCall) DoAndReturn(f func() (types.ATXID, err return c } +// Previous mocks base method. +func (m *MockatxProvider) Previous(id types.ATXID) ([]types.ATXID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Previous", id) + ret0, _ := ret[0].([]types.ATXID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Previous indicates an expected call of Previous. +func (mr *MockatxProviderMockRecorder) Previous(id any) *MockatxProviderPreviousCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Previous", reflect.TypeOf((*MockatxProvider)(nil).Previous), id) + return &MockatxProviderPreviousCall{Call: call} +} + +// MockatxProviderPreviousCall wrap *gomock.Call +type MockatxProviderPreviousCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockatxProviderPreviousCall) Return(arg0 []types.ATXID, arg1 error) *MockatxProviderPreviousCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockatxProviderPreviousCall) Do(f func(types.ATXID) ([]types.ATXID, error)) *MockatxProviderPreviousCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockatxProviderPreviousCall) DoAndReturn(f func(types.ATXID) ([]types.ATXID, error)) *MockatxProviderPreviousCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // MockpostState is a mock of postState interface. type MockpostState struct { ctrl *gomock.Controller diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index 472f62a8be..fc1a33eda3 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -29,52 +29,57 @@ func TestMain(m *testing.M) { os.Exit(res) } +type activationTx struct { + *types.ActivationTx + previous types.ATXID +} + type miner struct { - atxs []*types.ActivationTx + atxs []activationTx malfeasanceProof []byte } var allMiners = []miner{ // smesher 1 has 7 ATXs, one in each epoch from 1 to 7 { - atxs: []*types.ActivationTx{ - newAtx(types.ATXID{17}, types.ATXID{16}, nil, 7, 6, 123, []byte("smesher1")), - newAtx(types.ATXID{16}, types.ATXID{15}, nil, 6, 5, 123, []byte("smesher1")), - newAtx(types.ATXID{15}, types.ATXID{14}, nil, 5, 4, 123, []byte("smesher1")), - newAtx(types.ATXID{14}, types.ATXID{13}, nil, 4, 3, 123, []byte("smesher1")), - newAtx(types.ATXID{13}, types.ATXID{12}, nil, 3, 2, 123, []byte("smesher1")), - newAtx(types.ATXID{12}, types.ATXID{11}, nil, 2, 1, 123, []byte("smesher1")), - newAtx(types.ATXID{11}, types.EmptyATXID, &types.ATXID{1}, 1, 0, 123, []byte("smesher1")), + atxs: []activationTx{ + {newAtx(types.ATXID{17}, nil, 7, 6, 123, []byte("smesher1")), types.ATXID{16}}, + {newAtx(types.ATXID{16}, nil, 6, 5, 123, []byte("smesher1")), types.ATXID{15}}, + {newAtx(types.ATXID{15}, nil, 5, 4, 123, []byte("smesher1")), types.ATXID{14}}, + {newAtx(types.ATXID{14}, nil, 4, 3, 123, []byte("smesher1")), types.ATXID{13}}, + {newAtx(types.ATXID{13}, nil, 3, 2, 123, []byte("smesher1")), types.ATXID{12}}, + {newAtx(types.ATXID{12}, nil, 2, 1, 123, []byte("smesher1")), types.ATXID{11}}, + {newAtx(types.ATXID{11}, &types.ATXID{1}, 1, 0, 123, []byte("smesher1")), types.EmptyATXID}, }, }, // smesher 2 has 1 ATX in epoch 7 { - atxs: []*types.ActivationTx{ - newAtx(types.ATXID{27}, types.EmptyATXID, &types.ATXID{2}, 7, 0, 152, []byte("smesher2")), + atxs: []activationTx{ + {newAtx(types.ATXID{27}, &types.ATXID{2}, 7, 0, 152, []byte("smesher2")), types.EmptyATXID}, }, }, // smesher 3 has 1 ATX in epoch 2 { - atxs: []*types.ActivationTx{ - newAtx(types.ATXID{32}, types.EmptyATXID, &types.ATXID{3}, 2, 0, 211, []byte("smesher3")), + atxs: []activationTx{ + {newAtx(types.ATXID{32}, &types.ATXID{3}, 2, 0, 211, []byte("smesher3")), types.EmptyATXID}, }, }, // smesher 4 has 1 ATX in epoch 3 and one in epoch 7 { - atxs: []*types.ActivationTx{ - newAtx(types.ATXID{47}, types.ATXID{43}, nil, 7, 1, 420, []byte("smesher4")), - newAtx(types.ATXID{43}, types.EmptyATXID, &types.ATXID{4}, 4, 0, 420, []byte("smesher4")), + atxs: []activationTx{ + {newAtx(types.ATXID{47}, nil, 7, 1, 420, []byte("smesher4")), types.ATXID{43}}, + {newAtx(types.ATXID{43}, &types.ATXID{4}, 4, 0, 420, []byte("smesher4")), types.EmptyATXID}, }, }, // smesher 5 is malicious and equivocated in epoch 7 { - atxs: []*types.ActivationTx{ - newAtx(types.ATXID{83}, types.EmptyATXID, &types.ATXID{27}, 7, 0, 113, []byte("smesher5")), - newAtx(types.ATXID{97}, types.EmptyATXID, &types.ATXID{16}, 7, 0, 113, []byte("smesher5")), + atxs: []activationTx{ + {newAtx(types.ATXID{83}, &types.ATXID{27}, 7, 0, 113, []byte("smesher5")), types.EmptyATXID}, + {newAtx(types.ATXID{97}, &types.ATXID{16}, 7, 0, 113, []byte("smesher5")), types.EmptyATXID}, }, malfeasanceProof: []byte("im bad"), }, @@ -181,7 +186,7 @@ func expectedCheckpoint(t testing.TB, snapshot types.LayerID, numAtxs int, miner for i := 0; i < n; i++ { atxData = append( atxData, - asAtxSnapshot(atxs[i], atxs[len(atxs)-1].CommitmentATX), + asAtxSnapshot(atxs[i].ActivationTx, atxs[len(atxs)-1].CommitmentATX), ) } } @@ -215,7 +220,7 @@ func expectedCheckpoint(t testing.TB, snapshot types.LayerID, numAtxs int, miner } func newAtx( - id, prevID types.ATXID, + id types.ATXID, commitAtx *types.ATXID, epoch uint32, seq, vrfnonce uint64, @@ -225,7 +230,6 @@ func newAtx( PublishEpoch: types.EpochID(epoch), Sequence: seq, CommitmentATX: commitAtx, - PrevATXID: prevID, NumUnits: 2, Coinbase: types.Address{1, 2, 3}, TickCount: 1, @@ -262,8 +266,8 @@ func createMesh(t testing.TB, db *sql.Database, miners []miner, accts []*types.A t.Helper() for _, miner := range miners { for _, atx := range miner.atxs { - require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) - require.NoError(t, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) + require.NoError(t, atxs.Add(db, atx.ActivationTx, types.AtxBlob{})) + require.NoError(t, atxs.SetPost(db, atx.ID(), atx.previous, 0, atx.SmesherID, atx.NumUnits)) } if proof := miner.malfeasanceProof; len(proof) > 0 { require.NoError(t, identities.SetMalicious(db, miner.atxs[0].SmesherID, proof, time.Now())) @@ -344,8 +348,8 @@ func TestRunner_Generate_Error(t *testing.T) { db := sql.InMemory() snapshot := types.LayerID(5) - atx := newAtx(types.ATXID{13}, types.EmptyATXID, nil, 2, 1, 11, types.RandomNodeID().Bytes()) - createMesh(t, db, []miner{{atxs: []*types.ActivationTx{atx}}}, allAccounts) + atx := newAtx(types.ATXID{13}, nil, 2, 1, 11, types.RandomNodeID().Bytes()) + createMesh(t, db, []miner{{atxs: []activationTx{{atx, types.EmptyATXID}}}}, allAccounts) fs := afero.NewMemMapFs() dir, err := afero.TempDir(fs, "", "Generate") @@ -395,7 +399,7 @@ func TestRunner_Generate_PreservesMarriageATX(t *testing.T) { } atx.SetID(types.RandomATXID()) require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) - require.NoError(t, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) + require.NoError(t, atxs.SetPost(db, atx.ID(), types.EmptyATXID, 0, atx.SmesherID, atx.NumUnits)) fs := afero.NewMemMapFs() dir, err := afero.TempDir(fs, "", "Generate") diff --git a/common/types/activation.go b/common/types/activation.go index 4a1da34c15..10a585e80d 100644 --- a/common/types/activation.go +++ b/common/types/activation.go @@ -174,8 +174,6 @@ type ActivationTx struct { // Two ATXs with the same sequence number from the same miner can be used as the proof of malfeasance against // that miner. Sequence uint64 - // the previous ATX's ID (for all but the first in the sequence) - PrevATXID ATXID // CommitmentATX is the ATX used in the commitment for initializing the PoST of the node. CommitmentATX *ATXID @@ -228,7 +226,6 @@ func (atx *ActivationTx) MarshalLogObject(encoder log.ObjectEncoder) error { encoder.AddString("atx_id", atx.id.String()) encoder.AddString("smesher", atx.SmesherID.String()) encoder.AddUint32("publish_epoch", atx.PublishEpoch.Uint32()) - encoder.AddString("prev_atx_id", atx.PrevATXID.String()) if atx.CommitmentATX != nil { encoder.AddString("commitment_atx_id", atx.CommitmentATX.String()) diff --git a/datastore/store.go b/datastore/store.go index 506117f505..7a2cccb416 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -200,6 +200,11 @@ func (db *CachedDB) GetAtx(id types.ATXID) (*types.ActivationTx, error) { return atx, nil } +// Previous retrieves the list of previous ATXs for the given ATX ID. +func (db *CachedDB) Previous(id types.ATXID) ([]types.ATXID, error) { + return atxs.Previous(db, id) +} + func (db *CachedDB) IterateMalfeasanceProofs( iter func(types.NodeID, *wire.MalfeasanceProof) error, ) error { diff --git a/go.mod b/go.mod index f40a85a1c0..d4b4248087 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/seehuhn/mt19937 v1.0.0 github.com/slok/go-http-metrics v0.12.0 - github.com/spacemeshos/api/release/go v1.51.0 + github.com/spacemeshos/api/release/go v1.52.0 github.com/spacemeshos/economics v0.1.3 github.com/spacemeshos/fixed v0.1.1 github.com/spacemeshos/go-scale v1.2.0 diff --git a/go.sum b/go.sum index 9c4255fc40..394e1c8c26 100644 --- a/go.sum +++ b/go.sum @@ -602,8 +602,8 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= -github.com/spacemeshos/api/release/go v1.51.0 h1:MSKRIUiXBAoDrj2Lj24q9g52ZaSIC3I0UH/Y0Oaz95o= -github.com/spacemeshos/api/release/go v1.51.0/go.mod h1:Qr/pVPMmN5Q5qLHSXqVMDKDCu6LkHWzGPNflylE0u00= +github.com/spacemeshos/api/release/go v1.52.0 h1:3cohOoFIk0RLF5fdL0y6pFgZ7Ngg1Yht+aeN3Xm5Qn8= +github.com/spacemeshos/api/release/go v1.52.0/go.mod h1:Qr/pVPMmN5Q5qLHSXqVMDKDCu6LkHWzGPNflylE0u00= github.com/spacemeshos/economics v0.1.3 h1:ACkq3mTebIky4Zwbs9SeSSRZrUCjU/Zk0wq9Z0BTh2A= github.com/spacemeshos/economics v0.1.3/go.mod h1:FH7u0FzTIm6Kpk+X5HOZDvpkgNYBKclmH86rVwYaDAo= github.com/spacemeshos/fixed v0.1.1 h1:N1y4SUpq1EV+IdJrWJwUCt1oBFzeru/VKVcBsvPc2Fk= diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 5e14cddde1..361a2da573 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -22,7 +22,7 @@ const ( // filters that refer to the id column. const fieldsQuery = `select atxs.id, atxs.nonce, atxs.base_tick_height, atxs.tick_count, atxs.pubkey, atxs.effective_num_units, -atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.prev_id, atxs.commitment_atx, atxs.weight, +atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.commitment_atx, atxs.weight, atxs.marriage_atx` const fullQuery = fieldsQuery + ` from atxs` @@ -56,16 +56,13 @@ func decoder(fn decoderCallback) sql.Decoder { stmt.ColumnBytes(9, a.Coinbase[:]) a.SetValidity(types.Validity(stmt.ColumnInt(10))) if stmt.ColumnType(11) != sqlite.SQLITE_NULL { - stmt.ColumnBytes(11, a.PrevATXID[:]) - } - if stmt.ColumnType(12) != sqlite.SQLITE_NULL { a.CommitmentATX = new(types.ATXID) - stmt.ColumnBytes(12, a.CommitmentATX[:]) + stmt.ColumnBytes(11, a.CommitmentATX[:]) } - a.Weight = uint64(stmt.ColumnInt64(13)) - if stmt.ColumnType(14) != sqlite.SQLITE_NULL { + a.Weight = uint64(stmt.ColumnInt64(12)) + if stmt.ColumnType(13) != sqlite.SQLITE_NULL { a.MarriageATX = new(types.ATXID) - stmt.ColumnBytes(14, a.MarriageATX[:]) + stmt.ColumnBytes(13, a.MarriageATX[:]) } return fn(&a) @@ -404,6 +401,40 @@ func getBlob(ctx context.Context, db sql.Executor, id []byte, blob *sql.Blob) (t return version, nil } +// Previous gets all previous ATXs for a given ATX ID. +func Previous(db sql.Executor, id types.ATXID) ([]types.ATXID, error) { + var previous []types.ATXID + enc := func(stmt *sql.Statement) { + stmt.BindBytes(1, id.Bytes()) + } + dec := func(stmt *sql.Statement) bool { + var prev types.ATXID + if stmt.ColumnType(0) != sqlite.SQLITE_NULL { + stmt.ColumnBytes(0, prev[:]) + } + // Index is returned in descending order, so the first one defines the length of the slice. + index := stmt.ColumnInt(1) + if previous == nil { + previous = make([]types.ATXID, index+1) + } + previous[index] = prev + return true + } + + rows, err := db.Exec( + "SELECT prev_atxid, prev_atx_index FROM posts WHERE atxid = ?1 ORDER BY prev_atx_index DESC;", + enc, + dec, + ) + switch { + case err != nil: + return nil, fmt.Errorf("previous ATXs for ATX ID %v: %w", id, err) + case rows == 0: + return nil, sql.ErrNotFound + } + return previous, nil +} + // NonceByID retrieves VRFNonce corresponding to the specified ATX ID. func NonceByID(db sql.Executor, id types.ATXID) (nonce types.VRFPostIndex, err error) { enc := func(stmt *sql.Statement) { @@ -439,20 +470,17 @@ func Add(db sql.Executor, atx *types.ActivationTx, blob types.AtxBlob) error { stmt.BindInt64(10, int64(atx.Sequence)) stmt.BindBytes(11, atx.Coinbase.Bytes()) stmt.BindInt64(12, int64(atx.Validity())) - if atx.PrevATXID != types.EmptyATXID { - stmt.BindBytes(13, atx.PrevATXID.Bytes()) - } - stmt.BindInt64(14, int64(atx.Weight)) + stmt.BindInt64(13, int64(atx.Weight)) if atx.MarriageATX != nil { - stmt.BindBytes(15, atx.MarriageATX.Bytes()) + stmt.BindBytes(14, atx.MarriageATX.Bytes()) } } _, err := db.Exec(` insert into atxs (id, epoch, effective_num_units, commitment_atx, nonce, pubkey, received, base_tick_height, tick_count, sequence, coinbase, - validity, prev_id, weight, marriage_atx) - values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)`, enc, nil) + validity, weight, marriage_atx) + values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)`, enc, nil) if err != nil { return fmt.Errorf("insert ATX ID %v: %w", atx.ID(), err) } @@ -466,6 +494,7 @@ func AddBlob(db sql.Executor, id types.ATXID, blob []byte, version types.AtxVers stmt.BindBytes(2, blob) stmt.BindInt64(3, int64(version)) } + _, err := db.Exec("insert into atx_blobs (id, atx, version) values (?1, ?2, ?3)", enc, nil) if err != nil { return fmt.Errorf("insert ATX blob %v: %w", id, err) @@ -647,7 +676,8 @@ func AddCheckpointed(db sql.Executor, catx *CheckpointAtx) error { } for id, units := range catx.Units { - if err := SetUnits(db, catx.ID, id, units); err != nil { + // FIXME: should a checkpointed ATX reference its real previous ATX? + if err := SetPost(db, catx.ID, types.EmptyATXID, 0, id, units); err != nil { return fmt.Errorf("insert checkpoint ATX units %v: %w", catx.ID, err) } } @@ -816,7 +846,7 @@ func IterateAtxsWithMalfeasance( func(s *sql.Statement) { s.BindInt64(1, int64(publish)) }, func(s *sql.Statement) bool { return decoder(func(atx *types.ActivationTx) bool { - return fn(atx, s.ColumnInt(15) != 0) + return fn(atx, s.ColumnInt(14) != 0) })(s) }, ) @@ -875,10 +905,10 @@ func PrevATXCollisions(db sql.Executor) ([]PrevATXCollision, error) { // we are joining the table with itself to find ATXs with the same prevATX // the WHERE clause ensures that we only get the pairs once if _, err := db.Exec(` - SELECT t1.pubkey, t2.pubkey, t1.id, t2.id - FROM atxs t1 - INNER JOIN atxs t2 ON t1.prev_id = t2.prev_id - WHERE t1.id < t2.id;`, nil, dec); err != nil { + SELECT p1.pubkey, p2.pubkey, p1.atxid, p2.atxid + FROM posts p1 + INNER JOIN posts p2 ON p1.prev_atxid = p2.prev_atxid + WHERE p1.atxid < p2.atxid;`, nil, dec); err != nil { return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err) } @@ -959,13 +989,17 @@ func AllUnits(db sql.Executor, id types.ATXID) (map[types.NodeID]uint32, error) return units, nil } -func SetUnits(db sql.Executor, atxID types.ATXID, id types.NodeID, units uint32) error { +func SetPost(db sql.Executor, atxID, prev types.ATXID, prevIndex int, id types.NodeID, units uint32) error { _, err := db.Exec( - `INSERT INTO posts (atxid, pubkey, units) VALUES (?1, ?2, ?3);`, + `INSERT INTO posts (atxid, pubkey, prev_atxid, prev_atx_index, units) VALUES (?1, ?2, ?3, ?4, ?5);`, func(stmt *sql.Statement) { stmt.BindBytes(1, atxID.Bytes()) stmt.BindBytes(2, id.Bytes()) - stmt.BindInt64(3, int64(units)) + if prev != types.EmptyATXID { + stmt.BindBytes(3, prev.Bytes()) + } + stmt.BindInt64(4, int64(prevIndex)) + stmt.BindInt64(5, int64(units)) }, nil, ) @@ -984,7 +1018,7 @@ func AtxWithPrevious(db sql.Executor, prev types.ATXID, id types.NodeID) (types. return false } if prev == types.EmptyATXID { - rows, err = db.Exec("SELECT id FROM atxs WHERE pubkey = ?1 AND prev_id IS NULL ORDER BY received ASC;", + rows, err = db.Exec("SELECT atxid FROM posts WHERE pubkey = ?1 AND prev_atxid IS NULL;", func(s *sql.Statement) { s.BindBytes(1, id.Bytes()) }, @@ -992,7 +1026,7 @@ func AtxWithPrevious(db sql.Executor, prev types.ATXID, id types.NodeID) (types. ) } else { rows, err = db.Exec(` - SELECT id FROM atxs WHERE pubkey = ?1 AND prev_id = ?2 ORDER BY received ASC;`, + SELECT atxid FROM posts WHERE pubkey = ?1 AND prev_atxid = ?2;`, func(s *sql.Statement) { s.BindBytes(1, id.Bytes()) s.BindBytes(2, prev.Bytes()) diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 1dd1914968..eb507fe969 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -4,10 +4,12 @@ import ( "context" "errors" "os" + "slices" "testing" "time" "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/common/fixture" @@ -159,7 +161,7 @@ func TestLatestN(t *testing.T) { for _, atx := range []*types.ActivationTx{atx1, atx2, atx3, atx4, atx5, atx6} { require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) - require.NoError(t, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) + require.NoError(t, atxs.SetPost(db, atx.ID(), types.EmptyATXID, 0, atx.SmesherID, atx.NumUnits)) } for _, tc := range []struct { @@ -522,7 +524,7 @@ func TestVRFNonce(t *testing.T) { atx1, blob := newAtx(t, sig, withPublishEpoch(20), withNonce(333)) require.NoError(t, atxs.Add(db, atx1, blob)) - atx2, blob := newAtx(t, sig, withPublishEpoch(50), withNonce(777), withPrevATXID(atx1.ID())) + atx2, blob := newAtx(t, sig, withPublishEpoch(50), withNonce(777)) require.NoError(t, atxs.Add(db, atx2, blob)) // Act & Assert @@ -814,12 +816,6 @@ func withNonce(nonce types.VRFPostIndex) createAtxOpt { } } -func withPrevATXID(id types.ATXID) createAtxOpt { - return func(atx *types.ActivationTx) { - atx.PrevATXID = id - } -} - func withCoinbase(addr types.Address) createAtxOpt { return func(atx *types.ActivationTx) { atx.Coinbase = addr @@ -1035,11 +1031,13 @@ func Test_PrevATXCollisions(t *testing.T) { // create two ATXs with the same PrevATXID prevATXID := types.RandomATXID() - atx1, blob1 := newAtx(t, sig, withPublishEpoch(1), withPrevATXID(prevATXID)) - atx2, blob2 := newAtx(t, sig, withPublishEpoch(2), withPrevATXID(prevATXID)) + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1)) + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2)) require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.SetPost(db, atx1.ID(), prevATXID, 0, sig.NodeID(), 10)) require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetPost(db, atx2.ID(), prevATXID, 0, sig.NodeID(), 10)) // verify that the ATXs were added got1, err := atxs.Get(db, atx1.ID()) @@ -1060,9 +1058,9 @@ func Test_PrevATXCollisions(t *testing.T) { atx2, blob2 := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i+1)), - withPrevATXID(atx.ID()), ) require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetPost(db, atx2.ID(), atx.ID(), 0, sig.NodeID(), 10)) } // get the collisions @@ -1121,7 +1119,7 @@ func TestUnits(t *testing.T) { t.Parallel() db := sql.InMemory() atxID := types.RandomATXID() - require.NoError(t, atxs.SetUnits(db, atxID, types.RandomNodeID(), 10)) + require.NoError(t, atxs.SetPost(db, atxID, types.EmptyATXID, 0, types.RandomNodeID(), 10)) _, err := atxs.Units(db, atxID, types.RandomNodeID()) require.ErrorIs(t, err, sql.ErrNotFound) }) @@ -1134,7 +1132,7 @@ func TestUnits(t *testing.T) { {4, 5, 6}: 20, } for id, units := range units { - require.NoError(t, atxs.SetUnits(db, atxID, id, units)) + require.NoError(t, atxs.SetPost(db, atxID, types.EmptyATXID, 0, id, units)) } nodeID := types.NodeID{1, 2, 3} @@ -1163,40 +1161,46 @@ func Test_AtxWithPrevious(t *testing.T) { t.Run("finds other ATX with same previous", func(t *testing.T) { db := sql.InMemory() + prev := types.RandomATXID() atx, blob := newAtx(t, sig) require.NoError(t, atxs.Add(db, atx, blob)) + require.NoError(t, atxs.SetPost(db, atx.ID(), prev, 0, sig.NodeID(), 10)) - id, err := atxs.AtxWithPrevious(db, atx.PrevATXID, sig.NodeID()) + id, err := atxs.AtxWithPrevious(db, prev, sig.NodeID()) require.NoError(t, err) require.Equal(t, atx.ID(), id) }) t.Run("finds other ATX with same previous (empty)", func(t *testing.T) { db := sql.InMemory() - atx, blob := newAtx(t, sig, withPrevATXID(types.EmptyATXID)) + atx, blob := newAtx(t, sig) require.NoError(t, atxs.Add(db, atx, blob)) + require.NoError(t, atxs.SetPost(db, atx.ID(), types.EmptyATXID, 0, sig.NodeID(), 10)) - id, err := atxs.AtxWithPrevious(db, atx.PrevATXID, sig.NodeID()) + id, err := atxs.AtxWithPrevious(db, types.EmptyATXID, sig.NodeID()) require.NoError(t, err) require.Equal(t, atx.ID(), id) }) - t.Run("filters out by node ID", func(t *testing.T) { + t.Run("same previous used by 2 IDs in two ATXs", func(t *testing.T) { db := sql.InMemory() sig2, err := signing.NewEdSigner() require.NoError(t, err) + prev := types.RandomATXID() - atx, blob := newAtx(t, sig, withPrevATXID(types.EmptyATXID)) + atx, blob := newAtx(t, sig) require.NoError(t, atxs.Add(db, atx, blob)) + require.NoError(t, atxs.SetPost(db, atx.ID(), prev, 0, sig.NodeID(), 10)) - atx2, blob := newAtx(t, sig2, withPrevATXID(types.EmptyATXID)) + atx2, blob := newAtx(t, sig2) require.NoError(t, atxs.Add(db, atx2, blob)) + require.NoError(t, atxs.SetPost(db, atx2.ID(), prev, 0, sig2.NodeID(), 10)) - id, err := atxs.AtxWithPrevious(db, atx.PrevATXID, sig.NodeID()) + id, err := atxs.AtxWithPrevious(db, prev, sig.NodeID()) require.NoError(t, err) require.Equal(t, atx.ID(), id) - id, err = atxs.AtxWithPrevious(db, atx.PrevATXID, sig2.NodeID()) + id, err = atxs.AtxWithPrevious(db, prev, sig2.NodeID()) require.NoError(t, err) require.Equal(t, atx2.ID(), id) }) @@ -1220,7 +1224,7 @@ func Test_FindDoublePublish(t *testing.T) { // one atx atx0, blob := newAtx(t, sig, withPublishEpoch(1)) require.NoError(t, atxs.Add(db, atx0, blob)) - require.NoError(t, atxs.SetUnits(db, atx0.ID(), atx0.SmesherID, 10)) + require.NoError(t, atxs.SetPost(db, atx0.ID(), types.EmptyATXID, 0, atx0.SmesherID, 10)) _, err = atxs.FindDoublePublish(db, atx0.SmesherID, atx0.PublishEpoch) require.ErrorIs(t, err, sql.ErrNotFound) @@ -1228,7 +1232,7 @@ func Test_FindDoublePublish(t *testing.T) { // two atxs in different epochs atx1, blob := newAtx(t, sig, withPublishEpoch(atx0.PublishEpoch+1)) require.NoError(t, atxs.Add(db, atx1, blob)) - require.NoError(t, atxs.SetUnits(db, atx1.ID(), atx0.SmesherID, 10)) + require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, atx0.SmesherID, 10)) _, err = atxs.FindDoublePublish(db, atx0.SmesherID, atx0.PublishEpoch) require.ErrorIs(t, err, sql.ErrNotFound) @@ -1239,11 +1243,11 @@ func Test_FindDoublePublish(t *testing.T) { atx0, blob := newAtx(t, sig) require.NoError(t, atxs.Add(db, atx0, blob)) - require.NoError(t, atxs.SetUnits(db, atx0.ID(), atx0.SmesherID, 10)) + require.NoError(t, atxs.SetPost(db, atx0.ID(), types.EmptyATXID, 0, atx0.SmesherID, 10)) atx1, blob := newAtx(t, sig) require.NoError(t, atxs.Add(db, atx1, blob)) - require.NoError(t, atxs.SetUnits(db, atx1.ID(), atx0.SmesherID, 10)) + require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, atx0.SmesherID, 10)) atxids, err := atxs.FindDoublePublish(db, atx0.SmesherID, atx0.PublishEpoch) require.NoError(t, err) @@ -1262,16 +1266,16 @@ func Test_FindDoublePublish(t *testing.T) { atx0, blob := newAtx(t, atx0Signer) require.NoError(t, atxs.Add(db, atx0, blob)) - require.NoError(t, atxs.SetUnits(db, atx0.ID(), atx0.SmesherID, 10)) - require.NoError(t, atxs.SetUnits(db, atx0.ID(), sig.NodeID(), 10)) + require.NoError(t, atxs.SetPost(db, atx0.ID(), types.EmptyATXID, 0, atx0.SmesherID, 10)) + require.NoError(t, atxs.SetPost(db, atx0.ID(), types.EmptyATXID, 0, sig.NodeID(), 10)) atx1Signer, err := signing.NewEdSigner() require.NoError(t, err) atx1, blob := newAtx(t, atx1Signer) require.NoError(t, atxs.Add(db, atx1, blob)) - require.NoError(t, atxs.SetUnits(db, atx1.ID(), atx1.SmesherID, 10)) - require.NoError(t, atxs.SetUnits(db, atx1.ID(), sig.NodeID(), 10)) + require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, atx1.SmesherID, 10)) + require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, sig.NodeID(), 10)) atxIDs, err := atxs.FindDoublePublish(db, sig.NodeID(), atx0.PublishEpoch) require.NoError(t, err) @@ -1330,3 +1334,31 @@ func Test_MergeConflict(t *testing.T) { require.Len(t, ids, 2) }) } + +func Test_Previous(t *testing.T) { + t.Run("not found", func(t *testing.T) { + db := sql.InMemoryTest(t) + _, err := atxs.Previous(db, types.RandomATXID()) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("returns ATXs in order", func(t *testing.T) { + db := sql.InMemoryTest(t) + + atx := types.RandomATXID() + var previousAtxs []types.ATXID + // 10 previous ATXs + for range 10 { + previousAtxs = append(previousAtxs, types.RandomATXID()) + } + // used by 50 IDs randomly + for range 50 { + prev := previousAtxs[rand.Intn(len(previousAtxs))] + index := slices.Index(previousAtxs, prev) + require.NoError(t, atxs.SetPost(db, atx, prev, index, types.RandomNodeID(), 10)) + } + + got, err := atxs.Previous(db, atx) + require.NoError(t, err) + require.Equal(t, previousAtxs, got) + }) +} diff --git a/sql/ballots/ballots_test.go b/sql/ballots/ballots_test.go index 901161d0f4..519dbf80aa 100644 --- a/sql/ballots/ballots_test.go +++ b/sql/ballots/ballots_test.go @@ -147,7 +147,6 @@ func TestLayerBallotBySmesher(t *testing.T) { func newAtx(signer *signing.EdSigner, layerID types.LayerID) *types.ActivationTx { atx := &types.ActivationTx{ PublishEpoch: layerID.GetEpoch(), - PrevATXID: types.RandomATXID(), NumUnits: 2, TickCount: 1, SmesherID: signer.NodeID(), diff --git a/sql/migrations/state/0021_atx_posts.sql b/sql/migrations/state/0021_atx_posts.sql index 25ec2e2ca5..a009bd0655 100644 --- a/sql/migrations/state/0021_atx_posts.sql +++ b/sql/migrations/state/0021_atx_posts.sql @@ -1,9 +1,14 @@ --- Table showing the exact number of PoST units commited by smesher in given ATX. +-- Table showing the PoST commitment by a smesher in given ATX. +-- It shows the exact number of space units committed and the previous ATX id. CREATE TABLE posts ( - atxid CHAR(32) NOT NULL, - pubkey CHAR(32) NOT NULL, - units INT NOT NULL, + atxid CHAR(32) NOT NULL, + pubkey CHAR(32) NOT NULL, + prev_atxid CHAR(32), + prev_atx_index INT, + units INT NOT NULL, UNIQUE (atxid, pubkey) ); -CREATE INDEX posts_by_atxid_by_pubkey ON posts (atxid, pubkey); +CREATE INDEX posts_by_atxid_by_pubkey ON posts (atxid, pubkey, prev_atxid); + +ALTER TABLE atxs DROP COLUMN prev_id; diff --git a/sql/migrations/state_0021_migration.go b/sql/migrations/state_0021_migration.go index 8a98145e57..221289788b 100644 --- a/sql/migrations/state_0021_migration.go +++ b/sql/migrations/state_0021_migration.go @@ -38,7 +38,7 @@ func (*migration0021) Rollback() error { } func (m *migration0021) Apply(db sql.Executor) error { - if err := m.createTable(db); err != nil { + if err := m.applySql(db); err != nil { return err } var total int @@ -66,10 +66,12 @@ func (m *migration0021) Apply(db sql.Executor) error { } } -func (m *migration0021) createTable(db sql.Executor) error { +func (m *migration0021) applySql(db sql.Executor) error { query := `CREATE TABLE posts ( atxid CHAR(32) NOT NULL, pubkey CHAR(32) NOT NULL, + prev_atxid CHAR(32), + prev_atx_index INT, units INT NOT NULL, UNIQUE (atxid, pubkey) );` @@ -78,16 +80,23 @@ func (m *migration0021) createTable(db sql.Executor) error { return fmt.Errorf("creating posts table: %w", err) } - query = "CREATE INDEX posts_by_atxid_by_pubkey ON posts (atxid, pubkey);" + query = "CREATE INDEX posts_by_atxid_by_pubkey ON posts (atxid, pubkey, prev_atxid);" _, err = db.Exec(query, nil, nil) if err != nil { return fmt.Errorf("creating index `posts_by_atxid_by_pubkey`: %w", err) } + + query = "ALTER TABLE atxs DROP COLUMN prev_id;" + _, err = db.Exec(query, nil, nil) + if err != nil { + return fmt.Errorf("dropping column `prev_id` from `atxs`: %w", err) + } return nil } type update struct { id types.NodeID + prev types.ATXID units uint32 } @@ -135,7 +144,7 @@ func (m *migration0021) processBatch(db sql.Executor, offset, size int) (int, er func (m *migration0021) applyPendingUpdates(db sql.Executor, updates map[types.ATXID]*update) error { for atxID, upd := range updates { - if err := atxs.SetUnits(db, atxID, upd.id, upd.units); err != nil { + if err := atxs.SetPost(db, atxID, upd.prev, 0, upd.id, upd.units); err != nil { return err } } @@ -153,7 +162,7 @@ func processATX(blob types.AtxBlob) (*update, error) { if err := codec.Decode(blob.Blob, &watx); err != nil { return nil, fmt.Errorf("decoding ATX V1: %w", err) } - return &update{watx.SmesherID, watx.NumUnits}, nil + return &update{watx.SmesherID, watx.PrevATXID, watx.NumUnits}, nil default: return nil, fmt.Errorf("unsupported ATX version: %d", blob.Version) }