From 4aa7f1a12bf651bcdd4355cba5ea479614603f2f Mon Sep 17 00:00:00 2001 From: Matthias Fasching <5011972+fasmat@users.noreply.github.com> Date: Tue, 4 Feb 2025 18:38:31 +0000 Subject: [PATCH] Malfeasance2 fetcher and sync (#6652) ## Motivation This adds the necessary fetcher code for malfeasance v2 so that the node can sync new malfeasance proofs. Closes https://github.com/spacemeshos/go-spacemesh/issues/6689 --- activation/handler_v1.go | 21 +- activation/handler_v2.go | 154 ++++--- activation/handler_v2_test.go | 433 +++++++++++++----- activation/interface.go | 2 +- activation/malfeasance2.go | 12 +- activation/malfeasance2_test.go | 56 ++- activation/mocks.go | 12 +- activation/nipost.go | 3 +- activation/wire/interface.go | 20 + activation/wire/malfeasance_double_marry.go | 9 + .../wire/malfeasance_double_marry_test.go | 41 ++ activation/wire/malfeasance_double_merge.go | 27 +- .../wire/malfeasance_double_merge_test.go | 74 +++ activation/wire/malfeasance_invalid_post.go | 4 + .../wire/malfeasance_invalid_prev_atx.go | 13 + .../wire/malfeasance_invalid_prev_atx_test.go | 37 ++ activation/wire/mocks.go | 77 ++++ api/grpcserver/v2alpha1/malfeasance.go | 6 +- api/grpcserver/v2alpha1/malfeasance_test.go | 12 +- api/grpcserver/v2beta1/malfeasance.go | 6 +- api/grpcserver/v2beta1/malfeasance_test.go | 12 +- checkpoint/runner.go | 7 +- checkpoint/runner_test.go | 36 +- config/logging.go | 2 + config/presets/fastnet.go | 2 +- datastore/mocks.go | 81 ++++ datastore/store.go | 137 ++++-- datastore/store_test.go | 128 +++--- fetch/fetch.go | 66 ++- fetch/fetch_test.go | 84 ++-- fetch/handler.go | 88 +++- fetch/handler_test.go | 68 ++- fetch/mesh_data.go | 87 +++- fetch/mesh_data_test.go | 77 +++- fetch/p2p_test.go | 94 +++- go.mod | 3 +- go.sum | 10 +- malfeasance/handler.go | 8 +- malfeasance/handler_test.go | 22 +- malfeasance2/handler.go | 64 ++- malfeasance2/handler_test.go | 52 --- malfeasance2/publisher.go | 225 ++++++--- malfeasance2/publisher_test.go | 213 ++++++++- mesh/mesh.go | 2 + mesh/mesh_test.go | 3 - node/node.go | 52 +-- node/node_test.go | 8 +- p2p/server/deadline_adjuster.go | 3 +- p2p/server/server.go | 6 +- sql/malfeasance/malfeasance.go | 15 + sql/malfeasance/malfeasance_test.go | 47 ++ sql/malsync/malsync.go | 43 +- sql/malsync/malsync_test.go | 25 +- sql/marriage/marriages.go | 12 +- sql/marriage/marriages_test.go | 45 ++ syncer/interface.go | 2 +- syncer/malsync/mocks/mocks.go | 125 ++++- syncer/malsync/syncer.go | 316 +++++++++++-- syncer/malsync/syncer_test.go | 295 +++++++++--- syncer/mocks/mocks.go | 114 ++--- syncer/syncer.go | 11 +- syncer/syncer_test.go | 6 + systest/Makefile | 2 +- systest/cluster/nodes.go | 33 +- systest/parameters/fastnet/smesher.json | 11 +- systest/tests/common.go | 7 +- .../distributed_post_verification_test.go | 361 ++++++++++----- systest/tests/equivocation_test.go | 7 +- tortoise/model/core.go | 7 +- 69 files changed, 3050 insertions(+), 1093 deletions(-) create mode 100644 datastore/mocks.go diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 653727bc1d..f59b4071d1 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -27,6 +27,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/system" ) @@ -222,12 +223,12 @@ func (h *HandlerV1) syntacticallyValidateDeps( watx.NumUnits, PostSubset([]byte(h.local)), // use the local peer ID as seed for random subset ) - var invalidIdx *verifying.ErrInvalidIndex - if errors.As(err, &invalidIdx) { + var errInvalidIdx *verifying.ErrInvalidIndex + if errors.As(err, &errInvalidIdx) { h.logger.Debug("ATX with invalid post index", log.ZContext(ctx), zap.Stringer("atx_id", watx.ID()), - zap.Int("index", invalidIdx.Index), + zap.Int("index", errInvalidIdx.Index), ) malicious, err := identities.IsMalicious(h.cdb, watx.SmesherID) if err != nil { @@ -236,13 +237,20 @@ func (h *HandlerV1) syntacticallyValidateDeps( if malicious { return nil, fmt.Errorf("smesher %s is known malfeasant", watx.SmesherID.ShortString()) } + malicious, err = malfeasance.IsMalicious(h.cdb, watx.SmesherID) + if err != nil { + return nil, fmt.Errorf("check if smesher is malicious: %w", err) + } + if malicious { + return nil, fmt.Errorf("smesher %s is known malfeasant", watx.SmesherID.ShortString()) + } proof := &mwire.MalfeasanceProof{ Layer: watx.PublishEpoch.FirstLayer(), Proof: mwire.Proof{ Type: mwire.InvalidPostIndex, Data: &mwire.InvalidPostIndexProof{ Atx: *watx, - InvalidIdx: uint32(invalidIdx.Index), + InvalidIdx: uint32(errInvalidIdx.Index), }, }, } @@ -489,6 +497,11 @@ func (h *HandlerV1) storeAtx(ctx context.Context, atx *types.ActivationTx, watx if err != nil { return fmt.Errorf("check if node is malicious: %w", err) } + malicious2, err := malfeasance.IsMalicious(tx, atx.SmesherID) + if err != nil { + return fmt.Errorf("check if node is malicious: %w", err) + } + malicious = malicious || malicious2 if !malicious { malicious, err = h.checkMalicious(ctx, tx, watx) if err != nil { diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 74574fae4d..a16969379f 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -679,10 +679,15 @@ func (h *HandlerV2) validatePost( if err == nil { return nil } - errInvalid := &verifying.ErrInvalidIndex{} - if !errors.As(err, &errInvalid) { + errInvalidIdx := &verifying.ErrInvalidIndex{} + if !errors.As(err, &errInvalidIdx) { return fmt.Errorf("validating post for ID %s: %w", nodeID.ShortString(), err) } + h.logger.Debug("ATX with invalid post index", + log.ZContext(ctx), + zap.Stringer("atx_id", atx.ID()), + zap.Int("index", errInvalidIdx.Index), + ) // check if post contains at least one valid label validIdx := 0 @@ -715,7 +720,7 @@ func (h *HandlerV2) validatePost( commitment, nodeID, nipostIndex, - uint32(errInvalid.Index), + uint32(errInvalidIdx.Index), uint32(validIdx), ) if err != nil { @@ -724,40 +729,35 @@ func (h *HandlerV2) validatePost( if err := h.malPublisher.Publish(ctx, nodeID, proof); err != nil { return fmt.Errorf("publishing malfeasance proof for invalid post: %w", err) } - return fmt.Errorf("invalid post for ID %s: %w", nodeID.ShortString(), errInvalid) + return fmt.Errorf("invalid post for ID %s: %w", nodeID.ShortString(), errInvalidIdx) } -func (h *HandlerV2) checkMalicious(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { - malicious, err := malfeasance.IsMalicious(tx, atx.SmesherID) - if err != nil { - return malicious, fmt.Errorf("checking if node is malicious: %w", err) - } - if malicious { - return true, nil - } - - malicious, err = h.checkDoubleMarry(ctx, tx, atx) +func (h *HandlerV2) checkMalicious( + ctx context.Context, + tx sql.Transaction, + watx *activationTx, +) (wire.Proof, types.NodeID, error) { + proof, nodeID, err := h.checkDoubleMarry(ctx, tx, watx) if err != nil { - return malicious, fmt.Errorf("checking double marry: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("checking double marry: %w", err) } - if malicious { - return true, nil + if proof != nil { + return proof, nodeID, nil } - malicious, err = h.checkDoubleMerge(ctx, tx, atx) + proof, nodeID, err = h.checkDoubleMerge(ctx, tx, watx) if err != nil { - return malicious, fmt.Errorf("checking double merge: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("checking double merge: %w", err) } - if malicious { - return true, nil + if proof != nil { + return proof, nodeID, nil } - malicious, err = h.checkPrevAtx(ctx, tx, atx) + proof, nodeID, err = h.checkPrevAtx(ctx, tx, watx) if err != nil { - return malicious, fmt.Errorf("checking previous ATX: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("checking previous ATX: %w", err) } - - return malicious, err + return proof, nodeID, nil } func (h *HandlerV2) fetchWireAtx( @@ -778,11 +778,15 @@ func (h *HandlerV2) fetchWireAtx( return atx, nil } -func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { +func (h *HandlerV2) checkDoubleMarry( + ctx context.Context, + tx sql.Transaction, + atx *activationTx, +) (wire.Proof, types.NodeID, error) { for _, m := range atx.marriages { info, err := marriage.FindByNodeID(tx, m.id) if err != nil { - return false, fmt.Errorf("checking if ID is married: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("checking if ID is married: %w", err) } if info.ATX == atx.ID() { continue @@ -795,28 +799,32 @@ func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx sql.Transaction, at zap.Stringer("atx_id", info.ATX), ) case err != nil: - return false, fmt.Errorf("fetching other ATX: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("fetching other ATX: %w", err) } proof, err := wire.NewDoubleMarryProof(tx, atx.ActivationTxV2, otherAtx, m.id) if err != nil { - return true, fmt.Errorf("creating double marry proof: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("creating double marry proof: %w", err) } - return true, h.malPublisher.Publish(ctx, m.id, proof) + return proof, m.id, nil } - return false, nil + return nil, types.EmptyNodeID, nil } -func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { +func (h *HandlerV2) checkDoubleMerge( + ctx context.Context, + tx sql.Transaction, + atx *activationTx, +) (wire.Proof, types.NodeID, error) { if atx.MarriageATX == nil { - return false, nil + return nil, types.EmptyNodeID, nil } ids, err := atxs.MergeConflict(tx, *atx.MarriageATX, atx.PublishEpoch) switch { case errors.Is(err, sql.ErrNotFound): - return false, nil + return nil, types.EmptyNodeID, nil case err != nil: - return false, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err) } otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != atx.ID() }) other := ids[otherIndex] @@ -836,7 +844,7 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx sql.Transaction, at // see https://github.com/spacemeshos/go-spacemesh/issues/6434 otherAtx, err := h.fetchWireAtx(ctx, tx, other) if err != nil { - return false, fmt.Errorf("fetching other ATX: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("fetching other ATX: %w", err) } // TODO(mafa): checkpoints need to include all marriage ATXs in full to be able to create malfeasance proofs @@ -845,16 +853,20 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx sql.Transaction, at // see https://github.com/spacemeshos/go-spacemesh/issues/6435 proof, err := wire.NewDoubleMergeProof(tx, atx.ActivationTxV2, otherAtx) if err != nil { - return true, fmt.Errorf("creating double merge proof: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("creating double merge proof: %w", err) } - return true, h.malPublisher.Publish(ctx, atx.ActivationTxV2.SmesherID, proof) + return proof, atx.ActivationTxV2.SmesherID, nil } -func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { +func (h *HandlerV2) checkPrevAtx( + ctx context.Context, + tx sql.Transaction, + atx *activationTx, +) (wire.Proof, types.NodeID, error) { for id, data := range atx.ids { expectedPrevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch) if err != nil && !errors.Is(err, sql.ErrNotFound) { - return false, fmt.Errorf("get last atx by node id: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("get last atx by node id: %w", err) } if expectedPrevID == data.previous { continue @@ -871,7 +883,7 @@ func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx sql.Transaction, atx *a case errors.Is(err, sql.ErrNotFound): continue case err != nil: - return true, fmt.Errorf("checking for previous ATX collision: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("checking for previous ATX collision: %w", err) } var wireAtxV1 *wire.ActivationTxV1 @@ -882,7 +894,7 @@ func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx sql.Transaction, atx *a var blob sql.Blob v, err := atxs.LoadBlob(ctx, tx, collision.Bytes(), &blob) if err != nil { - return true, fmt.Errorf("get atx blob %s: %w", id.ShortString(), err) + return nil, types.EmptyNodeID, fmt.Errorf("get atx blob %s: %w", id.ShortString(), err) } switch v { case types.AtxV1: @@ -903,9 +915,9 @@ func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx sql.Transaction, atx *a ) proof, err := wire.NewInvalidPrevAtxProofV2(tx, atx.ActivationTxV2, wireAtx, id) if err != nil { - return true, fmt.Errorf("creating invalid previous ATX proof: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("creating invalid previous ATX proof: %w", err) } - return true, h.malPublisher.Publish(ctx, id, proof) + return proof, id, nil default: h.logger.Fatal("Failed to create invalid previous ATX proof: unknown ATX version", zap.Stringer("atx_id", collision), @@ -921,16 +933,19 @@ func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx sql.Transaction, atx *a ) proof, err := wire.NewInvalidPrevAtxProofV1(tx, atx.ActivationTxV2, wireAtxV1, id) if err != nil { - return true, fmt.Errorf("creating invalid previous ATX proof: %w", err) + return nil, types.EmptyNodeID, fmt.Errorf("creating invalid previous ATX proof: %w", err) } - return true, h.malPublisher.Publish(ctx, id, proof) + return proof, id, nil } - return false, nil + return nil, types.EmptyNodeID, nil } // Store an ATX in the DB. func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx *activationTx) error { republishProof := false + malicious := false + var proof wire.Proof + var nodeID types.NodeID if err := h.cdb.WithTxImmediate(ctx, func(tx sql.Transaction) error { if len(watx.marriages) != 0 { newMarriageID, err := marriage.NewID(tx) @@ -942,7 +957,6 @@ func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx ATX: atx.ID(), Target: atx.SmesherID, } - malicious := false marriageIDs := make(map[marriage.ID]struct{}, 1) marriageIDs[newMarriageID] = struct{}{} for i, m := range watx.marriages { @@ -1009,34 +1023,44 @@ func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx return fmt.Errorf("setting atx units for ID %s: %w", id, err) } } - return nil - }); err != nil { - return fmt.Errorf("store atx: %w", err) - } - malicious := false - err := h.cdb.WithTxImmediate(ctx, func(tx sql.Transaction) error { + if malicious || republishProof { + return nil + } + // malfeasance check happens after storing the ATX because storing updates the marriage set // that is needed for the malfeasance proof // // TODO(mafa): don't store own ATX if it would mark the node as malicious // this probably needs to be done by validating and storing own ATXs eagerly and skipping validation in // the gossip handler (not sync!) - if republishProof { - malicious = true - return h.malPublisher.Regossip(ctx, atx.SmesherID) - } - - var err error - malicious, err = h.checkMalicious(ctx, tx, watx) + proof, nodeID, err = h.checkMalicious(ctx, tx, watx) return err - }) - if err != nil { - return fmt.Errorf("check malicious: %w", err) + }); err != nil { + return fmt.Errorf("store atx: %w", err) + } + + switch { + case republishProof: // marriage set of known malicious smesher has changed, force re-gossip of proof + if err := h.malPublisher.Regossip(ctx, watx.SmesherID); err != nil { + h.logger.Error("failed to regossip malfeasance proof", + zap.Stringer("atx_id", watx.ID()), + zap.Stringer("smesher_id", watx.SmesherID), + zap.Error(err), + ) + } + case proof != nil: // new malfeasance proof for identity created, publish proof (gossip is decided by publisher) + if err := h.malPublisher.Publish(ctx, nodeID, proof); err != nil { + h.logger.Error("failed to publish malfeasance proof", + zap.Stringer("atx_id", watx.ID()), + zap.Stringer("smesher_id", watx.SmesherID), + zap.Error(err), + ) + } } h.beacon.OnAtx(atx) - if added := h.atxsdata.AddFromAtx(atx, malicious); added != nil { + if added := h.atxsdata.AddFromAtx(atx, malicious || proof != nil); added != nil { h.tortoise.OnAtx(atx.TargetEpoch(), atx.ID(), added) } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index f373b4ef77..7cbe9681d0 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -15,7 +15,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/atxsdata" @@ -35,7 +38,8 @@ import ( type v2TestHandler struct { *HandlerV2 - tb testing.TB + tb testing.TB + observedLogs *observer.ObservedLogs handlerMocks } @@ -50,8 +54,13 @@ const ( ) func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { - lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(statesql.InMemoryTest(tb), lg) + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + cdb := datastore.NewCachedDB(statesql.InMemoryTest(tb), logger) tb.Cleanup(func() { assert.NoError(tb, cdb.Close()) }) mocks := newTestHandlerMocks(tb, golden) return &v2TestHandler{ @@ -64,13 +73,14 @@ func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { tickSize: tickSize, goldenATXID: golden, nipostValidator: mocks.mValidator, - logger: lg, + logger: logger, fetcher: mocks.mockFetch, beacon: mocks.mBeacon, tortoise: mocks.mTortoise, malPublisher: mocks.mMalPublish, }, tb: tb, + observedLogs: observedLogs, handlerMocks: mocks, } } @@ -991,6 +1001,8 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return atxHandler.edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().IdentityExists(sig.NodeID()).Return(true, nil).AnyTimes() + verifier.EXPECT().IdentityExists(signers[2].NodeID()).Return(true, nil).AnyTimes() atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) atxHandler.mMalPublish.EXPECT().Publish( @@ -1938,150 +1950,329 @@ func Test_Marriages(t *testing.T) { }) t.Run("can't marry twice (separate marriages)", func(t *testing.T) { t.Parallel() - atxHandler := newV2TestHandler(t, golden) + t.Run("publish succeeds", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) - otherSig, err := signing.NewEdSigner() - require.NoError(t, err) - atx, _ := marryIDs(t, atxHandler, []*signing.EdSigner{sig, otherSig}, golden) + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + atx, _ := marryIDs(t, atxHandler, []*signing.EdSigner{sig, otherSig}, golden) - // otherSig2 cannot marry sig, trying to extend its set. - otherSig2, err := signing.NewEdSigner() - require.NoError(t, err) - others2Atx := atxHandler.createAndProcessInitial(otherSig2) - atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) - atx2.Marriages = []wire.MarriageCertificate{ - { - Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - { - ReferenceAtx: others2Atx.ID(), - Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - } - atx2.Sign(sig) - atxHandler.expectAtxV2(atx2) + // otherSig2 cannot marry sig, trying to extend its set. + otherSig2, err := signing.NewEdSigner() + require.NoError(t, err) + others2Atx := atxHandler.createAndProcessInitial(otherSig2) + atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) + atx2.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: others2Atx.ID(), + Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx2.Sign(sig) + atxHandler.expectAtxV2(atx2) - verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return atxHandler.edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() + verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHandler.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + verifier.EXPECT().IdentityExists(sig.NodeID()).Return(true, nil).AnyTimes() - atxHandler.mMalPublish.EXPECT().Publish( - gomock.Any(), - sig.NodeID(), - gomock.AssignableToTypeOf(&wire.ProofDoubleMarry{}), - ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { - malProof := proof.(*wire.ProofDoubleMarry) - nId, err := malProof.Valid(ctx, verifier) + atxHandler.mMalPublish.EXPECT().Publish( + gomock.Any(), + sig.NodeID(), + gomock.AssignableToTypeOf(&wire.ProofDoubleMarry{}), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofDoubleMarry) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), nId) + return nil + }) + err = atxHandler.processATX(context.Background(), "", atx2, time.Now()) require.NoError(t, err) - require.Equal(t, sig.NodeID(), nId) - return nil + + // The equivocation set of sig and otherSig were merged + id, err := marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + equiv, err := marriage.NodeIDsByID(atxHandler.cdb, id) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()}, equiv) }) - err = atxHandler.processATX(context.Background(), "", atx2, time.Now()) - require.NoError(t, err) - // The equivocation set of sig and otherSig were merged - id, err := marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) - require.NoError(t, err) - equiv, err := marriage.NodeIDsByID(atxHandler.cdb, id) - require.NoError(t, err) - require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()}, equiv) + t.Run("publish fails", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + atx, _ := marryIDs(t, atxHandler, []*signing.EdSigner{sig, otherSig}, golden) + + // otherSig2 cannot marry sig, trying to extend its set. + otherSig2, err := signing.NewEdSigner() + require.NoError(t, err) + others2Atx := atxHandler.createAndProcessInitial(otherSig2) + atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) + atx2.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: others2Atx.ID(), + Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx2.Sign(sig) + atxHandler.expectAtxV2(atx2) + + verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHandler.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + verifier.EXPECT().IdentityExists(sig.NodeID()).Return(true, nil).AnyTimes() + + atxHandler.mMalPublish.EXPECT().Publish( + gomock.Any(), + sig.NodeID(), + gomock.AssignableToTypeOf(&wire.ProofDoubleMarry{}), + ).Return(errors.New("publish failed")) + err = atxHandler.processATX(context.Background(), "", atx2, time.Now()) + require.NoError(t, err) + + // The equivocation set of sig and otherSig were merged + id, err := marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + equiv, err := marriage.NodeIDsByID(atxHandler.cdb, id) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()}, equiv) + + observedLogs := atxHandler.observedLogs.FilterLevelExact(zapcore.ErrorLevel) + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") + require.Equal(t, zapcore.ErrorLevel, observedLogs.All()[0].Level) + require.Equal(t, "failed to publish malfeasance proof", observedLogs.All()[0].Message) + require.Equal(t, sig.NodeID().String(), observedLogs.All()[0].ContextMap()["smesher_id"]) + require.Equal(t, atx2.ID().ShortString(), observedLogs.All()[0].ContextMap()["atx_id"]) + require.Equal(t, "publish failed", observedLogs.All()[0].ContextMap()["error"]) + }) }) t.Run("marring existing malicious equivocation set: mark all malicious and regossip proof", func(t *testing.T) { t.Parallel() - atxHandler := newV2TestHandler(t, golden) + t.Run("regossip succeeds", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) - otherSig, err := signing.NewEdSigner() - require.NoError(t, err) - atx, _ := marryIDs(t, atxHandler, []*signing.EdSigner{sig, otherSig}, golden) + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + atx, _ := marryIDs(t, atxHandler, []*signing.EdSigner{sig, otherSig}, golden) - // sig becomes malicious in some way and with it otherSig - id, err := marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) - require.NoError(t, err) - require.NoError(t, malfeasance.AddProof(atxHandler.cdb, sig.NodeID(), &id, []byte("proof"), 0, time.Now())) - require.NoError(t, malfeasance.SetMalicious(atxHandler.cdb, otherSig.NodeID(), id, time.Now())) + // sig becomes malicious in some way and with it otherSig + id, err := marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.NoError(t, malfeasance.AddProof(atxHandler.cdb, sig.NodeID(), &id, []byte("proof"), 0, time.Now())) + require.NoError(t, malfeasance.SetMalicious(atxHandler.cdb, otherSig.NodeID(), id, time.Now())) - // otherSig2 cannot marry sig, trying to extend its set. - otherSig2, err := signing.NewEdSigner() - require.NoError(t, err) - others2Atx := atxHandler.createAndProcessInitial(otherSig2) - atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) - atx2.Marriages = []wire.MarriageCertificate{ - { - Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - { - ReferenceAtx: others2Atx.ID(), - Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - } - atx2.Sign(sig) - atxHandler.expectAtxV2(atx2) + // otherSig2 cannot marry sig, trying to extend its set. + otherSig2, err := signing.NewEdSigner() + require.NoError(t, err) + others2Atx := atxHandler.createAndProcessInitial(otherSig2) + atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) + atx2.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: others2Atx.ID(), + Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx2.Sign(sig) + atxHandler.expectAtxV2(atx2) - atxHandler.mMalPublish.EXPECT().Regossip(gomock.Any(), sig.NodeID()) - err = atxHandler.processATX(context.Background(), "", atx2, time.Now()) - require.NoError(t, err) + atxHandler.mMalPublish.EXPECT().Regossip(gomock.Any(), sig.NodeID()) + err = atxHandler.processATX(context.Background(), "", atx2, time.Now()) + require.NoError(t, err) - // The equivocation set of sig and otherSig were merged - id, err = marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) - require.NoError(t, err) - equiv, err := marriage.NodeIDsByID(atxHandler.cdb, id) - require.NoError(t, err) - require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()}, equiv) + // The equivocation set of sig and otherSig were merged + id, err = marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + equiv, err := marriage.NodeIDsByID(atxHandler.cdb, id) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()}, equiv) - for _, sig := range []*signing.EdSigner{sig, otherSig, otherSig2} { - m, err := malfeasance.IsMalicious(atxHandler.cdb, sig.NodeID()) + for _, sig := range []*signing.EdSigner{sig, otherSig, otherSig2} { + m, err := malfeasance.IsMalicious(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.True(t, m, "expected %s to be malicious", sig) + } + }) + + t.Run("regossip fails", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + + otherSig, err := signing.NewEdSigner() require.NoError(t, err) - require.True(t, m, "expected %s to be malicious", sig) - } + atx, _ := marryIDs(t, atxHandler, []*signing.EdSigner{sig, otherSig}, golden) + + // sig becomes malicious in some way and with it otherSig + id, err := marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.NoError(t, malfeasance.AddProof(atxHandler.cdb, sig.NodeID(), &id, []byte("proof"), 0, time.Now())) + require.NoError(t, malfeasance.SetMalicious(atxHandler.cdb, otherSig.NodeID(), id, time.Now())) + + // otherSig2 cannot marry sig, trying to extend its set. + otherSig2, err := signing.NewEdSigner() + require.NoError(t, err) + others2Atx := atxHandler.createAndProcessInitial(otherSig2) + atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) + atx2.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: others2Atx.ID(), + Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx2.Sign(sig) + atxHandler.expectAtxV2(atx2) + + atxHandler.mMalPublish.EXPECT().Regossip(gomock.Any(), sig.NodeID()).Return(errors.New("regossip failed")) + err = atxHandler.processATX(context.Background(), "", atx2, time.Now()) + require.NoError(t, err) + + // The equivocation set of sig and otherSig were merged + id, err = marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + equiv, err := marriage.NodeIDsByID(atxHandler.cdb, id) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()}, equiv) + + for _, sig := range []*signing.EdSigner{sig, otherSig, otherSig2} { + m, err := malfeasance.IsMalicious(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.True(t, m, "expected %s to be malicious", sig) + } + + observedLogs := atxHandler.observedLogs.FilterLevelExact(zapcore.ErrorLevel) + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") + require.Equal(t, zapcore.ErrorLevel, observedLogs.All()[0].Level) + require.Equal(t, "failed to regossip malfeasance proof", observedLogs.All()[0].Message) + require.Equal(t, sig.NodeID().String(), observedLogs.All()[0].ContextMap()["smesher_id"]) + require.Equal(t, atx2.ID().ShortString(), observedLogs.All()[0].ContextMap()["atx_id"]) + require.Equal(t, "regossip failed", observedLogs.All()[0].ContextMap()["error"]) + }) }) t.Run("malicious marring existing equivocation set: mark all malicious and regossip proof", func(t *testing.T) { t.Parallel() - atxHandler := newV2TestHandler(t, golden) + t.Run("regossip succeeds", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) - otherSig, err := signing.NewEdSigner() - require.NoError(t, err) - atx, _ := marryIDs(t, atxHandler, []*signing.EdSigner{sig, otherSig}, golden) + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + atx, _ := marryIDs(t, atxHandler, []*signing.EdSigner{sig, otherSig}, golden) - // otherSig2 cannot marry sig, trying to extend its set. - otherSig2, err := signing.NewEdSigner() - require.NoError(t, err) - others2Atx := atxHandler.createAndProcessInitial(otherSig2) + // otherSig2 cannot marry sig, trying to extend its set. + otherSig2, err := signing.NewEdSigner() + require.NoError(t, err) + others2Atx := atxHandler.createAndProcessInitial(otherSig2) - // otherSig2 becomes malicious in some way - err = malfeasance.AddProof(atxHandler.cdb, otherSig2.NodeID(), nil, []byte("proof"), 0, time.Now()) - require.NoError(t, err) + // otherSig2 becomes malicious in some way + err = malfeasance.AddProof(atxHandler.cdb, otherSig2.NodeID(), nil, []byte("proof"), 0, time.Now()) + require.NoError(t, err) - atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) - atx2.Marriages = []wire.MarriageCertificate{ - { - Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - { - ReferenceAtx: others2Atx.ID(), - Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - } - atx2.Sign(sig) - atxHandler.expectAtxV2(atx2) + atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) + atx2.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: others2Atx.ID(), + Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx2.Sign(sig) + atxHandler.expectAtxV2(atx2) - atxHandler.mMalPublish.EXPECT().Regossip(gomock.Any(), sig.NodeID()) - err = atxHandler.processATX(context.Background(), "", atx2, time.Now()) - require.NoError(t, err) + atxHandler.mMalPublish.EXPECT().Regossip(gomock.Any(), sig.NodeID()) + err = atxHandler.processATX(context.Background(), "", atx2, time.Now()) + require.NoError(t, err) - // The equivocation set of sig and otherSig were merged - id, err := marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) - require.NoError(t, err) - equiv, err := marriage.NodeIDsByID(atxHandler.cdb, id) - require.NoError(t, err) - require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()}, equiv) + // The equivocation set of sig and otherSig were merged + id, err := marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + equiv, err := marriage.NodeIDsByID(atxHandler.cdb, id) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()}, equiv) - for _, sig := range []*signing.EdSigner{sig, otherSig, otherSig2} { - m, err := malfeasance.IsMalicious(atxHandler.cdb, sig.NodeID()) + for _, sig := range []*signing.EdSigner{sig, otherSig, otherSig2} { + m, err := malfeasance.IsMalicious(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.True(t, m, "expected %s to be malicious", sig) + } + }) + + t.Run("regossip fails", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + + otherSig, err := signing.NewEdSigner() require.NoError(t, err) - require.True(t, m, "expected %s to be malicious", sig) - } + atx, _ := marryIDs(t, atxHandler, []*signing.EdSigner{sig, otherSig}, golden) + + // otherSig2 cannot marry sig, trying to extend its set. + otherSig2, err := signing.NewEdSigner() + require.NoError(t, err) + others2Atx := atxHandler.createAndProcessInitial(otherSig2) + + // otherSig2 becomes malicious in some way + err = malfeasance.AddProof(atxHandler.cdb, otherSig2.NodeID(), nil, []byte("proof"), 0, time.Now()) + require.NoError(t, err) + + atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) + atx2.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: others2Atx.ID(), + Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx2.Sign(sig) + atxHandler.expectAtxV2(atx2) + + atxHandler.mMalPublish.EXPECT().Regossip(gomock.Any(), sig.NodeID()).Return(errors.New("regossip failed")) + err = atxHandler.processATX(context.Background(), "", atx2, time.Now()) + require.NoError(t, err) + + // The equivocation set of sig and otherSig were merged + id, err := marriage.FindIDByNodeID(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + equiv, err := marriage.NodeIDsByID(atxHandler.cdb, id) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()}, equiv) + + for _, sig := range []*signing.EdSigner{sig, otherSig, otherSig2} { + m, err := malfeasance.IsMalicious(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.True(t, m, "expected %s to be malicious", sig) + } + + observedLogs := atxHandler.observedLogs.FilterLevelExact(zapcore.ErrorLevel) + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") + require.Equal(t, zapcore.ErrorLevel, observedLogs.All()[0].Level) + require.Equal(t, "failed to regossip malfeasance proof", observedLogs.All()[0].Message) + require.Equal(t, sig.NodeID().String(), observedLogs.All()[0].ContextMap()["smesher_id"]) + require.Equal(t, atx2.ID().ShortString(), observedLogs.All()[0].ContextMap()["atx_id"]) + require.Equal(t, "regossip failed", observedLogs.All()[0].ContextMap()["error"]) + }) }) t.Run("malicious marring malicious equivocation set: no proof published", func(t *testing.T) { t.Parallel() @@ -2281,6 +2472,7 @@ func TestContextual_PreviousATX(t *testing.T) { DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return atxHdlr.edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().IdentityExists(signers[1].NodeID()).Return(true, nil).AnyTimes() atxHdlr.mMalPublish.EXPECT().Publish( gomock.Any(), @@ -2436,6 +2628,7 @@ func TestContextual_PreviousATX(t *testing.T) { DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return atxHdlr.edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().IdentityExists(otherSig.NodeID()).Return(true, nil).AnyTimes() atxHdlr.mMalPublish.EXPECT().Publish( gomock.Any(), diff --git a/activation/interface.go b/activation/interface.go index a49f05496e..91fcc9f1bb 100644 --- a/activation/interface.go +++ b/activation/interface.go @@ -116,7 +116,7 @@ type atxMalfeasancePublisher interface { // and mark the associated identity as malfeasant. We do this to prevent spamming the network with proofs for identities // where most likely the network already knows they are malicious. type malfeasancePublisher interface { - PublishATXProof(ctx context.Context, nodeID types.NodeID, proof []byte) error + PublishATXProof(ctx context.Context, nodeID types.NodeID, proof []byte, allowNoRefATXs bool) error Regossip(ctx context.Context, nodeID types.NodeID) error } diff --git a/activation/malfeasance2.go b/activation/malfeasance2.go index 9c5c6135af..6390f63768 100644 --- a/activation/malfeasance2.go +++ b/activation/malfeasance2.go @@ -12,10 +12,13 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" ) type MalfeasanceHandlerV2 struct { logger *zap.Logger + db sql.StateDatabase malPublisher malfeasancePublisher edVerifier *signing.EdVerifier @@ -27,12 +30,14 @@ type MalfeasanceHandlerV2 struct { func NewMalfeasanceHandlerV2( logger *zap.Logger, + db sql.StateDatabase, malPublisher malfeasancePublisher, edVerifier *signing.EdVerifier, validator nipostValidatorV2, ) *MalfeasanceHandlerV2 { return &MalfeasanceHandlerV2{ logger: logger, + db: db, malPublisher: malPublisher, edVerifier: edVerifier, validator: validator, @@ -80,7 +85,8 @@ func (p *MalfeasanceHandlerV2) Publish(ctx context.Context, nodeID types.NodeID, Proof: codec.MustEncode(proof), } - return p.malPublisher.PublishATXProof(ctx, nodeID, codec.MustEncode(atxProof)) + p.logger.Debug("publishing ATX malfeasance proof", log.ZShortStringer("node_id", nodeID)) + return p.malPublisher.PublishATXProof(ctx, nodeID, codec.MustEncode(atxProof), proof.AllowNoRefATXs()) } func (p *MalfeasanceHandlerV2) Regossip(ctx context.Context, nodeID types.NodeID) error { @@ -158,3 +164,7 @@ func (mh *MalfeasanceHandlerV2) PostIndex( func (mh *MalfeasanceHandlerV2) Signature(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return mh.edVerifier.Verify(d, nodeID, m, sig) } + +func (mh *MalfeasanceHandlerV2) IdentityExists(nodeID types.NodeID) (bool, error) { + return atxs.IdentityExists(mh.db, nodeID) +} diff --git a/activation/malfeasance2_test.go b/activation/malfeasance2_test.go index 3e84aa5ac8..97eb008141 100644 --- a/activation/malfeasance2_test.go +++ b/activation/malfeasance2_test.go @@ -18,6 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) @@ -41,11 +42,13 @@ func newTestMalHandler(tb testing.TB) *testMalHandler { ))) ctrl := gomock.NewController(tb) + db := statesql.InMemoryTest(tb) mPublish := NewMockmalfeasancePublisher(ctrl) mValidator := NewMocknipostValidator(ctrl) handler := NewMalfeasanceHandlerV2( logger, + db, mPublish, edVerifier, mValidator, @@ -237,8 +240,9 @@ func TestPublish(t *testing.T) { nodeID := types.RandomNodeID() proof := wire.NewMockProof(th.ctrl) + proof.EXPECT().AllowNoRefATXs().Return(false) proof.EXPECT().Valid(context.Background(), th.MalfeasanceHandlerV2).Return(nodeID, nil) - proof.EXPECT().Type().Return(wire.DoubleMarry) + proof.EXPECT().Type().Return(wire.DoubleMarry).AnyTimes() proof.EXPECT().EncodeScale(gomock.Any()) atxProof := &wire.ATXProof{ @@ -247,7 +251,32 @@ func TestPublish(t *testing.T) { Proof: []byte{}, } - th.mPublish.EXPECT().PublishATXProof(context.Background(), nodeID, codec.MustEncode(atxProof)).Return(nil) + th.mPublish.EXPECT().PublishATXProof(context.Background(), nodeID, codec.MustEncode(atxProof), false) + + err := th.Publish(context.Background(), nodeID, proof) + require.NoError(t, err) + }) + + t.Run("valid invalid post proof", func(t *testing.T) { + t.Parallel() + + th := newTestMalHandler(t) + + nodeID := types.RandomNodeID() + proof := wire.NewMockProof(th.ctrl) + + proof.EXPECT().AllowNoRefATXs().Return(true) + proof.EXPECT().Valid(context.Background(), th.MalfeasanceHandlerV2).Return(nodeID, nil) + proof.EXPECT().Type().Return(wire.InvalidPost).AnyTimes() + proof.EXPECT().EncodeScale(gomock.Any()) + + atxProof := &wire.ATXProof{ + Version: 0x01, // for now we only have one version + ProofType: wire.InvalidPost, + + Proof: []byte{}, + } + th.mPublish.EXPECT().PublishATXProof(context.Background(), nodeID, codec.MustEncode(atxProof), true) err := th.Publish(context.Background(), nodeID, proof) require.NoError(t, err) @@ -482,3 +511,26 @@ func TestValidate(t *testing.T) { require.Equal(t, types.EmptyNodeID, id) }) } + +func TestIdentityExists(t *testing.T) { + t.Parallel() + + th := newTestMalHandler(t) + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + yes, err := th.IdentityExists(sig.NodeID()) + require.NoError(t, err) + require.False(t, yes) + + atx := &types.ActivationTx{ + SmesherID: sig.NodeID(), + } + atx.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(th.db, atx, types.AtxBlob{})) + + yes, err = th.IdentityExists(sig.NodeID()) + require.NoError(t, err) + require.True(t, yes) +} diff --git a/activation/mocks.go b/activation/mocks.go index 9e46e4224b..157ce3cf4a 100644 --- a/activation/mocks.go +++ b/activation/mocks.go @@ -1280,17 +1280,17 @@ func (m *MockmalfeasancePublisher) EXPECT() *MockmalfeasancePublisherMockRecorde } // PublishATXProof mocks base method. -func (m *MockmalfeasancePublisher) PublishATXProof(ctx context.Context, nodeID types.NodeID, proof []byte) error { +func (m *MockmalfeasancePublisher) PublishATXProof(ctx context.Context, nodeID types.NodeID, proof []byte, allowNoRefATXs bool) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PublishATXProof", ctx, nodeID, proof) + ret := m.ctrl.Call(m, "PublishATXProof", ctx, nodeID, proof, allowNoRefATXs) ret0, _ := ret[0].(error) return ret0 } // PublishATXProof indicates an expected call of PublishATXProof. -func (mr *MockmalfeasancePublisherMockRecorder) PublishATXProof(ctx, nodeID, proof any) *MockmalfeasancePublisherPublishATXProofCall { +func (mr *MockmalfeasancePublisherMockRecorder) PublishATXProof(ctx, nodeID, proof, allowNoRefATXs any) *MockmalfeasancePublisherPublishATXProofCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishATXProof", reflect.TypeOf((*MockmalfeasancePublisher)(nil).PublishATXProof), ctx, nodeID, proof) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishATXProof", reflect.TypeOf((*MockmalfeasancePublisher)(nil).PublishATXProof), ctx, nodeID, proof, allowNoRefATXs) return &MockmalfeasancePublisherPublishATXProofCall{Call: call} } @@ -1306,13 +1306,13 @@ func (c *MockmalfeasancePublisherPublishATXProofCall) Return(arg0 error) *Mockma } // Do rewrite *gomock.Call.Do -func (c *MockmalfeasancePublisherPublishATXProofCall) Do(f func(context.Context, types.NodeID, []byte) error) *MockmalfeasancePublisherPublishATXProofCall { +func (c *MockmalfeasancePublisherPublishATXProofCall) Do(f func(context.Context, types.NodeID, []byte, bool) error) *MockmalfeasancePublisherPublishATXProofCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockmalfeasancePublisherPublishATXProofCall) DoAndReturn(f func(context.Context, types.NodeID, []byte) error) *MockmalfeasancePublisherPublishATXProofCall { +func (c *MockmalfeasancePublisherPublishATXProofCall) DoAndReturn(f func(context.Context, types.NodeID, []byte, bool) error) *MockmalfeasancePublisherPublishATXProofCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/activation/nipost.go b/activation/nipost.go index 22d8d62f1b..7c008fe542 100644 --- a/activation/nipost.go +++ b/activation/nipost.go @@ -234,7 +234,8 @@ func (nb *NIPostBuilder) BuildNIPost( ctx, signer, poetProofDeadline, - poetRoundStart, challenge.Bytes(), + poetRoundStart, + challenge.Bytes(), ) regErr := &PoetRegistrationMismatchError{} switch { diff --git a/activation/wire/interface.go b/activation/wire/interface.go index 46db9f6f69..ec39fbeb89 100644 --- a/activation/wire/interface.go +++ b/activation/wire/interface.go @@ -2,6 +2,7 @@ package wire import ( "context" + "errors" "github.com/spacemeshos/go-scale" @@ -26,14 +27,33 @@ type MalfeasanceValidator interface { // Signature validates the given signature against the given message and public key. Signature(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool + + // IdentityExists returns true if the given identity has published a valid ATX before. + IdentityExists(nodeID types.NodeID) (bool, error) } +var ErrUnknownIdentity = errors.New("unknown identity") + // Proof is an interface for all types of proofs that can be provided in an ATXProof. // Generally the proof should be able to validate itself and be scale encoded. type Proof interface { scale.Encodable scale.Decodable + // AllowNoRefATXs returns true if the proof type is valid without reference ATXs proving the existence of the + // malicious identity. + // + // To avoid spamming of malfeasance proofs for identities that do not exist, by default all proofs require reference + // ATXs (syntactically valid ATXs published by the malicious identity) to be provided. This way any identity that + // the network considers malicious must have been in good standing at some point before the malicious behavior. + // + // For some malfeasance proofs this requirement is not necessary, for example invalid post proofs. Since those + // require the creator of the proof to show that some labels in the post are valid and some invalid. If all are + // invalid, the ATX would be considered syntactically invalid by the network anyway and a proof is not needed. + // In contrast if we would require a reference ATX we couldn't proof an invalid post in an initial ATX of any new + // identity. + AllowNoRefATXs() bool + Type() ProofType TypeName() string Info() map[string]string diff --git a/activation/wire/malfeasance_double_marry.go b/activation/wire/malfeasance_double_marry.go index c81cb8b5c5..0f53c6e74b 100644 --- a/activation/wire/malfeasance_double_marry.go +++ b/activation/wire/malfeasance_double_marry.go @@ -42,6 +42,10 @@ type ProofDoubleMarry struct { Proof2 MarryProof } +func (p ProofDoubleMarry) AllowNoRefATXs() bool { + return false +} + func (p ProofDoubleMarry) TypeName() string { return "DoubleMarryProof" } @@ -107,5 +111,10 @@ func (p ProofDoubleMarry) Valid(_ context.Context, malValidator MalfeasanceValid if err := p.Proof2.Valid(malValidator, p.ATXID2, p.SmesherID2, p.NodeID); err != nil { return types.EmptyNodeID, fmt.Errorf("proof 2 is invalid: %w", err) } + if ok, err := malValidator.IdentityExists(p.NodeID); err != nil { + return types.EmptyNodeID, fmt.Errorf("checking identity: %w", err) + } else if !ok { + return types.EmptyNodeID, ErrUnknownIdentity + } return p.NodeID, nil } diff --git a/activation/wire/malfeasance_double_marry_test.go b/activation/wire/malfeasance_double_marry_test.go index 66f030da3a..550d85e516 100644 --- a/activation/wire/malfeasance_double_marry_test.go +++ b/activation/wire/malfeasance_double_marry_test.go @@ -58,12 +58,53 @@ func Test_DoubleMarryProof(t *testing.T) { DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().IdentityExists(otherSig.NodeID()).Return(true, nil).AnyTimes() id, err := proof.Valid(context.Background(), verifier) require.NoError(t, err) require.Equal(t, otherSig.NodeID(), id) }) + t.Run("identity unknown", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + atx1 := NewTestActivationTxV2( + t, + WithMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + WithMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + atx1.Sign(sig) + + atx2 := NewTestActivationTxV2( + t, + WithMarriageCertificate(otherSig, types.EmptyATXID, otherSig.NodeID()), + WithMarriageCertificate(sig, atx1.ID(), otherSig.NodeID()), + ) + atx2.Sign(otherSig) + + proof, err := NewDoubleMarryProof(db, atx1, atx2, otherSig.NodeID()) + require.NoError(t, err) + require.NotNil(t, proof) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + verifier.EXPECT().IdentityExists(otherSig.NodeID()).Return(false, nil).AnyTimes() + + id, err := proof.Valid(context.Background(), verifier) + require.ErrorIs(t, err, ErrUnknownIdentity) + require.Equal(t, types.EmptyNodeID, id) + }) + t.Run("identity is not included in both ATXs", func(t *testing.T) { t.Parallel() db := statesql.InMemoryTest(t) diff --git a/activation/wire/malfeasance_double_merge.go b/activation/wire/malfeasance_double_merge.go index 3047623c94..53a59de800 100644 --- a/activation/wire/malfeasance_double_merge.go +++ b/activation/wire/malfeasance_double_merge.go @@ -59,6 +59,10 @@ type ProofDoubleMerge struct { SmesherID2MarryProof MarryProof } +func (p ProofDoubleMerge) AllowNoRefATXs() bool { + return false +} + func (p ProofDoubleMerge) TypeName() string { return "DoubleMergeProof" } @@ -145,17 +149,17 @@ func NewDoubleMergeProof(db sql.Executor, atx1, atx2 *ActivationTxV2) (*ProofDou return &proof, nil } -func (p *ProofDoubleMerge) Valid(_ context.Context, edVerifier MalfeasanceValidator) (types.NodeID, error) { +func (p *ProofDoubleMerge) Valid(_ context.Context, malValidator MalfeasanceValidator) (types.NodeID, error) { // 1. The ATXs have different IDs. if p.ATXID1 == p.ATXID2 { return types.EmptyNodeID, errors.New("ATXs have the same ID") } // 2. Both ATXs have a valid signature. - if !edVerifier.Signature(signing.ATX, p.SmesherID1, p.ATXID1.Bytes(), p.Signature1) { + if !malValidator.Signature(signing.ATX, p.SmesherID1, p.ATXID1.Bytes(), p.Signature1) { return types.EmptyNodeID, errors.New("ATX 1 invalid signature") } - if !edVerifier.Signature(signing.ATX, p.SmesherID2, p.ATXID2.Bytes(), p.Signature2) { + if !malValidator.Signature(signing.ATX, p.SmesherID2, p.ATXID2.Bytes(), p.Signature2) { return types.EmptyNodeID, errors.New("ATX 2 invalid signature") } @@ -171,17 +175,30 @@ func (p *ProofDoubleMerge) Valid(_ context.Context, edVerifier MalfeasanceValida if !p.MarriageATXProof1.Valid(p.ATXID1, p.MarriageATX) { return types.EmptyNodeID, errors.New("ATX 1 invalid marriage ATX proof") } - err := p.SmesherID1MarryProof.Valid(edVerifier, p.MarriageATX, p.MarriageATXSmesherID, p.SmesherID1) + err := p.SmesherID1MarryProof.Valid(malValidator, p.MarriageATX, p.MarriageATXSmesherID, p.SmesherID1) if err != nil { return types.EmptyNodeID, errors.New("ATX 1 invalid marriage ATX proof") } if !p.MarriageATXProof2.Valid(p.ATXID2, p.MarriageATX) { return types.EmptyNodeID, errors.New("ATX 2 invalid marriage ATX proof") } - err = p.SmesherID2MarryProof.Valid(edVerifier, p.MarriageATX, p.MarriageATXSmesherID, p.SmesherID2) + err = p.SmesherID2MarryProof.Valid(malValidator, p.MarriageATX, p.MarriageATXSmesherID, p.SmesherID2) if err != nil { return types.EmptyNodeID, errors.New("ATX 2 invalid marriage ATX proof") } + // 6. smeshers have published valid ATXs before + if ok, err := malValidator.IdentityExists(p.SmesherID1); err != nil { + return types.EmptyNodeID, fmt.Errorf("checking identity: %w", err) + } else if !ok { + return types.EmptyNodeID, ErrUnknownIdentity + } + + if ok, err := malValidator.IdentityExists(p.SmesherID2); err != nil { + return types.EmptyNodeID, fmt.Errorf("checking identity: %w", err) + } else if !ok { + return types.EmptyNodeID, ErrUnknownIdentity + } + return p.SmesherID1, nil } diff --git a/activation/wire/malfeasance_double_merge_test.go b/activation/wire/malfeasance_double_merge_test.go index bf0370f79c..c4d546cfe3 100644 --- a/activation/wire/malfeasance_double_merge_test.go +++ b/activation/wire/malfeasance_double_merge_test.go @@ -76,6 +76,8 @@ func Test_DoubleMergeProof(t *testing.T) { DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().IdentityExists(sig.NodeID()).Return(true, nil).AnyTimes() + verifier.EXPECT().IdentityExists(otherSig.NodeID()).Return(true, nil).AnyTimes() marriageAtx := setupMarriage(db) @@ -100,6 +102,78 @@ func Test_DoubleMergeProof(t *testing.T) { require.Equal(t, sig.NodeID(), id) }) + t.Run("identity1 unknown", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + verifier.EXPECT().IdentityExists(sig.NodeID()).Return(false, nil).AnyTimes() + verifier.EXPECT().IdentityExists(otherSig.NodeID()).Return(true, nil).AnyTimes() + + marriageAtx := setupMarriage(db) + + atx1 := NewTestActivationTxV2( + t, + WithMarriageATX(marriageAtx.ID()), + WithPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx1.Sign(sig) + + atx2 := NewTestActivationTxV2( + t, + WithMarriageATX(marriageAtx.ID()), + WithPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx2.Sign(otherSig) + + proof, err := NewDoubleMergeProof(db, atx1, atx2) + require.NoError(t, err) + id, err := proof.Valid(context.Background(), verifier) + require.ErrorIs(t, err, ErrUnknownIdentity) + require.Equal(t, types.EmptyNodeID, id) + }) + + t.Run("identity2 unknown", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + verifier.EXPECT().IdentityExists(sig.NodeID()).Return(true, nil).AnyTimes() + verifier.EXPECT().IdentityExists(otherSig.NodeID()).Return(false, nil).AnyTimes() + + marriageAtx := setupMarriage(db) + + atx1 := NewTestActivationTxV2( + t, + WithMarriageATX(marriageAtx.ID()), + WithPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx1.Sign(sig) + + atx2 := NewTestActivationTxV2( + t, + WithMarriageATX(marriageAtx.ID()), + WithPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx2.Sign(otherSig) + + proof, err := NewDoubleMergeProof(db, atx1, atx2) + require.NoError(t, err) + id, err := proof.Valid(context.Background(), verifier) + require.ErrorIs(t, err, ErrUnknownIdentity) + require.Equal(t, types.EmptyNodeID, id) + }) + t.Run("same ATX ID", func(t *testing.T) { t.Parallel() db := statesql.InMemoryTest(t) diff --git a/activation/wire/malfeasance_invalid_post.go b/activation/wire/malfeasance_invalid_post.go index e674d7f682..cb062e044a 100644 --- a/activation/wire/malfeasance_invalid_post.go +++ b/activation/wire/malfeasance_invalid_post.go @@ -39,6 +39,10 @@ type ProofInvalidPost struct { InvalidPostProof InvalidPostProof } +func (p ProofInvalidPost) AllowNoRefATXs() bool { + return true +} + func (p ProofInvalidPost) TypeName() string { return "InvalidPoSTProof" } diff --git a/activation/wire/malfeasance_invalid_prev_atx.go b/activation/wire/malfeasance_invalid_prev_atx.go index b1a9376832..ea9c32c3e9 100644 --- a/activation/wire/malfeasance_invalid_prev_atx.go +++ b/activation/wire/malfeasance_invalid_prev_atx.go @@ -32,6 +32,10 @@ type ProofInvalidPrevAtxV2 struct { Proofs [2]InvalidPrevAtxProof } +func (p ProofInvalidPrevAtxV2) AllowNoRefATXs() bool { + return false +} + func (p ProofInvalidPrevAtxV2) TypeName() string { return "InvalidPreviousATXProofV2" } @@ -186,6 +190,11 @@ func (p ProofInvalidPrevAtxV2) Valid(_ context.Context, malValidator Malfeasance if err := p.Proofs[1].Valid(p.PrevATXID, p.NodeID, malValidator); err != nil { return types.EmptyNodeID, fmt.Errorf("proof 2 is invalid: %w", err) } + if ok, err := malValidator.IdentityExists(p.NodeID); err != nil { + return types.EmptyNodeID, fmt.Errorf("checking identity: %w", err) + } else if !ok { + return types.EmptyNodeID, ErrUnknownIdentity + } return p.NodeID, nil } @@ -209,6 +218,10 @@ type ProofInvalidPrevAtxV1 struct { ATXv1 ActivationTxV1 } +func (p ProofInvalidPrevAtxV1) AllowNoRefATXs() bool { + return false +} + func (p ProofInvalidPrevAtxV1) TypeName() string { return "InvalidPreviousATXProofV1" } diff --git a/activation/wire/malfeasance_invalid_prev_atx_test.go b/activation/wire/malfeasance_invalid_prev_atx_test.go index 0a4356b2d4..eedb3345b5 100644 --- a/activation/wire/malfeasance_invalid_prev_atx_test.go +++ b/activation/wire/malfeasance_invalid_prev_atx_test.go @@ -125,6 +125,7 @@ func Test_InvalidPrevAtxProofV2(t *testing.T) { DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().IdentityExists(sig.NodeID()).Return(true, nil).AnyTimes() // verify the proof id, err := proof.Valid(context.Background(), verifier) @@ -132,6 +133,41 @@ func Test_InvalidPrevAtxProofV2(t *testing.T) { require.Equal(t, sig.NodeID(), id) }) + t.Run("identity unknown", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + atx1 := NewTestActivationTxV2( + t, + WithPreviousATXs(prevATXID), + WithPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := NewTestActivationTxV2( + t, + WithPreviousATXs(prevATXID), + WithPublishEpoch(7), + ) + atx2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + verifier.EXPECT().IdentityExists(sig.NodeID()).Return(false, nil).AnyTimes() + + // verify the proof + id, err := proof.Valid(context.Background(), verifier) + require.ErrorIs(t, err, ErrUnknownIdentity) + require.Equal(t, types.EmptyNodeID, id) + }) + t.Run("valid merged & solo atx", func(t *testing.T) { t.Parallel() db := statesql.InMemoryTest(t) @@ -158,6 +194,7 @@ func Test_InvalidPrevAtxProofV2(t *testing.T) { DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().IdentityExists(sig.NodeID()).Return(true, nil).AnyTimes() // verify the proof id, err := proof.Valid(context.Background(), verifier) diff --git a/activation/wire/mocks.go b/activation/wire/mocks.go index be5ddb8950..0994ebd820 100644 --- a/activation/wire/mocks.go +++ b/activation/wire/mocks.go @@ -43,6 +43,45 @@ func (m *MockMalfeasanceValidator) EXPECT() *MockMalfeasanceValidatorMockRecorde return m.recorder } +// IdentityExists mocks base method. +func (m *MockMalfeasanceValidator) IdentityExists(nodeID types.NodeID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IdentityExists", nodeID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IdentityExists indicates an expected call of IdentityExists. +func (mr *MockMalfeasanceValidatorMockRecorder) IdentityExists(nodeID any) *MockMalfeasanceValidatorIdentityExistsCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IdentityExists", reflect.TypeOf((*MockMalfeasanceValidator)(nil).IdentityExists), nodeID) + return &MockMalfeasanceValidatorIdentityExistsCall{Call: call} +} + +// MockMalfeasanceValidatorIdentityExistsCall wrap *gomock.Call +type MockMalfeasanceValidatorIdentityExistsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMalfeasanceValidatorIdentityExistsCall) Return(arg0 bool, arg1 error) *MockMalfeasanceValidatorIdentityExistsCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMalfeasanceValidatorIdentityExistsCall) Do(f func(types.NodeID) (bool, error)) *MockMalfeasanceValidatorIdentityExistsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMalfeasanceValidatorIdentityExistsCall) DoAndReturn(f func(types.NodeID) (bool, error)) *MockMalfeasanceValidatorIdentityExistsCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // PostIndex mocks base method. func (m *MockMalfeasanceValidator) PostIndex(ctx context.Context, smesherID types.NodeID, commitment types.ATXID, post *types.Post, challenge []byte, numUnits uint32, idx int) error { m.ctrl.T.Helper() @@ -143,6 +182,44 @@ func (m *MockProof) EXPECT() *MockProofMockRecorder { return m.recorder } +// AllowNoRefATXs mocks base method. +func (m *MockProof) AllowNoRefATXs() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AllowNoRefATXs") + ret0, _ := ret[0].(bool) + return ret0 +} + +// AllowNoRefATXs indicates an expected call of AllowNoRefATXs. +func (mr *MockProofMockRecorder) AllowNoRefATXs() *MockProofAllowNoRefATXsCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowNoRefATXs", reflect.TypeOf((*MockProof)(nil).AllowNoRefATXs)) + return &MockProofAllowNoRefATXsCall{Call: call} +} + +// MockProofAllowNoRefATXsCall wrap *gomock.Call +type MockProofAllowNoRefATXsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockProofAllowNoRefATXsCall) Return(arg0 bool) *MockProofAllowNoRefATXsCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockProofAllowNoRefATXsCall) Do(f func() bool) *MockProofAllowNoRefATXsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockProofAllowNoRefATXsCall) DoAndReturn(f func() bool) *MockProofAllowNoRefATXsCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // DecodeScale mocks base method. func (m *MockProof) DecodeScale(dec *scale.Decoder) (int, error) { m.ctrl.T.Helper() diff --git a/api/grpcserver/v2alpha1/malfeasance.go b/api/grpcserver/v2alpha1/malfeasance.go index 10e91527ac..42502056cf 100644 --- a/api/grpcserver/v2alpha1/malfeasance.go +++ b/api/grpcserver/v2alpha1/malfeasance.go @@ -280,12 +280,12 @@ func fetchMetaData( zap.String("type", properties["type"]), zap.Error(err), ) - return nil + } else { + delete(properties, "type") } - delete(properties, "type") return &spacemeshv2alpha1.MalfeasanceProof{ Smesher: id.Bytes(), - Domain: spacemeshv2alpha1.MalfeasanceProof_MalfeasanceDomain(domain), // TODO(mafa): add new domains + Domain: spacemeshv2alpha1.MalfeasanceProof_MalfeasanceDomain(domain), Type: uint32(proofType), Properties: properties, } diff --git a/api/grpcserver/v2alpha1/malfeasance_test.go b/api/grpcserver/v2alpha1/malfeasance_test.go index 4ed7616e3c..9589d9b0af 100644 --- a/api/grpcserver/v2alpha1/malfeasance_test.go +++ b/api/grpcserver/v2alpha1/malfeasance_test.go @@ -62,7 +62,7 @@ func TestMalfeasanceService_List(t *testing.T) { proofs[i] = malInfo{ID: types.RandomNodeID(), Proof: types.RandomBytes(100)} proofs[i].Properties = map[string]string{ "domain": strconv.FormatUint(uint64(i%4+1), 10), - "type": strconv.FormatUint(uint64(i%4+1), 10), + "type": fmt.Sprintf("Type %d", i%4+1), fmt.Sprintf("key%d", i): fmt.Sprintf("value%d", i), } info.EXPECT().Info(gomock.Any(), proofs[i].ID).DoAndReturn( @@ -92,7 +92,7 @@ func TestMalfeasanceService_List(t *testing.T) { proofs[70].Proof = types.RandomBytes(100) proofs[70].Properties = map[string]string{ "domain": "1", - "type": "1", + "type": "Type Marry", "key": "value", } info.EXPECT().Info(gomock.Any(), proofs[70].ID).DoAndReturn( @@ -198,7 +198,7 @@ func TestMalfeasanceStreamService_Stream(t *testing.T) { proofs[i] = malInfo{ID: types.RandomNodeID(), Proof: types.RandomBytes(100)} proofs[i].Properties = map[string]string{ "domain": strconv.FormatUint(uint64(i%4+1), 10), - "type": strconv.FormatUint(uint64(i%4+1), 10), + "type": fmt.Sprintf("Type %d", i%4+1), fmt.Sprintf("key%d", i): fmt.Sprintf("value%d", i), } info.EXPECT().Info(gomock.Any(), proofs[i].ID).DoAndReturn( @@ -228,7 +228,7 @@ func TestMalfeasanceStreamService_Stream(t *testing.T) { proofs[70].Proof = types.RandomBytes(100) proofs[70].Properties = map[string]string{ "domain": "1", - "type": "1", + "type": "Type Marry", "key": "value", } info.EXPECT().Info(gomock.Any(), proofs[70].ID).DoAndReturn( @@ -324,7 +324,7 @@ func TestMalfeasanceStreamService_Stream(t *testing.T) { }) properties := map[string]string{ "domain": strconv.FormatUint(uint64(i%4+1), 10), - "type": strconv.FormatUint(uint64(i%4+1), 10), + "type": fmt.Sprintf("Type %d", i%4+1), fmt.Sprintf("key%d", i): fmt.Sprintf("value%d", i), } info.EXPECT().Info(gomock.Any(), streamed[i].Smesher).DoAndReturn( @@ -344,7 +344,7 @@ func TestMalfeasanceStreamService_Stream(t *testing.T) { }) properties := map[string]string{ "domain": "1", - "type": "1", + "type": "Type Marry", "key": "value", } info.EXPECT().Info(gomock.Any(), streamed[i].Smesher).DoAndReturn( diff --git a/api/grpcserver/v2beta1/malfeasance.go b/api/grpcserver/v2beta1/malfeasance.go index 1ce85fe254..3a1150dd16 100644 --- a/api/grpcserver/v2beta1/malfeasance.go +++ b/api/grpcserver/v2beta1/malfeasance.go @@ -280,12 +280,12 @@ func fetchMetaData( zap.String("type", properties["type"]), zap.Error(err), ) - return nil + } else { + delete(properties, "type") } - delete(properties, "type") return &spacemeshv2beta1.MalfeasanceProof{ Smesher: id.Bytes(), - Domain: spacemeshv2beta1.MalfeasanceProof_MalfeasanceDomain(domain), // TODO(mafa): add new domains + Domain: spacemeshv2beta1.MalfeasanceProof_MalfeasanceDomain(domain), Type: uint32(proofType), Properties: properties, } diff --git a/api/grpcserver/v2beta1/malfeasance_test.go b/api/grpcserver/v2beta1/malfeasance_test.go index 2572f5aac1..50945796c4 100644 --- a/api/grpcserver/v2beta1/malfeasance_test.go +++ b/api/grpcserver/v2beta1/malfeasance_test.go @@ -62,7 +62,7 @@ func TestMalfeasanceService_List(t *testing.T) { proofs[i] = malInfo{ID: types.RandomNodeID(), Proof: types.RandomBytes(100)} proofs[i].Properties = map[string]string{ "domain": strconv.FormatUint(uint64(i%4+1), 10), - "type": strconv.FormatUint(uint64(i%4+1), 10), + "type": fmt.Sprintf("Type %d", i%4+1), fmt.Sprintf("key%d", i): fmt.Sprintf("value%d", i), } info.EXPECT().Info(gomock.Any(), proofs[i].ID).DoAndReturn( @@ -92,7 +92,7 @@ func TestMalfeasanceService_List(t *testing.T) { proofs[70].Proof = types.RandomBytes(100) proofs[70].Properties = map[string]string{ "domain": "1", - "type": "1", + "type": "Type Marry", "key": "value", } info.EXPECT().Info(gomock.Any(), proofs[70].ID).DoAndReturn( @@ -198,7 +198,7 @@ func TestMalfeasanceStreamService_Stream(t *testing.T) { proofs[i] = malInfo{ID: types.RandomNodeID(), Proof: types.RandomBytes(100)} proofs[i].Properties = map[string]string{ "domain": strconv.FormatUint(uint64(i%4+1), 10), - "type": strconv.FormatUint(uint64(i%4+1), 10), + "type": fmt.Sprintf("Type %d", i%4+1), fmt.Sprintf("key%d", i): fmt.Sprintf("value%d", i), } info.EXPECT().Info(gomock.Any(), proofs[i].ID).DoAndReturn( @@ -228,7 +228,7 @@ func TestMalfeasanceStreamService_Stream(t *testing.T) { proofs[70].Proof = types.RandomBytes(100) proofs[70].Properties = map[string]string{ "domain": "1", - "type": "1", + "type": "Type Marry", "key": "value", } info.EXPECT().Info(gomock.Any(), proofs[70].ID).DoAndReturn( @@ -324,7 +324,7 @@ func TestMalfeasanceStreamService_Stream(t *testing.T) { }) properties := map[string]string{ "domain": strconv.FormatUint(uint64(i%4+1), 10), - "type": strconv.FormatUint(uint64(i%4+1), 10), + "type": fmt.Sprintf("Type %d", i%4+1), fmt.Sprintf("key%d", i): fmt.Sprintf("value%d", i), } info.EXPECT().Info(gomock.Any(), streamed[i].Smesher).DoAndReturn( @@ -344,7 +344,7 @@ func TestMalfeasanceStreamService_Stream(t *testing.T) { }) properties := map[string]string{ "domain": "1", - "type": "1", + "type": "Type Marry", "key": "value", } info.EXPECT().Info(gomock.Any(), streamed[i].Smesher).DoAndReturn( diff --git a/checkpoint/runner.go b/checkpoint/runner.go index 614a7d7ffe..a1413ca9c2 100644 --- a/checkpoint/runner.go +++ b/checkpoint/runner.go @@ -15,6 +15,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/builder" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/marriage" ) @@ -68,7 +69,11 @@ func checkpointDB( if err != nil { return nil, fmt.Errorf("atxs snapshot check identity: %w", err) } - malicious[catx.SmesherID] = mal + mal2, err := malfeasance.IsMalicious(tx, catx.SmesherID) + if err != nil { + return nil, fmt.Errorf("atxs snapshot check malfeasance: %w", err) + } + malicious[catx.SmesherID] = mal || mal2 } commitmentAtx, err := atxs.CommitmentATX(tx, catx.SmesherID) if err != nil { diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index 6b3a870bbb..0528ea8f4c 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -21,6 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) @@ -35,9 +36,15 @@ type activationTx struct { previous types.ATXID } +type malProof struct { + proof []byte + domain int +} + type miner struct { - atxs []activationTx - malfeasanceProof []byte + atxs []activationTx + malfeasanceProof []byte + malfeasanceProof2 *malProof } var allMiners = []miner{ @@ -76,14 +83,25 @@ var allMiners = []miner{ }, }, - // smesher 5 is malicious and equivocated in epoch 7 + // smesher 5 is malicious and equivocated in epoch 6 { atxs: []activationTx{ - {newAtx(types.ATXID{83}, &types.ATXID{27}, 7, 0, 113, []byte("smesher5")), types.EmptyATXID}, - {newAtx(types.ATXID{97}, &types.ATXID{16}, 7, 0, 113, []byte("smesher5")), types.EmptyATXID}, + {newAtx(types.ATXID{53}, &types.ATXID{27}, 6, 0, 113, []byte("smesher5")), types.EmptyATXID}, + {newAtx(types.ATXID{57}, &types.ATXID{16}, 6, 0, 113, []byte("smesher5")), types.EmptyATXID}, }, malfeasanceProof: []byte("im bad"), }, + // smesher 6 is malicious and equivocated in epoch 7 + { + atxs: []activationTx{ + {newAtx(types.ATXID{63}, &types.ATXID{27}, 7, 0, 113, []byte("smesher6")), types.EmptyATXID}, + {newAtx(types.ATXID{67}, &types.ATXID{16}, 7, 0, 113, []byte("smesher6")), types.EmptyATXID}, + }, + malfeasanceProof2: &malProof{ + proof: []byte("im bad"), + domain: 1, + }, + }, } var allAccounts = []*types.Account{ @@ -179,6 +197,9 @@ func expectedCheckpoint(tb testing.TB, snapshot types.LayerID, numAtxs int, mine if len(miner.malfeasanceProof) > 0 { continue } + if miner.malfeasanceProof2 != nil { + continue + } atxs := miner.atxs n := len(atxs) if n > numAtxs { @@ -193,7 +214,6 @@ func expectedCheckpoint(tb testing.TB, snapshot types.LayerID, numAtxs int, mine } result.Data.Atxs = atxData - accounts := make(map[types.Address]*types.Account) for _, account := range allAccounts { if account.Layer <= snapshot { @@ -276,6 +296,10 @@ func createMesh(tb testing.TB, db sql.StateDatabase, miners []miner, accts []*ty if proof := miner.malfeasanceProof; len(proof) > 0 { require.NoError(tb, identities.SetMalicious(db, miner.atxs[0].SmesherID, proof, time.Now())) } + if proof := miner.malfeasanceProof2; proof != nil { + err := malfeasance.AddProof(db, miner.atxs[1].SmesherID, nil, proof.proof, proof.domain, time.Now()) + require.NoError(tb, err) + } } for _, it := range accts { diff --git a/config/logging.go b/config/logging.go index 15c5f647bf..fe1573589a 100644 --- a/config/logging.go +++ b/config/logging.go @@ -49,6 +49,7 @@ type LoggerConfig struct { ConStateLoggerLevel string `mapstructure:"conState"` ExecutorLoggerLevel string `mapstructure:"executor"` MalfeasanceLoggerLevel string `mapstructure:"malfeasance"` + Malfeasance2LoggerLevel string `mapstructure:"malfeasance2"` BootstrapLoggerLevel string `mapstructure:"bootstrap"` } @@ -86,6 +87,7 @@ func DefaultLoggingConfig() LoggerConfig { PostServiceLoggerLevel: defaultLoggingLevel.String(), ConStateLoggerLevel: defaultLoggingLevel.String(), MalfeasanceLoggerLevel: defaultLoggingLevel.String(), + Malfeasance2LoggerLevel: defaultLoggingLevel.String(), BootstrapLoggerLevel: defaultLoggingLevel.String(), } } diff --git a/config/presets/fastnet.go b/config/presets/fastnet.go index ba9838d26d..19292ad671 100644 --- a/config/presets/fastnet.go +++ b/config/presets/fastnet.go @@ -68,7 +68,7 @@ func fastnet() config.Config { conf.POST.K1 = 12 conf.POST.K2 = 4 - conf.POST.K3 = 1 + conf.POST.K3 = 2 conf.POST.LabelsPerUnit = 128 conf.POST.MaxNumUnits = 4 conf.POST.MinNumUnits = 2 diff --git a/datastore/mocks.go b/datastore/mocks.go new file mode 100644 index 0000000000..f2c8f836a8 --- /dev/null +++ b/datastore/mocks.go @@ -0,0 +1,81 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./store.go +// +// Generated by this command: +// +// mockgen -typed -package=datastore -destination=./mocks.go -source=./store.go +// + +// Package datastore is a generated GoMock package. +package datastore + +import ( + context "context" + reflect "reflect" + + types "github.com/spacemeshos/go-spacemesh/common/types" + gomock "go.uber.org/mock/gomock" +) + +// MockMalfeasanceProvider is a mock of MalfeasanceProvider interface. +type MockMalfeasanceProvider struct { + ctrl *gomock.Controller + recorder *MockMalfeasanceProviderMockRecorder + isgomock struct{} +} + +// MockMalfeasanceProviderMockRecorder is the mock recorder for MockMalfeasanceProvider. +type MockMalfeasanceProviderMockRecorder struct { + mock *MockMalfeasanceProvider +} + +// NewMockMalfeasanceProvider creates a new mock instance. +func NewMockMalfeasanceProvider(ctrl *gomock.Controller) *MockMalfeasanceProvider { + mock := &MockMalfeasanceProvider{ctrl: ctrl} + mock.recorder = &MockMalfeasanceProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMalfeasanceProvider) EXPECT() *MockMalfeasanceProviderMockRecorder { + return m.recorder +} + +// ProofByID mocks base method. +func (m *MockMalfeasanceProvider) ProofByID(ctx context.Context, nodeID types.NodeID) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ProofByID", ctx, nodeID) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ProofByID indicates an expected call of ProofByID. +func (mr *MockMalfeasanceProviderMockRecorder) ProofByID(ctx, nodeID any) *MockMalfeasanceProviderProofByIDCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProofByID", reflect.TypeOf((*MockMalfeasanceProvider)(nil).ProofByID), ctx, nodeID) + return &MockMalfeasanceProviderProofByIDCall{Call: call} +} + +// MockMalfeasanceProviderProofByIDCall wrap *gomock.Call +type MockMalfeasanceProviderProofByIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMalfeasanceProviderProofByIDCall) Return(arg0 []byte, arg1 error) *MockMalfeasanceProviderProofByIDCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMalfeasanceProviderProofByIDCall) Do(f func(context.Context, types.NodeID) ([]byte, error)) *MockMalfeasanceProviderProofByIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMalfeasanceProviderProofByIDCall) DoAndReturn(f func(context.Context, types.NodeID) ([]byte, error)) *MockMalfeasanceProviderProofByIDCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/datastore/store.go b/datastore/store.go index fbc9a34381..1620c188e6 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -20,6 +20,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/builder" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -113,7 +114,9 @@ func NewCachedDB(db sql.StateDatabase, lg *zap.Logger, opts ...Opt) *CachedDB { } } -// TODO(mafa): this needs to be removed, since it only works with v1 malfeasance proofs. +// MalfeasanceProof returns the malfeasance proof for the given node ID. This function is thread safe and will return +// an error if the proof is not found in the ATX DB. +// Deprecated: use functions in the `sql/identities` and `sql/malfeasance` packages. func (db *CachedDB) MalfeasanceProof(id types.NodeID) ([]byte, error) { if id == types.EmptyNodeID { panic("invalid argument to GetMalfeasanceProof") @@ -137,7 +140,8 @@ func (db *CachedDB) MalfeasanceProof(id types.NodeID) ([]byte, error) { return blob.Bytes, err } -// TODO(mafa): this needs to be removed, since it only works with v1 malfeasance proofs. +// CacheMalfeasanceProof caches the malfeasance proof for the given node ID. This function is thread safe. +// Deprecated: caching is done by the sql database automatically. func (db *CachedDB) CacheMalfeasanceProof(id types.NodeID, proof []byte) { if id == types.EmptyNodeID { panic("invalid argument to CacheMalfeasanceProof") @@ -188,10 +192,14 @@ func (db *CachedDB) GetAtx(id types.ATXID) (*types.ActivationTx, error) { } // Previous retrieves the list of previous ATXs for the given ATX ID. +// Deprecated: replaced by atxs.Previous. func (db *CachedDB) Previous(id types.ATXID) ([]types.ATXID, error) { return atxs.Previous(db, id) } +// IterateMalfeasanceProofs iterates over all malfeasance proofs in the database and calls the provided callback on +// each. +// Deprecated: replaced by identities.IterateOps and malfeasance.IterateOps. func (db *CachedDB) IterateMalfeasanceProofs( iter func(types.NodeID, []byte) error, ) error { @@ -206,6 +214,8 @@ func (db *CachedDB) IterateMalfeasanceProofs( return callbackErr } +// MaxHeightAtx returns the ATX ID with the maximum height. +// Deprecated: replaced by atxs.GetIDWithMaxHeight. func (db *CachedDB) MaxHeightAtx() (types.ATXID, error) { return atxs.GetIDWithMaxHeight(db, types.EmptyNodeID, atxs.FilterAll) } @@ -215,26 +225,46 @@ type Hint string // DB hints per DB. const ( - NoHint Hint = "" - BallotDB Hint = "ballotDB" - BlockDB Hint = "blocksDB" - ProposalDB Hint = "proposalDB" - ATXDB Hint = "ATXDB" - TXDB Hint = "TXDB" - POETDB Hint = "POETDB" - Malfeasance Hint = "malfeasance" - ActiveSet Hint = "activeset" + NoHint Hint = "" + BallotDB Hint = "ballotDB" + BlockDB Hint = "blocksDB" + ProposalDB Hint = "proposalDB" + ATXDB Hint = "ATXDB" + TXDB Hint = "TXDB" + POETDB Hint = "POETDB" + LegacyMalfeasance Hint = "malfeasance" + Malfeasance Hint = "malfeasance2" + ActiveSet Hint = "activeset" ) // NewBlobStore returns a BlobStore. -func NewBlobStore(db sql.Executor, proposals *store.Store) *BlobStore { - return &BlobStore{DB: db, proposals: proposals} +func NewBlobStore(db sql.StateDatabase, proposals *store.Store) *BlobStore { + return &BlobStore{ + DB: db, + proposals: proposals, + } +} + +// SetMalfeasanceProvider sets the malfeasance provider dependency. +// +// TODO(mafa): this is a hack because of a cyclic dependency between the packages +// +// malfeasance2 -> fetcher -> datastore -> malfeasance2 +func (bs *BlobStore) SetMalfeasanceProvider(p MalfeasanceProvider) { + bs.malfeasance = p +} + +//go:generate mockgen -typed -package=datastore -destination=./mocks.go -source=./store.go + +type MalfeasanceProvider interface { + ProofByID(ctx context.Context, nodeID types.NodeID) ([]byte, error) } // BlobStore gets data as a blob to serve direct fetch requests. type BlobStore struct { - DB sql.Executor - proposals *store.Store + DB sql.StateDatabase + proposals *store.Store + malfeasance MalfeasanceProvider } type ( @@ -247,22 +277,22 @@ var loadBlobDispatch = map[Hint]loadBlobFunc{ _, err := atxs.LoadBlob(ctx, db, key, blob) return err }, - BallotDB: ballots.LoadBlob, - BlockDB: blocks.LoadBlob, - TXDB: transactions.LoadBlob, - POETDB: poets.LoadBlob, - Malfeasance: identities.LoadMalfeasanceBlob, - ActiveSet: activesets.LoadBlob, + BallotDB: ballots.LoadBlob, + BlockDB: blocks.LoadBlob, + TXDB: transactions.LoadBlob, + POETDB: poets.LoadBlob, + LegacyMalfeasance: identities.LoadMalfeasanceBlob, + ActiveSet: activesets.LoadBlob, } var blobSizeDispatch = map[Hint]blobSizeFunc{ - ATXDB: atxs.GetBlobSizes, - BallotDB: ballots.GetBlobSizes, - BlockDB: blocks.GetBlobSizes, - TXDB: transactions.GetBlobSizes, - POETDB: poets.GetBlobSizes, - Malfeasance: identities.GetBlobSizes, - ActiveSet: activesets.GetBlobSizes, + ATXDB: atxs.GetBlobSizes, + BallotDB: ballots.GetBlobSizes, + BlockDB: blocks.GetBlobSizes, + TXDB: transactions.GetBlobSizes, + POETDB: poets.GetBlobSizes, + LegacyMalfeasance: identities.GetBlobSizes, + ActiveSet: activesets.GetBlobSizes, } func (bs *BlobStore) loadProposal(key []byte, blob *sql.Blob) error { @@ -279,7 +309,7 @@ func (bs *BlobStore) loadProposal(key []byte, blob *sql.Blob) error { } } -func (bs *BlobStore) getProposalSizes(keys [][]byte) (sizes []int, err error) { +func (bs *BlobStore) proposalSizes(keys [][]byte) (sizes []int, err error) { sizes = make([]int, len(keys)) for n, k := range keys { id := types.ProposalID(types.BytesToHash(k).ToHash20()) @@ -296,10 +326,45 @@ func (bs *BlobStore) getProposalSizes(keys [][]byte) (sizes []int, err error) { return sizes, err } +func (bs *BlobStore) loadMalfeasance(key []byte, blob *sql.Blob) error { + id := types.BytesToNodeID(key) + b, err := bs.malfeasance.ProofByID(context.Background(), id) + switch { + case err == nil: + blob.Bytes = b + return nil + case errors.Is(err, sql.ErrNotFound): + return ErrNotFound + default: + return err + } +} + +func (bs *BlobStore) malfeasanceSizes(keys [][]byte) (sizes []int, err error) { + sizes = make([]int, len(keys)) + for n, k := range keys { + id := types.NodeID(k) + b, err := bs.malfeasance.ProofByID(context.Background(), id) + switch { + case err == nil: + sizes[n] = len(b) + case errors.Is(err, store.ErrNotFound): + sizes[n] = -1 + default: + return nil, err + } + } + return sizes, err +} + // LoadBlob gets an blob as bytes by an object ID as bytes. func (bs *BlobStore) LoadBlob(ctx context.Context, hint Hint, key []byte, blob *sql.Blob) error { - if hint == ProposalDB { + switch hint { + case ProposalDB: return bs.loadProposal(key, blob) + case Malfeasance: + return bs.loadMalfeasance(key, blob) + default: } loader, found := loadBlobDispatch[hint] if !found { @@ -319,8 +384,12 @@ func (bs *BlobStore) LoadBlob(ctx context.Context, hint Hint, key []byte, blob * // GetBlobSizes returns the sizes of the blobs corresponding to the specified ids. For // non-existent objects, the corresponding items are set to -1. func (bs *BlobStore) GetBlobSizes(hint Hint, ids [][]byte) (sizes []int, err error) { - if hint == ProposalDB { - return bs.getProposalSizes(ids) + switch hint { + case ProposalDB: + return bs.proposalSizes(ids) + case Malfeasance: + return bs.malfeasanceSizes(ids) + default: } getSizes, found := blobSizeDispatch[hint] if !found { @@ -349,8 +418,10 @@ func (bs *BlobStore) Has(hint Hint, key []byte) (bool, error) { return transactions.Has(bs.DB, types.TransactionID(types.BytesToHash(key))) case POETDB: return poets.Has(bs.DB, types.ByteToPoetProofRef(key)) - case Malfeasance: + case LegacyMalfeasance: return identities.IsMalicious(bs.DB, types.BytesToNodeID(key)) + case Malfeasance: + return malfeasance.IsMalicious(bs.DB, types.BytesToNodeID(key)) case ActiveSet: return activesets.Has(bs.DB, types.BytesToHash(key)) } diff --git a/datastore/store_test.go b/datastore/store_test.go index 8683eb54bb..c5bdce9648 100644 --- a/datastore/store_test.go +++ b/datastore/store_test.go @@ -1,7 +1,6 @@ package datastore_test import ( - "bytes" "context" "errors" "os" @@ -9,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" "github.com/spacemeshos/go-spacemesh/codec" @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" @@ -35,18 +36,6 @@ func TestMain(m *testing.M) { os.Exit(res) } -type blobId interface { - Bytes() []byte -} - -func getBytes(ctx context.Context, bs *datastore.BlobStore, hint datastore.Hint, id blobId) ([]byte, error) { - var blob sql.Blob - if err := bs.LoadBlob(ctx, hint, id.Bytes(), &blob); err != nil { - return nil, err - } - return blob.Bytes, nil -} - func TestMalfeasanceProof_Dishonest(t *testing.T) { db := statesql.InMemoryTest(t) cdb := datastore.NewCachedDB(db, zaptest.NewLogger(t)) @@ -110,21 +99,19 @@ func TestBlobStore_GetATXBlob(t *testing.T) { require.NoError(t, err) require.False(t, has) - _, err = getBytes(context.Background(), bs, datastore.ATXDB, atx.ID()) + var blob sql.Blob + err = bs.LoadBlob(context.Background(), datastore.ATXDB, atx.ID().Bytes(), &blob) require.ErrorIs(t, err, datastore.ErrNotFound) - blob := types.AtxBlob{Blob: types.RandomBytes(100)} - require.NoError(t, atxs.Add(db, atx, blob)) + atxBlob := types.AtxBlob{Blob: types.RandomBytes(100)} + require.NoError(t, atxs.Add(db, atx, atxBlob)) has, err = bs.Has(datastore.ATXDB, atx.ID().Bytes()) require.NoError(t, err) require.True(t, has) - got, err := getBytes(context.Background(), bs, datastore.ATXDB, atx.ID()) + err = bs.LoadBlob(context.Background(), datastore.ATXDB, atx.ID().Bytes(), &blob) require.NoError(t, err) - require.Equal(t, blob.Blob, got) - - _, err = getBytes(context.Background(), bs, datastore.BallotDB, atx.ID()) - require.ErrorIs(t, err, datastore.ErrNotFound) + require.Equal(t, atxBlob.Blob, blob.Bytes) } func TestBlobStore_GetBallotBlob(t *testing.T) { @@ -142,23 +129,23 @@ func TestBlobStore_GetBallotBlob(t *testing.T) { has, err := bs.Has(datastore.BallotDB, blt.ID().Bytes()) require.NoError(t, err) require.False(t, has) - _, err = getBytes(context.Background(), bs, datastore.BallotDB, blt.ID()) + + var blob sql.Blob + err = bs.LoadBlob(context.Background(), datastore.BallotDB, blt.ID().Bytes(), &blob) require.ErrorIs(t, err, datastore.ErrNotFound) require.NoError(t, ballots.Add(db, blt)) has, err = bs.Has(datastore.BallotDB, blt.ID().Bytes()) require.NoError(t, err) require.True(t, has) - got, err := getBytes(context.Background(), bs, datastore.BallotDB, blt.ID()) + + err = bs.LoadBlob(context.Background(), datastore.BallotDB, blt.ID().Bytes(), &blob) require.NoError(t, err) var gotB types.Ballot - require.NoError(t, codec.Decode(got, &gotB)) + require.NoError(t, codec.Decode(blob.Bytes, &gotB)) require.NoError(t, gotB.Initialize()) require.Equal(t, *blt, gotB) - - _, err = getBytes(context.Background(), bs, datastore.BlockDB, blt.ID()) - require.ErrorIs(t, err, datastore.ErrNotFound) } func TestBlobStore_GetBlockBlob(t *testing.T) { @@ -177,22 +164,21 @@ func TestBlobStore_GetBlockBlob(t *testing.T) { require.NoError(t, err) require.False(t, has) - _, err = getBytes(context.Background(), bs, datastore.BlockDB, blk.ID()) + var blob sql.Blob + err = bs.LoadBlob(context.Background(), datastore.BlockDB, blk.ID().Bytes(), &blob) require.ErrorIs(t, err, datastore.ErrNotFound) require.NoError(t, blocks.Add(db, &blk)) has, err = bs.Has(datastore.BlockDB, blk.ID().Bytes()) require.NoError(t, err) require.True(t, has) - got, err := getBytes(context.Background(), bs, datastore.BlockDB, blk.ID()) + + err = bs.LoadBlob(context.Background(), datastore.BlockDB, blk.ID().Bytes(), &blob) require.NoError(t, err) var gotB types.Block - require.NoError(t, codec.Decode(got, &gotB)) + require.NoError(t, codec.Decode(blob.Bytes, &gotB)) gotB.Initialize() require.Equal(t, blk, gotB) - - _, err = getBytes(context.Background(), bs, datastore.ProposalDB, blk.ID()) - require.ErrorIs(t, err, datastore.ErrNotFound) } func TestBlobStore_GetPoetBlob(t *testing.T) { @@ -212,15 +198,14 @@ func TestBlobStore_GetPoetBlob(t *testing.T) { var poetRef types.PoetProofRef copy(poetRef[:], ref) require.NoError(t, poets.Add(db, poetRef, poet, sid, rid)) + has, err = bs.Has(datastore.POETDB, ref) require.NoError(t, err) require.True(t, has) var blob sql.Blob require.NoError(t, bs.LoadBlob(context.Background(), datastore.POETDB, poetRef[:], &blob)) - require.True(t, bytes.Equal(poet, blob.Bytes)) - - require.ErrorIs(t, bs.LoadBlob(context.Background(), datastore.BlockDB, ref, &sql.Blob{}), datastore.ErrNotFound) + require.Equal(t, poet, blob.Bytes) } func TestBlobStore_GetProposalBlob(t *testing.T) { @@ -245,17 +230,20 @@ func TestBlobStore_GetProposalBlob(t *testing.T) { has, err := bs.Has(datastore.ProposalDB, p.ID().Bytes()) require.NoError(t, err) require.False(t, has) - _, err = getBytes(context.Background(), bs, datastore.ProposalDB, p.ID()) + + var blob sql.Blob + err = bs.LoadBlob(context.Background(), datastore.ProposalDB, p.ID().Bytes(), &blob) require.ErrorIs(t, err, datastore.ErrNotFound) require.NoError(t, proposals.Add(&p)) has, err = bs.Has(datastore.ProposalDB, p.ID().Bytes()) require.NoError(t, err) require.True(t, has) - got, err := getBytes(context.Background(), bs, datastore.ProposalDB, p.ID()) + + err = bs.LoadBlob(context.Background(), datastore.ProposalDB, p.ID().Bytes(), &blob) require.NoError(t, err) var gotP types.Proposal - require.NoError(t, codec.Decode(got, &gotP)) + require.NoError(t, codec.Decode(blob.Bytes, &gotP)) require.NoError(t, gotP.Initialize()) require.Equal(t, p, gotP) } @@ -272,22 +260,21 @@ func TestBlobStore_GetTXBlob(t *testing.T) { require.NoError(t, err) require.False(t, has) - _, err = getBytes(context.Background(), bs, datastore.TXDB, tx.ID) + var blob sql.Blob + err = bs.LoadBlob(context.Background(), datastore.TXDB, tx.ID.Bytes(), &blob) require.ErrorIs(t, err, datastore.ErrNotFound) require.NoError(t, transactions.Add(db, tx, time.Now())) has, err = bs.Has(datastore.TXDB, tx.ID.Bytes()) require.NoError(t, err) require.True(t, has) - got, err := getBytes(context.Background(), bs, datastore.TXDB, tx.ID) - require.NoError(t, err) - require.Equal(t, tx.Raw, got) - _, err = getBytes(context.Background(), bs, datastore.BlockDB, tx.ID) - require.ErrorIs(t, err, datastore.ErrNotFound) + err = bs.LoadBlob(context.Background(), datastore.TXDB, tx.ID.Bytes(), &blob) + require.NoError(t, err) + require.Equal(t, tx.Raw, blob.Bytes) } -func TestBlobStore_GetMalfeasanceBlob(t *testing.T) { +func TestBlobStore_GetLegacyMalfeasanceBlob(t *testing.T) { db := statesql.InMemoryTest(t) bs := datastore.NewBlobStore(db, store.New()) @@ -304,20 +291,53 @@ func TestBlobStore_GetMalfeasanceBlob(t *testing.T) { require.NoError(t, err) nodeID := types.NodeID{1, 2, 3} - has, err := bs.Has(datastore.Malfeasance, nodeID.Bytes()) + has, err := bs.Has(datastore.LegacyMalfeasance, nodeID.Bytes()) require.NoError(t, err) require.False(t, has) - _, err = getBytes(context.Background(), bs, datastore.Malfeasance, nodeID) + var blob sql.Blob + err = bs.LoadBlob(context.Background(), datastore.LegacyMalfeasance, nodeID.Bytes(), &blob) require.ErrorIs(t, err, datastore.ErrNotFound) require.NoError(t, identities.SetMalicious(db, nodeID, encoded, time.Now())) + has, err = bs.Has(datastore.LegacyMalfeasance, nodeID.Bytes()) + require.NoError(t, err) + require.True(t, has) + + err = bs.LoadBlob(context.Background(), datastore.LegacyMalfeasance, nodeID.Bytes(), &blob) + require.NoError(t, err) + require.Equal(t, encoded, blob.Bytes) +} + +func TestBlobStore_GetMalfeasanceBlob(t *testing.T) { + db := statesql.InMemoryTest(t) + bs := datastore.NewBlobStore(db, store.New()) + + ctrl := gomock.NewController(t) + mMal := datastore.NewMockMalfeasanceProvider(ctrl) + bs.SetMalfeasanceProvider(mMal) + + proofBytes := types.RandomBytes(100) + nodeID := types.NodeID{1, 2, 3} + + has, err := bs.Has(datastore.Malfeasance, nodeID.Bytes()) + require.NoError(t, err) + require.False(t, has) + + mMal.EXPECT().ProofByID(gomock.Any(), nodeID).Return(nil, sql.ErrNotFound) + var blob sql.Blob + err = bs.LoadBlob(context.Background(), datastore.Malfeasance, nodeID.Bytes(), &blob) + require.ErrorIs(t, err, datastore.ErrNotFound) + + require.NoError(t, malfeasance.AddProof(db, nodeID, nil, proofBytes, 1, time.Now())) has, err = bs.Has(datastore.Malfeasance, nodeID.Bytes()) require.NoError(t, err) require.True(t, has) - got, err := getBytes(context.Background(), bs, datastore.Malfeasance, nodeID) + + mMal.EXPECT().ProofByID(gomock.Any(), nodeID).Return(proofBytes, nil) + err = bs.LoadBlob(context.Background(), datastore.Malfeasance, nodeID.Bytes(), &blob) require.NoError(t, err) - require.Equal(t, encoded, got) + require.Equal(t, proofBytes, blob.Bytes) } func TestBlobStore_GetActiveSet(t *testing.T) { @@ -331,14 +351,16 @@ func TestBlobStore_GetActiveSet(t *testing.T) { require.NoError(t, err) require.False(t, has) - _, err = getBytes(context.Background(), bs, datastore.ActiveSet, hash) + var blob sql.Blob + err = bs.LoadBlob(context.Background(), datastore.ActiveSet, hash.Bytes(), &blob) require.ErrorIs(t, err, datastore.ErrNotFound) require.NoError(t, activesets.Add(db, hash, as)) has, err = bs.Has(datastore.ActiveSet, hash.Bytes()) require.NoError(t, err) require.True(t, has) - got, err := getBytes(context.Background(), bs, datastore.ActiveSet, hash) + + err = bs.LoadBlob(context.Background(), datastore.ActiveSet, hash.Bytes(), &blob) require.NoError(t, err) - require.Equal(t, codec.MustEncode(as), got) + require.Equal(t, codec.MustEncode(as), blob.Bytes) } diff --git a/fetch/fetch.go b/fetch/fetch.go index a1206baf3f..807704c608 100644 --- a/fetch/fetch.go +++ b/fetch/fetch.go @@ -25,6 +25,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/server" "github.com/spacemeshos/go-spacemesh/proposals/store" + "github.com/spacemeshos/go-spacemesh/sql" ) const ( @@ -33,7 +34,8 @@ const ( hashProtocol = "hs/1" activeSetProtocol = "as/1" meshHashProtocol = "mh/1" - malProtocol = "ml/1" + legacyMalProtocol = "ml/1" + malProtocol = "ml/2" OpnProtocol = "lp/2" cacheSize = 1000 @@ -178,9 +180,11 @@ func DefaultConfig() Config { hashProtocol: {Queue: 2000, Requests: 200, Interval: time.Second}, // active sets (can get quite large) activeSetProtocol: {Queue: 10, Requests: 1, Interval: time.Second}, - // serves at most 100 hashes - 3KB + // serves at most 100 hashes - 3 KB meshHashProtocol: {Queue: 1000, Requests: 100, Interval: time.Second}, - // serves all malicious ids (id - 32 byte) - 10KB + // serves all legacy malicious ids (á 32 byte, ~2255 as of Jan 2025) - <100 KB + legacyMalProtocol: {Queue: 100, Requests: 10, Interval: time.Second}, + // serves all malicious ids (á 32 byte, 0 as of Jan 2025) - <100 KB malProtocol: {Queue: 100, Requests: 10, Interval: time.Second}, // 64 bytes OpnProtocol: {Queue: 10000, Requests: 1000, Interval: time.Second}, @@ -268,13 +272,13 @@ type Fetch struct { // NewFetch creates a new Fetch struct. func NewFetch( - cdb *datastore.CachedDB, + db sql.StateDatabase, proposals *store.Store, host *p2p.Host, peerCache *peers.Peers, opts ...Option, ) (*Fetch, error) { - bs := datastore.NewBlobStore(cdb, proposals) + bs := datastore.NewBlobStore(db, proposals) hashPeerCache, err := NewHashPeersCache(cacheSize) if err != nil { @@ -347,7 +351,7 @@ func NewFetch( f.batchTimeout = time.NewTicker(f.cfg.BatchTimeout) if len(f.servers) == 0 { - h := newHandler(cdb, bs, f.logger.Named("handler")) + h := newHandler(db, bs, f.logger.Named("handler")) if f.cfg.Streaming { f.registerServer(host, atxProtocol, h.handleEpochInfoReqStream) f.registerServer(host, hashProtocol, h.handleHashReqStream) @@ -357,6 +361,7 @@ func NewFetch( return h.doHandleHashReqStream(ctx, msg, s, datastore.ActiveSet) }) f.registerServer(host, meshHashProtocol, h.handleMeshHashReqStream) + f.registerServer(host, legacyMalProtocol, h.handleLegacyMaliciousIDsReqStream) f.registerServer(host, malProtocol, h.handleMaliciousIDsReqStream) } else { f.registerServer(host, atxProtocol, server.WrapHandler(h.handleEpochInfoReq)) @@ -367,6 +372,7 @@ func NewFetch( return h.doHandleHashReq(ctx, data, datastore.ActiveSet) })) f.registerServer(host, meshHashProtocol, server.WrapHandler(h.handleMeshHashReq)) + f.registerServer(host, legacyMalProtocol, server.WrapHandler(h.handleLegacyMaliciousIDsReq)) f.registerServer(host, malProtocol, server.WrapHandler(h.handleMaliciousIDsReq)) } f.registerServer(host, lyrDataProtocol, server.WrapHandler(h.handleLayerDataReq)) @@ -394,15 +400,25 @@ func (f *Fetch) registerServer( } type dataValidators struct { - atx SyncValidator - poet SyncValidator - ballot SyncValidator - activeset SyncValidator - block SyncValidator - proposal SyncValidator - txBlock SyncValidator - txProposal SyncValidator - malfeasance SyncValidator + atx SyncValidator + poet SyncValidator + ballot SyncValidator + activeset SyncValidator + block SyncValidator + proposal SyncValidator + txBlock SyncValidator + txProposal SyncValidator + legacyMalfeasance SyncValidator + malfeasance SyncValidator +} + +// SetMalfeasanceProvider sets the malfeasance provider dependency. +// +// TODO(mafa): this is a hack because of a cyclic dependency between the packages +// +// malfeasance2 -> fetcher -> datastore -> malfeasance2 +func (f *Fetch) SetMalfeasanceProvider(p datastore.MalfeasanceProvider) { + f.bs.SetMalfeasanceProvider(p) } // SetValidators sets the handlers to validate various mesh data fetched from peers. @@ -416,17 +432,19 @@ func (f *Fetch) SetValidators( txBlock SyncValidator, txProposal SyncValidator, mal SyncValidator, + mal2 SyncValidator, ) { f.validators = &dataValidators{ - atx: atx, - poet: poet, - ballot: ballot, - activeset: activeset, - block: block, - proposal: prop, - txBlock: txBlock, - txProposal: txProposal, - malfeasance: mal, + atx: atx, + poet: poet, + ballot: ballot, + activeset: activeset, + block: block, + proposal: prop, + txBlock: txBlock, + txProposal: txProposal, + legacyMalfeasance: mal, + malfeasance: mal2, } } diff --git a/fetch/fetch_test.go b/fetch/fetch_test.go index 24748d95ff..5b4ee38ccb 100644 --- a/fetch/fetch_test.go +++ b/fetch/fetch_test.go @@ -28,15 +28,15 @@ import ( type testFetch struct { *Fetch - mh *mocks.Mockhost - mMalS *mocks.Mockrequester - mAtxS *mocks.Mockrequester - mLyrS *mocks.Mockrequester - mHashS *mocks.Mockrequester - mMHashS *mocks.Mockrequester - mOpn2S *mocks.Mockrequester + mh *mocks.Mockhost + mAtxS *mocks.Mockrequester + mLyrS *mocks.Mockrequester + mHashS *mocks.Mockrequester + mMHashS *mocks.Mockrequester + mOpn2S *mocks.Mockrequester + mLegacyMalS *mocks.Mockrequester + mMalS *mocks.Mockrequester - mMalH *mocks.MockSyncValidator mAtxH *mocks.MockSyncValidator mBallotH *mocks.MockSyncValidator mActiveSetH *mocks.MockSyncValidator @@ -46,19 +46,23 @@ type testFetch struct { mTxBlocksH *mocks.MockSyncValidator mTxProposalH *mocks.MockSyncValidator mPoetH *mocks.MockSyncValidator + mLegacyMalH *mocks.MockSyncValidator + mMalH *mocks.MockSyncValidator } func createFetch(tb testing.TB) *testFetch { ctrl := gomock.NewController(tb) tf := &testFetch{ - mh: mocks.NewMockhost(ctrl), - mMalS: mocks.NewMockrequester(ctrl), - mAtxS: mocks.NewMockrequester(ctrl), - mLyrS: mocks.NewMockrequester(ctrl), - mHashS: mocks.NewMockrequester(ctrl), - mMHashS: mocks.NewMockrequester(ctrl), - mOpn2S: mocks.NewMockrequester(ctrl), - mMalH: mocks.NewMockSyncValidator(ctrl), + mh: mocks.NewMockhost(ctrl), + + mAtxS: mocks.NewMockrequester(ctrl), + mLyrS: mocks.NewMockrequester(ctrl), + mHashS: mocks.NewMockrequester(ctrl), + mMHashS: mocks.NewMockrequester(ctrl), + mOpn2S: mocks.NewMockrequester(ctrl), + mLegacyMalS: mocks.NewMockrequester(ctrl), + mMalS: mocks.NewMockrequester(ctrl), + mAtxH: mocks.NewMockSyncValidator(ctrl), mBallotH: mocks.NewMockSyncValidator(ctrl), mActiveSetH: mocks.NewMockSyncValidator(ctrl), @@ -67,10 +71,18 @@ func createFetch(tb testing.TB) *testFetch { mTxBlocksH: mocks.NewMockSyncValidator(ctrl), mTxProposalH: mocks.NewMockSyncValidator(ctrl), mPoetH: mocks.NewMockSyncValidator(ctrl), + mLegacyMalH: mocks.NewMockSyncValidator(ctrl), + mMalH: mocks.NewMockSyncValidator(ctrl), } - for _, srv := range []*mocks.Mockrequester{tf.mMalS, tf.mAtxS, tf.mLyrS, tf.mHashS, tf.mMHashS, tf.mOpn2S} { - srv.EXPECT().Run(gomock.Any()).AnyTimes() - } + + tf.mAtxS.EXPECT().Run(gomock.Any()).AnyTimes() + tf.mLyrS.EXPECT().Run(gomock.Any()).AnyTimes() + tf.mHashS.EXPECT().Run(gomock.Any()).AnyTimes() + tf.mMHashS.EXPECT().Run(gomock.Any()).AnyTimes() + tf.mOpn2S.EXPECT().Run(gomock.Any()).AnyTimes() + tf.mLegacyMalS.EXPECT().Run(gomock.Any()).AnyTimes() + tf.mMalS.EXPECT().Run(gomock.Any()).AnyTimes() + cfg := Config{ BatchTimeout: 2 * time.Second, // make sure we never hit the batch timeout BatchSize: 3, @@ -82,10 +94,9 @@ func createFetch(tb testing.TB) *testFetch { } lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(statesql.InMemoryTest(tb), lg) - tb.Cleanup(func() { require.NoError(tb, cdb.Close()) }) + db := statesql.InMemoryTest(tb) fetch, err := NewFetch( - cdb, + db, store.New(), nil, peers.New(), @@ -93,12 +104,13 @@ func createFetch(tb testing.TB) *testFetch { WithConfig(cfg), WithLogger(lg), withServers(map[string]requester{ - malProtocol: tf.mMalS, - atxProtocol: tf.mAtxS, - lyrDataProtocol: tf.mLyrS, - hashProtocol: tf.mHashS, - meshHashProtocol: tf.mMHashS, - OpnProtocol: tf.mOpn2S, + atxProtocol: tf.mAtxS, + lyrDataProtocol: tf.mLyrS, + hashProtocol: tf.mHashS, + meshHashProtocol: tf.mMHashS, + OpnProtocol: tf.mOpn2S, + legacyMalProtocol: tf.mLegacyMalS, + malProtocol: tf.mMalS, }), withHost(tf.mh), ) @@ -114,6 +126,7 @@ func createFetch(tb testing.TB) *testFetch { tf.mProposalH, tf.mTxBlocksH, tf.mTxProposalH, + tf.mLegacyMalH, tf.mMalH, ) return tf @@ -129,10 +142,9 @@ func badReceiver(context.Context, types.Hash32, p2p.Peer, []byte) error { func TestFetch_Start(t *testing.T) { lg := zaptest.NewLogger(t) - cdb := datastore.NewCachedDB(statesql.InMemoryTest(t), lg) - t.Cleanup(func() { require.NoError(t, cdb.Close()) }) + db := statesql.InMemoryTest(t) f, err := NewFetch( - cdb, + db, store.New(), nil, peers.New(), @@ -140,7 +152,7 @@ func TestFetch_Start(t *testing.T) { WithConfig(DefaultConfig()), WithLogger(lg), withServers(map[string]requester{ - malProtocol: nil, + atxProtocol: nil, }), ) require.NoError(t, err) @@ -404,10 +416,9 @@ func TestFetch_PeerDroppedWhenMessageResultsInValidationReject(t *testing.T) { }) defer eg.Wait() - cdb := datastore.NewCachedDB(statesql.InMemoryTest(t), lg) - t.Cleanup(func() { require.NoError(t, cdb.Close()) }) + db := statesql.InMemoryTest(t) fetcher, err := NewFetch( - cdb, + db, store.New(), h, peers.New(), @@ -422,7 +433,7 @@ func TestFetch_PeerDroppedWhenMessageResultsInValidationReject(t *testing.T) { vf := ValidatorFunc( func(context.Context, types.Hash32, peer.ID, []byte) error { return pubsub.ErrValidationReject }, ) - fetcher.SetValidators(vf, nil, nil, nil, nil, nil, nil, nil, nil) + fetcher.SetValidators(vf, nil, nil, nil, nil, nil, nil, nil, nil, nil) // Request an atx by hash _, err = fetcher.getHash( @@ -452,6 +463,7 @@ func TestFetch_PeerDroppedWhenMessageResultsInValidationReject(t *testing.T) { nil, nil, nil, + nil, ) // Request an atx by hash diff --git a/fetch/handler.go b/fetch/handler.go index 7f10f696af..21e9181da4 100644 --- a/fetch/handler.go +++ b/fetch/handler.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "time" "github.com/spacemeshos/go-scale" "go.uber.org/zap" @@ -18,56 +19,107 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/ballots" + "github.com/spacemeshos/go-spacemesh/sql/builder" "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" ) type handler struct { logger *zap.Logger - cdb *datastore.CachedDB + db sql.StateDatabase bs *datastore.BlobStore } func newHandler( - cdb *datastore.CachedDB, + db sql.StateDatabase, bs *datastore.BlobStore, lg *zap.Logger, ) *handler { return &handler{ logger: lg, - cdb: cdb, + db: db, bs: bs, } } +// handleLegacyMaliciousIDsReq returns the IDs of all known malicious nodes. +func (h *handler) handleLegacyMaliciousIDsReq(ctx context.Context, _ p2p.Peer, _ []byte) ([]byte, error) { + nodeIDs, err := identities.AllMalicious(h.db) + if err != nil { + return nil, fmt.Errorf("getting malicious IDs: %w", err) + } + h.logger.Debug("responded to malicious IDs request", log.ZContext(ctx), zap.Int("num_malicious", len(nodeIDs))) + malicious := &MaliciousIDs{ + NodeIDs: nodeIDs, + } + return codec.MustEncode(malicious), nil +} + // handleMaliciousIDsReq returns the IDs of all known malicious nodes. func (h *handler) handleMaliciousIDsReq(ctx context.Context, _ p2p.Peer, _ []byte) ([]byte, error) { - nodes, err := identities.AllMalicious(h.cdb) + tx, err := h.db.TxImmediate(ctx) + if err != nil { + return nil, fmt.Errorf("starting transaction: %w", err) + } + defer tx.Release() + total, err := malfeasance.Count(tx) + if err != nil { + return nil, fmt.Errorf("counting malicious nodes: %w", err) + } + nodeIDs := make([]types.NodeID, 0, total) + err = malfeasance.IterateOps(h.db, builder.Operations{}, + func(nodeID types.NodeID, _ []byte, _ int, _ time.Time) bool { + nodeIDs = append(nodeIDs, nodeID) + return true + }) if err != nil { return nil, fmt.Errorf("getting malicious IDs: %w", err) } - h.logger.Debug("responded to malicious IDs request", log.ZContext(ctx), zap.Int("num_malicious", len(nodes))) + h.logger.Debug("responded to malicious IDs request", log.ZContext(ctx), zap.Int("num_malicious", len(nodeIDs))) malicious := &MaliciousIDs{ - NodeIDs: nodes, + NodeIDs: nodeIDs, } return codec.MustEncode(malicious), nil } -func (h *handler) handleMaliciousIDsReqStream(ctx context.Context, _ p2p.Peer, msg []byte, s io.ReadWriter) error { +func (h *handler) handleLegacyMaliciousIDsReqStream(ctx context.Context, _ p2p.Peer, _ []byte, s io.ReadWriter) error { if err := h.streamIDs(ctx, s, func(cbk retrieveCallback) error { - nodeIDs, err := identities.AllMalicious(h.cdb) + nodeIDs, err := identities.AllMalicious(h.db) if err != nil { return fmt.Errorf("getting malicious IDs: %w", err) } for _, nodeID := range nodeIDs { - cbk(len(nodeIDs), nodeID[:]) + cbk(len(nodeIDs), nodeID.Bytes()) } return nil }); err != nil { h.logger.Debug("failed to stream malicious node IDs", log.ZContext(ctx), zap.Error(err)) } + return nil +} +func (h *handler) handleMaliciousIDsReqStream(ctx context.Context, _ p2p.Peer, _ []byte, s io.ReadWriter) error { + err := h.streamIDs(ctx, s, func(cbk retrieveCallback) error { + return h.db.WithTxImmediate(ctx, func(tx sql.Transaction) error { + total, err := malfeasance.Count(tx) + if err != nil { + return fmt.Errorf("counting malicious nodes: %w", err) + } + return malfeasance.IterateOps(tx, builder.Operations{}, + func(nodeID types.NodeID, _ []byte, _ int, _ time.Time) bool { + if err := cbk(total, nodeID.Bytes()); err != nil { + h.logger.Debug("failed to stream malicious node IDs", log.ZContext(ctx), zap.Error(err)) + return false + } + return true + }) + }) + }) + if err != nil { + h.logger.Debug("failed to stream malicious node IDs", log.ZContext(ctx), zap.Error(err)) + } return nil } @@ -78,7 +130,7 @@ func (h *handler) handleEpochInfoReq(ctx context.Context, _ p2p.Peer, msg []byte return nil, err } - atxids, err := atxs.GetIDsByEpoch(ctx, h.cdb, epoch) + atxids, err := atxs.GetIDsByEpoch(ctx, h.db, epoch) if err != nil { return nil, fmt.Errorf("getting ATX IDs: %w", err) } @@ -103,7 +155,7 @@ func (h *handler) handleEpochInfoReqStream(ctx context.Context, _ p2p.Peer, msg return err } if err := h.streamIDs(ctx, s, func(cbk retrieveCallback) error { - atxids, err := atxs.GetIDsByEpoch(ctx, h.cdb, epoch) + atxids, err := atxs.GetIDsByEpoch(ctx, h.db, epoch) if err != nil { return fmt.Errorf("getting ATX IDs: %w", err) } @@ -140,7 +192,7 @@ func (h *handler) streamIDs(ctx context.Context, s io.ReadWriter, retrieve retri return err } } - if _, err := s.Write(id[:]); err != nil { + if _, err := s.Write(id); err != nil { return err } return nil @@ -184,7 +236,7 @@ func (h *handler) handleLayerDataReq(ctx context.Context, _ p2p.Peer, req []byte if err := codec.Decode(req, &lid); err != nil { return nil, err } - ld.Ballots, err = ballots.IDsInLayer(h.cdb, lid) + ld.Ballots, err = ballots.IDsInLayer(h.db, lid) if err != nil && !errors.Is(err, sql.ErrNotFound) { return nil, fmt.Errorf("getting ballots for layer %d: %w", lid, err) } @@ -213,11 +265,11 @@ func (h *handler) handleLayerOpinionsReq2(ctx context.Context, _ p2p.Peer, data ) opnReqV2.Inc() - lo.PrevAggHash, err = layers.GetAggregatedHash(h.cdb, lid.Sub(1)) + lo.PrevAggHash, err = layers.GetAggregatedHash(h.db, lid.Sub(1)) if err != nil && !errors.Is(err, sql.ErrNotFound) { return nil, fmt.Errorf("getting aggregated hash for layer %d: %w", lid.Sub(1), err) } - bid, err := certificates.CertifiedBlock(h.cdb, lid) + bid, err := certificates.CertifiedBlock(h.db, lid) if err != nil && !errors.Is(err, sql.ErrNotFound) { return nil, fmt.Errorf("getting certified block for layer %d: %w", lid, err) } @@ -233,7 +285,7 @@ func (h *handler) handleLayerOpinionsReq2(ctx context.Context, _ p2p.Peer, data func (h *handler) handleCertReq(ctx context.Context, lid types.LayerID, bid types.BlockID) ([]byte, error) { certReq.Inc() - certs, err := certificates.Get(h.cdb, lid) + certs, err := certificates.Get(h.db, lid) if err != nil && !errors.Is(err, sql.ErrNotFound) { return nil, fmt.Errorf("getting certificates for layer %d: %w", lid, err) } @@ -425,7 +477,7 @@ func (h *handler) handleMeshHashReq(ctx context.Context, _ p2p.Peer, reqData []b if err := req.Validate(); err != nil { return nil, fmt.Errorf("validating request: %w", err) } - hashes, err = layers.GetAggHashes(h.cdb, req.From, req.To, req.Step) + hashes, err = layers.GetAggHashes(h.db, req.From, req.To, req.Step) if err != nil { return nil, err } @@ -454,7 +506,7 @@ func (h *handler) handleMeshHashReqStream(ctx context.Context, _ p2p.Peer, reqDa return fmt.Errorf("validating request: %w", err) } - hashes, err := layers.GetAggHashes(h.cdb, req.From, req.To, req.Step) + hashes, err := layers.GetAggHashes(h.db, req.From, req.To, req.Step) if err != nil { return err } diff --git a/fetch/handler_test.go b/fetch/handler_test.go index 2fcaba63ed..0b0b121703 100644 --- a/fetch/handler_test.go +++ b/fetch/handler_test.go @@ -23,28 +23,23 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testHandler struct { *handler - db sql.StateDatabase - cdb *datastore.CachedDB } func createTestHandler(tb testing.TB, opts ...sql.Opt) *testHandler { lg := zaptest.NewLogger(tb) db := statesql.InMemoryTest(tb, opts...) - cdb := datastore.NewCachedDB(db, lg) - tb.Cleanup(func() { require.NoError(tb, cdb.Close()) }) return &testHandler{ - handler: newHandler(cdb, datastore.NewBlobStore(cdb, store.New()), lg), - db: db, - cdb: cdb, + handler: newHandler(db, datastore.NewBlobStore(db, store.New()), lg), } } -func createLayer(tb testing.TB, db *datastore.CachedDB, lid types.LayerID) ([]types.BallotID, []types.BlockID) { +func createLayer(tb testing.TB, db sql.StateDatabase, lid types.LayerID) ([]types.BallotID, []types.BlockID) { num := 5 blts := make([]types.BallotID, 0, num) blks := make([]types.BlockID, 0, num) @@ -69,7 +64,7 @@ func createLayer(tb testing.TB, db *datastore.CachedDB, lid types.LayerID) ([]ty func createOpinions( tb testing.TB, - db *datastore.CachedDB, + db sql.StateDatabase, lid types.LayerID, genCert bool, ) (types.BlockID, types.Hash32) { @@ -104,7 +99,7 @@ func TestHandleLayerDataReq(t *testing.T) { lid := types.LayerID(111) th := createTestHandler(t) - blts, _ := createLayer(t, th.cdb, lid) + blts, _ := createLayer(t, th.db, lid) lidBytes, err := codec.Encode(&lid) require.NoError(t, err) @@ -143,13 +138,13 @@ func TestHandleLayerOpinionsReq(t *testing.T) { th := createTestHandler(t) lid := types.LayerID(111) - _, aggHash := createOpinions(t, th.cdb, lid, !tc.missingCert) + _, aggHash := createOpinions(t, th.db, lid, !tc.missingCert) if tc.multipleCerts { bid := types.RandomBlockID() - require.NoError(t, certificates.Add(th.cdb, lid, &types.Certificate{ + require.NoError(t, certificates.Add(th.db, lid, &types.Certificate{ BlockID: bid, })) - require.NoError(t, certificates.SetInvalid(th.cdb, lid, bid)) + require.NoError(t, certificates.SetInvalid(th.db, lid, bid)) } req := OpinionRequest{Layer: lid} @@ -188,7 +183,7 @@ func TestHandleCertReq(t *testing.T) { require.Nil(t, resp) cert := &types.Certificate{BlockID: bid} - require.NoError(t, certificates.Add(th.cdb, lid, cert)) + require.NoError(t, certificates.Add(th.db, lid, cert)) resp, err = th.handleLayerOpinionsReq2(context.Background(), p2p.Peer(""), reqData) require.NoError(t, err) @@ -244,7 +239,7 @@ func TestHandleMeshHashReq(t *testing.T) { } if !tc.hashMissing { for lid := req.From; !lid.After(req.To); lid = lid.Add(1) { - require.NoError(t, layers.SetMeshHash(th.cdb, lid, types.RandomHash())) + require.NoError(t, layers.SetMeshHash(th.db, lid, types.RandomHash())) } } reqData, err := codec.Encode(req) @@ -300,7 +295,7 @@ func TestHandleEpochInfoReq(t *testing.T) { if !tc.missingData { for i := 0; i < 10; i++ { vatx := newAtx(t, epoch) - require.NoError(t, atxs.Add(th.cdb, vatx, types.AtxBlob{})) + require.NoError(t, atxs.Add(th.db, vatx, types.AtxBlob{})) expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) } } @@ -341,6 +336,41 @@ func TestHandleEpochInfoReq(t *testing.T) { } } +func TestHandleLegacyMaliciousIDsReq(t *testing.T) { + tt := []struct { + name string + numBad int + }{ + { + name: "some bad guys", + numBad: 11, + }, + { + name: "no bad guys", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + th := createTestHandler(t) + var bad []types.NodeID + for i := 0; i < tc.numBad; i++ { + nodeID := types.NodeID{byte(i + 1)} + bad = append(bad, nodeID) + require.NoError(t, identities.SetMalicious(th.db, nodeID, types.RandomBytes(11), time.Now())) + } + + out, err := th.handleLegacyMaliciousIDsReq(context.Background(), p2p.Peer(""), []byte{}) + require.NoError(t, err) + var got MaliciousIDs + require.NoError(t, codec.Decode(out, &got)) + require.ElementsMatch(t, bad, got.NodeIDs) + }) + } +} + func TestHandleMaliciousIDsReq(t *testing.T) { tt := []struct { name string @@ -362,9 +392,9 @@ func TestHandleMaliciousIDsReq(t *testing.T) { th := createTestHandler(t) var bad []types.NodeID for i := 0; i < tc.numBad; i++ { - nid := types.NodeID{byte(i + 1)} - bad = append(bad, nid) - require.NoError(t, identities.SetMalicious(th.cdb, nid, types.RandomBytes(11), time.Now())) + nodeID := types.NodeID{byte(i + 1)} + bad = append(bad, nodeID) + require.NoError(t, malfeasance.AddProof(th.db, nodeID, nil, types.RandomBytes(11), 1, time.Now())) } out, err := th.handleMaliciousIDsReq(context.Background(), p2p.Peer(""), []byte{}) diff --git a/fetch/mesh_data.go b/fetch/mesh_data.go index 34a5f62181..c060659adc 100644 --- a/fetch/mesh_data.go +++ b/fetch/mesh_data.go @@ -160,12 +160,28 @@ func (f *Fetch) GetActiveSet(ctx context.Context, set types.Hash32) error { return f.getHashes(ctx, []types.Hash32{set}, datastore.ActiveSet, f.validators.activeset.HandleMessage) } -// GetMalfeasanceProofs gets malfeasance proofs for the specified NodeIDs and validates them. -func (f *Fetch) GetMalfeasanceProofs(ctx context.Context, ids []types.NodeID) error { +// LegacyMalfeasanceProofs gets legacy malfeasance proofs (v1) for the specified NodeIDs and validates them. +func (f *Fetch) LegacyMalfeasanceProofs(ctx context.Context, ids []types.NodeID) error { if len(ids) == 0 { return nil } - f.logger.Debug("requesting malfeasance proofs from peer", log.ZContext(ctx), zap.Int("num_proofs", len(ids))) + f.logger.Debug("requesting legacy malfeasance proofs from peers", + log.ZContext(ctx), + zap.Int("num_proofs", len(ids)), + ) + hashes := types.NodeIDsToHashes(ids) + return f.getHashes(ctx, hashes, datastore.LegacyMalfeasance, f.validators.legacyMalfeasance.HandleMessage) +} + +// MalfeasanceProofs gets malfeasance proofs (v2) for the specified NodeIDs and validates them. +func (f *Fetch) MalfeasanceProofs(ctx context.Context, ids []types.NodeID) error { + if len(ids) == 0 { + return nil + } + f.logger.Debug("requesting malfeasance proofs from peers", + log.ZContext(ctx), + zap.Int("num_proofs", len(ids)), + ) hashes := types.NodeIDsToHashes(ids) return f.getHashes(ctx, hashes, datastore.Malfeasance, f.validators.malfeasance.HandleMessage) } @@ -269,27 +285,49 @@ func (f *Fetch) GetPoetProof(ctx context.Context, id types.Hash32) error { log.ZContext(ctx), zap.String("hint", string(datastore.POETDB)), zap.Stringer("hash", id), - zap.Error(pm.err)) + zap.Error(pm.err), + ) return pm.err } } -func (f *Fetch) GetMaliciousIDs(ctx context.Context, peer p2p.Peer) ([]types.NodeID, error) { +// LegacyMaliciousIDs gets the malicious IDs from the specified peer. Proofs for those IDs can be fetched via the +// legacy malfeasance proofs protocol (see also LegacyMalfeasanceProofs). +func (f *Fetch) LegacyMaliciousIDs(ctx context.Context, peer p2p.Peer) ([]types.NodeID, error) { var malIDs MaliciousIDs - if f.cfg.Streaming { - if err := f.meteredStreamRequest( - ctx, malProtocol, peer, []byte{}, - func(ctx context.Context, s io.ReadWriter) (int, error) { - total, err := readIDSlice(s, &malIDs.NodeIDs, maxMaliciousIDs) - if ctx.Err() != nil { - return total, ctx.Err() - } - return total, err - }, - ); err != nil { + if !f.cfg.Streaming { + data, err := f.meteredRequest(ctx, legacyMalProtocol, peer, []byte{}) + if err != nil { return nil, err } - } else { + if err := codec.Decode(data, &malIDs); err != nil { + return nil, err + } + f.RegisterPeerHashes(peer, types.NodeIDsToHashes(malIDs.NodeIDs)) + return malIDs.NodeIDs, nil + } + + err := f.meteredStreamRequest(ctx, legacyMalProtocol, peer, []byte{}, + func(ctx context.Context, s io.ReadWriter) (int, error) { + total, err := readIDSlice(s, &malIDs.NodeIDs, maxMaliciousIDs) + if ctx.Err() != nil { + return total, ctx.Err() + } + return total, err + }, + ) + if err != nil { + return nil, err + } + f.RegisterPeerHashes(peer, types.NodeIDsToHashes(malIDs.NodeIDs)) + return malIDs.NodeIDs, nil +} + +// MaliciousIDs gets the malicious IDs from the specified peer. Proofs for those IDs can be fetched via the malfeasance +// proof protocol (see also MalfeasanceProofs). +func (f *Fetch) MaliciousIDs(ctx context.Context, peer p2p.Peer) ([]types.NodeID, error) { + var malIDs MaliciousIDs + if !f.cfg.Streaming { data, err := f.meteredRequest(ctx, malProtocol, peer, []byte{}) if err != nil { return nil, err @@ -297,6 +335,21 @@ func (f *Fetch) GetMaliciousIDs(ctx context.Context, peer p2p.Peer) ([]types.Nod if err := codec.Decode(data, &malIDs); err != nil { return nil, err } + f.RegisterPeerHashes(peer, types.NodeIDsToHashes(malIDs.NodeIDs)) + return malIDs.NodeIDs, nil + } + + err := f.meteredStreamRequest(ctx, malProtocol, peer, []byte{}, + func(ctx context.Context, s io.ReadWriter) (int, error) { + total, err := readIDSlice(s, &malIDs.NodeIDs, maxMaliciousIDs) + if ctx.Err() != nil { + return total, ctx.Err() + } + return total, err + }, + ) + if err != nil { + return nil, err } f.RegisterPeerHashes(peer, types.NodeIDsToHashes(malIDs.NodeIDs)) return malIDs.NodeIDs, nil diff --git a/fetch/mesh_data_test.go b/fetch/mesh_data_test.go index 82c5d422f3..74a003f7d4 100644 --- a/fetch/mesh_data_test.go +++ b/fetch/mesh_data_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/codec" @@ -100,15 +99,6 @@ func startTestLoop(tb testing.TB, f *Fetch, eg *errgroup.Group, stop chan struct }) } -func generateMaliciousIDs(tb testing.TB) []types.NodeID { - tb.Helper() - malIDs := make([]types.NodeID, numMalicious) - for i := range malIDs { - malIDs[i] = types.RandomNodeID() - } - return malIDs -} - func generateLayerContent(tb testing.TB) []byte { tb.Helper() ballotIDs := make([]types.BallotID, 0, numBallots) @@ -384,7 +374,24 @@ func TestFetch_getHashesStreaming(t *testing.T) { }) } -func TestFetch_GetMalfeasanceProofs(t *testing.T) { +func TestFetch_LegacyMalfeasanceProofs(t *testing.T) { + nodeIDs := []types.NodeID{{1}, {2}, {3}} + f := createFetch(t) + f.mLegacyMalH.EXPECT(). + HandleMessage(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + Times(len(nodeIDs)) + + stop := make(chan struct{}, 1) + var eg errgroup.Group + startTestLoop(t, f.Fetch, &eg, stop) + + require.NoError(t, f.LegacyMalfeasanceProofs(context.Background(), nodeIDs)) + close(stop) + require.NoError(t, eg.Wait()) +} + +func TestFetch_MalfeasanceProofs(t *testing.T) { nodeIDs := []types.NodeID{{1}, {2}, {3}} f := createFetch(t) f.mMalH.EXPECT(). @@ -396,7 +403,7 @@ func TestFetch_GetMalfeasanceProofs(t *testing.T) { var eg errgroup.Group startTestLoop(t, f.Fetch, &eg, stop) - require.NoError(t, f.GetMalfeasanceProofs(context.Background(), nodeIDs)) + require.NoError(t, f.MalfeasanceProofs(context.Background(), nodeIDs)) close(stop) require.NoError(t, eg.Wait()) } @@ -659,24 +666,53 @@ func TestGetPoetProof(t *testing.T) { require.NoError(t, eg.Wait()) } -func TestFetch_GetMaliciousIDs(t *testing.T) { +func TestFetch_LegacyMaliciousIDs(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() f := createFetch(t) - expectedIds := generateMaliciousIDs(t) - resp := codec.MustEncode(&MaliciousIDs{NodeIDs: expectedIds}) + expectedIDs := make([]types.NodeID, numMalicious) + for i := range expectedIDs { + expectedIDs[i] = types.RandomNodeID() + } + resp := codec.MustEncode(&MaliciousIDs{NodeIDs: expectedIDs}) + f.mh.EXPECT().ID().Return("self").AnyTimes() + f.mLegacyMalS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), []byte{}).Return(resp, nil) + ids, err := f.LegacyMaliciousIDs(context.Background(), "p0") + require.NoError(t, err) + require.Equal(t, expectedIDs, ids) + }) + t.Run("failure", func(t *testing.T) { + t.Parallel() + errUnknown := errors.New("unknown") + f := createFetch(t) + f.mLegacyMalS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), []byte{}).Return(nil, errUnknown) + ids, err := f.LegacyMaliciousIDs(context.Background(), "p0") + require.ErrorIs(t, err, errUnknown) + require.Nil(t, ids) + }) +} + +func TestFetch_MaliciousIDs(t *testing.T) { + t.Run("success", func(t *testing.T) { + t.Parallel() + f := createFetch(t) + expectedIDs := make([]types.NodeID, numMalicious) + for i := range expectedIDs { + expectedIDs[i] = types.RandomNodeID() + } + resp := codec.MustEncode(&MaliciousIDs{NodeIDs: expectedIDs}) f.mh.EXPECT().ID().Return("self").AnyTimes() f.mMalS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), []byte{}).Return(resp, nil) - ids, err := f.GetMaliciousIDs(context.Background(), "p0") + ids, err := f.MaliciousIDs(context.Background(), "p0") require.NoError(t, err) - require.Equal(t, expectedIds, ids) + require.Equal(t, expectedIDs, ids) }) t.Run("failure", func(t *testing.T) { t.Parallel() errUnknown := errors.New("unknown") f := createFetch(t) f.mMalS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), []byte{}).Return(nil, errUnknown) - ids, err := f.GetMaliciousIDs(context.Background(), "p0") + ids, err := f.MaliciousIDs(context.Background(), "p0") require.ErrorIs(t, err, errUnknown) require.Nil(t, ids) }) @@ -1019,13 +1055,12 @@ func Test_GetAtxsLimiting(t *testing.T) { cfg.QueueSize = 1000 cfg.GetAtxsConcurrency = getAtxConcurrency - cdb := datastore.NewCachedDB(statesql.InMemoryTest(t), zaptest.NewLogger(t)) - t.Cleanup(func() { require.NoError(t, cdb.Close()) }) + db := statesql.InMemoryTest(t) client := server.New(wrapHost(mesh.Hosts()[0]), hashProtocol, nil) host, err := p2p.Upgrade(mesh.Hosts()[0]) require.NoError(t, err) ps := peers.New() - f, err := NewFetch(cdb, store.New(), host, + f, err := NewFetch(db, store.New(), host, ps, WithContext(context.Background()), withServers(map[string]requester{hashProtocol: client}), diff --git a/fetch/p2p_test.go b/fetch/p2p_test.go index 183e5cafa1..52ba8bf5e2 100644 --- a/fetch/p2p_test.go +++ b/fetch/p2p_test.go @@ -27,6 +27,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" @@ -42,14 +43,17 @@ type testP2PFetch struct { tb testing.TB // client proposals clientPDB *store.Store - clientCDB *datastore.CachedDB + clientDB sql.StateDatabase clientFetch *Fetch serverID peer.ID serverDB sql.StateDatabase // server proposals - serverPDB *store.Store - serverCDB *datastore.CachedDB - serverFetch *Fetch + serverPDB *store.Store + serverCDB *datastore.CachedDB + serverFetch *Fetch + + malProvider *malProvider + recvMtx sync.Mutex receivedData map[blobKey][]byte } @@ -83,6 +87,17 @@ func p2pCfg(tb testing.TB) p2p.Config { return p2pconf } +type malProvider struct { + db sql.StateDatabase +} + +func (m *malProvider) ProofByID(ctx context.Context, nodeID types.NodeID) ([]byte, error) { + // this is an incomplete implementation, the proof returned here normally needs to be wrapped into a + // malfeasance2.MalfeasanceProof struct before being returned to the fetcher, but for the test it is sufficient + proof, _, err := malfeasance.NodeIDProof(m.db, nodeID) + return proof, err +} + func createP2PFetch( tb testing.TB, clientStreaming, @@ -113,24 +128,23 @@ func createP2PFetch( sqlOpts = []sql.Opt{sql.WithQueryCache(true)} } clientDB := statesql.InMemoryTest(tb, sqlOpts...) - clientCDB := datastore.NewCachedDB(clientDB, lg) - tb.Cleanup(func() { assert.NoError(tb, clientDB.Close()) }) serverDB := statesql.InMemoryTest(tb, sqlOpts...) serverCDB := datastore.NewCachedDB(serverDB, lg) tb.Cleanup(func() { assert.NoError(tb, serverDB.Close()) }) tpf := &testP2PFetch{ tb: tb, clientPDB: store.New(store.WithLogger(lg)), - clientCDB: clientCDB, + clientDB: clientDB, serverID: serverHost.ID(), serverDB: serverDB, serverPDB: store.New(store.WithLogger(lg)), serverCDB: serverCDB, + malProvider: &malProvider{db: serverDB}, receivedData: make(map[blobKey][]byte), } fetcher, err := NewFetch( - tpf.serverCDB, + tpf.serverDB, tpf.serverPDB, serverHost, peers.New(), @@ -143,7 +157,8 @@ func createP2PFetch( vf := ValidatorFunc( func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }, ) - tpf.serverFetch.SetValidators(vf, vf, vf, vf, vf, vf, vf, vf, vf) + tpf.serverFetch.SetMalfeasanceProvider(tpf.malProvider) + tpf.serverFetch.SetValidators(vf, vf, vf, vf, vf, vf, vf, vf, vf, vf) require.NoError(tb, tpf.serverFetch.Start()) tb.Cleanup(tpf.serverFetch.Stop) @@ -152,7 +167,7 @@ func createP2PFetch( }, 10*time.Second, 10*time.Millisecond) fetcher, err = NewFetch( - tpf.clientCDB, + tpf.clientDB, tpf.clientPDB, clientHost, peers.New(), @@ -172,6 +187,7 @@ func createP2PFetch( mkFakeValidator(tpf, "txBlock"), mkFakeValidator(tpf, "txProposal"), mkFakeValidator(tpf, "mal"), + mkFakeValidator(tpf, "mal2"), ) require.NoError(tb, tpf.clientFetch.Start()) tb.Cleanup(tpf.clientFetch.Stop) @@ -340,6 +356,30 @@ func TestP2PPeerMeshHashes(t *testing.T) { }) } +func TestP2PLegacyMaliciousIDs(t *testing.T) { + forStreaming( + t, "database closed", false, + func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { + var bad []types.NodeID + for i := 0; i < 11; i++ { + nid := types.NodeID{byte(i + 1)} + bad = append(bad, nid) + require.NoError(t, identities.SetMalicious(tpf.serverCDB, nid, types.RandomBytes(11), time.Now())) + } + if errStr != "" { + tpf.serverDB.Close() + } + + malIDs, err := tpf.clientFetch.LegacyMaliciousIDs(context.Background(), tpf.serverID) + if errStr == "" { + require.NoError(t, err) + require.ElementsMatch(t, bad, malIDs) + } else { + require.ErrorContains(t, err, errStr) + } + }) +} + func TestP2PMaliciousIDs(t *testing.T) { forStreaming( t, "database closed", false, @@ -348,14 +388,13 @@ func TestP2PMaliciousIDs(t *testing.T) { for i := 0; i < 11; i++ { nid := types.NodeID{byte(i + 1)} bad = append(bad, nid) - require.NoError(t, identities.SetMalicious( - tpf.serverCDB, nid, types.RandomBytes(11), time.Now())) + require.NoError(t, malfeasance.AddProof(tpf.serverCDB, nid, nil, types.RandomBytes(11), 1, time.Now())) } if errStr != "" { tpf.serverDB.Close() } - malIDs, err := tpf.clientFetch.GetMaliciousIDs(context.Background(), tpf.serverID) + malIDs, err := tpf.clientFetch.MaliciousIDs(context.Background(), tpf.serverID) if errStr == "" { require.NoError(t, err) require.ElementsMatch(t, bad, malIDs) @@ -522,16 +561,35 @@ func TestP2PGetProposalTransactions(t *testing.T) { }) } -func TestP2PGetMalfeasanceProofs(t *testing.T) { +func TestP2PLegacyMalfeasanceProofs(t *testing.T) { forStreaming( t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { - nid := types.RandomNodeID() + nodeID := types.RandomNodeID() proof := types.RandomBytes(11) - require.NoError(t, identities.SetMalicious(tpf.serverCDB, nid, proof, time.Now())) + require.NoError(t, identities.SetMalicious(tpf.serverCDB, nodeID, proof, time.Now())) tpf.verifyGetHash( - func() error { return tpf.clientFetch.GetMalfeasanceProofs(context.Background(), []types.NodeID{nid}) }, - errStr, "mal", "hs/1", types.Hash32(nid), nid.Bytes(), + func() error { + return tpf.clientFetch.LegacyMalfeasanceProofs(context.Background(), []types.NodeID{nodeID}) + }, + errStr, "mal", "hs/1", types.Hash32(nodeID), nodeID.Bytes(), + proof, + ) + }) +} + +func TestP2PMalfeasanceProofs(t *testing.T) { + forStreaming( + t, "database closed", false, + func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { + nodeID := types.RandomNodeID() + proof := types.RandomBytes(11) + require.NoError(t, malfeasance.AddProof(tpf.serverCDB, nodeID, nil, proof, 1, time.Now())) + tpf.verifyGetHash( + func() error { + return tpf.clientFetch.MalfeasanceProofs(context.Background(), []types.NodeID{nodeID}) + }, + errStr, "mal2", "hs/1", types.Hash32(nodeID), nodeID.Bytes(), proof, ) }) diff --git a/go.mod b/go.mod index 0fffd92243..3b83b3ef57 100644 --- a/go.mod +++ b/go.mod @@ -88,7 +88,6 @@ require ( github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/c0mm4nd/go-ripemd v0.0.0-20200326052756-bd1759ad7d10 // indirect - github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect github.com/containerd/cgroups v1.1.0 // indirect @@ -99,7 +98,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/elastic/gosigar v0.14.3 // indirect github.com/emicklei/go-restful/v3 v3.12.1 // indirect - github.com/envoyproxy/go-control-plane v0.13.1 // indirect + github.com/envoyproxy/go-control-plane/envoy v1.32.3 // indirect github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect github.com/ericlagergren/decimal v0.0.0-20240411145413-00de7ca16731 // indirect github.com/evanphx/json-patch/v5 v5.9.0 // indirect diff --git a/go.sum b/go.sum index 977abd4458..f64d4e7441 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,6 @@ github.com/bxcodec/faker v2.0.1+incompatible/go.mod h1:BNzfpVdTwnFJ6GtfYTcQu6l6r github.com/c0mm4nd/go-ripemd v0.0.0-20200326052756-bd1759ad7d10 h1:wJ2csnFApV9G1jgh5KmYdxVOQMi+fihIggVTjcbM7ts= github.com/c0mm4nd/go-ripemd v0.0.0-20200326052756-bd1759ad7d10/go.mod h1:mYPR+a1fzjnHY3VFH5KL3PkEjMlVfGXP7c8rbWlkLJg= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= -github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chaos-mesh/chaos-mesh/api v0.0.0-20250108051104-b3d81ecc62fa h1:0OwWndUnfgo4ZC1jKF08aRZmPFxHGEBpa5AQPYZOu5E= @@ -122,8 +120,12 @@ github.com/emicklei/go-restful/v3 v3.12.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRr github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/go-control-plane v0.13.1 h1:vPfJZCkob6yTMEgS+0TwfTUfbHjfy/6vOJ8hUWX/uXE= -github.com/envoyproxy/go-control-plane v0.13.1/go.mod h1:X45hY0mufo6Fd0KW3rqsGvQMw58jvjymeCzBU3mWyHw= +github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M= +github.com/envoyproxy/go-control-plane v0.13.4/go.mod h1:kDfuBlDVsSj2MjrLEtRWtHlsWIFcGyB2RMO44Dc5GZA= +github.com/envoyproxy/go-control-plane/envoy v1.32.3 h1:hVEaommgvzTjTd4xCaFd+kEQ2iYBtGxP6luyLrx6uOk= +github.com/envoyproxy/go-control-plane/envoy v1.32.3/go.mod h1:F6hWupPfh75TBXGKA++MCT/CZHFq5r9/uwt/kQYkZfE= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v1.1.0 h1:tntQDh69XqOCOZsDz0lVJQez/2L6Uu2PdjCQwWCJ3bM= github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4= diff --git a/malfeasance/handler.go b/malfeasance/handler.go index 005a90b86d..88b05b948a 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -152,8 +152,8 @@ func (h *Handler) Info(ctx context.Context, nodeID types.NodeID) (map[string]str return properties, nil } -// HandleSyncedMalfeasanceProof is the sync validator for MalfeasanceProof. -func (h *Handler) HandleSyncedMalfeasanceProof( +// HandleSynced is the sync validator for MalfeasanceProof. +func (h *Handler) HandleSynced( ctx context.Context, expHash types.Hash32, peer p2p.Peer, @@ -188,8 +188,8 @@ func (h *Handler) HandleSyncedMalfeasanceProof( return err } -// HandleMalfeasanceProof is the gossip receiver for MalfeasanceGossip. -func (h *Handler) HandleMalfeasanceProof(ctx context.Context, peer p2p.Peer, data []byte) error { +// HandleGossip is the gossip receiver for MalfeasanceGossip. +func (h *Handler) HandleGossip(ctx context.Context, peer p2p.Peer, data []byte) error { var p wire.MalfeasanceGossip if err := codec.Decode(data, &p); err != nil { h.numMalformed.Inc() diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index 85dc77304d..ebf5093e83 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -73,7 +73,7 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) { t.Run("malformed data", func(t *testing.T) { h := newHandler(t) - err := h.HandleMalfeasanceProof(context.Background(), "peer", []byte{0x01}) + err := h.HandleGossip(context.Background(), "peer", []byte{0x01}) require.ErrorIs(t, err, errMalformedData) require.ErrorIs(t, err, pubsub.ErrValidationReject) @@ -98,7 +98,7 @@ spacemesh_malfeasance_num_invalid_proofs{type="mal"} 1 }, } - err := h.HandleMalfeasanceProof(context.Background(), "peer", codec.MustEncode(gossip)) + err := h.HandleGossip(context.Background(), "peer", codec.MustEncode(gossip)) require.ErrorIs(t, err, errUnknownProof) require.ErrorIs(t, err, pubsub.ErrValidationReject) @@ -134,7 +134,7 @@ spacemesh_malfeasance_num_invalid_proofs{type="mal"} 1 }, } - err := h.HandleMalfeasanceProof(context.Background(), "peer", codec.MustEncode(gossip)) + err := h.HandleGossip(context.Background(), "peer", codec.MustEncode(gossip)) require.ErrorContains(t, err, "invalid proof") require.ErrorIs(t, err, pubsub.ErrValidationReject) @@ -173,7 +173,7 @@ spacemesh_malfeasance_num_invalid_proofs{type="multiATXs"} 1 } h.mockTrt.EXPECT().OnMalfeasance(nodeID) - err := h.HandleMalfeasanceProof(context.Background(), "peer", codec.MustEncode(gossip)) + err := h.HandleGossip(context.Background(), "peer", codec.MustEncode(gossip)) require.NoError(t, err) var blob sql.Blob @@ -221,7 +221,7 @@ spacemesh_malfeasance_num_proofs{type="multiATXs"} 1 }, } - err := h.HandleMalfeasanceProof(context.Background(), "peer", codec.MustEncode(gossip)) + err := h.HandleGossip(context.Background(), "peer", codec.MustEncode(gossip)) require.NoError(t, err) var blob sql.Blob @@ -234,7 +234,7 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { t.Run("malformed data", func(t *testing.T) { h := newHandler(t) - err := h.HandleSyncedMalfeasanceProof( + err := h.HandleSynced( context.Background(), types.RandomHash(), "peer", @@ -262,7 +262,7 @@ spacemesh_malfeasance_num_invalid_proofs{type="mal"} 1 }, } - err := h.HandleSyncedMalfeasanceProof( + err := h.HandleSynced( context.Background(), types.RandomHash(), "peer", @@ -304,7 +304,7 @@ spacemesh_malfeasance_num_invalid_proofs{type="mal"} 1 expectedHash := types.RandomHash() h.mockTrt.EXPECT().OnMalfeasance(nodeID) - err := h.HandleSyncedMalfeasanceProof( + err := h.HandleSynced( context.Background(), expectedHash, "peer", @@ -351,7 +351,7 @@ spacemesh_malfeasance_num_proofs{type="multiATXs"} 1 }, } - err := h.HandleSyncedMalfeasanceProof( + err := h.HandleSynced( context.Background(), types.Hash32(nodeID), "peer", @@ -394,7 +394,7 @@ spacemesh_malfeasance_num_invalid_proofs{type="multiATXs"} 1 proofBytes := codec.MustEncode(proof) h.mockTrt.EXPECT().OnMalfeasance(nodeID) - err := h.HandleSyncedMalfeasanceProof(context.Background(), types.Hash32(nodeID), "peer", proofBytes) + err := h.HandleSynced(context.Background(), types.Hash32(nodeID), "peer", proofBytes) require.NoError(t, err) var blob sql.Blob @@ -443,7 +443,7 @@ spacemesh_malfeasance_num_proofs{type="multiATXs"} 1 newProofBytes := codec.MustEncode(newProof) require.NotEqual(t, proofBytes, newProofBytes) - err := h.HandleSyncedMalfeasanceProof(context.Background(), types.Hash32(nodeID), "peer", newProofBytes) + err := h.HandleSynced(context.Background(), types.Hash32(nodeID), "peer", newProofBytes) require.NoError(t, err) var blob sql.Blob diff --git a/malfeasance2/handler.go b/malfeasance2/handler.go index c2d04a7374..c4d75c58c3 100644 --- a/malfeasance2/handler.go +++ b/malfeasance2/handler.go @@ -21,7 +21,6 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/marriage" "github.com/spacemeshos/go-spacemesh/system" @@ -36,7 +35,7 @@ var ( type Handler struct { logger *zap.Logger - db sql.Executor + db sql.StateDatabase self p2p.Peer nodeIDs []types.NodeID fetcher system.Fetcher @@ -51,7 +50,7 @@ type Handler struct { } func NewHandler( - db sql.Executor, + db sql.StateDatabase, lg *zap.Logger, self p2p.Peer, nodeIDs []types.NodeID, @@ -262,24 +261,20 @@ func (h *Handler) handleProof(ctx context.Context, peer p2p.Peer, proof Malfeasa return nil, fmt.Errorf("%w: %d", ErrUnknownDomain, proof.Domain) } + if err := h.fetchReferences(ctx, peer, proof.RefATXs); err != nil { + return nil, fmt.Errorf("fetch references: %w", err) + } + nodeID, err := handler.Validate(ctx, proof.Proof) if err != nil { h.countInvalidProof(proof) return nil, err } - if err := h.fetchReferences(ctx, peer, proof.RefATXs); err != nil { - return nil, fmt.Errorf("fetch references: %w", err) - } - mID, err := marriage.FindIDByNodeID(h.db, nodeID) switch { case errors.Is(err, sql.ErrNotFound): - // smesher is not married, check if identity exists in the DB - _, err := atxs.GetFirstIDByNodeID(h.db, nodeID) - if err != nil { - return nil, fmt.Errorf("%w: missing proof for identities existence", ErrMalformedData) - } + // smesher is not married return []types.NodeID{nodeID}, nil case err != nil: return nil, fmt.Errorf("get marriage ID for %s: %w", nodeID.ShortString(), err) @@ -315,33 +310,30 @@ func (h *Handler) fetchReferences(ctx context.Context, peer p2p.Peer, atxIDs []t } func (h *Handler) storeProof(ctx context.Context, nodeIDs []types.NodeID, proof []byte, domain ProofDomain) error { - if len(nodeIDs) == 1 { - // smesher is not married - malicious, err := malfeasance.IsMalicious(h.db, nodeIDs[0]) + // Persisting the proof in the DB has to be done within a transaction to ensure consistency. The ATX handler could + // update the (or merge multiple) marriage set in parallel, so we need to make sure data is consistent while we + // update the malfeasance table. + return h.db.WithTxImmediate(ctx, func(tx sql.Transaction) error { + if len(nodeIDs) == 1 { + // smesher is not married + if err := malfeasance.AddProof(tx, nodeIDs[0], nil, proof, int(domain), time.Now()); err != nil { + return fmt.Errorf("store malfeasance proof for %s: %w", nodeIDs[0], err) + } + return nil + } + + mID, err := marriage.FindIDByNodeID(tx, nodeIDs[0]) if err != nil { - return fmt.Errorf("check if smesher is malicious: %w", err) + return fmt.Errorf("get marriage ID for %s: %w", nodeIDs[0].ShortString(), err) } - if malicious { - h.logger.Debug("smesher is already marked as malicious", zap.String("smesher_id", nodeIDs[0].ShortString())) - return nil + if err := malfeasance.AddProof(tx, nodeIDs[0], &mID, proof, int(domain), time.Now()); err != nil { + return fmt.Errorf("store malfeasance proof for %s: %w", nodeIDs[0].ShortString(), err) } - if err := malfeasance.AddProof(h.db, nodeIDs[0], nil, proof, int(domain), time.Now()); err != nil { - return fmt.Errorf("store malfeasance proof for %s: %w", nodeIDs[0], err) + for _, nodeID := range nodeIDs[1:] { + if err := malfeasance.SetMalicious(tx, nodeID, mID, time.Now()); err != nil { + return fmt.Errorf("update malfeasance state for %s: %w", nodeID.ShortString(), err) + } } return nil - } - - mID, err := marriage.FindIDByNodeID(h.db, nodeIDs[0]) - if err != nil { - return fmt.Errorf("get marriage ID for %s: %w", nodeIDs[0].ShortString(), err) - } - if err := malfeasance.AddProof(h.db, nodeIDs[0], &mID, proof, int(domain), time.Now()); err != nil { - return fmt.Errorf("store malfeasance proof for %s: %w", nodeIDs[0].ShortString(), err) - } - for _, nodeID := range nodeIDs[1:] { - if err := malfeasance.SetMalicious(h.db, nodeID, mID, time.Now()); err != nil { - return fmt.Errorf("update malfeasance state for %s: %w", nodeID.ShortString(), err) - } - } - return nil + }) } diff --git a/malfeasance2/handler_test.go b/malfeasance2/handler_test.go index 81bf411fd2..f2d065cb2e 100644 --- a/malfeasance2/handler_test.go +++ b/malfeasance2/handler_test.go @@ -308,7 +308,6 @@ spacemesh_malfeasance2_num_proofs{domain="ATX",type="invalidPost"} 1 nodeID := types.RandomNodeID() atxID := types.RandomATXID() mockHandler := malfeasance2.NewMockMalfeasanceHandler(th.ctrl) - mockHandler.EXPECT().Validate(gomock.Any(), validProof).Return(nodeID, nil) th.RegisterHandler(malfeasance2.InvalidActivation, mockHandler) th.mockFetch.EXPECT().RegisterPeerHashes(p2p.Peer("peer"), []types.Hash32{atxID.Hash32()}) errFetchFailed := errors.New("fetch failed") @@ -370,31 +369,6 @@ spacemesh_malfeasance2_num_proofs{domain="ATX",type="invalidPost"} 1 require.True(t, malicious) }) - t.Run("valid proof, no reference ATX and identity is unknown", func(t *testing.T) { - t.Parallel() - th := newTestHandler(t) - validProof := []byte("valid") - nodeID := types.RandomNodeID() - mockHandler := malfeasance2.NewMockMalfeasanceHandler(th.ctrl) - mockHandler.EXPECT().Validate(gomock.Any(), validProof).Return(nodeID, nil) - th.RegisterHandler(malfeasance2.InvalidActivation, mockHandler) - - proof := &malfeasance2.MalfeasanceProof{ - Version: 0, - // no reference ATX - Domain: malfeasance2.InvalidActivation, - Proof: validProof, - } - - err := th.HandleSynced(context.Background(), types.Hash32(nodeID), "peer", codec.MustEncode(proof)) - require.ErrorIs(t, err, pubsub.ErrValidationReject) - - // not marked malicious since no proof of existence - malicious, err := malfeasance.IsMalicious(th.db, nodeID) - require.NoError(t, err) - require.False(t, malicious) - }) - t.Run("valid proof, wrong hash", func(t *testing.T) { t.Parallel() th := newTestHandler(t) @@ -642,7 +616,6 @@ spacemesh_malfeasance2_num_proofs{domain="ATX",type="invalidPost"} 1 nodeID := types.RandomNodeID() atxID := types.RandomATXID() mockHandler := malfeasance2.NewMockMalfeasanceHandler(th.ctrl) - mockHandler.EXPECT().Validate(gomock.Any(), validProof).Return(nodeID, nil) th.RegisterHandler(malfeasance2.InvalidActivation, mockHandler) th.mockFetch.EXPECT().RegisterPeerHashes(p2p.Peer("peer"), []types.Hash32{atxID.Hash32()}) errFetchFailed := errors.New("fetch failed") @@ -704,31 +677,6 @@ spacemesh_malfeasance2_num_proofs{domain="ATX",type="invalidPost"} 1 require.True(t, malicious) }) - t.Run("valid proof, no reference ATX and identity is unknown", func(t *testing.T) { - t.Parallel() - th := newTestHandler(t) - validProof := []byte("valid") - nodeID := types.RandomNodeID() - mockHandler := malfeasance2.NewMockMalfeasanceHandler(th.ctrl) - mockHandler.EXPECT().Validate(gomock.Any(), validProof).Return(nodeID, nil) - th.RegisterHandler(malfeasance2.InvalidActivation, mockHandler) - - proof := &malfeasance2.MalfeasanceProof{ - Version: 0, - // no reference ATX - Domain: malfeasance2.InvalidActivation, - Proof: validProof, - } - - err := th.HandleGossip(context.Background(), "peer", codec.MustEncode(proof)) - require.ErrorIs(t, err, pubsub.ErrValidationReject) - - // not marked malicious since no proof of existence - malicious, err := malfeasance.IsMalicious(th.db, nodeID) - require.NoError(t, err) - require.False(t, malicious) - }) - t.Run("valid proof for known malicious identity", func(t *testing.T) { t.Parallel() th := newTestHandler(t) diff --git a/malfeasance2/publisher.go b/malfeasance2/publisher.go index 4bd1dfaf6c..6f545b0a67 100644 --- a/malfeasance2/publisher.go +++ b/malfeasance2/publisher.go @@ -12,6 +12,7 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" @@ -21,7 +22,7 @@ import ( type Publisher struct { logger *zap.Logger - db sql.Executor + db sql.StateDatabase sync syncer tortoise tortoise publisher pubsub.Publisher @@ -29,7 +30,7 @@ type Publisher struct { func NewPublisher( logger *zap.Logger, - db sql.Executor, + db sql.StateDatabase, sync syncer, tortoise tortoise, publisher pubsub.Publisher, @@ -43,78 +44,101 @@ func NewPublisher( } } -func (p *Publisher) PublishATXProof(ctx context.Context, nodeID types.NodeID, proof []byte) error { - marriageID, err := marriage.FindIDByNodeID(p.db, nodeID) - switch { - case errors.Is(err, sql.ErrNotFound): // smesher is not married - malicious, err := malfeasance.IsMalicious(p.db, nodeID) - if err != nil { - return fmt.Errorf("check if smesher is malicious: %w", err) - } - if malicious { - p.logger.Debug("smesher is already marked as malicious", zap.String("smesher_id", nodeID.ShortString())) +func (p *Publisher) PublishATXProof(ctx context.Context, nodeID types.NodeID, proof []byte, allowNoRefATXs bool) error { + publish := false // whether to publish the proof + var set []types.NodeID + var refATXs []types.ATXID + + // Persisting the proof in the DB has to be done within a transaction to ensure consistency. The ATX handler could + // update the (or merge multiple) marriage set in parallel, so we need to make sure data is consistent while we + // update the malfeasance table. + err := p.db.WithTxImmediate(ctx, func(tx sql.Transaction) error { + marriageID, err := marriage.FindIDByNodeID(tx, nodeID) + switch { + case errors.Is(err, sql.ErrNotFound): // smesher is not married + malicious, err := malfeasance.IsMalicious(tx, nodeID) + if err != nil { + return fmt.Errorf("check if smesher is malicious: %w", err) + } + if malicious { + p.logger.Debug("smesher is already marked as malicious", zap.String("smesher_id", nodeID.ShortString())) + return nil + } + if err := malfeasance.AddProof(tx, nodeID, nil, proof, int(InvalidActivation), time.Now()); err != nil { + return fmt.Errorf("setting malfeasance proof: %w", err) + } + atxID, err := atxs.GetFirstIDByNodeID(tx, nodeID) + switch { + case errors.Is(err, sql.ErrNotFound) && allowNoRefATXs: + // no ATXs found for this node, but we allow it + case err != nil: + return fmt.Errorf("getting atx id: %w", err) + default: // ATX found + refATXs = []types.ATXID{atxID} + } + publish = true + set = []types.NodeID{nodeID} return nil + case err != nil: + return fmt.Errorf("getting equivocation set: %w", err) + default: // smesher is married } - if err := malfeasance.AddProof(p.db, nodeID, nil, proof, int(InvalidActivation), time.Now()); err != nil { - return fmt.Errorf("setting malfeasance proof: %w", err) - } - atxID, err := atxs.GetFirstIDByNodeID(p.db, nodeID) - if err != nil { - return fmt.Errorf("getting atx id: %w", err) - } - p.tortoise.OnMalfeasance(nodeID) - return p.publish(ctx, []types.NodeID{nodeID}, []types.ATXID{atxID}, proof, InvalidActivation) - case err != nil: - return fmt.Errorf("getting equivocation set: %w", err) - default: // smesher is married - } - - // Combine IDs from the present equivocation set for atx.SmesherID and IDs in atx.Marriages. - set, err := marriage.NodeIDsByID(p.db, marriageID) - if err != nil { - return fmt.Errorf("getting equivocation set: %w", err) - } - publish := false // whether to publish the proof - malicious, err := malfeasance.IsMalicious(p.db, nodeID) - if err != nil { - return fmt.Errorf("check if smesher is malicious: %w", err) - } - if !malicious { - err := malfeasance.AddProof(p.db, nodeID, &marriageID, proof, int(InvalidActivation), time.Now()) + // Combine IDs from the present equivocation set for atx.SmesherID and IDs in atx.Marriages. + set, err = marriage.NodeIDsByID(tx, marriageID) if err != nil { - return fmt.Errorf("setting malfeasance proof: %w", err) + return fmt.Errorf("getting equivocation set: %w", err) } - publish = true - } else { - p.logger.Debug("smesher is already marked as malicious", zap.String("smesher_id", nodeID.ShortString())) - } - mATXs := make(map[types.ATXID]struct{}) - for _, id := range set { - info, err := marriage.FindByNodeID(p.db, id) - if err != nil { - return fmt.Errorf("getting marriage info: %w", err) - } - mATXs[info.ATX] = struct{}{} - if id == nodeID { - // already handled - continue - } - malicious, err := malfeasance.IsMalicious(p.db, id) + malicious, err := malfeasance.IsMalicious(tx, nodeID) if err != nil { return fmt.Errorf("check if smesher is malicious: %w", err) } - if malicious { - p.logger.Debug("smesher is already marked as malicious", zap.String("smesher_id", id.ShortString())) - continue + if !malicious { + err := malfeasance.AddProof(tx, nodeID, &marriageID, proof, int(InvalidActivation), time.Now()) + if err != nil { + return fmt.Errorf("setting malfeasance proof: %w", err) + } + publish = true + } else { + p.logger.Debug("smesher is already marked as malicious", zap.String("smesher_id", nodeID.ShortString())) } - publish = true - if err := malfeasance.SetMalicious(p.db, id, marriageID, time.Now()); err != nil { - return fmt.Errorf("setting malicious: %w", err) + + mATXs := make(map[types.ATXID]struct{}) + for _, id := range set { + info, err := marriage.FindByNodeID(tx, id) + if err != nil { + return fmt.Errorf("getting marriage info: %w", err) + } + mATXs[info.ATX] = struct{}{} + if id == nodeID { + // already handled + continue + } + malicious, err := malfeasance.IsMalicious(tx, id) + if err != nil { + return fmt.Errorf("check if smesher is malicious: %w", err) + } + if malicious { + p.logger.Debug("smesher is already marked as malicious", zap.String("smesher_id", id.ShortString())) + continue + } + publish = true + if err := malfeasance.SetMalicious(tx, id, marriageID, time.Now()); err != nil { + return fmt.Errorf("setting malicious: %w", err) + } } + refATXs = maps.Keys(mATXs) + return nil + }) + if err != nil { + p.logger.Error("failed to persist malfeasance proof", + zap.Error(err), + log.ZShortStringer("node_id", nodeID), + ) + return err } - + p.logger.Debug("persisted malfeasance proof", log.ZShortStringer("node_id", nodeID)) if !publish { // all smeshers were already marked as malicious - no gossip to void spamming the network return nil @@ -122,20 +146,25 @@ func (p *Publisher) PublishATXProof(ctx context.Context, nodeID types.NodeID, pr for _, nodeID := range set { p.tortoise.OnMalfeasance(nodeID) } - return p.publish(ctx, set, maps.Keys(mATXs), proof, ProofDomain(InvalidActivation)) + return p.publish(ctx, set, refATXs, proof, ProofDomain(InvalidActivation)) } func (p *Publisher) Regossip(ctx context.Context, nodeID types.NodeID) error { - marriageID, err := marriage.FindIDByNodeID(p.db, nodeID) + tx, err := p.db.TxImmediate(ctx) + if err != nil { + return fmt.Errorf("starting transaction: %w", err) + } + defer tx.Release() + marriageID, err := marriage.FindIDByNodeID(tx, nodeID) switch { case errors.Is(err, sql.ErrNotFound): // smesher is not married - proof, domain, err := malfeasance.NodeIDProof(p.db, nodeID) + proof, domain, err := malfeasance.NodeIDProof(tx, nodeID) if err != nil { return fmt.Errorf("getting malfeasance proof: %w", err) } - atxID, err := atxs.GetFirstIDByNodeID(p.db, nodeID) + atxID, err := atxs.GetFirstIDByNodeID(tx, nodeID) if err != nil { - return fmt.Errorf("getting atx id: %w", err) + return fmt.Errorf("getting first atx of identity %s: %w", nodeID.ShortString(), err) } return p.publish(ctx, []types.NodeID{nodeID}, []types.ATXID{atxID}, proof, ProofDomain(domain)) case err != nil: @@ -143,17 +172,17 @@ func (p *Publisher) Regossip(ctx context.Context, nodeID types.NodeID) error { default: // smesher is married } - proof, domain, err := malfeasance.MarriageProof(p.db, marriageID) + proof, domain, err := malfeasance.MarriageProof(tx, marriageID) if err != nil { return fmt.Errorf("getting malfeasance proof: %w", err) } - nodeIDs, err := marriage.NodeIDsByID(p.db, marriageID) + nodeIDs, err := marriage.NodeIDsByID(tx, marriageID) if err != nil { return fmt.Errorf("getting equivocation set: %w", err) } - atxs, err := marriage.MarriageATXs(p.db, marriageID) + atxs, err := marriage.MarriageATXs(tx, marriageID) if err != nil { return fmt.Errorf("getting equivocation info: %w", err) } @@ -191,6 +220,60 @@ func (p *Publisher) publish( p.logger.Error("failed to broadcast malfeasance proof", zap.Error(err)) return fmt.Errorf("broadcast atx malfeasance proof: %w", err) } - + p.logger.Debug("broadcast malfeasance proof", + zap.Array("smesher_ids", zapcore.ArrayMarshalerFunc(func(enc zapcore.ArrayEncoder) error { + for _, nodeID := range nodeID { + enc.AppendString(nodeID.ShortString()) + } + return nil + })), + ) return nil } + +func (p *Publisher) ProofByID(ctx context.Context, nodeID types.NodeID) ([]byte, error) { + tx, err := p.db.TxImmediate(ctx) + if err != nil { + return nil, fmt.Errorf("starting transaction: %w", err) + } + defer tx.Release() + mID, err := marriage.FindIDByNodeID(tx, nodeID) + switch { + case errors.Is(err, sql.ErrNotFound): // smesher is not married + proof, domain, err := malfeasance.NodeIDProof(tx, nodeID) + if err != nil { + return nil, fmt.Errorf("getting malfeasance proof: %w", err) + } + atxID, err := atxs.GetFirstIDByNodeID(tx, nodeID) + if err != nil { + return nil, fmt.Errorf("getting first atx of identity %s: %w", nodeID.ShortString(), err) + } + malfeasanceProof := &MalfeasanceProof{ + Version: 0, + RefATXs: []types.ATXID{atxID}, + Domain: ProofDomain(domain), + Proof: proof, + } + return codec.MustEncode(malfeasanceProof), nil + case err != nil: + return nil, fmt.Errorf("getting equivocation set: %w", err) + default: // smesher is married + } + + atxs, err := marriage.MarriageATXs(tx, mID) + if err != nil { + return nil, fmt.Errorf("getting equivocation info: %w", err) + } + + proof, domain, err := malfeasance.MarriageProof(tx, mID) + if err != nil { + return nil, fmt.Errorf("getting malfeasance proof: %w", err) + } + malfeasanceProof := &MalfeasanceProof{ + Version: 0, + RefATXs: atxs, + Domain: ProofDomain(domain), + Proof: proof, + } + return codec.MustEncode(malfeasanceProof), nil +} diff --git a/malfeasance2/publisher_test.go b/malfeasance2/publisher_test.go index 0353250a5c..5fb9210951 100644 --- a/malfeasance2/publisher_test.go +++ b/malfeasance2/publisher_test.go @@ -95,7 +95,33 @@ func TestPublishATXProof(t *testing.T) { tp.mockSync.EXPECT().ListenToATXGossip().Return(true) tp.mockPub.EXPECT().Publish(gomock.Any(), pubsub.MalfeasanceProof2, codec.MustEncode(malfeasanceProof)) - err := tp.PublishATXProof(context.Background(), nodeID, proof) + err := tp.PublishATXProof(context.Background(), nodeID, proof, false) + require.NoError(t, err) + + dbProof, domain, err := malfeasance.NodeIDProof(tp.db, nodeID) + require.NoError(t, err) + require.Equal(t, malfeasance2.InvalidActivation, malfeasance2.ProofDomain(domain)) + require.Equal(t, proof, dbProof) + }) + + t.Run("not married and in sync, allow without refATXs", func(t *testing.T) { + t.Parallel() + tp := newTestPublisher(t) + proof := types.RandomBytes(10) + nodeID := types.RandomNodeID() + + malfeasanceProof := &malfeasance2.MalfeasanceProof{ + Version: 0, + RefATXs: []types.ATXID{}, + Domain: malfeasance2.InvalidActivation, + Proof: proof, + } + + tp.mockTrt.EXPECT().OnMalfeasance(nodeID) + tp.mockSync.EXPECT().ListenToATXGossip().Return(true) + tp.mockPub.EXPECT().Publish(gomock.Any(), pubsub.MalfeasanceProof2, codec.MustEncode(malfeasanceProof)) + + err := tp.PublishATXProof(context.Background(), nodeID, proof, true) require.NoError(t, err) dbProof, domain, err := malfeasance.NodeIDProof(tp.db, nodeID) @@ -128,7 +154,7 @@ func TestPublishATXProof(t *testing.T) { tp.mockPub.EXPECT().Publish(gomock.Any(), pubsub.MalfeasanceProof2, codec.MustEncode(malfeasanceProof)). Return(errPublish) - err := tp.PublishATXProof(context.Background(), nodeID, proof) + err := tp.PublishATXProof(context.Background(), nodeID, proof, false) require.ErrorIs(t, err, errPublish) logs := tp.observedLogs.FilterLevelExact(zap.ErrorLevel) @@ -158,7 +184,25 @@ func TestPublishATXProof(t *testing.T) { tp.mockTrt.EXPECT().OnMalfeasance(nodeID) tp.mockSync.EXPECT().ListenToATXGossip().Return(false) // results in no gossip but only storing the proof - err := tp.PublishATXProof(context.Background(), nodeID, proof) + err := tp.PublishATXProof(context.Background(), nodeID, proof, false) + require.NoError(t, err) + + dbProof, domain, err := malfeasance.NodeIDProof(tp.db, nodeID) + require.NoError(t, err) + require.Equal(t, malfeasance2.InvalidActivation, malfeasance2.ProofDomain(domain)) + require.Equal(t, proof, dbProof) + }) + + t.Run("not married, not in sync, allow no ref ATXs", func(t *testing.T) { + t.Parallel() + tp := newTestPublisher(t) + proof := types.RandomBytes(10) + nodeID := types.RandomNodeID() + + tp.mockTrt.EXPECT().OnMalfeasance(nodeID) + tp.mockSync.EXPECT().ListenToATXGossip().Return(false) // results in no gossip but only storing the proof + + err := tp.PublishATXProof(context.Background(), nodeID, proof, true) require.NoError(t, err) dbProof, domain, err := malfeasance.NodeIDProof(tp.db, nodeID) @@ -240,7 +284,7 @@ func TestPublishATXProof(t *testing.T) { }, ) - err = tp.PublishATXProof(context.Background(), nodeIDs[2], proof) + err = tp.PublishATXProof(context.Background(), nodeIDs[2], proof, false) require.NoError(t, err) for i := range nodeIDs { @@ -268,7 +312,7 @@ func TestPublishATXProof(t *testing.T) { err := malfeasance.AddProof(tp.db, nodeID, nil, proof, int(malfeasance2.InvalidActivation), time.Now()) require.NoError(t, err) - err = tp.PublishATXProof(context.Background(), nodeID, proof) + err = tp.PublishATXProof(context.Background(), nodeID, proof, false) require.NoError(t, err) dbProof, domain, err := malfeasance.NodeIDProof(tp.db, nodeID) @@ -278,9 +322,11 @@ func TestPublishATXProof(t *testing.T) { logs := tp.observedLogs.FilterLevelExact(zap.DebugLevel) - require.Equal(t, 1, logs.Len()) + require.Equal(t, 2, logs.Len()) require.Equal(t, zap.DebugLevel, logs.All()[0].Level) require.Contains(t, logs.All()[0].Message, "smesher is already marked as malicious") + require.Equal(t, zap.DebugLevel, logs.All()[1].Level) + require.Contains(t, logs.All()[1].Message, "persisted malfeasance proof") }) t.Run("married and all already malicious", func(t *testing.T) { @@ -343,7 +389,7 @@ func TestPublishATXProof(t *testing.T) { require.NoError(t, malfeasance.SetMalicious(tp.db, nodeID, mID, time.Now())) } - err = tp.PublishATXProof(context.Background(), nodeIDs[2], proof) + err = tp.PublishATXProof(context.Background(), nodeIDs[2], proof, false) require.NoError(t, err) for i := range nodeIDs { @@ -358,11 +404,13 @@ func TestPublishATXProof(t *testing.T) { logs := tp.observedLogs.FilterLevelExact(zap.DebugLevel) - require.Equal(t, 30, logs.Len()) + require.Equal(t, 31, logs.Len()) require.Equal(t, zap.DebugLevel, logs.All()[0].Level) for i := range nodeIDs { require.Contains(t, logs.All()[i].Message, "smesher is already marked as malicious") } + require.Equal(t, zap.DebugLevel, logs.All()[30].Level) + require.Contains(t, logs.All()[30].Message, "persisted malfeasance proof") }) t.Run("married and some already malicious", func(t *testing.T) { @@ -450,7 +498,7 @@ func TestPublishATXProof(t *testing.T) { }, ) - err = tp.PublishATXProof(context.Background(), nodeIDs[2], proof) + err = tp.PublishATXProof(context.Background(), nodeIDs[2], proof, false) require.NoError(t, err) for i := range nodeIDs { @@ -465,12 +513,16 @@ func TestPublishATXProof(t *testing.T) { logs := tp.observedLogs.FilterLevelExact(zap.DebugLevel) - require.Equal(t, 20, logs.Len()) + require.Equal(t, 22, logs.Len()) require.Equal(t, zap.DebugLevel, logs.All()[0].Level) for i := range nodeIDs[:20] { // first 20 were already malicious require.Contains(t, logs.All()[i].Message, "smesher is already marked as malicious") } + require.Equal(t, zap.DebugLevel, logs.All()[20].Level) + require.Contains(t, logs.All()[20].Message, "persisted malfeasance proof") + require.Equal(t, zap.DebugLevel, logs.All()[21].Level) + require.Contains(t, logs.All()[21].Message, "broadcast malfeasance proof") }) } @@ -627,3 +679,144 @@ func TestRegossip(t *testing.T) { require.NoError(t, err) }) } + +func TestProofByID(t *testing.T) { + t.Run("not married no proof", func(t *testing.T) { + t.Parallel() + tp := newTestPublisher(t) + nodeID := types.RandomNodeID() + + proofBytes, err := tp.ProofByID(context.Background(), nodeID) + require.ErrorIs(t, err, sql.ErrNotFound) + require.Nil(t, proofBytes) + }) + + t.Run("not married with proof", func(t *testing.T) { + t.Parallel() + tp := newTestPublisher(t) + nodeID := types.RandomNodeID() + atx := types.ActivationTx{ + SmesherID: nodeID, + } + atx.SetID(types.RandomATXID()) + atxs.Add(tp.db, &atx, types.AtxBlob{}) + proof := types.RandomBytes(10) + err := malfeasance.AddProof(tp.db, nodeID, nil, proof, int(malfeasance2.InvalidActivation), time.Now()) + require.NoError(t, err) + + proofBytes, err := tp.ProofByID(context.Background(), nodeID) + require.NoError(t, err) + + var malProof malfeasance2.MalfeasanceProof + require.NoError(t, codec.Decode(proofBytes, &malProof)) + require.Equal(t, proof, malProof.Proof) + require.Equal(t, malfeasance2.InvalidActivation, malProof.Domain) + require.Len(t, malProof.RefATXs, 1) + require.Equal(t, atx.ID(), malProof.RefATXs[0]) + }) + + t.Run("married no proof", func(t *testing.T) { + t.Parallel() + tp := newTestPublisher(t) + nodeIDs := make([]types.NodeID, 20) + for i := range nodeIDs { + nodeIDs[i] = types.RandomNodeID() + } + mATXID := types.RandomATXID() + atx := &types.ActivationTx{ + SmesherID: nodeIDs[0], + } + atx.SetID(mATXID) + require.NoError(t, atxs.Add(tp.db, atx, types.AtxBlob{})) + mATX2ID := types.RandomATXID() + atx2 := &types.ActivationTx{ + SmesherID: nodeIDs[0], + } + atx2.SetID(mATX2ID) + require.NoError(t, atxs.Add(tp.db, atx2, types.AtxBlob{})) + + mID, err := marriage.NewID(tp.db) + require.NoError(t, err) + + for i := range nodeIDs { + require.NoError(t, marriage.Add(tp.db, marriage.Info{ + ID: mID, + NodeID: nodeIDs[i], + ATX: mATXID, + MarriageIndex: i, + Target: nodeIDs[0], + Signature: types.RandomEdSignature(), + })) + } + + proofBytes, err := tp.ProofByID(context.Background(), nodeIDs[4]) + require.ErrorIs(t, err, sql.ErrNotFound) + require.Nil(t, proofBytes) + }) + + t.Run("married with proof", func(t *testing.T) { + t.Parallel() + tp := newTestPublisher(t) + proof := types.RandomBytes(10) + nodeIDs := make([]types.NodeID, 20) + for i := range nodeIDs { + nodeIDs[i] = types.RandomNodeID() + } + atx := &types.ActivationTx{ + SmesherID: nodeIDs[0], + } + atx.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(tp.db, atx, types.AtxBlob{})) + atx2 := &types.ActivationTx{ + SmesherID: nodeIDs[0], + } + atx2.SetID(types.RandomATXID()) + require.NoError(t, atxs.Add(tp.db, atx2, types.AtxBlob{})) + + mID, err := marriage.NewID(tp.db) + require.NoError(t, err) + + for i := range nodeIDs[:10] { + require.NoError(t, marriage.Add(tp.db, marriage.Info{ + ID: mID, + NodeID: nodeIDs[i], + ATX: atx.ID(), + MarriageIndex: i, + Target: nodeIDs[0], + Signature: types.RandomEdSignature(), + })) + if i == 0 { + require.NoError(t, malfeasance.AddProof( + tp.db, + nodeIDs[i], + &mID, + proof, + int(malfeasance2.InvalidActivation), + time.Now()), + ) + continue + } + require.NoError(t, malfeasance.SetMalicious(tp.db, nodeIDs[i], mID, time.Now())) + } + for i := range nodeIDs[10:] { // smesher has married twice + require.NoError(t, marriage.Add(tp.db, marriage.Info{ + ID: mID, + NodeID: nodeIDs[i+10], + ATX: atx2.ID(), + MarriageIndex: i, + Target: nodeIDs[0], + Signature: types.RandomEdSignature(), + })) + require.NoError(t, malfeasance.SetMalicious(tp.db, nodeIDs[i], mID, time.Now())) + } + + proofBytes, err := tp.ProofByID(context.Background(), nodeIDs[4]) + require.NoError(t, err) + + var malProof malfeasance2.MalfeasanceProof + require.NoError(t, codec.Decode(proofBytes, &malProof)) + require.Equal(t, proof, malProof.Proof) + require.Equal(t, malfeasance2.InvalidActivation, malProof.Domain) + require.ElementsMatch(t, []types.ATXID{atx.ID(), atx2.ID()}, malProof.RefATXs) + }) +} diff --git a/mesh/mesh.go b/mesh/mesh.go index 1d509fc9f7..3c8ef57a99 100644 --- a/mesh/mesh.go +++ b/mesh/mesh.go @@ -382,6 +382,8 @@ func (msh *Mesh) applyResults(ctx context.Context, results []result.Layer) error } } if err := msh.executor.Execute(ctx, layer.Layer, block); err != nil { + // TODO(mafa): sometimes this fails because the block executed references a tx that is not in the DB + // maybe in that case the node should try to fetch the missing txs and retry executing the block? return fmt.Errorf("execute block %v/%v: %w", layer.Layer, target, err) } } diff --git a/mesh/mesh_test.go b/mesh/mesh_test.go index f76c1a77df..8a4d7115c5 100644 --- a/mesh/mesh_test.go +++ b/mesh/mesh_test.go @@ -383,9 +383,6 @@ func TestMesh_MaliciousBallots(t *testing.T) { require.NoError(t, err) require.Nil(t, malProof) require.False(t, blts[0].IsMalicious()) - mal, err := identities.IsMalicious(tm.cdb, sig.NodeID()) - require.NoError(t, err) - require.False(t, mal) malicious, err := identities.IsMalicious(tm.cdb, sig.NodeID()) require.NoError(t, err) diff --git a/node/node.go b/node/node.go index 8a98ccd131..ff7ccda2e4 100644 --- a/node/node.go +++ b/node/node.go @@ -749,7 +749,7 @@ func (app *App) initServices(ctx context.Context) error { peerCache := peers.New() flog := app.addLogger(Fetcher, lg).Zap() fetcher, err := fetch.NewFetch( - app.cachedDB, + app.db, proposalsStore, app.host, peerCache, @@ -852,13 +852,14 @@ func (app *App) initServices(ctx context.Context) error { malfeasanceLogger := app.addLogger(Malfeasance2Logger, lg).Zap() malfeasancePublisher := malfeasance2.NewPublisher( malfeasanceLogger, - app.cachedDB, + app.db, syncer, trtl, app.host, ) atxMalHandler := activation.NewMalfeasanceHandlerV2( malfeasanceLogger, + app.db, malfeasancePublisher, app.edVerifier, validator, @@ -1197,7 +1198,7 @@ func (app *App) initServices(ctx context.Context) error { malHandler.RegisterHandler(malfeasance.InvalidPrevATX, invalidPrevMH) malHandler2 := malfeasance2.NewHandler( - app.cachedDB, + app.db, malfeasanceLogger, app.host.ID(), nodeIDs, @@ -1206,6 +1207,7 @@ func (app *App) initServices(ctx context.Context) error { ) malHandler2.RegisterHandler(malfeasance2.InvalidActivation, atxMalHandler) + fetcher.SetMalfeasanceProvider(malfeasancePublisher) fetcher.SetValidators( fetch.ValidatorFunc( pubsub.DropPeerOnSyncValidationReject(atxHandler.HandleSyncedAtx, app.host, lg.Zap()), @@ -1214,11 +1216,7 @@ func (app *App) initServices(ctx context.Context) error { pubsub.DropPeerOnSyncValidationReject(poetDb.ValidateAndStoreMsg, app.host, lg.Zap()), ), fetch.ValidatorFunc( - pubsub.DropPeerOnSyncValidationReject( - proposalListener.HandleSyncedBallot, - app.host, - lg.Zap(), - ), + pubsub.DropPeerOnSyncValidationReject(proposalListener.HandleSyncedBallot, app.host, lg.Zap()), ), fetch.ValidatorFunc( pubsub.DropPeerOnSyncValidationReject(proposalListener.HandleActiveSet, app.host, lg.Zap()), @@ -1227,41 +1225,20 @@ func (app *App) initServices(ctx context.Context) error { pubsub.DropPeerOnSyncValidationReject(blockHandler.HandleSyncedBlock, app.host, lg.Zap()), ), fetch.ValidatorFunc( - pubsub.DropPeerOnSyncValidationReject( - proposalListener.HandleSyncedProposal, - app.host, - lg.Zap(), - ), + pubsub.DropPeerOnSyncValidationReject(proposalListener.HandleSyncedProposal, app.host, lg.Zap()), + ), + fetch.ValidatorFunc( + pubsub.DropPeerOnSyncValidationReject(app.txHandler.HandleBlockTransaction, app.host, lg.Zap()), ), fetch.ValidatorFunc( - pubsub.DropPeerOnSyncValidationReject( - app.txHandler.HandleBlockTransaction, - app.host, - lg.Zap(), - ), + pubsub.DropPeerOnSyncValidationReject(app.txHandler.HandleProposalTransaction, app.host, lg.Zap()), ), fetch.ValidatorFunc( - pubsub.DropPeerOnSyncValidationReject( - app.txHandler.HandleProposalTransaction, - app.host, - lg.Zap(), - ), + pubsub.DropPeerOnSyncValidationReject(malHandler.HandleSynced, app.host, lg.Zap()), ), fetch.ValidatorFunc( - pubsub.DropPeerOnSyncValidationReject( - malHandler.HandleSyncedMalfeasanceProof, - app.host, - lg.Zap(), - ), + pubsub.DropPeerOnSyncValidationReject(malHandler2.HandleSynced, app.host, lg.Zap()), ), - // TODO(mafa): add malfeasance2 handler to fetcher - // fetch.ValidatorFunc( - // pubsub.DropPeerOnSyncValidationReject( - // malHandler2.HandleSyncedMalfeasanceProof, - // app.host, - // lg.Zap(), - // ), - // ), ) checkSynced := func(_ context.Context, _ p2p.Peer, _ []byte) error { @@ -1318,7 +1295,7 @@ func (app *App) initServices(ctx context.Context) error { ) app.host.Register( pubsub.MalfeasanceProof, - pubsub.ChainGossipHandler(checkAtxSynced, malHandler.HandleMalfeasanceProof), + pubsub.ChainGossipHandler(checkAtxSynced, malHandler.HandleGossip), ) app.host.Register( pubsub.MalfeasanceProof2, @@ -2194,7 +2171,6 @@ func (app *App) Start(ctx context.Context) error { Msg: "node is shutting down", Level: zapcore.InfoLevel, }) - // TODO: pass app.eg to components and wait for them collectively if app.ptimesync != nil { app.eg.Go(func() error { app.errCh <- app.ptimesync.Wait() diff --git a/node/node_test.go b/node/node_test.go index 835bb69fbb..782258446d 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -377,13 +377,12 @@ func TestSpacemeshApp_NodeService(t *testing.T) { return app.Start(appCtx) } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - // Run the app in a goroutine. As noted above, it blocks if it succeeds. // If there's an error in the args, it will return immediately. var eg errgroup.Group eg.Go(func() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() str, err := testArgs(ctx, cmdWithRun(run)) assert.Empty(t, str) assert.NoError(t, err) @@ -397,7 +396,10 @@ func TestSpacemeshApp_NodeService(t *testing.T) { ) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, conn.Close()) }) + c := pb.NewNodeServiceClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() eg.Go(func() error { streamStatus, err := c.StatusStream(ctx, &pb.StatusStreamRequest{}) diff --git a/p2p/server/deadline_adjuster.go b/p2p/server/deadline_adjuster.go index 634e638272..e03a33f546 100644 --- a/p2p/server/deadline_adjuster.go +++ b/p2p/server/deadline_adjuster.go @@ -39,7 +39,8 @@ func (err *deadlineAdjusterError) Error() string { err.totalWritten, err.timeout, err.hardTimeout, - err.innerErr) + err.innerErr, + ) } type deadlineAdjuster struct { diff --git a/p2p/server/server.go b/p2p/server/server.go index ebc424cc24..45d2222607 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -496,12 +496,10 @@ func (s *Server) NumAcceptedRequests() int { func writeResponse(w io.Writer, resp *Response) error { wr := bufio.NewWriter(w) if _, err := codec.EncodeTo(wr, resp); err != nil { - return fmt.Errorf("failed to write response (len %d err len %d): %w", - len(resp.Data), len(resp.Error), err) + return fmt.Errorf("failed to write response (len %d err len %d): %w", len(resp.Data), len(resp.Error), err) } if err := wr.Flush(); err != nil { - return fmt.Errorf("failed to write response (len %d err len %d): %w", - len(resp.Data), len(resp.Error), err) + return fmt.Errorf("failed to write response (len %d err len %d): %w", len(resp.Data), len(resp.Error), err) } return nil } diff --git a/sql/malfeasance/malfeasance.go b/sql/malfeasance/malfeasance.go index ce313721ca..5524758c0a 100644 --- a/sql/malfeasance/malfeasance.go +++ b/sql/malfeasance/malfeasance.go @@ -70,6 +70,21 @@ func IsMalicious(db sql.Executor, nodeID types.NodeID) (bool, error) { return rows > 0, nil } +func Count(db sql.Executor) (int, error) { + var count int + _, err := db.Exec(` + SELECT COUNT(*) + FROM malfeasance + `, nil, func(stmt *sql.Statement) bool { + count = stmt.ColumnInt(0) + return true + }) + if err != nil { + return 0, fmt.Errorf("count identities: %w", err) + } + return count, nil +} + func IterateOps( db sql.Executor, operations builder.Operations, diff --git a/sql/malfeasance/malfeasance_test.go b/sql/malfeasance/malfeasance_test.go index 078f021f70..f2538de421 100644 --- a/sql/malfeasance/malfeasance_test.go +++ b/sql/malfeasance/malfeasance_test.go @@ -202,6 +202,53 @@ func TestIsMalicious(t *testing.T) { }) } +func Test_Count(t *testing.T) { + db := statesql.InMemoryTest(t) + + count, err := malfeasance.Count(db) + require.NoError(t, err) + require.Zero(t, count) + + nodeID := types.RandomNodeID() + err = malfeasance.AddProof(db, nodeID, nil, types.RandomBytes(100), 1, time.Now()) + require.NoError(t, err) + + count, err = malfeasance.Count(db) + require.NoError(t, err) + require.Equal(t, 1, count) + + nodeID = types.RandomNodeID() + marriageATX := types.RandomATXID() + id, err := marriage.NewID(db) + require.NoError(t, err) + + err = marriage.Add(db, marriage.Info{ + ID: id, + NodeID: nodeID, + ATX: marriageATX, + MarriageIndex: 0, + Target: nodeID, + Signature: types.RandomEdSignature(), + }) + require.NoError(t, err) + + ids := make([]types.NodeID, 5) + ids[0] = nodeID + proof := types.RandomBytes(11) + err = malfeasance.AddProof(db, ids[0], &id, proof, 1, time.Now()) + require.NoError(t, err) + + for i := 1; i < len(ids); i++ { + ids[i] = types.RandomNodeID() + err := malfeasance.SetMalicious(db, ids[i], id, time.Now()) + require.NoError(t, err) + } + + count, err = malfeasance.Count(db) + require.NoError(t, err) + require.Equal(t, 6, count) +} + func Test_IterateMaliciousOps(t *testing.T) { db := statesql.InMemoryTest(t) tt := []struct { diff --git a/sql/malsync/malsync.go b/sql/malsync/malsync.go index fac8b50fba..18e6439645 100644 --- a/sql/malsync/malsync.go +++ b/sql/malsync/malsync.go @@ -7,10 +7,21 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" ) -func GetSyncState(db sql.Executor) (time.Time, error) { +func LegacySyncState(db sql.Executor) (time.Time, error) { + return getSyncState(db, 1) +} + +func SyncState(db sql.Executor) (time.Time, error) { + return getSyncState(db, 2) +} + +func getSyncState(db sql.Executor, version int64) (time.Time, error) { var timestamp time.Time - rows, err := db.Exec("select timestamp from malfeasance_sync_state where id = 1", - nil, func(stmt *sql.Statement) bool { + rows, err := db.Exec("select timestamp from malfeasance_sync_state where id = ?1", + func(s *sql.Statement) { + s.BindInt64(1, version) + }, + func(stmt *sql.Statement) bool { v := stmt.ColumnInt64(0) if v > 0 { timestamp = time.Unix(v, 0) @@ -27,23 +38,31 @@ func GetSyncState(db sql.Executor) (time.Time, error) { } } -func updateSyncState(db sql.Executor, ts int64) error { - if _, err := db.Exec( - `insert into malfeasance_sync_state (id, timestamp) values(1, ?1) - on conflict (id) do update set timestamp = ?1`, - func(stmt *sql.Statement) { - stmt.BindInt64(1, ts) - }, nil, +func updateSyncState(db sql.Executor, version, ts int64) error { + if _, err := db.Exec(` + insert into malfeasance_sync_state (id, timestamp) values(?1, ?2) + on conflict (id) do update set timestamp = ?2 + `, func(stmt *sql.Statement) { + stmt.BindInt64(1, version) + stmt.BindInt64(2, ts) + }, nil, ); err != nil { return fmt.Errorf("error initializing malfeasance sync state: %w", err) } return nil } +func UpdateLegacySyncState(db sql.Executor, timestamp time.Time) error { + return updateSyncState(db, 1, timestamp.Unix()) +} + func UpdateSyncState(db sql.Executor, timestamp time.Time) error { - return updateSyncState(db, timestamp.Unix()) + return updateSyncState(db, 2, timestamp.Unix()) } func Clear(db sql.Executor) error { - return updateSyncState(db, 0) + if err := updateSyncState(db, 1, 0); err != nil { + return err + } + return updateSyncState(db, 2, 0) } diff --git a/sql/malsync/malsync_test.go b/sql/malsync/malsync_test.go index 38624ca27e..d994e94fe8 100644 --- a/sql/malsync/malsync_test.go +++ b/sql/malsync/malsync_test.go @@ -9,21 +9,40 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql" ) +func TestLegacyMalfeasanceSyncState(t *testing.T) { + db := localsql.InMemoryTest(t) + timestamp, err := LegacySyncState(db) + require.NoError(t, err) + require.Equal(t, time.Time{}, timestamp) + ts := time.Now() + for i := 0; i < 3; i++ { + require.NoError(t, UpdateLegacySyncState(db, ts)) + timestamp, err = LegacySyncState(db) + require.NoError(t, err) + require.Equal(t, ts.Truncate(time.Second), timestamp) + ts = ts.Add(3 * time.Minute) + } + require.NoError(t, Clear(db)) + timestamp, err = LegacySyncState(db) + require.NoError(t, err) + require.Equal(t, time.Time{}, timestamp) +} + func TestMalfeasanceSyncState(t *testing.T) { db := localsql.InMemoryTest(t) - timestamp, err := GetSyncState(db) + timestamp, err := SyncState(db) require.NoError(t, err) require.Equal(t, time.Time{}, timestamp) ts := time.Now() for i := 0; i < 3; i++ { require.NoError(t, UpdateSyncState(db, ts)) - timestamp, err = GetSyncState(db) + timestamp, err = SyncState(db) require.NoError(t, err) require.Equal(t, ts.Truncate(time.Second), timestamp) ts = ts.Add(3 * time.Minute) } require.NoError(t, Clear(db)) - timestamp, err = GetSyncState(db) + timestamp, err = SyncState(db) require.NoError(t, err) require.Equal(t, time.Time{}, timestamp) } diff --git a/sql/marriage/marriages.go b/sql/marriage/marriages.go index 58410e1a9c..1c7d1f8a7f 100644 --- a/sql/marriage/marriages.go +++ b/sql/marriage/marriages.go @@ -3,8 +3,8 @@ package marriage import ( "bytes" "fmt" + "maps" "slices" - "sort" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" @@ -132,7 +132,7 @@ func FindByNodeID(db sql.Executor, nodeID types.NodeID) (Info, error) { } func MarriageATXs(db sql.Executor, id ID) ([]types.ATXID, error) { - var atxs []types.ATXID + atxs := make(map[types.ATXID]struct{}) rows, err := db.Exec(` SELECT marriage_atx FROM marriages @@ -142,7 +142,7 @@ func MarriageATXs(db sql.Executor, id ID) ([]types.ATXID, error) { }, func(s *sql.Statement) bool { var atx types.ATXID s.ColumnBytes(0, atx[:]) - atxs = append(atxs, atx) + atxs[atx] = struct{}{} return true }) if err != nil { @@ -151,8 +151,10 @@ func MarriageATXs(db sql.Executor, id ID) ([]types.ATXID, error) { if rows == 0 { return nil, sql.ErrNotFound } - sort.Slice(atxs, func(i, j int) bool { return bytes.Compare(atxs[i].Bytes(), atxs[j].Bytes()) < 0 }) - return slices.Compact(atxs), nil + + return slices.SortedFunc(maps.Keys(atxs), func(a, b types.ATXID) int { + return bytes.Compare(a.Bytes(), b.Bytes()) + }), nil } func NodeIDsByID(db sql.Executor, id ID) ([]types.NodeID, error) { diff --git a/sql/marriage/marriages_test.go b/sql/marriage/marriages_test.go index bb5f940e65..46eed31191 100644 --- a/sql/marriage/marriages_test.go +++ b/sql/marriage/marriages_test.go @@ -48,6 +48,51 @@ func TestFind(t *testing.T) { require.ErrorIs(t, err, sql.ErrNotFound) } +func TestUpdateMarriageID(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + id1, err := marriage.NewID(db) + require.NoError(t, err) + require.NotZero(t, id1) + + nodeID1 := types.RandomNodeID() + nodeID2 := types.RandomNodeID() + info := marriage.Info{ + ID: id1, + NodeID: nodeID1, + ATX: types.RandomATXID(), + MarriageIndex: rand.N(256), + Target: types.RandomNodeID(), + Signature: types.RandomEdSignature(), + } + err = marriage.Add(db, info) + require.NoError(t, err) + + info.NodeID = nodeID2 + info.MarriageIndex = (info.MarriageIndex + 1) % 256 + err = marriage.Add(db, info) + require.NoError(t, err) + + id2, err := marriage.NewID(db) + require.NoError(t, err) + require.NotZero(t, id2) + require.NotEqual(t, id1, id2) + + err = marriage.UpdateMarriageID(db, id1, id2) + require.NoError(t, err) + + for _, nodeID := range []types.NodeID{nodeID1, nodeID2} { + id, err := marriage.FindIDByNodeID(db, nodeID) + require.NoError(t, err) + require.Equal(t, id2, id) + + info, err = marriage.FindByNodeID(db, nodeID) + require.NoError(t, err) + require.Equal(t, id2, info.ID) + } +} + func TestAdd(t *testing.T) { t.Parallel() db := statesql.InMemoryTest(t) diff --git a/syncer/interface.go b/syncer/interface.go index 95d8f17125..13367152ed 100644 --- a/syncer/interface.go +++ b/syncer/interface.go @@ -36,6 +36,7 @@ type atxSyncer interface { } type malSyncer interface { + EnsureLegacyInSync(parent context.Context, epochStart, epochEnd time.Time) error EnsureInSync(parent context.Context, epochStart, epochEnd time.Time) error DownloadLoop(parent context.Context) error } @@ -47,7 +48,6 @@ type fetcher interface { GetCert(context.Context, types.LayerID, types.BlockID, []p2p.Peer) (*types.Certificate, error) GetAtxs(context.Context, []types.ATXID, ...system.GetAtxOpt) error - GetMalfeasanceProofs(context.Context, []types.NodeID) error GetBallots(context.Context, []types.BallotID) error GetBlocks(context.Context, []types.BlockID) error RegisterPeerHashes(peer p2p.Peer, hashes []types.Hash32) diff --git a/syncer/malsync/mocks/mocks.go b/syncer/malsync/mocks/mocks.go index e1417c676a..3be31c849d 100644 --- a/syncer/malsync/mocks/mocks.go +++ b/syncer/malsync/mocks/mocks.go @@ -42,79 +42,156 @@ func (m *Mockfetcher) EXPECT() *MockfetcherMockRecorder { return m.recorder } -// GetMalfeasanceProofs mocks base method. -func (m *Mockfetcher) GetMalfeasanceProofs(arg0 context.Context, arg1 []types.NodeID) error { +// LegacyMalfeasanceProofs mocks base method. +func (m *Mockfetcher) LegacyMalfeasanceProofs(arg0 context.Context, arg1 []types.NodeID) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMalfeasanceProofs", arg0, arg1) + ret := m.ctrl.Call(m, "LegacyMalfeasanceProofs", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// GetMalfeasanceProofs indicates an expected call of GetMalfeasanceProofs. -func (mr *MockfetcherMockRecorder) GetMalfeasanceProofs(arg0, arg1 any) *MockfetcherGetMalfeasanceProofsCall { +// LegacyMalfeasanceProofs indicates an expected call of LegacyMalfeasanceProofs. +func (mr *MockfetcherMockRecorder) LegacyMalfeasanceProofs(arg0, arg1 any) *MockfetcherLegacyMalfeasanceProofsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMalfeasanceProofs", reflect.TypeOf((*Mockfetcher)(nil).GetMalfeasanceProofs), arg0, arg1) - return &MockfetcherGetMalfeasanceProofsCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LegacyMalfeasanceProofs", reflect.TypeOf((*Mockfetcher)(nil).LegacyMalfeasanceProofs), arg0, arg1) + return &MockfetcherLegacyMalfeasanceProofsCall{Call: call} } -// MockfetcherGetMalfeasanceProofsCall wrap *gomock.Call -type MockfetcherGetMalfeasanceProofsCall struct { +// MockfetcherLegacyMalfeasanceProofsCall wrap *gomock.Call +type MockfetcherLegacyMalfeasanceProofsCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockfetcherGetMalfeasanceProofsCall) Return(arg0 error) *MockfetcherGetMalfeasanceProofsCall { +func (c *MockfetcherLegacyMalfeasanceProofsCall) Return(arg0 error) *MockfetcherLegacyMalfeasanceProofsCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockfetcherGetMalfeasanceProofsCall) Do(f func(context.Context, []types.NodeID) error) *MockfetcherGetMalfeasanceProofsCall { +func (c *MockfetcherLegacyMalfeasanceProofsCall) Do(f func(context.Context, []types.NodeID) error) *MockfetcherLegacyMalfeasanceProofsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockfetcherGetMalfeasanceProofsCall) DoAndReturn(f func(context.Context, []types.NodeID) error) *MockfetcherGetMalfeasanceProofsCall { +func (c *MockfetcherLegacyMalfeasanceProofsCall) DoAndReturn(f func(context.Context, []types.NodeID) error) *MockfetcherLegacyMalfeasanceProofsCall { c.Call = c.Call.DoAndReturn(f) return c } -// GetMaliciousIDs mocks base method. -func (m *Mockfetcher) GetMaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]types.NodeID, error) { +// LegacyMaliciousIDs mocks base method. +func (m *Mockfetcher) LegacyMaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]types.NodeID, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMaliciousIDs", arg0, arg1) + ret := m.ctrl.Call(m, "LegacyMaliciousIDs", arg0, arg1) ret0, _ := ret[0].([]types.NodeID) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetMaliciousIDs indicates an expected call of GetMaliciousIDs. -func (mr *MockfetcherMockRecorder) GetMaliciousIDs(arg0, arg1 any) *MockfetcherGetMaliciousIDsCall { +// LegacyMaliciousIDs indicates an expected call of LegacyMaliciousIDs. +func (mr *MockfetcherMockRecorder) LegacyMaliciousIDs(arg0, arg1 any) *MockfetcherLegacyMaliciousIDsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaliciousIDs", reflect.TypeOf((*Mockfetcher)(nil).GetMaliciousIDs), arg0, arg1) - return &MockfetcherGetMaliciousIDsCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LegacyMaliciousIDs", reflect.TypeOf((*Mockfetcher)(nil).LegacyMaliciousIDs), arg0, arg1) + return &MockfetcherLegacyMaliciousIDsCall{Call: call} } -// MockfetcherGetMaliciousIDsCall wrap *gomock.Call -type MockfetcherGetMaliciousIDsCall struct { +// MockfetcherLegacyMaliciousIDsCall wrap *gomock.Call +type MockfetcherLegacyMaliciousIDsCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockfetcherGetMaliciousIDsCall) Return(arg0 []types.NodeID, arg1 error) *MockfetcherGetMaliciousIDsCall { +func (c *MockfetcherLegacyMaliciousIDsCall) Return(arg0 []types.NodeID, arg1 error) *MockfetcherLegacyMaliciousIDsCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockfetcherGetMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherGetMaliciousIDsCall { +func (c *MockfetcherLegacyMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherLegacyMaliciousIDsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockfetcherGetMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherGetMaliciousIDsCall { +func (c *MockfetcherLegacyMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherLegacyMaliciousIDsCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MalfeasanceProofs mocks base method. +func (m *Mockfetcher) MalfeasanceProofs(arg0 context.Context, arg1 []types.NodeID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MalfeasanceProofs", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// MalfeasanceProofs indicates an expected call of MalfeasanceProofs. +func (mr *MockfetcherMockRecorder) MalfeasanceProofs(arg0, arg1 any) *MockfetcherMalfeasanceProofsCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MalfeasanceProofs", reflect.TypeOf((*Mockfetcher)(nil).MalfeasanceProofs), arg0, arg1) + return &MockfetcherMalfeasanceProofsCall{Call: call} +} + +// MockfetcherMalfeasanceProofsCall wrap *gomock.Call +type MockfetcherMalfeasanceProofsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockfetcherMalfeasanceProofsCall) Return(arg0 error) *MockfetcherMalfeasanceProofsCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockfetcherMalfeasanceProofsCall) Do(f func(context.Context, []types.NodeID) error) *MockfetcherMalfeasanceProofsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockfetcherMalfeasanceProofsCall) DoAndReturn(f func(context.Context, []types.NodeID) error) *MockfetcherMalfeasanceProofsCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MaliciousIDs mocks base method. +func (m *Mockfetcher) MaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]types.NodeID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MaliciousIDs", arg0, arg1) + ret0, _ := ret[0].([]types.NodeID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MaliciousIDs indicates an expected call of MaliciousIDs. +func (mr *MockfetcherMockRecorder) MaliciousIDs(arg0, arg1 any) *MockfetcherMaliciousIDsCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaliciousIDs", reflect.TypeOf((*Mockfetcher)(nil).MaliciousIDs), arg0, arg1) + return &MockfetcherMaliciousIDsCall{Call: call} +} + +// MockfetcherMaliciousIDsCall wrap *gomock.Call +type MockfetcherMaliciousIDsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockfetcherMaliciousIDsCall) Return(arg0 []types.NodeID, arg1 error) *MockfetcherMaliciousIDsCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockfetcherMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherMaliciousIDsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockfetcherMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherMaliciousIDsCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go index a9b2605d36..7ccf4672d6 100644 --- a/syncer/malsync/syncer.go +++ b/syncer/malsync/syncer.go @@ -18,6 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/malsync" ) @@ -25,8 +26,10 @@ import ( type fetcher interface { SelectBestShuffled(int) []p2p.Peer - GetMaliciousIDs(context.Context, p2p.Peer) ([]types.NodeID, error) - GetMalfeasanceProofs(context.Context, []types.NodeID) error + LegacyMaliciousIDs(context.Context, p2p.Peer) ([]types.NodeID, error) + MaliciousIDs(context.Context, p2p.Peer) ([]types.NodeID, error) + LegacyMalfeasanceProofs(context.Context, []types.NodeID) error + MalfeasanceProofs(context.Context, []types.NodeID) error } type Opt func(*Syncer) @@ -237,8 +240,20 @@ func New(fetcher fetcher, db sql.Executor, localdb sql.LocalDatabase, opts ...Op return s } +func (s *Syncer) shouldSyncLegacy(epochStart, epochEnd time.Time) (bool, error) { + timestamp, err := malsync.LegacySyncState(s.localdb) + if err != nil { + return false, fmt.Errorf("error getting malfeasance sync state: %w", err) + } + if timestamp.Before(epochStart) { + return true, nil + } + cutoff := epochEnd.Sub(epochStart).Seconds() * s.cfg.MaxEpochFraction + return s.clock.Now().Sub(timestamp).Seconds() > cutoff, nil +} + func (s *Syncer) shouldSync(epochStart, epochEnd time.Time) (bool, error) { - timestamp, err := malsync.GetSyncState(s.localdb) + timestamp, err := malsync.SyncState(s.localdb) if err != nil { return false, fmt.Errorf("error getting malfeasance sync state: %w", err) } @@ -249,6 +264,25 @@ func (s *Syncer) shouldSync(epochStart, epochEnd time.Time) (bool, error) { return s.clock.Now().Sub(timestamp).Seconds() > cutoff, nil } +func (s *Syncer) downloadLegacy(parent context.Context, initial bool) error { + s.logger.Info("starting legacy malfeasance proof sync", log.ZContext(parent)) + defer s.logger.Debug("legacy malfeasance proof sync terminated", log.ZContext(parent)) + ctx, cancel := context.WithCancel(parent) + eg, ctx := errgroup.WithContext(ctx) + updates := make(chan malUpdate, s.cfg.MalfeasanceIDPeers) + eg.Go(func() error { + return s.downloadLegacyNodeIDs(ctx, initial, updates) + }) + eg.Go(func() error { + defer cancel() + return s.downloadLegacyMalfeasanceProofs(ctx, initial, updates) + }) + if err := eg.Wait(); err != nil { + return err + } + return parent.Err() +} + func (s *Syncer) download(parent context.Context, initial bool) error { s.logger.Info("starting malfeasance proof sync", log.ZContext(parent)) defer s.logger.Debug("malfeasance proof sync terminated", log.ZContext(parent)) @@ -268,6 +302,78 @@ func (s *Syncer) download(parent context.Context, initial bool) error { return parent.Err() } +func (s *Syncer) downloadLegacyNodeIDs(ctx context.Context, initial bool, updates chan<- malUpdate) error { + interval := s.cfg.IDRequestInterval + if initial { + interval = 0 + } + for { + if interval != 0 { + s.logger.Debug( + "pausing between legacy malicious node ID requests", + zap.Duration("duration", interval), + ) + select { + case <-ctx.Done(): + return nil + // TODO(ivan4th) this has to be randomized in a followup + // when sync will be scheduled in advance, in order to smooth out request rate across the network + case <-s.clock.After(interval): + } + } + + peers := s.fetcher.SelectBestShuffled(s.cfg.MalfeasanceIDPeers) + if len(peers) == 0 { + s.logger.Debug( + "don't have enough peers for legacy malfeasance sync", + zap.Int("nPeers", s.cfg.MalfeasanceIDPeers), + ) + if interval == 0 { + interval = s.cfg.RetryInterval + } + continue + } + + var eg errgroup.Group + for _, peer := range peers { + eg.Go(func() error { + malIDs, err := s.fetcher.LegacyMaliciousIDs(ctx, peer) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + s.peerErrMetric.Inc() + s.logger.Warn("failed to download legacy malicious node IDs", + log.ZContext(ctx), + zap.String("peer", peer.String()), + zap.Error(err), + ) + return nil + } + s.logger.Debug("downloaded legacy malicious node IDs", + log.ZContext(ctx), + zap.String("peer", peer.String()), + zap.Int("ids", len(malIDs)), + ) + select { + case <-ctx.Done(): + return ctx.Err() + case updates <- malUpdate{peer: peer, nodeIDs: malIDs}: + } + return nil + }) + } + + if err := eg.Wait(); err != nil { + return err + } + + if interval == 0 { + interval = s.cfg.RetryInterval + } + } +} + func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan<- malUpdate) error { interval := s.cfg.IDRequestInterval if initial { @@ -276,8 +382,9 @@ func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan for { if interval != 0 { s.logger.Debug( - "pausing between malfeasant node ID requests", - zap.Duration("duration", interval)) + "pausing between malicious node ID requests", + zap.Duration("duration", interval), + ) select { case <-ctx.Done(): return nil @@ -302,20 +409,20 @@ func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan var eg errgroup.Group for _, peer := range peers { eg.Go(func() error { - malIDs, err := s.fetcher.GetMaliciousIDs(ctx, peer) + malIDs, err := s.fetcher.MaliciousIDs(ctx, peer) if err != nil { if errors.Is(err, context.Canceled) { return nil } s.peerErrMetric.Inc() - s.logger.Warn("failed to download malfeasant node IDs", + s.logger.Warn("failed to download malicious node IDs", log.ZContext(ctx), zap.String("peer", peer.String()), zap.Error(err), ) return nil } - s.logger.Debug("downloaded malfeasant node IDs", + s.logger.Debug("downloaded malicious node IDs", log.ZContext(ctx), zap.String("peer", peer.String()), zap.Int("ids", len(malIDs)), @@ -339,6 +446,22 @@ func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan } } +func (s *Syncer) updateLegacyState(ctx context.Context) error { + if err := s.localdb.WithTxImmediate(ctx, func(tx sql.Transaction) error { + return malsync.UpdateLegacySyncState(tx, s.clock.Now()) + }); err != nil { + if ctx.Err() != nil { + // FIXME: with crawshaw, canceling the context which has been used to get + // a connection from the pool may cause "database: no free connection" errors. + // Related: #6273 + err = ctx.Err() + } + return fmt.Errorf("error updating legacy malsync state: %w", err) + } + + return nil +} + func (s *Syncer) updateState(ctx context.Context) error { if err := s.localdb.WithTxImmediate(ctx, func(tx sql.Transaction) error { return malsync.UpdateSyncState(tx, s.clock.Now()) @@ -355,7 +478,7 @@ func (s *Syncer) updateState(ctx context.Context) error { return nil } -func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, updates <-chan malUpdate) error { +func (s *Syncer) downloadLegacyMalfeasanceProofs(ctx context.Context, initial bool, updates <-chan malUpdate) error { var ( update malUpdate sst = newSyncState(s.cfg.RequestsLimit, initial) @@ -366,13 +489,13 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up if nothingToDownload { sst.done() if initial && sst.numSyncedPeers() >= s.cfg.MinSyncPeers { - if err := s.updateState(ctx); err != nil { + if err := s.updateLegacyState(ctx); err != nil { return err } - s.logger.Info("initial sync of malfeasance proofs completed", log.ZContext(ctx)) + s.logger.Info("initial sync of legacy malfeasance proofs completed", log.ZContext(ctx)) return nil } else if !initial && gotUpdate { - if err := s.updateState(ctx); err != nil { + if err := s.updateLegacyState(ctx); err != nil { return err } } @@ -380,7 +503,7 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up case <-ctx.Done(): return ctx.Err() case update = <-updates: - s.logger.Debug("malfeasance sync update", + s.logger.Debug("legacy malfeasance sync update", log.ZContext(ctx), zap.Int("count", len(update.nodeIDs)), ) @@ -392,7 +515,7 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up case <-ctx.Done(): return ctx.Err() case update = <-updates: - s.logger.Debug("malfeasance sync update", + s.logger.Debug("legacy malfeasance sync update", log.ZContext(ctx), zap.Int("count", len(update.nodeIDs)), ) @@ -412,57 +535,168 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up return isMalicious, err }) if err != nil { - return fmt.Errorf("error checking malfeasant node IDs: %w", err) + return fmt.Errorf("error checking legacy malicious node IDs: %w", err) } nothingToDownload = len(batch) == 0 + if len(batch) == 0 { + s.logger.Debug("no new legacy malicious identities", log.ZContext(ctx)) + continue + } - if len(batch) != 0 { - s.logger.Debug("retrieving malfeasant identities", + s.logger.Debug("retrieving legacy malicious identities", + log.ZContext(ctx), + zap.Int("count", len(batch)), + ) + batchError := &fetch.BatchError{} + err = s.fetcher.LegacyMalfeasanceProofs(ctx, batch) + switch { + case errors.Is(err, context.Canceled): + return ctx.Err() + case errors.As(err, &batchError): + for hash, err := range batchError.Errors { + nodeID := types.NodeID(hash) + switch { + case !sst.has(nodeID): + continue + case errors.Is(err, pubsub.ErrValidationReject): + sst.rejected(nodeID) + default: + sst.failed(nodeID) + } + } + case err != nil: + s.logger.Debug("failed to download malfeasance proofs", log.ZContext(ctx), - zap.Int("count", len(batch)), + log.NiceZapError(err), ) - err := s.fetcher.GetMalfeasanceProofs(ctx, batch) - if err != nil { - if errors.Is(err, context.Canceled) { - return ctx.Err() + } + } +} + +func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, updates <-chan malUpdate) error { + var ( + update malUpdate + sst = newSyncState(s.cfg.RequestsLimit, initial) + nothingToDownload = true + gotUpdate = false + ) + for { + if nothingToDownload { + sst.done() + if initial && sst.numSyncedPeers() >= s.cfg.MinSyncPeers { + if err := s.updateState(ctx); err != nil { + return err } - s.logger.Debug("failed to download malfeasance proofs", + s.logger.Info("initial sync of malfeasance proofs completed", log.ZContext(ctx)) + return nil + } else if !initial && gotUpdate { + if err := s.updateState(ctx); err != nil { + return err + } + } + select { + case <-ctx.Done(): + return ctx.Err() + case update = <-updates: + s.logger.Debug("malfeasance sync update", log.ZContext(ctx), - log.NiceZapError(err), + zap.Int("count", len(update.nodeIDs)), ) + sst.update(update) + gotUpdate = true } - batchError := &fetch.BatchError{} - if errors.As(err, &batchError) { - for hash, err := range batchError.Errors { - nodeID := types.NodeID(hash) - switch { - case !sst.has(nodeID): - continue - case errors.Is(err, pubsub.ErrValidationReject): - sst.rejected(nodeID) - default: - sst.failed(nodeID) - } + } else { + select { + case <-ctx.Done(): + return ctx.Err() + case update = <-updates: + s.logger.Debug("malfeasance sync update", + log.ZContext(ctx), + zap.Int("count", len(update.nodeIDs)), + ) + sst.update(update) + gotUpdate = true + default: + // If we have some hashes to fetch already, don't wait for + // another update + } + } + batch, err := sst.missing(s.cfg.MaxBatchSize, func(nodeID types.NodeID) (bool, error) { + // TODO(mafa): check multiple node IDs at once in a single SQL query + isMalicious, err := malfeasance.IsMalicious(s.db, nodeID) + if err != nil && errors.Is(err, sql.ErrNotFound) { + return false, nil + } + return isMalicious, err + }) + if err != nil { + return fmt.Errorf("error checking malicious node IDs: %w", err) + } + + nothingToDownload = len(batch) == 0 + if len(batch) == 0 { + s.logger.Debug("no new malicious identities", log.ZContext(ctx)) + continue + } + + s.logger.Debug("retrieving malicious identities", + log.ZContext(ctx), + zap.Int("count", len(batch)), + ) + batchError := &fetch.BatchError{} + err = s.fetcher.MalfeasanceProofs(ctx, batch) + switch { + case errors.Is(err, context.Canceled): + return ctx.Err() + case errors.As(err, &batchError): + for hash, err := range batchError.Errors { + nodeID := types.NodeID(hash) + switch { + case !sst.has(nodeID): + continue + case errors.Is(err, pubsub.ErrValidationReject): + sst.rejected(nodeID) + default: + sst.failed(nodeID) } } - } else { - s.logger.Debug("no new malfeasant identities", log.ZContext(ctx)) + case err != nil: + s.logger.Debug("failed to download malfeasance proofs", + log.ZContext(ctx), + log.NiceZapError(err), + ) } } } -func (s *Syncer) EnsureInSync(parent context.Context, epochStart, epochEnd time.Time) error { +func (s *Syncer) EnsureLegacyInSync(ctx context.Context, epochStart, epochEnd time.Time) error { + if shouldSync, err := s.shouldSyncLegacy(epochStart, epochEnd); err != nil { + return err + } else if !shouldSync { + return nil + } + return s.downloadLegacy(ctx, true) +} + +func (s *Syncer) EnsureInSync(ctx context.Context, epochStart, epochEnd time.Time) error { if shouldSync, err := s.shouldSync(epochStart, epochEnd); err != nil { return err } else if !shouldSync { return nil } - return s.download(parent, true) + return s.download(ctx, true) } func (s *Syncer) DownloadLoop(parent context.Context) error { - return s.download(parent, false) + eg, ctx := errgroup.WithContext(parent) + eg.Go(func() error { + return s.downloadLegacy(ctx, false) + }) + eg.Go(func() error { + return s.download(ctx, false) + }) + return eg.Wait() } type malUpdate struct { diff --git a/syncer/malsync/syncer_test.go b/syncer/malsync/syncer_test.go index 498cea440e..28c7e4c042 100644 --- a/syncer/malsync/syncer_test.go +++ b/syncer/malsync/syncer_test.go @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/malsync/mocks" ) @@ -136,16 +137,19 @@ func malData(ids ...string) []types.NodeID { } type tester struct { - tb testing.TB - syncer *Syncer - db sql.StateDatabase - cfg Config - fetcher *mocks.Mockfetcher - clock *clockwork.FakeClock - received map[types.NodeID]bool - attempts map[types.NodeID]int - peers []p2p.Peer - peerErrCount *fakeCounter + tb testing.TB + syncer *Syncer + db sql.StateDatabase + cfg Config + fetcher *mocks.Mockfetcher + clock *clockwork.FakeClock + + peers []p2p.Peer + peerErrCount *fakeCounter + receivedLegacy map[types.NodeID]bool + attemptsLegacy map[types.NodeID]int + received map[types.NodeID]bool + attempts map[types.NodeID]int } func newTester(tb testing.TB, cfg Config) *tester { @@ -162,49 +166,87 @@ func newTester(tb testing.TB, cfg Config) *tester { WithPeerErrMetric(peerErrCount), ) return &tester{ - tb: tb, - syncer: syncer, - db: db, - cfg: cfg, - fetcher: fetcher, - clock: clock, - received: make(map[types.NodeID]bool), - attempts: make(map[types.NodeID]int), - peers: []p2p.Peer{"a", "b", "c"}, - peerErrCount: peerErrCount, + tb: tb, + syncer: syncer, + db: db, + cfg: cfg, + fetcher: fetcher, + clock: clock, + receivedLegacy: make(map[types.NodeID]bool), + attemptsLegacy: make(map[types.NodeID]int), + received: make(map[types.NodeID]bool), + attempts: make(map[types.NodeID]int), + peers: []p2p.Peer{"a", "b", "c"}, + peerErrCount: peerErrCount, } } -func (tester *tester) expectGetMaliciousIDs() { - // "2" comes just from a single peer +func (tester *tester) expectLegacyMaliciousIDs() { + // "2" comes just from a single peer via legacy protocol tester.fetcher.EXPECT(). - GetMaliciousIDs(gomock.Any(), tester.peers[0]). + LegacyMaliciousIDs(gomock.Any(), tester.peers[0]). Return(malData("4", "1", "3", "2"), nil) for _, p := range tester.peers[1:] { tester.fetcher.EXPECT(). - GetMaliciousIDs(gomock.Any(), p). + LegacyMaliciousIDs(gomock.Any(), p). Return(malData("4", "1", "3"), nil) } } -func (tester *tester) expectGetProofs(errMap map[types.NodeID]error) { +func (tester *tester) expectMaliciousIDs() { + // "102" comes just from a single peer tester.fetcher.EXPECT(). - GetMalfeasanceProofs(gomock.Any(), gomock.Any()). + MaliciousIDs(gomock.Any(), tester.peers[0]). + Return(malData("104", "101", "103", "102"), nil) + for _, p := range tester.peers[1:] { + tester.fetcher.EXPECT(). + MaliciousIDs(gomock.Any(), p). + Return(malData("104", "101", "103"), nil) + } +} + +func (t *tester) expectLegacyProofs(errMap map[types.NodeID]error) { + t.fetcher.EXPECT(). + LegacyMalfeasanceProofs(gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, ids []types.NodeID) error { batchErr := &fetch.BatchError{ Errors: make(map[types.Hash32]error), } for _, id := range ids { - tester.attempts[id]++ - require.NotContains(tester.tb, tester.received, id) + t.attemptsLegacy[id]++ + require.NotContains(t.tb, t.receivedLegacy, id) if err := errMap[id]; err != nil { batchErr.Errors[types.Hash32(id)] = err continue } - tester.received[id] = true + t.receivedLegacy[id] = true proofData := codec.MustEncode(mproof(id)) - require.NoError(tester.tb, identities.SetMalicious( - tester.db, id, proofData, tester.syncer.clock.Now())) + require.NoError(t.tb, identities.SetMalicious(t.db, id, proofData, t.syncer.clock.Now())) + } + if len(batchErr.Errors) != 0 { + return batchErr + } + return nil + }).AnyTimes() +} + +func (t *tester) expectProofs(errMap map[types.NodeID]error) { + t.fetcher.EXPECT(). + MalfeasanceProofs(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, ids []types.NodeID) error { + batchErr := &fetch.BatchError{ + Errors: make(map[types.Hash32]error), + } + for _, id := range ids { + t.attempts[id]++ + require.NotContains(t.tb, t.received, id) + if err := errMap[id]; err != nil { + batchErr.Errors[types.Hash32(id)] = err + continue + } + t.received[id] = true + proof := codec.MustEncode(mproof(id)) + require.NoError(t.tb, malfeasance.AddProof(t.db, id, nil, proof, 1, t.syncer.clock.Now())) } if len(batchErr.Errors) != 0 { return batchErr @@ -218,26 +260,60 @@ func (tester *tester) expectPeers(peers []p2p.Peer) { } func TestSyncer(t *testing.T) { - t.Run("EnsureInSync", func(t *testing.T) { + t.Run("EnsureLegacyInSync", func(t *testing.T) { tester := newTester(t, DefaultConfig()) tester.expectPeers(tester.peers) - tester.expectGetMaliciousIDs() - tester.expectGetProofs(nil) + tester.expectLegacyMaliciousIDs() + tester.expectLegacyProofs(nil) epochStart := tester.clock.Now().Truncate(time.Second) epochEnd := epochStart.Add(10 * time.Minute) - require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd)) require.ElementsMatch(t, []types.NodeID{ nid("1"), nid("2"), nid("3"), nid("4"), - }, maps.Keys(tester.received)) + }, maps.Keys(tester.receivedLegacy)) require.Equal(t, map[types.NodeID]int{ nid("1"): 1, nid("2"): 1, nid("3"): 1, nid("4"): 1, + }, tester.attemptsLegacy) + tester.clock.Advance(1 * time.Minute) + // second call does nothing after recent sync + require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd)) + require.Zero(t, tester.peerErrCount.n) + }) + t.Run("EnsureInSync", func(t *testing.T) { + tester := newTester(t, DefaultConfig()) + tester.expectPeers(tester.peers) + tester.expectMaliciousIDs() + tester.expectProofs(nil) + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.ElementsMatch(t, []types.NodeID{ + nid("101"), nid("102"), nid("103"), nid("104"), + }, maps.Keys(tester.received)) + require.Equal(t, map[types.NodeID]int{ + nid("101"): 1, + nid("102"): 1, + nid("103"): 1, + nid("104"): 1, }, tester.attempts) tester.clock.Advance(1 * time.Minute) // second call does nothing after recent sync require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + }) + t.Run("EnsureLegacyInSync with no malfeasant identities", func(t *testing.T) { + tester := newTester(t, DefaultConfig()) + tester.expectPeers(tester.peers) + for _, p := range tester.peers { + tester.fetcher.EXPECT(). + LegacyMaliciousIDs(gomock.Any(), p). + Return(nil, nil) + } + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd)) require.Zero(t, tester.peerErrCount.n) }) t.Run("EnsureInSync with no malfeasant identities", func(t *testing.T) { @@ -245,13 +321,12 @@ func TestSyncer(t *testing.T) { tester.expectPeers(tester.peers) for _, p := range tester.peers { tester.fetcher.EXPECT(). - GetMaliciousIDs(gomock.Any(), p). + MaliciousIDs(gomock.Any(), p). Return(nil, nil) } epochStart := tester.clock.Now().Truncate(time.Second) epochEnd := epochStart.Add(10 * time.Minute) - require.NoError(t, - tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) require.Zero(t, tester.peerErrCount.n) }) t.Run("interruptible", func(t *testing.T) { @@ -260,10 +335,16 @@ func TestSyncer(t *testing.T) { cancel() tester.expectPeers([]p2p.Peer{"a"}) tester.fetcher.EXPECT(). - GetMaliciousIDs(gomock.Any(), gomock.Any()). + LegacyMaliciousIDs(gomock.Any(), gomock.Any()). Return(malData("1"), nil).AnyTimes() tester.fetcher.EXPECT(). - GetMalfeasanceProofs(gomock.Any(), gomock.Any()). + LegacyMalfeasanceProofs(gomock.Any(), gomock.Any()). + Return(errors.New("no atxs")).AnyTimes() + tester.fetcher.EXPECT(). + MaliciousIDs(gomock.Any(), gomock.Any()). + Return(malData("101"), nil).AnyTimes() + tester.fetcher.EXPECT(). + MalfeasanceProofs(gomock.Any(), gomock.Any()). Return(errors.New("no atxs")).AnyTimes() require.ErrorIs(t, tester.syncer.DownloadLoop(ctx), context.Canceled) }) @@ -280,98 +361,182 @@ func TestSyncer(t *testing.T) { require.ErrorIs(t, tester.syncer.DownloadLoop(ctx), context.Canceled) return nil }) - tester.clock.BlockUntilContext(context.Background(), 1) + tester.clock.BlockUntilContext(context.Background(), 2) tester.clock.Advance(tester.cfg.IDRequestInterval) ch <- nil - tester.clock.BlockUntilContext(context.Background(), 1) + ch <- nil + tester.clock.BlockUntilContext(context.Background(), 2) tester.clock.Advance(tester.cfg.IDRequestInterval) - tester.expectGetMaliciousIDs() - tester.expectGetProofs(nil) + tester.expectLegacyMaliciousIDs() + tester.expectLegacyProofs(nil) + tester.expectMaliciousIDs() + tester.expectProofs(nil) ch <- tester.peers - tester.clock.BlockUntilContext(context.Background(), 1) + ch <- tester.peers + tester.clock.BlockUntilContext(context.Background(), 2) cancel() eg.Wait() }) - t.Run("getting ids from MinSyncPeers peers is enough", func(t *testing.T) { + t.Run("getting ids from MinSyncPeers peers is enough - legacy", func(t *testing.T) { cfg := DefaultConfig() cfg.MinSyncPeers = 2 tester := newTester(t, cfg) tester.expectPeers(tester.peers) tester.fetcher.EXPECT(). - GetMaliciousIDs(gomock.Any(), tester.peers[0]). + LegacyMaliciousIDs(gomock.Any(), tester.peers[0]). Return(nil, errors.New("fail")) for _, p := range tester.peers[1:] { tester.fetcher.EXPECT(). - GetMaliciousIDs(gomock.Any(), p). + LegacyMaliciousIDs(gomock.Any(), p). Return(malData("4", "1", "3", "2"), nil) } - tester.expectGetProofs(nil) + tester.expectLegacyProofs(nil) epochStart := tester.clock.Now().Truncate(time.Second) epochEnd := epochStart.Add(10 * time.Minute) require.NoError(t, - tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd)) require.ElementsMatch(t, []types.NodeID{ nid("1"), nid("2"), nid("3"), nid("4"), - }, maps.Keys(tester.received)) + }, maps.Keys(tester.receivedLegacy)) require.Equal(t, map[types.NodeID]int{ nid("1"): 1, nid("2"): 1, nid("3"): 1, nid("4"): 1, + }, tester.attemptsLegacy) + tester.clock.Advance(1 * time.Minute) + // second call does nothing after recent sync + require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd)) + require.Equal(t, 1, tester.peerErrCount.n) + }) + t.Run("getting ids from MinSyncPeers peers is enough", func(t *testing.T) { + cfg := DefaultConfig() + cfg.MinSyncPeers = 2 + tester := newTester(t, cfg) + tester.expectPeers(tester.peers) + tester.fetcher.EXPECT(). + MaliciousIDs(gomock.Any(), tester.peers[0]). + Return(nil, errors.New("fail")) + for _, p := range tester.peers[1:] { + tester.fetcher.EXPECT(). + MaliciousIDs(gomock.Any(), p). + Return(malData("104", "101", "103", "102"), nil) + } + tester.expectProofs(nil) + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.ElementsMatch(t, []types.NodeID{ + nid("101"), nid("102"), nid("103"), nid("104"), + }, maps.Keys(tester.received)) + require.Equal(t, map[types.NodeID]int{ + nid("101"): 1, + nid("102"): 1, + nid("103"): 1, + nid("104"): 1, }, tester.attempts) tester.clock.Advance(1 * time.Minute) // second call does nothing after recent sync require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) require.Equal(t, 1, tester.peerErrCount.n) }) - t.Run("skip hashes after max retries", func(t *testing.T) { + t.Run("skip hashes after max retries - legacy", func(t *testing.T) { cfg := DefaultConfig() cfg.RequestsLimit = 3 tester := newTester(t, cfg) tester.expectPeers(tester.peers) - tester.expectGetMaliciousIDs() - tester.expectGetProofs(map[types.NodeID]error{ + tester.expectLegacyMaliciousIDs() + tester.expectLegacyProofs(map[types.NodeID]error{ nid("2"): errors.New("fail"), }) epochStart := tester.clock.Now().Truncate(time.Second) epochEnd := epochStart.Add(10 * time.Minute) - require.NoError(t, - tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd)) require.ElementsMatch(t, []types.NodeID{ nid("1"), nid("3"), nid("4"), - }, maps.Keys(tester.received)) + }, maps.Keys(tester.receivedLegacy)) require.Equal(t, map[types.NodeID]int{ nid("1"): 1, nid("2"): tester.cfg.RequestsLimit, nid("3"): 1, nid("4"): 1, + }, tester.attemptsLegacy) + tester.clock.Advance(1 * time.Minute) + // second call does nothing after recent sync + require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd)) + }) + t.Run("skip hashes after max retries", func(t *testing.T) { + cfg := DefaultConfig() + cfg.RequestsLimit = 3 + tester := newTester(t, cfg) + tester.expectPeers(tester.peers) + tester.expectMaliciousIDs() + tester.expectProofs(map[types.NodeID]error{ + nid("102"): errors.New("fail"), + }) + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.ElementsMatch(t, []types.NodeID{ + nid("101"), nid("103"), nid("104"), + }, maps.Keys(tester.received)) + require.Equal(t, map[types.NodeID]int{ + nid("101"): 1, + nid("102"): tester.cfg.RequestsLimit, + nid("103"): 1, + nid("104"): 1, }, tester.attempts) tester.clock.Advance(1 * time.Minute) // second call does nothing after recent sync require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) }) - t.Run("skip hashes after validation reject", func(t *testing.T) { + t.Run("skip hashes after validation reject - legacy", func(t *testing.T) { tester := newTester(t, DefaultConfig()) tester.expectPeers(tester.peers) - tester.expectGetMaliciousIDs() - tester.expectGetProofs(map[types.NodeID]error{ + tester.expectLegacyMaliciousIDs() + tester.expectLegacyProofs(map[types.NodeID]error{ // note that "2" comes just from a single peer - // (see expectGetMaliciousIDs) + // (see expectMaliciousIDs) nid("2"): pubsub.ErrValidationReject, }) epochStart := tester.clock.Now().Truncate(time.Second) epochEnd := epochStart.Add(10 * time.Minute) - require.NoError(t, - tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd)) require.ElementsMatch(t, []types.NodeID{ nid("1"), nid("3"), nid("4"), - }, maps.Keys(tester.received)) + }, maps.Keys(tester.receivedLegacy)) require.Equal(t, map[types.NodeID]int{ nid("1"): 1, nid("2"): 1, nid("3"): 1, nid("4"): 1, + }, tester.attemptsLegacy) + tester.clock.Advance(1 * time.Minute) + // second call does nothing after recent sync + require.NoError(t, tester.syncer.EnsureLegacyInSync(context.Background(), epochStart, epochEnd)) + }) + t.Run("skip hashes after validation reject", func(t *testing.T) { + tester := newTester(t, DefaultConfig()) + tester.expectPeers(tester.peers) + tester.expectMaliciousIDs() + tester.expectProofs(map[types.NodeID]error{ + // note that "102" comes just from a single peer + // (see expectMaliciousIDs) + nid("102"): pubsub.ErrValidationReject, + }) + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.ElementsMatch(t, []types.NodeID{ + nid("101"), nid("103"), nid("104"), + }, maps.Keys(tester.received)) + require.Equal(t, map[types.NodeID]int{ + nid("101"): 1, + nid("102"): 1, + nid("103"): 1, + nid("104"): 1, }, tester.attempts) tester.clock.Advance(1 * time.Minute) // second call does nothing after recent sync diff --git a/syncer/mocks/mocks.go b/syncer/mocks/mocks.go index 8ad4707d01..8c567af5a3 100644 --- a/syncer/mocks/mocks.go +++ b/syncer/mocks/mocks.go @@ -381,44 +381,6 @@ func (c *MockfetchLogicGetLayerOpinionsCall) DoAndReturn(f func(context.Context, return c } -// GetMalfeasanceProofs mocks base method. -func (m *MockfetchLogic) GetMalfeasanceProofs(arg0 context.Context, arg1 []types.NodeID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMalfeasanceProofs", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetMalfeasanceProofs indicates an expected call of GetMalfeasanceProofs. -func (mr *MockfetchLogicMockRecorder) GetMalfeasanceProofs(arg0, arg1 any) *MockfetchLogicGetMalfeasanceProofsCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMalfeasanceProofs", reflect.TypeOf((*MockfetchLogic)(nil).GetMalfeasanceProofs), arg0, arg1) - return &MockfetchLogicGetMalfeasanceProofsCall{Call: call} -} - -// MockfetchLogicGetMalfeasanceProofsCall wrap *gomock.Call -type MockfetchLogicGetMalfeasanceProofsCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockfetchLogicGetMalfeasanceProofsCall) Return(arg0 error) *MockfetchLogicGetMalfeasanceProofsCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockfetchLogicGetMalfeasanceProofsCall) Do(f func(context.Context, []types.NodeID) error) *MockfetchLogicGetMalfeasanceProofsCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockfetchLogicGetMalfeasanceProofsCall) DoAndReturn(f func(context.Context, []types.NodeID) error) *MockfetchLogicGetMalfeasanceProofsCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // PeerEpochInfo mocks base method. func (m *MockfetchLogic) PeerEpochInfo(arg0 context.Context, arg1 p2p.Peer, arg2 types.EpochID) (*fetch.EpochData, error) { m.ctrl.T.Helper() @@ -816,6 +778,44 @@ func (c *MockmalSyncerEnsureInSyncCall) DoAndReturn(f func(context.Context, time return c } +// EnsureLegacyInSync mocks base method. +func (m *MockmalSyncer) EnsureLegacyInSync(parent context.Context, epochStart, epochEnd time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnsureLegacyInSync", parent, epochStart, epochEnd) + ret0, _ := ret[0].(error) + return ret0 +} + +// EnsureLegacyInSync indicates an expected call of EnsureLegacyInSync. +func (mr *MockmalSyncerMockRecorder) EnsureLegacyInSync(parent, epochStart, epochEnd any) *MockmalSyncerEnsureLegacyInSyncCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureLegacyInSync", reflect.TypeOf((*MockmalSyncer)(nil).EnsureLegacyInSync), parent, epochStart, epochEnd) + return &MockmalSyncerEnsureLegacyInSyncCall{Call: call} +} + +// MockmalSyncerEnsureLegacyInSyncCall wrap *gomock.Call +type MockmalSyncerEnsureLegacyInSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockmalSyncerEnsureLegacyInSyncCall) Return(arg0 error) *MockmalSyncerEnsureLegacyInSyncCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockmalSyncerEnsureLegacyInSyncCall) Do(f func(context.Context, time.Time, time.Time) error) *MockmalSyncerEnsureLegacyInSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockmalSyncerEnsureLegacyInSyncCall) DoAndReturn(f func(context.Context, time.Time, time.Time) error) *MockmalSyncerEnsureLegacyInSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Mockfetcher is a mock of fetcher interface. type Mockfetcher struct { ctrl *gomock.Controller @@ -1076,44 +1076,6 @@ func (c *MockfetcherGetLayerOpinionsCall) DoAndReturn(f func(context.Context, p2 return c } -// GetMalfeasanceProofs mocks base method. -func (m *Mockfetcher) GetMalfeasanceProofs(arg0 context.Context, arg1 []types.NodeID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMalfeasanceProofs", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// GetMalfeasanceProofs indicates an expected call of GetMalfeasanceProofs. -func (mr *MockfetcherMockRecorder) GetMalfeasanceProofs(arg0, arg1 any) *MockfetcherGetMalfeasanceProofsCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMalfeasanceProofs", reflect.TypeOf((*Mockfetcher)(nil).GetMalfeasanceProofs), arg0, arg1) - return &MockfetcherGetMalfeasanceProofsCall{Call: call} -} - -// MockfetcherGetMalfeasanceProofsCall wrap *gomock.Call -type MockfetcherGetMalfeasanceProofsCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockfetcherGetMalfeasanceProofsCall) Return(arg0 error) *MockfetcherGetMalfeasanceProofsCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockfetcherGetMalfeasanceProofsCall) Do(f func(context.Context, []types.NodeID) error) *MockfetcherGetMalfeasanceProofsCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockfetcherGetMalfeasanceProofsCall) DoAndReturn(f func(context.Context, []types.NodeID) error) *MockfetcherGetMalfeasanceProofsCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // PeerEpochInfo mocks base method. func (m *Mockfetcher) PeerEpochInfo(arg0 context.Context, arg1 p2p.Peer, arg2 types.EpochID) (*fetch.EpochData, error) { m.ctrl.T.Helper() diff --git a/syncer/syncer.go b/syncer/syncer.go index af05f1463e..d227401b7d 100644 --- a/syncer/syncer.go +++ b/syncer/syncer.go @@ -788,10 +788,17 @@ func (s *Syncer) setStateAfterSync(ctx context.Context, success bool) { } } -func (s *Syncer) syncMalfeasance(ctx context.Context, epoch types.EpochID) error { +func (s *Syncer) syncMalfeasance(parent context.Context, epoch types.EpochID) error { epochStart := s.ticker.LayerToTime(epoch.FirstLayer()) epochEnd := s.ticker.LayerToTime(epoch.Add(1).FirstLayer()) - if err := s.malsyncer.EnsureInSync(ctx, epochStart, epochEnd); err != nil { + eg, ctx := errgroup.WithContext(parent) + eg.Go(func() error { + return s.malsyncer.EnsureLegacyInSync(ctx, epochStart, epochEnd) + }) + eg.Go(func() error { + return s.malsyncer.EnsureInSync(ctx, epochStart, epochEnd) + }) + if err := eg.Wait(); err != nil { return fmt.Errorf("syncing malfeasance proof: %w", err) } return nil diff --git a/syncer/syncer_test.go b/syncer/syncer_test.go index febf439f16..cb25fc905c 100644 --- a/syncer/syncer_test.go +++ b/syncer/syncer_test.go @@ -85,6 +85,11 @@ type testSyncer struct { } func (ts *testSyncer) expectMalEnsureInSync(current types.LayerID) { + ts.mMalSyncer.EXPECT().EnsureLegacyInSync( + gomock.Any(), + ts.mTicker.LayerToTime(current.GetEpoch().FirstLayer()), + ts.mTicker.LayerToTime(current.GetEpoch().Add(1).FirstLayer()), + ) ts.mMalSyncer.EXPECT().EnsureInSync( gomock.Any(), ts.mTicker.LayerToTime(current.GetEpoch().FirstLayer()), @@ -392,6 +397,7 @@ func TestSynchronize_FetchMalfeasanceFailed(t *testing.T) { ts.mTicker.advanceToLayer(current) lyr := current.Sub(1) ts.mAtxSyncer.EXPECT().Download(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + ts.mMalSyncer.EXPECT().EnsureLegacyInSync(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("meh")) ts.mMalSyncer.EXPECT().EnsureInSync(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("meh")) require.False(t, ts.syncer.synchronize(context.Background())) diff --git a/systest/Makefile b/systest/Makefile index 0d854af452..c16c27cd52 100644 --- a/systest/Makefile +++ b/systest/Makefile @@ -10,7 +10,7 @@ poet_image ?= $(org)/poet:v0.10.10 post_service_image ?= $(org)/post-service:v0.8.4 post_init_image ?= $(org)/postcli:v0.12.10 smesher_image ?= $(org)/go-spacemesh-dev:$(version_info) -old_smesher_image ?= $(org)/go-spacemesh-dev:v1.7.7 +old_smesher_image ?= $(org)/go-spacemesh-dev:8d89a07 # TODO: update this after merging, new version updated config bs_image ?= $(org)/go-spacemesh-dev-bs:$(version_info) test_id ?= systest-$(version_info) diff --git a/systest/cluster/nodes.go b/systest/cluster/nodes.go index 685fe9c44a..9b7116b02a 100644 --- a/systest/cluster/nodes.go +++ b/systest/cluster/nodes.go @@ -615,30 +615,27 @@ func areContainersReady(pod *apiv1.Pod) bool { } func waitPod(ctx *testcontext.Context, id string) (*apiv1.Pod, error) { - watcher, err := ctx.Client.CoreV1().Pods(ctx.Namespace).Watch(ctx, apimetav1.ListOptions{ - LabelSelector: labelSelector(id), - }) + watcherCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + watcher, err := ctx.Client.CoreV1(). + Pods(ctx.Namespace). + Watch(watcherCtx, apimetav1.ListOptions{ + LabelSelector: labelSelector(id), + }) if err != nil { return nil, err } defer watcher.Stop() - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case ev, open := <-watcher.ResultChan(): - if !open { - return nil, fmt.Errorf("watcher is terminated while waiting for pod with id %v", id) - } - pod, ok := ev.Object.(*apiv1.Pod) - if !ok { - continue - } - if pod.Status.Phase == apiv1.PodRunning && areContainersReady(pod) { - return pod, nil - } + for ev := range watcher.ResultChan() { + pod, ok := ev.Object.(*apiv1.Pod) + if !ok { + continue + } + if pod.Status.Phase == apiv1.PodRunning && areContainersReady(pod) { + return pod, nil } } + return nil, fmt.Errorf("watcher terminated while waiting for pod with id %v: %w", id, ctx.Err()) } func nodeLabels(name, id string) map[string]string { diff --git a/systest/parameters/fastnet/smesher.json b/systest/parameters/fastnet/smesher.json index fca3cf6fc7..05d066b454 100644 --- a/systest/parameters/fastnet/smesher.json +++ b/systest/parameters/fastnet/smesher.json @@ -44,10 +44,15 @@ } }, "logging": { - "log-encoder": "json", - "txHandler": "debug", + "atxHandler": "debug", + "fetcher": "debug", "grpc": "debug", + "log-encoder": "json", + "malfeasance": "debug", + "malfeasance2": "debug", + "nipostBuilder": "debug", + "nipostValidator": "debug", "sync": "debug", - "fetcher": "debug" + "txHandler": "debug" } } diff --git a/systest/tests/common.go b/systest/tests/common.go index 8b488630bd..99950fb61b 100644 --- a/systest/tests/common.go +++ b/systest/tests/common.go @@ -9,6 +9,7 @@ import ( "time" pb "github.com/spacemeshos/api/release/go/spacemesh/v1" + pb2 "github.com/spacemeshos/api/release/go/spacemesh/v2beta1" "github.com/stretchr/testify/require" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -228,12 +229,12 @@ func malfeasanceStream( ctx context.Context, node *cluster.NodeClient, logger *zap.Logger, - collector func(*pb.MalfeasanceStreamResponse) (bool, error), + collector func(*pb2.MalfeasanceProof) (bool, error), ) error { retries := 0 BACKOFF: - meshapi := pb.NewMeshServiceClient(node.PubConn()) - proofs, err := meshapi.MalfeasanceStream(ctx, &pb.MalfeasanceStreamRequest{IncludeProof: true}) + malapi := pb2.NewMalfeasanceStreamServiceClient(node.PrivConn()) + proofs, err := malapi.Stream(ctx, &pb2.MalfeasanceStreamRequest{Watch: true}) if err != nil { return err } diff --git a/systest/tests/distributed_post_verification_test.go b/systest/tests/distributed_post_verification_test.go index 513466bdde..4b50846fda 100644 --- a/systest/tests/distributed_post_verification_test.go +++ b/systest/tests/distributed_post_verification_test.go @@ -1,21 +1,27 @@ package tests import ( + "bytes" "context" + "errors" "fmt" "os" "path/filepath" + "slices" "testing" "time" "github.com/libp2p/go-libp2p/core/peer" pb "github.com/spacemeshos/api/release/go/spacemesh/v1" + pb2 "github.com/spacemeshos/api/release/go/spacemesh/v2beta1" + "github.com/spacemeshos/go-scale" "github.com/spacemeshos/post/shared" "github.com/spacemeshos/post/verifying" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/activation" @@ -24,10 +30,10 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/config" "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/fetch" "github.com/spacemeshos/go-spacemesh/fetch/peers" - mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/handshake" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" @@ -41,50 +47,74 @@ import ( "github.com/spacemeshos/go-spacemesh/timesync" ) +type builtAtx interface { + ID() types.ATXID + + scale.Encodable + zapcore.ObjectMarshaler +} + +func version(cfg *config.Config, publish types.EpochID) types.AtxVersion { + cfg.AtxVersions[0] = types.AtxV1 + epochs := maps.Keys(cfg.AtxVersions) + slices.Sort(epochs) + version := types.AtxV1 + for _, epoch := range epochs { + if publish >= epoch { + version = cfg.AtxVersions[epoch] + } + } + return version +} + // TestPostMalfeasanceProof tests that nodes can detect an invalid PoST and create a malfeasance proof against it. func TestPostMalfeasanceProof(t *testing.T) { t.Parallel() - testDir := t.TempDir() ctx := testcontext.New(t) - logger := ctx.Log.Desugar().WithOptions(zap.IncreaseLevel(zap.InfoLevel), zap.WithCaller(false)) // Prepare cluster ctx.PoetSize = 1 // one poet guarantees everybody gets the same proof - ctx.ClusterSize = 3 + ctx.ClusterSize = 8 cl := cluster.New(ctx, cluster.WithKeys(10)) require.NoError(t, cl.AddBootnodes(ctx, 1)) require.NoError(t, cl.AddBootstrappers(ctx)) require.NoError(t, cl.AddPoets(ctx)) require.NoError(t, cl.AddSmeshers(ctx, ctx.ClusterSize-cl.Total(), cluster.WithFlags(cluster.PostK3(1)))) + logger := ctx.Log.Desugar().WithOptions(zap.IncreaseLevel(zap.InfoLevel), zap.WithCaller(false)) + cfg := getConfig(t, logger, cl, ctx) + + // Test malfeasance for each ATX version, malfeasance1 in first epoch + publishEpoch := types.EpochID(1) + testPostMalfeasance(t, cfg, cl, logger, ctx, publishEpoch) + + for k, v := range cfg.AtxVersions { + if v == 2 { + publishEpoch = types.EpochID(k) + } + } + + // malfeasance2 in first epoch with ATXv2 + testPostMalfeasance(t, cfg, cl, logger, ctx, publishEpoch) +} + +func testPostMalfeasance( + t *testing.T, + cfg *config.Config, + cl *cluster.Cluster, + logger *zap.Logger, + ctx *testcontext.Context, + publishEpoch types.EpochID, +) { // Prepare config - cfg, err := cl.NodeConfig(ctx) - require.NoError(t, err) + testDir := t.TempDir() - types.SetLayersPerEpoch(cfg.LayersPerEpoch) cfg.DataDirParent = testDir cfg.SMESHING.Opts.DataDir = filepath.Join(testDir, "post-data") cfg.P2P.DataDir = filepath.Join(testDir, "p2p-dir") require.NoError(t, os.Mkdir(cfg.P2P.DataDir, os.ModePerm)) - cfg.POET.RequestTimeout = time.Minute - cfg.POET.MaxRequestRetries = 10 - - var bootnodes []*cluster.NodeClient - for i := 0; i < cl.Bootnodes(); i++ { - bootnodes = append(bootnodes, cl.Client(i)) - } - - endpoints, err := cluster.ExtractP2PEndpoints(ctx, bootnodes) - require.NoError(t, err) - cfg.P2P.Bootnodes = endpoints - cfg.P2P.PrivateNetwork = true - cfg.Bootstrap.URL = cluster.BootstrapperGlobalEndpoint(ctx.Namespace, 0) - cfg.P2P.MinPeers = 2 - ctx.Log.Debugw("Prepared config", "cfg", cfg) - - goldenATXID := cl.GoldenATX() signer, err := signing.NewEdSigner(signing.WithPrefix(cl.GenesisID().Bytes())) require.NoError(t, err) @@ -99,11 +129,11 @@ func TestPostMalfeasanceProof(t *testing.T) { logger.Info("p2p host created", zap.Stringer("id", host.ID())) host.Register(pubsub.AtxProtocol, func(context.Context, peer.ID, []byte) error { return nil }) require.NoError(t, host.Start()) - t.Cleanup(func() { assert.NoError(t, host.Stop()) }) + defer host.Stop() db := statesql.InMemoryTest(t) cdb := datastore.NewCachedDB(db, zap.NewNop()) - t.Cleanup(func() { assert.NoError(t, cdb.Close()) }) + defer cdb.Close() clock, err := timesync.NewClock( timesync.WithLayerDuration(cfg.LayerDuration), @@ -112,7 +142,7 @@ func TestPostMalfeasanceProof(t *testing.T) { timesync.WithLogger(logger.Named("clock")), ) require.NoError(t, err) - t.Cleanup(clock.Close) + defer clock.Close() proposalsStore := store.New( store.WithEvictedLayer(clock.CurrentLayer()), @@ -120,7 +150,7 @@ func TestPostMalfeasanceProof(t *testing.T) { store.WithCapacity(cfg.Tortoise.Zdist+1), ) - fetcher, err := fetch.NewFetch(cdb, proposalsStore, host, + fetcher, err := fetch.NewFetch(db, proposalsStore, host, peers.New(), fetch.WithContext(ctx), fetch.WithConfig(cfg.FETCH), @@ -138,10 +168,11 @@ func TestPostMalfeasanceProof(t *testing.T) { fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), ) require.NoError(t, fetcher.Start()) - t.Cleanup(fetcher.Stop) + defer fetcher.Stop() ctrl := gomock.NewController(t) syncer := activation.NewMocksyncer(ctrl) @@ -173,7 +204,7 @@ func TestPostMalfeasanceProof(t *testing.T) { builder, ) require.NoError(t, postSupervisor.Start(cfg.POSTService, cfg.SMESHING.Opts, signer)) - t.Cleanup(func() { assert.NoError(t, postSupervisor.Stop(false)) }) + defer postSupervisor.Stop(false) // 2. create ATX with invalid POST labels grpcPostService := grpcserver.NewPostService( @@ -190,7 +221,7 @@ func TestPostMalfeasanceProof(t *testing.T) { ) require.NoError(t, err) require.NoError(t, grpcPrivateServer.Start()) - t.Cleanup(func() { assert.NoError(t, grpcPrivateServer.Close()) }) + defer grpcPrivateServer.Close() localDb := localsql.InMemoryTest(t) certClient := activation.NewCertifierClient(db, localDb, logger.Named("certifier")) @@ -232,153 +263,231 @@ func TestPostMalfeasanceProof(t *testing.T) { require.NoError(t, err) // 2.1. Create initial POST - var challenge *wire.NIPostChallengeV1 + var client activation.PostClient for { - client, err := grpcPostService.Client(signer.NodeID()) + client, err = grpcPostService.Client(signer.NodeID()) if err != nil { - ctx.Log.Info("waiting for poet service to connect") + logger.Info("waiting for post service to connect") time.Sleep(time.Second) continue } - ctx.Log.Info("poet service to connected") - post, postInfo, err := client.Proof(ctx, shared.ZeroChallenge) - require.NoError(t, err) - - err = nipost.AddPost(localDb, signer.NodeID(), nipost.Post{ - Nonce: post.Nonce, - Indices: post.Indices, - Pow: post.Pow, - Challenge: shared.ZeroChallenge, - NumUnits: postInfo.NumUnits, - CommitmentATX: postInfo.CommitmentATX, - VRFNonce: *postInfo.Nonce, - }) - require.NoError(t, err) - - challenge = &wire.NIPostChallengeV1{ - PrevATXID: types.EmptyATXID, - PublishEpoch: 1, - PositioningATXID: goldenATXID, - CommitmentATXID: &postInfo.CommitmentATX, - InitialPost: &wire.PostV1{ - Nonce: post.Nonce, - Indices: post.Indices, - Pow: post.Pow, - }, - } break } + logger.Info("post service connected") + initialPost, initialPostInfo, err := client.Proof(ctx, shared.ZeroChallenge) + require.NoError(t, err) + + err = nipost.AddPost(localDb, signer.NodeID(), nipost.Post{ + Nonce: initialPost.Nonce, + Indices: initialPost.Indices, + Pow: initialPost.Pow, + Challenge: shared.ZeroChallenge, + NumUnits: initialPostInfo.NumUnits, + CommitmentATX: initialPostInfo.CommitmentATX, + VRFNonce: *initialPostInfo.Nonce, + }) + require.NoError(t, err) + + registerEpoch := publishEpoch - 1 + logger.Info("waiting for epoch to register at poet", + zap.Uint32("register_epoch", uint32(registerEpoch)), + zap.Uint32("publish_epoch", uint32(publishEpoch)), + ) + select { + case <-ctx.Done(): + require.Fail(t, "context canceled") + return + case <-clock.AwaitLayer(registerEpoch.FirstLayer()): + } + logger.Info("reached register epoch", zap.Uint32("register_epoch", uint32(registerEpoch))) + + registerEpoch = clock.CurrentLayer().GetEpoch() + publishEpoch = registerEpoch + 1 nipostChallenge := &types.NIPostChallenge{ - PublishEpoch: challenge.PublishEpoch, + PublishEpoch: publishEpoch, PrevATXID: types.EmptyATXID, - PositioningATX: challenge.PositioningATXID, - CommitmentATX: challenge.CommitmentATXID, + PositioningATX: cl.GoldenATX(), + CommitmentATX: &initialPostInfo.CommitmentATX, InitialPost: &types.Post{ - Nonce: challenge.InitialPost.Nonce, - Indices: challenge.InitialPost.Indices, - Pow: challenge.InitialPost.Pow, + Nonce: initialPost.Nonce, + Indices: initialPost.Indices, Pow: initialPost.Pow, }, } err = nipost.AddChallenge(localDb, signer.NodeID(), nipostChallenge) require.NoError(t, err) - nipost, err := nipostBuilder.BuildNIPost(ctx, signer, challenge.Hash(), nipostChallenge) + version := version(cfg, nipostChallenge.PublishEpoch) + var challengeHash types.Hash32 + switch version { + case types.AtxV1: + challengeHash = wire.NIPostChallengeToWireV1(nipostChallenge).Hash() + case types.AtxV2: + challengeHash = wire.NIPostChallengeToWireV2(nipostChallenge).Hash() + default: + require.Fail(t, fmt.Sprintf("unsupported ATX version: %v", version)) + } + nipost, err := nipostBuilder.BuildNIPost(ctx, signer, challengeHash, nipostChallenge) require.NoError(t, err) // 2.2 Create ATX with invalid POST + logger.Info("invalidating PoST") + invalidPost := false for i := range nipost.Post.Indices { - nipost.Post.Indices[i] += 1 + for range 256 { + nipost.Post.Indices[i] += 1 + err = verifier.Verify(ctx, (*shared.Proof)(nipost.Post), &shared.ProofMetadata{ + NodeId: signer.NodeID().Bytes(), + CommitmentAtxId: nipostChallenge.CommitmentATX.Bytes(), + NumUnits: nipost.NumUnits, + Challenge: nipost.PostMetadata.Challenge, + LabelsPerUnit: nipost.PostMetadata.LabelsPerUnit, + }) + var invalidIdxError *verifying.ErrInvalidIndex + if errors.As(err, &invalidIdxError) { + invalidPost = true + break + } + } + if invalidPost { + break + } } + require.True(t, invalidPost, "expected invalid POST") + logger.Info("PoST invalidated") - // Sanity check that the POST is invalid - err = verifier.Verify(ctx, (*shared.Proof)(nipost.Post), &shared.ProofMetadata{ - NodeId: signer.NodeID().Bytes(), - CommitmentAtxId: challenge.CommitmentATXID.Bytes(), - NumUnits: nipost.NumUnits, - Challenge: nipost.PostMetadata.Challenge, - LabelsPerUnit: nipost.PostMetadata.LabelsPerUnit, - }) - var invalidIdxError *verifying.ErrInvalidIndex - require.ErrorAs(t, err, &invalidIdxError) - - nodeID := signer.NodeID() - atx := wire.ActivationTxV1{ - InnerActivationTxV1: wire.InnerActivationTxV1{ - NIPostChallengeV1: *challenge, - Coinbase: types.Address{1, 2, 3, 4}, - NumUnits: nipost.NumUnits, - NIPost: wire.NiPostToWireV1(nipost.NIPost), - NodeID: &nodeID, - VRFNonce: (*uint64)(&nipost.VRFNonce), - }, + var ( + atx builtAtx + expectedDomain pb2.MalfeasanceProof_MalfeasanceDomain + expectedType uint32 + ) + expectedProperties := make(map[string]string) + switch version { + case types.AtxV1: + watx := &wire.ActivationTxV1{ + InnerActivationTxV1: wire.InnerActivationTxV1{ + NIPostChallengeV1: *wire.NIPostChallengeToWireV1(nipostChallenge), + Coinbase: types.Address{1, 2, 3, 4}, + NumUnits: nipost.NumUnits, + NIPost: wire.NiPostToWireV1(nipost.NIPost), + VRFNonce: (*uint64)(&nipost.VRFNonce), + }, + } + watx.Sign(signer) + atx = watx + expectedDomain = pb2.MalfeasanceProof_DOMAIN_UNSPECIFIED + expectedType = 4 + expectedProperties["atx"] = atx.ID().String() + case types.AtxV2: + watx := &wire.ActivationTxV2{ + PublishEpoch: nipostChallenge.PublishEpoch, + PositioningATX: nipostChallenge.PositioningATX, + Coinbase: types.Address{1, 2, 3, 4}, + VRFNonce: (uint64)(nipost.VRFNonce), + NIPosts: []wire.NIPostV2{ + { + Membership: wire.MerkleProofV2{ + Nodes: nipost.NIPost.Membership.Nodes, + }, + Challenge: types.Hash32(nipost.PostMetadata.Challenge), + Posts: []wire.SubPostV2{ + { + Post: *wire.PostToWireV1(nipost.Post), + NumUnits: nipost.NumUnits, + MembershipLeafIndex: nipost.NIPost.Membership.LeafIndex, + }, + }, + }, + }, + Initial: &wire.InitialAtxPartsV2{ + Post: *wire.PostToWireV1(nipostChallenge.InitialPost), + CommitmentATX: *nipostChallenge.CommitmentATX, + }, + } + watx.Sign(signer) + atx = watx + expectedDomain = pb2.MalfeasanceProof_DOMAIN_ACTIVATION + expectedType = 0 + expectedProperties["type"] = "InvalidPoSTProof" + expectedProperties["atx"] = atx.ID().String() + default: + require.Fail(t, fmt.Sprintf("unsupported ATX version: %v", version)) + return } - atx.Sign(signer) // 3. Wait for publish epoch require.NoError(t, cl.WaitAll(ctx)) - epoch := atx.PublishEpoch - logger.Sugar().Infow("waiting for publish epoch", "epoch", epoch, "layer", epoch.FirstLayer()) + logger.Info("waiting for publish epoch", + zap.Uint32("epoch", publishEpoch.Uint32()), + zap.Uint32("layer", publishEpoch.FirstLayer().Uint32()), + ) err = layersStream(ctx, cl.Client(0), logger, func(resp *pb.LayerStreamResponse) (bool, error) { logger.Info("new layer", zap.Uint32("layer", resp.Layer.Number.Number)) - return resp.Layer.Number.Number < epoch.FirstLayer().Uint32(), nil + return resp.Layer.Number.Number < publishEpoch.FirstLayer().Uint32(), nil }) require.NoError(t, err) // 4. Publish ATX - publishCtx, stopPublishing := context.WithCancel(ctx.Context) + timeout := time.Minute * 2 + publishCtx, stopPublishing := context.WithTimeout(ctx.Context, timeout) defer stopPublishing() var eg errgroup.Group - t.Cleanup(func() { assert.NoError(t, eg.Wait()) }) + defer eg.Wait() eg.Go(func() error { for { - logger.Info("publishing ATX", zap.Object("atx", &atx)) - buf := codec.MustEncode(&atx) + logger.Info("publishing ATX", zap.Object("atx", atx)) + buf := codec.MustEncode(atx) err = host.Publish(ctx, pubsub.AtxProtocol, buf) require.NoError(t, err) select { case <-publishCtx.Done(): return nil - case <-time.After(10 * time.Second): + case <-time.After(10 * time.Second): // retry every 10 seconds until context is done } } }) // 5. Wait for POST malfeasance proof receivedProof := false - timeout := time.Minute * 2 logger.Info("waiting for malfeasance proof", zap.Duration("timeout", timeout)) - awaitCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - err = malfeasanceStream(awaitCtx, cl.Client(0), logger, func(malf *pb.MalfeasanceStreamResponse) (bool, error) { + err = malfeasanceStream(publishCtx, cl.Client(0), logger, func(proof *pb2.MalfeasanceProof) (bool, error) { + if !bytes.Equal(proof.GetSmesher(), signer.NodeID().Bytes()) { + return true, nil + } stopPublishing() logger.Info("malfeasance proof received") - require.Equal(t, malf.GetProof().GetSmesherId().Id, signer.NodeID().Bytes()) - require.Equal(t, pb.MalfeasanceProof_MALFEASANCE_POST_INDEX, malf.GetProof().GetKind()) - - var proof mwire.MalfeasanceProof - require.NoError(t, codec.Decode(malf.Proof.Proof, &proof)) - require.Equal(t, mwire.InvalidPostIndex, proof.Proof.Type) - invalidPostProof := proof.Proof.Data.(*mwire.InvalidPostIndexProof) - logger.Info("malfeasance post proof", zap.Object("proof", invalidPostProof)) - invalidAtx := invalidPostProof.Atx - require.Equal(t, atx.PublishEpoch, invalidAtx.PublishEpoch) - require.Equal(t, atx.SmesherID, invalidAtx.SmesherID) - require.Equal(t, atx.ID(), invalidAtx.ID()) - - meta := &shared.ProofMetadata{ - NodeId: invalidAtx.NodeID.Bytes(), - CommitmentAtxId: invalidAtx.CommitmentATXID.Bytes(), - NumUnits: invalidAtx.NumUnits, - Challenge: invalidAtx.NIPost.PostMetadata.Challenge, - LabelsPerUnit: invalidAtx.NIPost.PostMetadata.LabelsPerUnit, - } - err := verifier.Verify(awaitCtx, (*shared.Proof)(invalidAtx.NIPost.Post), meta) - var invalidIdxError *verifying.ErrInvalidIndex - require.ErrorAs(t, err, &invalidIdxError) + require.Equal(t, expectedDomain, proof.Domain) + require.Equal(t, expectedType, proof.Type) + require.Subset(t, proof.Properties, expectedProperties) + require.Equal(t, atx.ID().ShortString(), proof.Properties["atx"]) receivedProof = true return false, nil }) require.NoError(t, err) require.True(t, receivedProof, "malfeasance proof not received") } + +func getConfig(t testing.TB, logger *zap.Logger, cl *cluster.Cluster, ctx *testcontext.Context) *config.Config { + cfg, err := cl.NodeConfig(ctx) + require.NoError(t, err) + + types.SetLayersPerEpoch(cfg.LayersPerEpoch) + + cfg.POET.RequestTimeout = time.Minute + cfg.POET.MaxRequestRetries = 10 + + var bootnodes []*cluster.NodeClient + for i := 0; i < cl.Bootnodes(); i++ { + bootnodes = append(bootnodes, cl.Client(i)) + } + + endpoints, err := cluster.ExtractP2PEndpoints(ctx, bootnodes) + require.NoError(t, err) + cfg.P2P.Bootnodes = endpoints + cfg.P2P.PrivateNetwork = true + + cfg.Bootstrap.URL = cluster.BootstrapperGlobalEndpoint(ctx.Namespace, 0) + cfg.P2P.MinPeers = 2 + logger.Debug("Prepared config", zap.Any("cfg", cfg)) + return cfg +} diff --git a/systest/tests/equivocation_test.go b/systest/tests/equivocation_test.go index bb7104b695..7a21628f6a 100644 --- a/systest/tests/equivocation_test.go +++ b/systest/tests/equivocation_test.go @@ -8,6 +8,7 @@ import ( "github.com/oasisprotocol/curve25519-voi/primitives/ed25519" pb "github.com/spacemeshos/api/release/go/spacemesh/v1" + pb2 "github.com/spacemeshos/api/release/go/spacemesh/v2beta1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" @@ -125,9 +126,9 @@ func TestEquivocation(t *testing.T) { proofs := make([]types.NodeID, 0, len(malfeasants)) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() - malfeasanceStream(ctx, client, cctx.Log.Desugar(), func(malf *pb.MalfeasanceStreamResponse) (bool, error) { - malfeasant := malf.GetProof().GetSmesherId().Id - proofs = append(proofs, types.NodeID(malfeasant)) + malfeasanceStream(ctx, client, cctx.Log.Desugar(), func(proof *pb2.MalfeasanceProof) (bool, error) { + malfeasant := proof.GetSmesher() + proofs = append(proofs, types.BytesToNodeID(malfeasant)) return len(proofs) < len(malfeasants), nil }) assert.ElementsMatchf(t, expected, proofs, "client: %s", cl.Client(i).Name) diff --git a/tortoise/model/core.go b/tortoise/model/core.go index cd762fb102..38750ab3e7 100644 --- a/tortoise/model/core.go +++ b/tortoise/model/core.go @@ -22,6 +22,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -184,7 +185,11 @@ func (c *core) OnMessage(m Messenger, event Message) { if err != nil { c.logger.Fatal("failed is malicious lookup", zap.Error(err)) } - c.atxdata.AddFromAtx(ev.Atx, malicious) + malicious2, err := malfeasance.IsMalicious(c.cdb, ev.Atx.SmesherID) + if err != nil { + c.logger.Fatal("failed is malicious lookup", zap.Error(err)) + } + c.atxdata.AddFromAtx(ev.Atx, malicious || malicious2) case MessageBeacon: beacons.Add(c.cdb, ev.EpochID+1, ev.Beacon) case MessageCoinflip: