From 21b620873bebc8fd67f2de45528c4cea1578181e Mon Sep 17 00:00:00 2001 From: Roman Behma <13855864+begmaroman@users.noreply.github.com> Date: Fri, 11 Nov 2022 11:46:27 +0000 Subject: [PATCH 01/17] Minor fixes in ibft and tests (#45) --- .golangci.yml | 13 +++++++------ core/ibft.go | 23 ++++++++++++++--------- core/mock_test.go | 26 ++++++++++++++++---------- core/rapid_test.go | 17 +++++++---------- 4 files changed, 44 insertions(+), 35 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 7f4683f..b5bf311 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -58,12 +58,10 @@ linters: - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unused # Checks Go code for unused constants, variables, functions and types - containedctx # containedctx is a linter that detects struct contained context.Context field - - cyclop # checks function and package cyclomatic complexity - durationcheck # check for two durations multiplied together - errchkjson - gochecknoglobals # check that no global variables exist - goerr113 # Golang linter to check the errors handling expressions - - gomnd # An analyzer to detect magic numbers. - ireturn # Accept Interfaces, Return Concrete Types - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - promlinter # Check Prometheus metrics naming via promlint @@ -167,7 +165,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#add-constant - name: add-constant severity: warning - disabled: false + disabled: true arguments: - maxLitCount: "3" allowStrs: '""' @@ -206,7 +204,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#cognitive-complexity - name: cognitive-complexity severity: warning - disabled: false + disabled: true arguments: [ 7 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#confusing-naming - name: confusing-naming @@ -233,7 +231,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#cyclomatic - name: cyclomatic severity: warning - disabled: false + disabled: true arguments: [ 3 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#datarace - name: datarace @@ -310,7 +308,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#function-length - name: function-length severity: warning - disabled: false + disabled: true arguments: [ 10, 0 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#get-return - name: get-return @@ -495,6 +493,9 @@ issues: - gosec - unparam - lll + - containedctx + - goerr113 + - revive include: - EXC0012 # Exported (.+) should have comment( \(or a comment on this block\))? or be unexported - EXC0013 # Package comment should be of the form "(.+)... diff --git a/core/ibft.go b/core/ibft.go index 58cb90b..72f5e91 100644 --- a/core/ibft.go +++ b/core/ibft.go @@ -12,12 +12,14 @@ import ( "github.com/0xPolygon/go-ibft/messages/proto" ) +// Logger represents the logger behaviour type Logger interface { Info(msg string, args ...interface{}) Debug(msg string, args ...interface{}) Error(msg string, args ...interface{}) } +// Messages represents the message managing behaviour type Messages interface { // Messages modifiers // AddMessage(message *proto.Message) @@ -36,10 +38,13 @@ type Messages interface { Unsubscribe(id messages.SubscriptionID) } +const ( + round0Timeout = 10 * time.Second + roundFactorBase = float64(2) +) + var ( errTimeoutExpired = errors.New("round timeout expired") - - round0Timeout = 10 * time.Second ) // IBFT represents a single instance of the IBFT state machine @@ -125,7 +130,7 @@ func (i *IBFT) startRoundTimer(ctx context.Context, round uint64) { var ( duration = int(i.baseRoundTimeout) - roundFactor = int(math.Pow(float64(2), float64(round))) + roundFactor = int(math.Pow(roundFactorBase, float64(round))) roundTimeout = time.Duration(duration * roundFactor) ) @@ -344,7 +349,7 @@ func (i *IBFT) RunSequence(ctx context.Context, h uint64) { teardown() return - case <-ctx.Done(): + case <-ctxRound.Done(): teardown() i.log.Debug("sequence cancelled") @@ -681,11 +686,11 @@ func (i *IBFT) validateProposal(msg *proto.Message, view *proto.View) bool { roundsAndPreparedBlockHashes := make([]roundHashTuple, 0) for _, rcMessage := range rcc.RoundChangeMessages { - certificate := messages.ExtractLatestPC(rcMessage) + cert := messages.ExtractLatestPC(rcMessage) // Check if there is a certificate, and if it's a valid PC - if certificate != nil && i.validPC(certificate, msg.View.Round, height) { - hash := messages.ExtractProposalHash(certificate.ProposalMessage) + if cert != nil && i.validPC(cert, msg.View.Round, height) { + hash := messages.ExtractProposalHash(cert.ProposalMessage) roundsAndPreparedBlockHashes = append(roundsAndPreparedBlockHashes, roundHashTuple{ round: rcMessage.View.Round, @@ -700,8 +705,8 @@ func (i *IBFT) validateProposal(msg *proto.Message, view *proto.View) bool { // Find the max round var ( - maxRound uint64 = 0 - expectedHash []byte = nil + maxRound uint64 + expectedHash []byte ) for _, tuple := range roundsAndPreparedBlockHashes { diff --git a/core/mock_test.go b/core/mock_test.go index 2170a97..f31d75d 100644 --- a/core/mock_test.go +++ b/core/mock_test.go @@ -322,11 +322,7 @@ func newMockCluster( nodes[index] = NewIBFT(logger, backend, transport) // Instantiate context for the nodes - ctx, cancelFn := context.WithCancel(context.Background()) - nodeCtxs[index] = mockNodeContext{ - ctx: ctx, - cancelFn: cancelFn, - } + nodeCtxs[index] = newMockNodeContext() } return &mockCluster{ @@ -341,6 +337,16 @@ type mockNodeContext struct { cancelFn context.CancelFunc } +// newMockNodeContext is the constructor of mockNodeContext +func newMockNodeContext() mockNodeContext { + ctx, cancelFn := context.WithCancel(context.Background()) + + return mockNodeContext{ + ctx: ctx, + cancelFn: cancelFn, + } +} + // mockNodeWg is the WaitGroup wrapper for the cluster nodes type mockNodeWg struct { sync.WaitGroup @@ -383,12 +389,10 @@ func (m *mockCluster) runSequence(height uint64) { node *IBFT, height uint64, ) { - defer func() { - m.wg.Done() - }() - // Start the main run loop for the node node.RunSequence(ctx, height) + + m.wg.Done() }(m.ctxs[nodeIndex].ctx, node, height) } } @@ -405,8 +409,10 @@ func (m *mockCluster) awaitCompletion() { // in the cluster, and awaits their completion func (m *mockCluster) forceShutdown() { // Send a stop signal to all the nodes - for _, ctx := range m.ctxs { + for i, ctx := range m.ctxs { ctx.cancelFn() + + m.ctxs[i] = newMockNodeContext() } // Wait for all the nodes to finish diff --git a/core/rapid_test.go b/core/rapid_test.go index 0a5d24f..1637b81 100644 --- a/core/rapid_test.go +++ b/core/rapid_test.go @@ -3,7 +3,6 @@ package core import ( "bytes" "context" - "fmt" "sync" "testing" "time" @@ -15,6 +14,10 @@ import ( "github.com/0xPolygon/go-ibft/messages/proto" ) +const ( + testRoundTimeout = time.Second +) + // mockInsertedProposals keeps track of inserted proposals for a cluster // of nodes type mockInsertedProposals struct { @@ -359,7 +362,7 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { // Set a small timeout, because of situations // where the byzantine node is the proposer - cluster.setBaseTimeout(time.Second * 2) + cluster.setBaseTimeout(testRoundTimeout) // Set the multicast callback to relay the message // to the entire cluster @@ -374,14 +377,8 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { // Wait until Quorum nodes finish their run loop ctx, cancelFn := context.WithTimeout(context.Background(), time.Second*5) - if err := cluster.awaitNCompletions(ctx, int64(quorum(numNodes))); err != nil { - t.Fatalf( - fmt.Sprintf( - "unable to wait for nodes to complete, %v", - err, - ), - ) - } + err := cluster.awaitNCompletions(ctx, int64(quorum(numNodes))) + assert.NoError(t, err, "unable to wait for nodes to complete") // Shutdown the remaining nodes that might be hanging cluster.forceShutdown() From 70fa07d86cf079ba4dee59c73bcfc3a0458ff79b Mon Sep 17 00:00:00 2001 From: Roman Behma <13855864+begmaroman@users.noreply.github.com> Date: Tue, 15 Nov 2022 09:13:53 +0000 Subject: [PATCH 02/17] Add rapid test with bad proposal coming from byzantine node (#44) --- .github/workflows/main.yml | 4 +- .golangci.yml | 2 + core/consensus_test.go | 100 ++++----------- core/ibft.go | 25 +++- core/ibft_test.go | 149 ++++++++++------------ core/mock_test.go | 77 ++++++++--- core/rapid_test.go | 254 +++++++++++++++++++++++++++---------- 7 files changed, 364 insertions(+), 247 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e896990..a8ecb87 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,7 +23,7 @@ jobs: submodules: recursive - name: Go test - run: go test -covermode=atomic -shuffle=on -coverprofile coverage.out -timeout 2m ./... + run: go test -covermode=atomic -shuffle=on -coverprofile coverage.out -test.short -timeout 5m ./... - name: Upload coverage file to Codecov uses: codecov/codecov-action@v3 @@ -45,7 +45,7 @@ jobs: submodules: recursive - name: Run Go Test with race - run: go test -race -shuffle=on -timeout 2m ./... + run: go test -test.short -race -shuffle=on -timeout 5m ./... reproducible-builds: runs-on: ubuntu-latest diff --git a/.golangci.yml b/.golangci.yml index b5bf311..ad47552 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -496,6 +496,8 @@ issues: - containedctx - goerr113 - revive + - gochecknoglobals + - exhaustive include: - EXC0012 # Exported (.+) should have comment( \(or a comment on this block\))? or be unexported - EXC0013 # Package comment should be of the form "(.+)... diff --git a/core/consensus_test.go b/core/consensus_test.go index 62c55a5..9759e7c 100644 --- a/core/consensus_test.go +++ b/core/consensus_test.go @@ -147,14 +147,13 @@ func commonHasQuorumFn(numNodes uint64) func(blockNumber uint64, messages []*pro func TestConsensus_ValidFlow(t *testing.T) { t.Parallel() - var multicastFn func(message *proto.Message) + var ( + multicastFn func(message *proto.Message) - proposal := []byte("proposal") - proposalHash := []byte("proposal hash") - committedSeal := []byte("seal") - numNodes := uint64(4) - nodes := generateNodeAddresses(numNodes) - insertedBlocks := make([][]byte, numNodes) + numNodes = uint64(4) + nodes = generateNodeAddresses(numNodes) + insertedBlocks = make([][]byte, numNodes) + ) // commonTransportCallback is the common method modification // required for Transport, for all nodes @@ -182,12 +181,12 @@ func TestConsensus_ValidFlow(t *testing.T) { // Make sure the proposal is valid if it matches what node 0 proposed backend.isValidBlockFn = func(newProposal []byte) bool { - return bytes.Equal(newProposal, proposal) + return bytes.Equal(newProposal, correctRoundMessage.proposal) } // Make sure the proposal hash matches backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { - return bytes.Equal(p, proposal) && bytes.Equal(ph, proposalHash) + return bytes.Equal(p, correctRoundMessage.proposal) && bytes.Equal(ph, correctRoundMessage.hash) } // Make sure the preprepare message is built correctly @@ -198,7 +197,7 @@ func TestConsensus_ValidFlow(t *testing.T) { ) *proto.Message { return buildBasicPreprepareMessage( proposal, - proposalHash, + correctRoundMessage.hash, certificate, nodes[nodeIndex], view) @@ -206,12 +205,12 @@ func TestConsensus_ValidFlow(t *testing.T) { // Make sure the prepare message is built correctly backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicPrepareMessage(proposalHash, nodes[nodeIndex], view) + return buildBasicPrepareMessage(correctRoundMessage.hash, nodes[nodeIndex], view) } // Make sure the commit message is built correctly backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicCommitMessage(proposalHash, committedSeal, nodes[nodeIndex], view) + return buildBasicCommitMessage(correctRoundMessage.hash, correctRoundMessage.seal, nodes[nodeIndex], view) } // Make sure the round change message is built correctly @@ -227,44 +226,19 @@ func TestConsensus_ValidFlow(t *testing.T) { backend.insertBlockFn = func(proposal []byte, _ []*messages.CommittedSeal) { insertedBlocks[nodeIndex] = proposal } - } - var ( - backendCallbackMap = map[int]backendConfigCallback{ - 0: func(backend *mockBackend) { - // Execute the common backend setup - commonBackendCallback(backend, 0) - - // Set the proposal creation method for node 0, since - // they are the proposer - backend.buildProposalFn = func(_ *proto.View) []byte { - return proposal - } - }, - 1: func(backend *mockBackend) { - commonBackendCallback(backend, 1) - }, - 2: func(backend *mockBackend) { - commonBackendCallback(backend, 2) - }, - 3: func(backend *mockBackend) { - commonBackendCallback(backend, 3) - }, - } - transportCallbackMap = map[int]transportConfigCallback{ - 0: commonTransportCallback, - 1: commonTransportCallback, - 2: commonTransportCallback, - 3: commonTransportCallback, + // Set the proposal creation method + backend.buildProposalFn = func(_ *proto.View) []byte { + return correctRoundMessage.proposal } - ) + } // Create the mock cluster cluster := newMockCluster( numNodes, - backendCallbackMap, + commonBackendCallback, nil, - transportCallbackMap, + commonTransportCallback, ) // Set the multicast callback to relay the message @@ -281,7 +255,7 @@ func TestConsensus_ValidFlow(t *testing.T) { // Make sure the inserted blocks match what node 0 proposed for _, block := range insertedBlocks { - assert.True(t, bytes.Equal(block, proposal)) + assert.True(t, bytes.Equal(block, correctRoundMessage.proposal)) } } @@ -394,45 +368,19 @@ func TestConsensus_InvalidBlock(t *testing.T) { backend.insertBlockFn = func(proposal []byte, _ []*messages.CommittedSeal) { insertedBlocks[nodeIndex] = proposal } - } - - var ( - backendCallbackMap = map[int]backendConfigCallback{ - 0: func(backend *mockBackend) { - commonBackendCallback(backend, 0) - - backend.buildProposalFn = func(_ *proto.View) []byte { - return proposals[0] - } - }, - 1: func(backend *mockBackend) { - commonBackendCallback(backend, 1) - backend.buildProposalFn = func(_ *proto.View) []byte { - return proposals[1] - } - }, - 2: func(backend *mockBackend) { - commonBackendCallback(backend, 2) - }, - 3: func(backend *mockBackend) { - commonBackendCallback(backend, 3) - }, + // Build proposal function + backend.buildProposalFn = func(_ *proto.View) []byte { + return proposals[nodeIndex] } - transportCallbackMap = map[int]transportConfigCallback{ - 0: commonTransportCallback, - 1: commonTransportCallback, - 2: commonTransportCallback, - 3: commonTransportCallback, - } - ) + } // Create the mock cluster cluster := newMockCluster( numNodes, - backendCallbackMap, + commonBackendCallback, nil, - transportCallbackMap, + commonTransportCallback, ) // Set the base timeout to be lower than usual diff --git a/core/ibft.go b/core/ibft.go index 72f5e91..01f92ab 100644 --- a/core/ibft.go +++ b/core/ibft.go @@ -128,14 +128,10 @@ func NewIBFT( func (i *IBFT) startRoundTimer(ctx context.Context, round uint64) { defer i.wg.Done() - var ( - duration = int(i.baseRoundTimeout) - roundFactor = int(math.Pow(roundFactorBase, float64(round))) - roundTimeout = time.Duration(duration * roundFactor) - ) + roundTimeout := getRoundTimeout(i.baseRoundTimeout, i.additionalTimeout, round) // Create a new timer instance - timer := time.NewTimer(roundTimeout + i.additionalTimeout) + timer := time.NewTimer(roundTimeout) select { case <-ctx.Done(): @@ -1147,3 +1143,20 @@ func (i *IBFT) sendCommitMessage(view *proto.View) { ), ) } + +// getRoundTimeout creates a round timeout based on the base timeout and the current round. +// Exponentially increases timeout depending on the round number. +// For instance: +// - round 1: 1 sec +// - round 2: 2 sec +// - round 3: 4 sec +// - round 4: 8 sec +func getRoundTimeout(baseRoundTimeout, additionalTimeout time.Duration, round uint64) time.Duration { + var ( + duration = int(baseRoundTimeout) + roundFactor = int(math.Pow(roundFactorBase, float64(round))) + roundTimeout = time.Duration(duration * roundFactor) + ) + + return roundTimeout + additionalTimeout +} diff --git a/core/ibft_test.go b/core/ibft_test.go index 6cbe895..9479e26 100644 --- a/core/ibft_test.go +++ b/core/ibft_test.go @@ -306,8 +306,6 @@ func TestRunNewRound_Proposer(t *testing.T) { var ( multicastedPreprepare *proto.Message = nil multicastedPrepare *proto.Message = nil - proposalHash = []byte("proposal hash") - proposal = []byte("proposal") notifyCh = make(chan uint64, 1) log = mockLogger{} @@ -327,7 +325,7 @@ func TestRunNewRound_Proposer(t *testing.T) { }, hasQuorumFn: defaultHasQuorumFn(quorum), buildProposalFn: func(_ *proto.View) []byte { - return proposal + return correctRoundMessage.proposal }, buildPrepareMessageFn: func(_ []byte, view *proto.View) *proto.Message { return &proto.Message{ @@ -335,7 +333,7 @@ func TestRunNewRound_Proposer(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -350,8 +348,8 @@ func TestRunNewRound_Proposer(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -401,7 +399,7 @@ func TestRunNewRound_Proposer(t *testing.T) { assert.Equal(t, multicastedPreprepare, i.state.proposalMessage) // Make sure the correct proposal value was multicasted - assert.True(t, proposalMatches(proposal, multicastedPreprepare)) + assert.True(t, proposalMatches(correctRoundMessage.proposal, multicastedPreprepare)) // Make sure the prepare message was not multicasted assert.Nil(t, multicastedPrepare) @@ -414,7 +412,6 @@ func TestRunNewRound_Proposer(t *testing.T) { t.Parallel() lastPreparedProposedBlock := []byte("last prepared block") - proposalHash := []byte("proposal hash") quorum := uint64(4) ctx, cancelFn := context.WithCancel(context.Background()) @@ -425,7 +422,7 @@ func TestRunNewRound_Proposer(t *testing.T) { for index, message := range prepareMessages { message.Payload = &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, } @@ -450,7 +447,7 @@ func TestRunNewRound_Proposer(t *testing.T) { Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ Proposal: lastPreparedProposedBlock, - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, Certificate: nil, }, }, @@ -489,7 +486,7 @@ func TestRunNewRound_Proposer(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -505,7 +502,7 @@ func TestRunNewRound_Proposer(t *testing.T) { Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ Proposal: proposal, - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, Certificate: certificate, }, }, @@ -572,8 +569,6 @@ func TestRunNewRound_Validator_Zero(t *testing.T) { ctx, cancelFn := context.WithCancel(context.Background()) var ( - proposal = []byte("new block") - proposalHash = []byte("proposal hash") proposer = []byte("proposer") multicastedPrepare *proto.Message = nil notifyCh = make(chan uint64, 1) @@ -599,7 +594,7 @@ func TestRunNewRound_Validator_Zero(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -634,7 +629,7 @@ func TestRunNewRound_Validator_Zero(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, + Proposal: correctRoundMessage.proposal, }, }, }, @@ -660,10 +655,10 @@ func TestRunNewRound_Validator_Zero(t *testing.T) { assert.Equal(t, prepare, i.state.name) // Make sure the accepted proposal is the one that was sent out - assert.Equal(t, proposal, i.state.getProposal()) + assert.Equal(t, correctRoundMessage.proposal, i.state.getProposal()) // Make sure the correct proposal hash was multicasted - assert.True(t, prepareHashMatches(proposalHash, multicastedPrepare)) + assert.True(t, prepareHashMatches(correctRoundMessage.hash, multicastedPrepare)) } // TestRunNewRound_Validator_NonZero validates the behavior @@ -672,8 +667,6 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { t.Parallel() quorum := uint64(4) - proposalHash := []byte("proposal hash") - proposal := []byte("new block") proposer := []byte("proposer") generateProposalWithNoPrevious := func() *proto.Message { @@ -689,8 +682,8 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, Certificate: &proto.RoundChangeCertificate{ RoundChangeMessages: roundChangeMessages, }, @@ -709,10 +702,14 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, Certificate: &proto.RoundChangeCertificate{ - RoundChangeMessages: generateFilledRCMessages(quorum, proposal, proposalHash), + RoundChangeMessages: generateFilledRCMessages( + quorum, + correctRoundMessage.proposal, + correctRoundMessage.hash, + ), }, }, }, @@ -767,7 +764,7 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -823,10 +820,10 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { assert.Equal(t, prepare, i.state.name) // Make sure the accepted proposal is the one that was sent out - assert.Equal(t, proposal, i.state.getProposal()) + assert.Equal(t, correctRoundMessage.proposal, i.state.getProposal()) // Make sure the correct proposal hash was multicasted - assert.True(t, prepareHashMatches(proposalHash, multicastedPrepare)) + assert.True(t, prepareHashMatches(correctRoundMessage.hash, multicastedPrepare)) }) } } @@ -844,8 +841,6 @@ func TestRunPrepare(t *testing.T) { ctx, cancelFn := context.WithCancel(context.Background()) var ( - proposal = []byte("block proposal") - proposalHash = []byte("proposal hash") multicastedCommit *proto.Message = nil notifyCh = make(chan uint64, 1) @@ -862,7 +857,7 @@ func TestRunPrepare(t *testing.T) { Type: proto.MessageType_COMMIT, Payload: &proto.Message_CommitData{ CommitData: &proto.CommitMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -871,7 +866,7 @@ func TestRunPrepare(t *testing.T) { return len(messages) >= 1 }, isValidProposalHashFn: func(_ []byte, hash []byte) bool { - return bytes.Equal(proposalHash, hash) + return bytes.Equal(correctRoundMessage.hash, hash) }, } messages = mockMessages{ @@ -896,7 +891,7 @@ func TestRunPrepare(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, }, @@ -913,8 +908,8 @@ func TestRunPrepare(t *testing.T) { i.state.proposalMessage = &proto.Message{ Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -932,10 +927,10 @@ func TestRunPrepare(t *testing.T) { assert.Equal(t, commit, i.state.name) // Make sure the proposal didn't change - assert.Equal(t, proposal, i.state.getProposal()) + assert.Equal(t, correctRoundMessage.proposal, i.state.getProposal()) // Make sure the proper proposal hash was multicasted - assert.True(t, commitHashMatches(proposalHash, multicastedCommit)) + assert.True(t, commitHashMatches(correctRoundMessage.hash, multicastedCommit)) }, ) } @@ -953,8 +948,6 @@ func TestRunCommit(t *testing.T) { var ( wg sync.WaitGroup - proposal = []byte("block proposal") - proposalHash = []byte("proposal hash") signer = []byte("signer") insertedProposal []byte = nil insertedCommittedSeals []*messages.CommittedSeal = nil @@ -978,7 +971,7 @@ func TestRunCommit(t *testing.T) { return len(messages) >= 1 }, isValidProposalHashFn: func(_ []byte, hash []byte) bool { - return bytes.Equal(proposalHash, hash) + return bytes.Equal(correctRoundMessage.hash, hash) }, } messages = mockMessages{ @@ -1000,7 +993,7 @@ func TestRunCommit(t *testing.T) { Type: proto.MessageType_COMMIT, Payload: &proto.Message_CommitData{ CommitData: &proto.CommitMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, CommittedSeal: committedSeals[0].Signature, }, }, @@ -1018,8 +1011,8 @@ func TestRunCommit(t *testing.T) { i.state.proposalMessage = &proto.Message{ Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -1056,7 +1049,7 @@ func TestRunCommit(t *testing.T) { assert.Equal(t, fin, i.state.name) // Make sure the inserted proposal was the one present - assert.Equal(t, insertedProposal, proposal) + assert.Equal(t, insertedProposal, correctRoundMessage.proposal) // Make sure the inserted committed seals were correct assert.Equal(t, insertedCommittedSeals, committedSeals) @@ -1271,8 +1264,6 @@ func TestIBFT_FutureProposal(t *testing.T) { nodeID := []byte("node ID") proposer := []byte("proposer") - proposal := []byte("proposal") - proposalHash := []byte("proposal hash") quorum := uint64(4) generateEmptyRCMessages := func(count uint64) []*proto.Message { @@ -1313,7 +1304,11 @@ func TestIBFT_FutureProposal(t *testing.T) { Height: 0, Round: 2, }, - generateFilledRCMessages(quorum, proposal, proposalHash), + generateFilledRCMessages( + quorum, + correctRoundMessage.proposal, + correctRoundMessage.hash, + ), 2, }, } @@ -1332,8 +1327,8 @@ func TestIBFT_FutureProposal(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, Certificate: &proto.RoundChangeCertificate{ RoundChangeMessages: testCase.roundChangeMessages, }, @@ -1355,7 +1350,8 @@ func TestIBFT_FutureProposal(t *testing.T) { return nodeID }, isValidProposalHashFn: func(p []byte, hash []byte) bool { - return bytes.Equal(hash, proposalHash) && bytes.Equal(p, proposal) + return bytes.Equal(hash, correctRoundMessage.hash) && + bytes.Equal(p, correctRoundMessage.proposal) }, hasQuorumFn: defaultHasQuorumFn(quorum), } @@ -1413,7 +1409,7 @@ func TestIBFT_FutureProposal(t *testing.T) { } assert.Equal(t, testCase.notifyRound, receivedProposalEvent.round) - assert.Equal(t, proposal, messages.ExtractProposal(receivedProposalEvent.proposalMessage)) + assert.Equal(t, correctRoundMessage.proposal, messages.ExtractProposal(receivedProposalEvent.proposalMessage)) }) } } @@ -1602,10 +1598,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1627,7 +1622,7 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit+1) @@ -1639,10 +1634,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1670,7 +1664,7 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit-1) @@ -1682,10 +1676,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1710,7 +1703,7 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit-1) @@ -1722,10 +1715,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1754,7 +1746,7 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit-1) @@ -1766,10 +1758,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1797,7 +1788,7 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit-1) @@ -2067,11 +2058,9 @@ func TestIBFT_WatchForFutureRCC(t *testing.T) { t.Parallel() quorum := uint64(4) - proposal := []byte("proposal") rccRound := uint64(10) - proposalHash := []byte("proposal hash") - roundChangeMessages := generateFilledRCMessages(quorum, proposal, proposalHash) + roundChangeMessages := generateFilledRCMessages(quorum, correctRoundMessage.proposal, correctRoundMessage.hash) setRoundForMessages(roundChangeMessages, rccRound) var ( @@ -2143,6 +2132,8 @@ func TestIBFT_WatchForFutureRCC(t *testing.T) { // TestState_String makes sure the string representation // of states is correct func TestState_String(t *testing.T) { + t.Parallel() + stringMap := map[stateType]string{ newRound: "new round", prepare: "prepare", diff --git a/core/mock_test.go b/core/mock_test.go index f31d75d..b62af7c 100644 --- a/core/mock_test.go +++ b/core/mock_test.go @@ -11,6 +11,24 @@ import ( "github.com/0xPolygon/go-ibft/messages/proto" ) +const ( + testRoundTimeout = time.Second +) + +var ( + correctRoundMessage = roundMessage{ + proposal: []byte("proposal"), + hash: []byte("proposal hash"), + seal: []byte("seal"), + } + + badRoundMessage = roundMessage{ + proposal: []byte("bad proposal"), + hash: []byte("bad proposal hash"), + seal: []byte("bad seal"), + } +) + // Define delegation methods type isValidBlockDelegate func([]byte) bool type isValidSenderDelegate func(*proto.Message) bool @@ -281,14 +299,38 @@ type transportConfigCallback func(*mockTransport) // newMockCluster creates a new IBFT cluster func newMockCluster( numNodes uint64, - backendCallbackMap map[int]backendConfigCallback, - loggerCallbackMap map[int]loggerConfigCallback, - transportCallbackMap map[int]transportConfigCallback, + backendCallback func(backend *mockBackend, i int), + loggerCallback loggerConfigCallback, + transportCallback transportConfigCallback, ) *mockCluster { if numNodes < 1 { return nil } + // Initialize the backend and transport callbacks for + // each node in the arbitrary cluster + backendCallbackMap := make(map[int]backendConfigCallback) + loggerCallbackMap := make(map[int]loggerConfigCallback) + transportCallbackMap := make(map[int]transportConfigCallback) + + for i := 0; i < int(numNodes); i++ { + i := i + + if backendCallback != nil { + backendCallbackMap[i] = func(backend *mockBackend) { + backendCallback(backend, i) + } + } + + if transportCallback != nil { + transportCallbackMap[i] = transportCallback + } + + if loggerCallback != nil { + loggerCallbackMap[i] = loggerCallback + } + } + nodes := make([]*IBFT, numNodes) nodeCtxs := make([]mockNodeContext, numNodes) @@ -300,21 +342,21 @@ func newMockCluster( ) // Execute set callbacks, if any - if backendCallbackMap != nil { - if backendCallback, isSet := backendCallbackMap[index]; isSet { - backendCallback(backend) + if len(backendCallbackMap) > 0 { + if bc, isSet := backendCallbackMap[index]; isSet { + bc(backend) } } - if loggerCallbackMap != nil { - if loggerCallback, isSet := loggerCallbackMap[index]; isSet { - loggerCallback(logger) + if len(loggerCallbackMap) > 0 { + if lc, isSet := loggerCallbackMap[index]; isSet { + lc(logger) } } - if transportCallbackMap != nil { - if transportCallback, isSet := transportCallbackMap[index]; isSet { - transportCallback(transport) + if len(transportCallbackMap) > 0 { + if tc, isSet := transportCallbackMap[index]; isSet { + tc(transport) } } @@ -325,10 +367,16 @@ func newMockCluster( nodeCtxs[index] = newMockNodeContext() } - return &mockCluster{ + cr := &mockCluster{ nodes: nodes, ctxs: nodeCtxs, } + + // Set a small timeout, because of situations + // where the byzantine node is the proposer + cr.setBaseTimeout(testRoundTimeout) + + return cr } // mockNodeContext keeps track of the node runtime context @@ -387,13 +435,12 @@ func (m *mockCluster) runSequence(height uint64) { go func( ctx context.Context, node *IBFT, - height uint64, ) { // Start the main run loop for the node node.RunSequence(ctx, height) m.wg.Done() - }(m.ctxs[nodeIndex].ctx, node, height) + }(m.ctxs[nodeIndex].ctx, node) } } diff --git a/core/rapid_test.go b/core/rapid_test.go index 1637b81..bbd2e3e 100644 --- a/core/rapid_test.go +++ b/core/rapid_test.go @@ -5,18 +5,21 @@ import ( "context" "sync" "testing" - "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "pgregory.net/rapid" "github.com/0xPolygon/go-ibft/messages" "github.com/0xPolygon/go-ibft/messages/proto" ) -const ( - testRoundTimeout = time.Second -) +// roundMessage contains message data within consensus round +type roundMessage struct { + proposal []byte + seal []byte + hash []byte +} // mockInsertedProposals keeps track of inserted proposals for a cluster // of nodes @@ -64,10 +67,6 @@ func TestProperty_AllHonestNodes(t *testing.T) { var multicastFn func(message *proto.Message) var ( - proposal = []byte("proposal") - proposalHash = []byte("proposal hash") - committedSeal = []byte("seal") - numNodes = rapid.Uint64Range(4, 30).Draw(t, "number of cluster nodes") desiredHeight = rapid.Uint64Range(10, 20).Draw(t, "minimum height to be reached") @@ -100,12 +99,12 @@ func TestProperty_AllHonestNodes(t *testing.T) { // Make sure the proposal is valid if it matches what node 0 proposed backend.isValidBlockFn = func(newProposal []byte) bool { - return bytes.Equal(newProposal, proposal) + return bytes.Equal(newProposal, correctRoundMessage.proposal) } // Make sure the proposal hash matches backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { - return bytes.Equal(p, proposal) && bytes.Equal(ph, proposalHash) + return bytes.Equal(p, correctRoundMessage.proposal) && bytes.Equal(ph, correctRoundMessage.hash) } // Make sure the preprepare message is built correctly @@ -116,7 +115,7 @@ func TestProperty_AllHonestNodes(t *testing.T) { ) *proto.Message { return buildBasicPreprepareMessage( proposal, - proposalHash, + correctRoundMessage.hash, certificate, nodes[nodeIndex], view) @@ -124,12 +123,12 @@ func TestProperty_AllHonestNodes(t *testing.T) { // Make sure the prepare message is built correctly backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicPrepareMessage(proposalHash, nodes[nodeIndex], view) + return buildBasicPrepareMessage(correctRoundMessage.hash, nodes[nodeIndex], view) } // Make sure the commit message is built correctly backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicCommitMessage(proposalHash, committedSeal, nodes[nodeIndex], view) + return buildBasicCommitMessage(correctRoundMessage.hash, correctRoundMessage.seal, nodes[nodeIndex], view) } // Make sure the round change message is built correctly @@ -148,31 +147,12 @@ func TestProperty_AllHonestNodes(t *testing.T) { // Make sure the proposal can be built backend.buildProposalFn = func(_ *proto.View) []byte { - return proposal + return correctRoundMessage.proposal } } - // Initialize the backend and transport callbacks for - // each node in the arbitrary cluster - backendCallbackMap := make(map[int]backendConfigCallback) - transportCallbackMap := make(map[int]transportConfigCallback) - - for i := 0; i < int(numNodes); i++ { - i := i - backendCallbackMap[i] = func(backend *mockBackend) { - commonBackendCallback(backend, i) - } - - transportCallbackMap[i] = commonTransportCallback - } - - // Create the mock cluster - cluster := newMockCluster( - numNodes, - backendCallbackMap, - nil, - transportCallbackMap, - ) + // Create default cluster for rapid tests + cluster := newMockCluster(numNodes, commonBackendCallback, nil, commonTransportCallback) // Set the multicast callback to relay the message // to the entire cluster @@ -195,7 +175,7 @@ func TestProperty_AllHonestNodes(t *testing.T) { assert.Len(t, proposalMap, int(desiredHeight)) for _, insertedProposal := range proposalMap { - assert.True(t, bytes.Equal(proposal, insertedProposal)) + assert.Equal(t, correctRoundMessage.proposal, insertedProposal) } } }) @@ -227,10 +207,6 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { var multicastFn func(message *proto.Message) var ( - proposal = []byte("proposal") - proposalHash = []byte("proposal hash") - committedSeal = []byte("seal") - numNodes = rapid.Uint64Range(4, 30).Draw(t, "number of cluster nodes") numByzantineNodes = rapid.Uint64Range(1, maxFaulty(numNodes)).Draw(t, "number of byzantine nodes") desiredHeight = rapid.Uint64Range(1, 5).Draw(t, "minimum height to be reached") @@ -238,6 +214,7 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { nodes = generateNodeAddresses(numNodes) insertedProposals = newMockInsertedProposals(numNodes) ) + // Initialize the byzantine nodes byzantineNodes := getByzantineNodes( numByzantineNodes, @@ -285,12 +262,12 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { // Make sure the proposal is valid if it matches what node 0 proposed backend.isValidBlockFn = func(newProposal []byte) bool { - return bytes.Equal(newProposal, proposal) + return bytes.Equal(newProposal, correctRoundMessage.proposal) } // Make sure the proposal hash matches backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { - return bytes.Equal(p, proposal) && bytes.Equal(ph, proposalHash) + return bytes.Equal(p, correctRoundMessage.proposal) && bytes.Equal(ph, correctRoundMessage.hash) } // Make sure the preprepare message is built correctly @@ -301,7 +278,7 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { ) *proto.Message { return buildBasicPreprepareMessage( proposal, - proposalHash, + correctRoundMessage.hash, certificate, nodes[nodeIndex], view, @@ -310,12 +287,17 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { // Make sure the prepare message is built correctly backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicPrepareMessage(proposalHash, nodes[nodeIndex], view) + return buildBasicPrepareMessage(correctRoundMessage.hash, nodes[nodeIndex], view) } // Make sure the commit message is built correctly backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicCommitMessage(proposalHash, committedSeal, nodes[nodeIndex], view) + return buildBasicCommitMessage( + correctRoundMessage.hash, + correctRoundMessage.seal, + nodes[nodeIndex], + view, + ) } // Make sure the round change message is built correctly @@ -334,35 +316,158 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { // Make sure the proposal can be built backend.buildProposalFn = func(_ *proto.View) []byte { - return proposal + return correctRoundMessage.proposal } } - // Initialize the backend and transport callbacks for - // each node in the arbitrary cluster - backendCallbackMap := make(map[int]backendConfigCallback) - transportCallbackMap := make(map[int]transportConfigCallback) + // Create default cluster for rapid tests + cluster := newMockCluster(numNodes, commonBackendCallback, nil, commonTransportCallback) - for i := 0; i < int(numNodes); i++ { - i := i - backendCallbackMap[i] = func(backend *mockBackend) { - commonBackendCallback(backend, i) - } + // Set the multicast callback to relay the message + // to the entire cluster + multicastFn = func(message *proto.Message) { + cluster.pushMessage(message) + } + + // Create context timeout based on the bad nodes number + ctxTimeout := getRoundTimeout(testRoundTimeout, 0, numByzantineNodes+1) + + // Run the sequence up until a certain height + for height := uint64(0); height < desiredHeight; height++ { + // Start the main run loops + cluster.runSequence(height) + + // Wait until Quorum nodes finish their run loop + ctx, cancelFn := context.WithTimeout(context.Background(), ctxTimeout) + err := cluster.awaitNCompletions(ctx, int64(quorum(numNodes))) + assert.NoError(t, err, "unable to wait for nodes to complete") + + // Shutdown the remaining nodes that might be hanging + cluster.forceShutdown() + cancelFn() + } + + // Make sure proposals map is not empty + require.Len(t, insertedProposals.proposals, int(numNodes)) - transportCallbackMap[i] = commonTransportCallback + // Make sure that the inserted proposal is valid for each height + for _, proposalMap := range insertedProposals.proposals { + for _, insertedProposal := range proposalMap { + assert.Equal(t, correctRoundMessage.proposal, insertedProposal) + } } + }) +} + +// TestProperty_MajorityHonestNodes_BroadcastBadMessage is a property-based test +// that assures the cluster can reach consensus on if "bad" nodes send +// wrong message during the round execution. Avoiding a scenario when +// a bad node is proposer +func TestProperty_MajorityHonestNodes_BroadcastBadMessage(t *testing.T) { + t.Parallel() + + rapid.Check(t, func(t *rapid.T) { + var multicastFn func(message *proto.Message) + + var ( + numNodes = rapid.Uint64Range(4, 15).Draw(t, "number of cluster nodes") + numByzantineNodes = rapid.Uint64Range(1, maxFaulty(numNodes)).Draw(t, "number of byzantine nodes") + desiredHeight = rapid.Uint64Range(1, 5).Draw(t, "minimum height to be reached") - // Create the mock cluster - cluster := newMockCluster( - numNodes, - backendCallbackMap, - nil, - transportCallbackMap, + nodes = generateNodeAddresses(numNodes) + insertedProposals = newMockInsertedProposals(numNodes) ) - // Set a small timeout, because of situations - // where the byzantine node is the proposer - cluster.setBaseTimeout(testRoundTimeout) + // commonTransportCallback is the common method modification + // required for Transport, for all nodes + commonTransportCallback := func(transport *mockTransport) { + transport.multicastFn = func(message *proto.Message) { + multicastFn(message) + } + } + + // commonBackendCallback is the common method modification required + // for the Backend, for all nodes + commonBackendCallback := func(backend *mockBackend, nodeIndex int) { + // Use a bad message if the current node is a malicious one + message := correctRoundMessage + if uint64(nodeIndex) < numByzantineNodes { + message = badRoundMessage + } + + // Make sure the quorum function is Quorum optimal + backend.hasQuorumFn = commonHasQuorumFn(numNodes) + + // Make sure the node ID is properly relayed + backend.idFn = func() []byte { + return nodes[nodeIndex] + } + + // Make sure the only proposer is picked using Round Robin + backend.isProposerFn = func(from []byte, height uint64, round uint64) bool { + return bytes.Equal( + from, + nodes[int(height+round)%len(nodes)], + ) + } + + // Make sure the proposal is valid if it matches what node 0 proposed + backend.isValidBlockFn = func(newProposal []byte) bool { + return bytes.Equal(newProposal, message.proposal) + } + + // Make sure the proposal hash matches + backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { + return bytes.Equal(p, message.proposal) && bytes.Equal(ph, message.hash) + } + + // Make sure the preprepare message is built correctly + backend.buildPrePrepareMessageFn = func( + proposal []byte, + certificate *proto.RoundChangeCertificate, + view *proto.View, + ) *proto.Message { + return buildBasicPreprepareMessage( + proposal, + message.hash, + certificate, + nodes[nodeIndex], + view, + ) + } + + // Make sure the prepare message is built correctly + backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { + return buildBasicPrepareMessage(message.hash, nodes[nodeIndex], view) + } + + // Make sure the commit message is built correctly + backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { + return buildBasicCommitMessage(message.hash, message.seal, nodes[nodeIndex], view) + } + + // Make sure the round change message is built correctly + backend.buildRoundChangeMessageFn = func( + proposal []byte, + certificate *proto.PreparedCertificate, + view *proto.View, + ) *proto.Message { + return buildBasicRoundChangeMessage(proposal, certificate, view, nodes[nodeIndex]) + } + + // Make sure the inserted proposal is noted + backend.insertBlockFn = func(proposal []byte, _ []*messages.CommittedSeal) { + insertedProposals.insertProposal(nodeIndex, proposal) + } + + // Make sure the proposal can be built + backend.buildProposalFn = func(_ *proto.View) []byte { + return message.proposal + } + } + + // Create default cluster for rapid tests + cluster := newMockCluster(numNodes, commonBackendCallback, nil, commonTransportCallback) // Set the multicast callback to relay the message // to the entire cluster @@ -370,13 +475,16 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { cluster.pushMessage(message) } + // Create context timeout based on the bad nodes number + ctxTimeout := getRoundTimeout(testRoundTimeout, 0, numByzantineNodes+1) + // Run the sequence up until a certain height for height := uint64(0); height < desiredHeight; height++ { // Start the main run loops cluster.runSequence(height) // Wait until Quorum nodes finish their run loop - ctx, cancelFn := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancelFn := context.WithTimeout(context.Background(), ctxTimeout) err := cluster.awaitNCompletions(ctx, int64(quorum(numNodes))) assert.NoError(t, err, "unable to wait for nodes to complete") @@ -385,10 +493,18 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { cancelFn() } + // Make sure proposals map is not empty + require.Len(t, insertedProposals.proposals, int(numNodes)) + // Make sure that the inserted proposal is valid for each height - for _, proposalMap := range insertedProposals.proposals { - for _, insertedProposal := range proposalMap { - assert.True(t, bytes.Equal(proposal, insertedProposal)) + for i, proposalMap := range insertedProposals.proposals { + if i < int(numByzantineNodes) { + // Proposals map must be empty when a byzantine node is proposer + assert.Empty(t, proposalMap) + } else { + for _, insertedProposal := range proposalMap { + assert.Equal(t, correctRoundMessage.proposal, insertedProposal) + } } } }) From 2ac556e04f627fde173aa50e9092092974331286 Mon Sep 17 00:00:00 2001 From: Roman Behma <13855864+begmaroman@users.noreply.github.com> Date: Mon, 21 Nov 2022 12:20:15 +0000 Subject: [PATCH 03/17] Implement event generator for rapid testing (#46) --- .github/workflows/main.yml | 4 +- core/consensus_test.go | 4 +- core/ibft_test.go | 46 +++++ core/mock_test.go | 14 +- core/rapid_test.go | 385 +++++++------------------------------ 5 files changed, 127 insertions(+), 326 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a8ecb87..9aaf1cd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,7 +23,7 @@ jobs: submodules: recursive - name: Go test - run: go test -covermode=atomic -shuffle=on -coverprofile coverage.out -test.short -timeout 5m ./... + run: go test -test.short -covermode=atomic -shuffle=on -coverprofile coverage.out -timeout 10m ./... - name: Upload coverage file to Codecov uses: codecov/codecov-action@v3 @@ -45,7 +45,7 @@ jobs: submodules: recursive - name: Run Go Test with race - run: go test -test.short -race -shuffle=on -timeout 5m ./... + run: go test -test.short -race -shuffle=on -timeout 10m ./... reproducible-builds: runs-on: ubuntu-latest diff --git a/core/consensus_test.go b/core/consensus_test.go index 9759e7c..901f71d 100644 --- a/core/consensus_test.go +++ b/core/consensus_test.go @@ -157,7 +157,7 @@ func TestConsensus_ValidFlow(t *testing.T) { // commonTransportCallback is the common method modification // required for Transport, for all nodes - commonTransportCallback := func(transport *mockTransport) { + commonTransportCallback := func(transport *mockTransport, _ int) { transport.multicastFn = func(message *proto.Message) { multicastFn(message) } @@ -290,7 +290,7 @@ func TestConsensus_InvalidBlock(t *testing.T) { // commonTransportCallback is the common method modification // required for Transport, for all nodes - commonTransportCallback := func(transport *mockTransport) { + commonTransportCallback := func(transport *mockTransport, _ int) { transport.multicastFn = func(message *proto.Message) { multicastFn(message) } diff --git a/core/ibft_test.go b/core/ibft_test.go index 9479e26..e31fd06 100644 --- a/core/ibft_test.go +++ b/core/ibft_test.go @@ -2283,3 +2283,49 @@ func TestIBFT_ExtendRoundTimer(t *testing.T) { // Make sure the round timeout was extended assert.Equal(t, additionalTimeout, i.additionalTimeout) } + +func Test_getRoundTimeout(t *testing.T) { + t.Parallel() + + type args struct { + baseRoundTimeout time.Duration + additionalTimeout time.Duration + round uint64 + } + + tests := []struct { + name string + args args + want time.Duration + }{ + { + name: "first round duration", + args: args{ + baseRoundTimeout: time.Second, + additionalTimeout: time.Second, + round: 0, + }, + want: time.Second * 2, + }, + { + name: "zero round duration", + args: args{ + baseRoundTimeout: time.Second, + additionalTimeout: time.Second, + round: 1, + }, + want: time.Second * 3, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := getRoundTimeout(tt.args.baseRoundTimeout, tt.args.additionalTimeout, tt.args.round) + assert.Equalf(t, tt.want, got, "getRoundTimeout(%v, %v, %v)", tt.args.baseRoundTimeout, tt.args.additionalTimeout, tt.args.round) + }) + } +} diff --git a/core/mock_test.go b/core/mock_test.go index b62af7c..1e98596 100644 --- a/core/mock_test.go +++ b/core/mock_test.go @@ -299,9 +299,9 @@ type transportConfigCallback func(*mockTransport) // newMockCluster creates a new IBFT cluster func newMockCluster( numNodes uint64, - backendCallback func(backend *mockBackend, i int), - loggerCallback loggerConfigCallback, - transportCallback transportConfigCallback, + backendCallback func(*mockBackend, int), + loggerCallback func(*mockLogger, int), + transportCallback func(*mockTransport, int), ) *mockCluster { if numNodes < 1 { return nil @@ -323,11 +323,15 @@ func newMockCluster( } if transportCallback != nil { - transportCallbackMap[i] = transportCallback + transportCallbackMap[i] = func(backend *mockTransport) { + transportCallback(backend, i) + } } if loggerCallback != nil { - loggerCallbackMap[i] = loggerCallback + loggerCallbackMap[i] = func(backend *mockLogger) { + loggerCallback(backend, i) + } } } diff --git a/core/rapid_test.go b/core/rapid_test.go index bbd2e3e..064e556 100644 --- a/core/rapid_test.go +++ b/core/rapid_test.go @@ -57,183 +57,66 @@ func (m *mockInsertedProposals) insertProposal( m.currentProposals[nodeIndex]++ } -// TestProperty_AllHonestNodes is a property-based test -// that assures the cluster can reach consensus on any -// arbitrary number of valid nodes -func TestProperty_AllHonestNodes(t *testing.T) { - t.Parallel() +// propertyTestEvent contains randomly-generated data for rapid testing +type propertyTestEvent struct { + // nodes is the total number of nodes + nodes uint64 - rapid.Check(t, func(t *rapid.T) { - var multicastFn func(message *proto.Message) + // byzantineNodes is the total number of byzantine nodes + byzantineNodes uint64 - var ( - numNodes = rapid.Uint64Range(4, 30).Draw(t, "number of cluster nodes") - desiredHeight = rapid.Uint64Range(10, 20).Draw(t, "minimum height to be reached") - - nodes = generateNodeAddresses(numNodes) - insertedProposals = newMockInsertedProposals(numNodes) - ) - // commonTransportCallback is the common method modification - // required for Transport, for all nodes - commonTransportCallback := func(transport *mockTransport) { - transport.multicastFn = func(message *proto.Message) { - multicastFn(message) - } - } + // silentByzantineNodes is the number of byzantine nodes + // that are going to be silent, i.e. do not respond + silentByzantineNodes uint64 - // commonBackendCallback is the common method modification required - // for the Backend, for all nodes - commonBackendCallback := func(backend *mockBackend, nodeIndex int) { - // Make sure the quorum function requires all nodes - backend.hasQuorumFn = commonHasQuorumFn(numNodes) - - // Make sure the node ID is properly relayed - backend.idFn = func() []byte { - return nodes[nodeIndex] - } - - // Make sure the only proposer is picked using Round Robin - backend.isProposerFn = func(from []byte, height uint64, _ uint64) bool { - return bytes.Equal(from, nodes[height%numNodes]) - } - - // Make sure the proposal is valid if it matches what node 0 proposed - backend.isValidBlockFn = func(newProposal []byte) bool { - return bytes.Equal(newProposal, correctRoundMessage.proposal) - } + // badByzantineNodes is the number of byzantine nodes + // that are going to send bad messages + badByzantineNodes uint64 - // Make sure the proposal hash matches - backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { - return bytes.Equal(p, correctRoundMessage.proposal) && bytes.Equal(ph, correctRoundMessage.hash) - } - - // Make sure the preprepare message is built correctly - backend.buildPrePrepareMessageFn = func( - proposal []byte, - certificate *proto.RoundChangeCertificate, - view *proto.View, - ) *proto.Message { - return buildBasicPreprepareMessage( - proposal, - correctRoundMessage.hash, - certificate, - nodes[nodeIndex], - view) - } - - // Make sure the prepare message is built correctly - backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicPrepareMessage(correctRoundMessage.hash, nodes[nodeIndex], view) - } - - // Make sure the commit message is built correctly - backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicCommitMessage(correctRoundMessage.hash, correctRoundMessage.seal, nodes[nodeIndex], view) - } - - // Make sure the round change message is built correctly - backend.buildRoundChangeMessageFn = func( - proposal []byte, - certificate *proto.PreparedCertificate, - view *proto.View, - ) *proto.Message { - return buildBasicRoundChangeMessage(proposal, certificate, view, nodes[nodeIndex]) - } - - // Make sure the inserted proposal is noted - backend.insertBlockFn = func(proposal []byte, _ []*messages.CommittedSeal) { - insertedProposals.insertProposal(nodeIndex, proposal) - } - - // Make sure the proposal can be built - backend.buildProposalFn = func(_ *proto.View) []byte { - return correctRoundMessage.proposal - } - } - - // Create default cluster for rapid tests - cluster := newMockCluster(numNodes, commonBackendCallback, nil, commonTransportCallback) - - // Set the multicast callback to relay the message - // to the entire cluster - multicastFn = func(message *proto.Message) { - cluster.pushMessage(message) - } - - // Run the sequence up until a certain height - for height := uint64(0); height < desiredHeight; height++ { - // Start the main run loops - cluster.runSequence(height) - - // Wait until the main run loops finish - cluster.awaitCompletion() - } - - // Make sure that the inserted proposal is valid for each height - for _, proposalMap := range insertedProposals.proposals { - // Make sure the node has the adequate number of inserted proposals - assert.Len(t, proposalMap, int(desiredHeight)) - - for _, insertedProposal := range proposalMap { - assert.Equal(t, correctRoundMessage.proposal, insertedProposal) - } - } - }) + // desiredHeight is the desired height number + desiredHeight uint64 } -// getByzantineNodes returns a random subset of -// byzantine nodes -func getByzantineNodes( - numNodes uint64, - set [][]byte, -) map[string]struct{} { - gen := rapid.SampledFrom(set) - byzantineNodes := make(map[string]struct{}) - - for i := 0; i < int(numNodes); i++ { - byzantineNodes[string(gen.Example(i))] = struct{}{} +// generatePropertyTestEvent generates propertyTestEvent model +func generatePropertyTestEvent(t *rapid.T) *propertyTestEvent { + var ( + numNodes = rapid.Uint64Range(4, 15).Draw(t, "number of cluster nodes") + numByzantineNodes = rapid.Uint64Range(0, maxFaulty(numNodes)).Draw(t, "number of byzantine nodes") + silentByzantineNodes = rapid.Uint64Range(0, numByzantineNodes).Draw(t, "number of silent byzantine nodes") + desiredHeight = rapid.Uint64Range(10, 20).Draw(t, "minimum height to be reached") + ) + + return &propertyTestEvent{ + nodes: numNodes, + byzantineNodes: numByzantineNodes, + silentByzantineNodes: silentByzantineNodes, + badByzantineNodes: numByzantineNodes - silentByzantineNodes, + desiredHeight: desiredHeight, } - - return byzantineNodes } -// TestProperty_MajorityHonestNodes is a property-based test -// that assures the cluster can reach consensus on any -// arbitrary number of valid nodes and byzantine nodes -func TestProperty_MajorityHonestNodes(t *testing.T) { +// TestProperty is a property-based test +// that assures the cluster can handle rounds properly in any cases. +func TestProperty(t *testing.T) { t.Parallel() rapid.Check(t, func(t *rapid.T) { var multicastFn func(message *proto.Message) var ( - numNodes = rapid.Uint64Range(4, 30).Draw(t, "number of cluster nodes") - numByzantineNodes = rapid.Uint64Range(1, maxFaulty(numNodes)).Draw(t, "number of byzantine nodes") - desiredHeight = rapid.Uint64Range(1, 5).Draw(t, "minimum height to be reached") - - nodes = generateNodeAddresses(numNodes) - insertedProposals = newMockInsertedProposals(numNodes) + testEvent = generatePropertyTestEvent(t) + currentQuorum = quorum(testEvent.nodes) + nodes = generateNodeAddresses(testEvent.nodes) + insertedProposals = newMockInsertedProposals(testEvent.nodes) ) - // Initialize the byzantine nodes - byzantineNodes := getByzantineNodes( - numByzantineNodes, - nodes, - ) - - isByzantineNode := func(from []byte) bool { - _, exists := byzantineNodes[string(from)] - - return exists - } - // commonTransportCallback is the common method modification // required for Transport, for all nodes - commonTransportCallback := func(transport *mockTransport) { + commonTransportCallback := func(transport *mockTransport, nodeIndex int) { transport.multicastFn = func(message *proto.Message) { - if isByzantineNode(message.From) { - // If the node is byzantine, mock - // not sending out the message + // If node is silent, don't send a message + if uint64(nodeIndex) >= testEvent.byzantineNodes && + uint64(nodeIndex) < testEvent.silentByzantineNodes { return } @@ -244,159 +127,14 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { // commonBackendCallback is the common method modification required // for the Backend, for all nodes commonBackendCallback := func(backend *mockBackend, nodeIndex int) { - // Make sure the quorum function is Quorum optimal - backend.hasQuorumFn = commonHasQuorumFn(numNodes) - - // Make sure the node ID is properly relayed - backend.idFn = func() []byte { - return nodes[nodeIndex] - } - - // Make sure the only proposer is picked using Round Robin - backend.isProposerFn = func(from []byte, height uint64, round uint64) bool { - return bytes.Equal( - from, - nodes[int(height+round)%len(nodes)], - ) - } - - // Make sure the proposal is valid if it matches what node 0 proposed - backend.isValidBlockFn = func(newProposal []byte) bool { - return bytes.Equal(newProposal, correctRoundMessage.proposal) - } - - // Make sure the proposal hash matches - backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { - return bytes.Equal(p, correctRoundMessage.proposal) && bytes.Equal(ph, correctRoundMessage.hash) - } - - // Make sure the preprepare message is built correctly - backend.buildPrePrepareMessageFn = func( - proposal []byte, - certificate *proto.RoundChangeCertificate, - view *proto.View, - ) *proto.Message { - return buildBasicPreprepareMessage( - proposal, - correctRoundMessage.hash, - certificate, - nodes[nodeIndex], - view, - ) - } - - // Make sure the prepare message is built correctly - backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicPrepareMessage(correctRoundMessage.hash, nodes[nodeIndex], view) - } - - // Make sure the commit message is built correctly - backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicCommitMessage( - correctRoundMessage.hash, - correctRoundMessage.seal, - nodes[nodeIndex], - view, - ) - } - - // Make sure the round change message is built correctly - backend.buildRoundChangeMessageFn = func( - proposal []byte, - certificate *proto.PreparedCertificate, - view *proto.View, - ) *proto.Message { - return buildBasicRoundChangeMessage(proposal, certificate, view, nodes[nodeIndex]) - } - - // Make sure the inserted proposal is noted - backend.insertBlockFn = func(proposal []byte, _ []*messages.CommittedSeal) { - insertedProposals.insertProposal(nodeIndex, proposal) - } - - // Make sure the proposal can be built - backend.buildProposalFn = func(_ *proto.View) []byte { - return correctRoundMessage.proposal - } - } - - // Create default cluster for rapid tests - cluster := newMockCluster(numNodes, commonBackendCallback, nil, commonTransportCallback) - - // Set the multicast callback to relay the message - // to the entire cluster - multicastFn = func(message *proto.Message) { - cluster.pushMessage(message) - } - - // Create context timeout based on the bad nodes number - ctxTimeout := getRoundTimeout(testRoundTimeout, 0, numByzantineNodes+1) - - // Run the sequence up until a certain height - for height := uint64(0); height < desiredHeight; height++ { - // Start the main run loops - cluster.runSequence(height) - - // Wait until Quorum nodes finish their run loop - ctx, cancelFn := context.WithTimeout(context.Background(), ctxTimeout) - err := cluster.awaitNCompletions(ctx, int64(quorum(numNodes))) - assert.NoError(t, err, "unable to wait for nodes to complete") - - // Shutdown the remaining nodes that might be hanging - cluster.forceShutdown() - cancelFn() - } - - // Make sure proposals map is not empty - require.Len(t, insertedProposals.proposals, int(numNodes)) - - // Make sure that the inserted proposal is valid for each height - for _, proposalMap := range insertedProposals.proposals { - for _, insertedProposal := range proposalMap { - assert.Equal(t, correctRoundMessage.proposal, insertedProposal) - } - } - }) -} - -// TestProperty_MajorityHonestNodes_BroadcastBadMessage is a property-based test -// that assures the cluster can reach consensus on if "bad" nodes send -// wrong message during the round execution. Avoiding a scenario when -// a bad node is proposer -func TestProperty_MajorityHonestNodes_BroadcastBadMessage(t *testing.T) { - t.Parallel() - - rapid.Check(t, func(t *rapid.T) { - var multicastFn func(message *proto.Message) - - var ( - numNodes = rapid.Uint64Range(4, 15).Draw(t, "number of cluster nodes") - numByzantineNodes = rapid.Uint64Range(1, maxFaulty(numNodes)).Draw(t, "number of byzantine nodes") - desiredHeight = rapid.Uint64Range(1, 5).Draw(t, "minimum height to be reached") - - nodes = generateNodeAddresses(numNodes) - insertedProposals = newMockInsertedProposals(numNodes) - ) - - // commonTransportCallback is the common method modification - // required for Transport, for all nodes - commonTransportCallback := func(transport *mockTransport) { - transport.multicastFn = func(message *proto.Message) { - multicastFn(message) - } - } - - // commonBackendCallback is the common method modification required - // for the Backend, for all nodes - commonBackendCallback := func(backend *mockBackend, nodeIndex int) { - // Use a bad message if the current node is a malicious one + // Use a bad message if the current node is a bad byzantine one message := correctRoundMessage - if uint64(nodeIndex) < numByzantineNodes { + if uint64(nodeIndex) < testEvent.byzantineNodes { message = badRoundMessage } // Make sure the quorum function is Quorum optimal - backend.hasQuorumFn = commonHasQuorumFn(numNodes) + backend.hasQuorumFn = commonHasQuorumFn(testEvent.nodes) // Make sure the node ID is properly relayed backend.idFn = func() []byte { @@ -467,41 +205,54 @@ func TestProperty_MajorityHonestNodes_BroadcastBadMessage(t *testing.T) { } // Create default cluster for rapid tests - cluster := newMockCluster(numNodes, commonBackendCallback, nil, commonTransportCallback) + cluster := newMockCluster(testEvent.nodes, commonBackendCallback, nil, commonTransportCallback) // Set the multicast callback to relay the message // to the entire cluster - multicastFn = func(message *proto.Message) { - cluster.pushMessage(message) + multicastFn = cluster.pushMessage + + // Minimum one round is required + minRounds := uint64(1) + if testEvent.byzantineNodes > minRounds { + minRounds = testEvent.byzantineNodes } // Create context timeout based on the bad nodes number - ctxTimeout := getRoundTimeout(testRoundTimeout, 0, numByzantineNodes+1) + ctxTimeout := getRoundTimeout(testRoundTimeout, testRoundTimeout, minRounds+1) // Run the sequence up until a certain height - for height := uint64(0); height < desiredHeight; height++ { + for height := uint64(0); height < testEvent.desiredHeight; height++ { // Start the main run loops cluster.runSequence(height) - // Wait until Quorum nodes finish their run loop - ctx, cancelFn := context.WithTimeout(context.Background(), ctxTimeout) - err := cluster.awaitNCompletions(ctx, int64(quorum(numNodes))) - assert.NoError(t, err, "unable to wait for nodes to complete") + if testEvent.byzantineNodes == 0 { + // Wait until all nodes propose messages + cluster.awaitCompletion() + } else { + // Wait until Quorum nodes finish their run loop + ctx, cancelFn := context.WithTimeout(context.Background(), ctxTimeout) + err := cluster.awaitNCompletions(ctx, int64(currentQuorum)) + assert.NoError(t, err, "unable to wait for nodes to complete on height %d", height) + cancelFn() + } // Shutdown the remaining nodes that might be hanging cluster.forceShutdown() - cancelFn() } // Make sure proposals map is not empty - require.Len(t, insertedProposals.proposals, int(numNodes)) + require.Len(t, insertedProposals.proposals, int(testEvent.nodes)) // Make sure that the inserted proposal is valid for each height for i, proposalMap := range insertedProposals.proposals { - if i < int(numByzantineNodes) { + if i < int(testEvent.byzantineNodes) { // Proposals map must be empty when a byzantine node is proposer assert.Empty(t, proposalMap) } else { + // Make sure the node has proposals + assert.NotEmpty(t, proposalMap) + + // Check values for _, insertedProposal := range proposalMap { assert.Equal(t, correctRoundMessage.proposal, insertedProposal) } From 8b1e3f20c4c3e1df86deeb356479340d37fe27c2 Mon Sep 17 00:00:00 2001 From: Igor Crevar Date: Mon, 28 Nov 2022 10:24:43 +0100 Subject: [PATCH 04/17] EVM-220 TestClusterBlockSync/BLS fails in voting power branch (#48) --- core/ibft.go | 12 ++++ core/ibft_test.go | 129 ++++++++++++++++++++++++++++++++++++++ core/mock_test.go | 7 +++ messages/messages.go | 4 ++ messages/messages_test.go | 1 + 5 files changed, 153 insertions(+) diff --git a/core/ibft.go b/core/ibft.go index 01f92ab..10fb96c 100644 --- a/core/ibft.go +++ b/core/ibft.go @@ -25,6 +25,8 @@ type Messages interface { AddMessage(message *proto.Message) PruneByHeight(height uint64) + SignalEvent(message *proto.Message) + // Messages fetchers // GetValidMessages( view *proto.View, @@ -999,6 +1001,14 @@ func (i *IBFT) AddMessage(message *proto.Message) { // Check if the message should even be considered if i.isAcceptableMessage(message) { i.messages.AddMessage(message) + + msgs := i.messages.GetValidMessages( + message.View, + message.Type, + func(_ *proto.Message) bool { return true }) + if i.backend.HasQuorum(message.View.Height, msgs, message.Type) { + i.messages.SignalEvent(message) + } } } @@ -1045,6 +1055,8 @@ func (i *IBFT) validPC( return false } + // Order of messages is important! + // Mesage with type of MessageType_PREPREPARE must be the first element of allMessages slice allMessages := append( []*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages..., diff --git a/core/ibft_test.go b/core/ibft_test.go index e31fd06..87387f0 100644 --- a/core/ibft_test.go +++ b/core/ibft_test.go @@ -2329,3 +2329,132 @@ func Test_getRoundTimeout(t *testing.T) { }) } } + +func TestIBFT_AddMessage(t *testing.T) { + t.Parallel() + + const ( + validHeight = uint64(10) + validRound = uint64(7) + validMsgType = proto.MessageType_PREPREPARE + ) + + var validSender = []byte{1, 2, 3} + + executeTest := func( + msg *proto.Message, + hasQuorum, shouldAddMessageCalled, shouldHasQuorumCalled, shouldSignalEventCalled bool) { + var ( + hasQuorumCalled = false + signalEventCalled = false + addMessageCalled = false + log = mockLogger{} + backend = mockBackend{} + transport = mockTransport{} + messages = mockMessages{} + ) + + backend.isValidSenderFn = func(m *proto.Message) bool { + return bytes.Equal(m.From, validSender) + } + + backend.hasQuorumFn = func(height uint64, _ []*proto.Message, msgType proto.MessageType) bool { + hasQuorumCalled = true + + assert.Equal(t, validHeight, height) + assert.Equal(t, validMsgType, msgType) + + return hasQuorum + } + + messages.addMessageFn = func(m *proto.Message) { + addMessageCalled = true + + assert.Equal(t, msg, m) + } + + messages.signalEventFn = func(*proto.Message) { + signalEventCalled = true + } + + i := NewIBFT(log, backend, transport) + i.messages = messages + i.state.view = &proto.View{Height: validHeight, Round: validRound} + + i.AddMessage(msg) + + assert.Equal(t, shouldAddMessageCalled, addMessageCalled) + assert.Equal(t, shouldHasQuorumCalled, hasQuorumCalled) + assert.Equal(t, shouldSignalEventCalled, signalEventCalled) + } + + t.Run("nil message case", func(t *testing.T) { + t.Parallel() + + executeTest(nil, true, false, false, false) + }) + + t.Run("!isAcceptableMessage - invalid sender", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + View: &proto.View{Height: validHeight, Round: validRound}, + Type: validMsgType, + } + executeTest(msg, true, false, false, false) + }) + + t.Run("!isAcceptableMessage - invalid view", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + } + executeTest(msg, true, false, false, false) + }) + + t.Run("!isAcceptableMessage - invalid height", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + View: &proto.View{Height: validHeight - 1, Round: validRound}, + } + executeTest(msg, true, false, false, false) + }) + + t.Run("!isAcceptableMessage - invalid round", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + View: &proto.View{Height: validHeight, Round: validRound - 1}, + } + executeTest(msg, true, false, false, false) + }) + + t.Run("correct - but quorum not reached", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + View: &proto.View{Height: validHeight, Round: validRound}, + } + executeTest(msg, false, true, true, false) + }) + + t.Run("correct - quorum reached", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + View: &proto.View{Height: validHeight, Round: validRound}, + } + executeTest(msg, true, true, true, true) + }) +} diff --git a/core/mock_test.go b/core/mock_test.go index 1e98596..5bb2db9 100644 --- a/core/mock_test.go +++ b/core/mock_test.go @@ -234,6 +234,7 @@ func (l mockLogger) Error(msg string, args ...interface{}) { type mockMessages struct { addMessageFn func(message *proto.Message) pruneByHeightFn func(height uint64) + signalEventFn func(message *proto.Message) getValidMessagesFn func( view *proto.View, @@ -284,6 +285,12 @@ func (m mockMessages) PruneByHeight(height uint64) { } } +func (m mockMessages) SignalEvent(msg *proto.Message) { + if m.signalEventFn != nil { + m.signalEventFn(msg) + } +} + func (m mockMessages) GetMostRoundChangeMessages(round, height uint64) []*proto.Message { if m.getMostRoundChangeMessagesFn != nil { return m.getMostRoundChangeMessagesFn(round, height) diff --git a/messages/messages.go b/messages/messages.go index 894834f..bf19ac1 100644 --- a/messages/messages.go +++ b/messages/messages.go @@ -72,7 +72,10 @@ func (ms *Messages) AddMessage(message *proto.Message) { // Append the message to the appropriate queue messages := heightMsgMap.getViewMessages(message.View) messages[string(message.From)] = message +} +// SignalEvent signals event +func (ms *Messages) SignalEvent(message *proto.Message) { ms.eventManager.signalEvent( message.Type, &proto.View{ @@ -82,6 +85,7 @@ func (ms *Messages) AddMessage(message *proto.Message) { ) } +// Close closes event manager func (ms *Messages) Close() { ms.eventManager.close() } diff --git a/messages/messages_test.go b/messages/messages_test.go index 54b6e85..aad99f3 100644 --- a/messages/messages_test.go +++ b/messages/messages_test.go @@ -341,6 +341,7 @@ func TestMessages_EventManager(t *testing.T) { randomMessages := generateRandomMessages(numMessages, baseView, messageType) for _, message := range randomMessages { messages.AddMessage(message) + messages.SignalEvent(message) } // Wait for the subscription event to happen From 50da8f5c4e92a58b401f0d4f90dfeeb4d320c552 Mon Sep 17 00:00:00 2001 From: Roman Behma <13855864+begmaroman@users.noreply.github.com> Date: Mon, 28 Nov 2022 10:28:50 +0000 Subject: [PATCH 05/17] Added per round event-based setup in rapid tests (#47) --- core/consensus_test.go | 4 +- core/mock_test.go | 7 +- core/rapid_test.go | 270 ++++++++++++++++++++++++++++++----------- 3 files changed, 199 insertions(+), 82 deletions(-) diff --git a/core/consensus_test.go b/core/consensus_test.go index 901f71d..c48bb89 100644 --- a/core/consensus_test.go +++ b/core/consensus_test.go @@ -388,9 +388,7 @@ func TestConsensus_InvalidBlock(t *testing.T) { // Set the multicast callback to relay the message // to the entire cluster - multicastFn = func(message *proto.Message) { - cluster.pushMessage(message) - } + multicastFn = cluster.pushMessage // Start the main run loops cluster.runSequence(1) diff --git a/core/mock_test.go b/core/mock_test.go index 5bb2db9..8e6b9e6 100644 --- a/core/mock_test.go +++ b/core/mock_test.go @@ -417,8 +417,8 @@ func (wg *mockNodeWg) Add(delta int) { } func (wg *mockNodeWg) Done() { - wg.WaitGroup.Done() atomic.AddInt64(&wg.count, 1) + wg.WaitGroup.Done() } func (wg *mockNodeWg) getDone() int64 { @@ -443,10 +443,7 @@ func (m *mockCluster) runSequence(height uint64) { for nodeIndex, node := range m.nodes { m.wg.Add(1) - go func( - ctx context.Context, - node *IBFT, - ) { + go func(ctx context.Context, node *IBFT) { // Start the main run loop for the node node.RunSequence(ctx, height) diff --git a/core/rapid_test.go b/core/rapid_test.go index 064e556..6adb2e0 100644 --- a/core/rapid_test.go +++ b/core/rapid_test.go @@ -3,11 +3,11 @@ package core import ( "bytes" "context" + "fmt" "sync" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "pgregory.net/rapid" "github.com/0xPolygon/go-ibft/messages" @@ -51,20 +51,18 @@ func (m *mockInsertedProposals) insertProposal( proposal []byte, ) { m.Lock() - defer m.Unlock() - m.proposals[nodeIndex][m.currentProposals[nodeIndex]] = proposal m.currentProposals[nodeIndex]++ + m.Unlock() } -// propertyTestEvent contains randomly-generated data for rapid testing -type propertyTestEvent struct { - // nodes is the total number of nodes - nodes uint64 - - // byzantineNodes is the total number of byzantine nodes - byzantineNodes uint64 +// getProposer returns proposer index +func getProposer(height, round, nodes uint64) uint64 { + return (height + round) % nodes +} +// propertyTestEvent is the behaviour setup per specific round +type propertyTestEvent struct { // silentByzantineNodes is the number of byzantine nodes // that are going to be silent, i.e. do not respond silentByzantineNodes uint64 @@ -72,27 +70,135 @@ type propertyTestEvent struct { // badByzantineNodes is the number of byzantine nodes // that are going to send bad messages badByzantineNodes uint64 +} + +func (e propertyTestEvent) badNodes() uint64 { + return e.silentByzantineNodes + e.badByzantineNodes +} + +func (e propertyTestEvent) isSilent(nodeIndex int) bool { + return uint64(nodeIndex) < e.silentByzantineNodes +} + +// getMessage returns bad message for byzantine bad node, +// correct message for non-byzantine nodes, and nil for silent nodes +func (e propertyTestEvent) getMessage(nodeIndex int) *roundMessage { + message := correctRoundMessage + if uint64(nodeIndex) < e.badNodes() { + message = badRoundMessage + } + + return &message +} + +// propertyTestSetup contains randomly-generated data for rapid testing +type propertyTestSetup struct { + sync.Mutex + + // nodes is the total number of nodes + nodes uint64 // desiredHeight is the desired height number desiredHeight uint64 + + // events is the mapping between the current height and its rounds + events [][]propertyTestEvent + + currentHeight map[int]uint64 + currentRound map[int]uint64 +} + +func (s *propertyTestSetup) setRound(nodeIndex int, round uint64) { + s.Lock() + s.currentRound[nodeIndex] = round + s.Unlock() +} + +func (s *propertyTestSetup) incHeight() { + s.Lock() + + for nodeIndex := 0; uint64(nodeIndex) < s.nodes; nodeIndex++ { + s.currentHeight[nodeIndex]++ + s.currentRound[nodeIndex] = 0 + } + + s.Unlock() +} + +func (s *propertyTestSetup) getEvent(nodeIndex int) propertyTestEvent { + s.Lock() + + var ( + height = int(s.currentHeight[nodeIndex]) + roundNumber = int(s.currentRound[nodeIndex]) + round propertyTestEvent + ) + + if roundNumber >= len(s.events[height]) { + round = s.events[height][len(s.events[height])-1] + } else { + round = s.events[height][roundNumber] + } + + s.Unlock() + + return round +} + +func (s *propertyTestSetup) lastRound(height uint64) propertyTestEvent { + return s.events[height][len(s.events[height])-1] } // generatePropertyTestEvent generates propertyTestEvent model -func generatePropertyTestEvent(t *rapid.T) *propertyTestEvent { +func generatePropertyTestEvent(t *rapid.T) *propertyTestSetup { + // Generate random setup of the nodes number, byzantine nodes number, and desired height var ( - numNodes = rapid.Uint64Range(4, 15).Draw(t, "number of cluster nodes") - numByzantineNodes = rapid.Uint64Range(0, maxFaulty(numNodes)).Draw(t, "number of byzantine nodes") - silentByzantineNodes = rapid.Uint64Range(0, numByzantineNodes).Draw(t, "number of silent byzantine nodes") - desiredHeight = rapid.Uint64Range(10, 20).Draw(t, "minimum height to be reached") + numNodes = rapid.Uint64Range(4, 30).Draw(t, "number of cluster nodes") + desiredHeight = rapid.Uint64Range(5, 20).Draw(t, "minimum height to be reached") + maxBadNodes = maxFaulty(numNodes) ) - return &propertyTestEvent{ - nodes: numNodes, - byzantineNodes: numByzantineNodes, - silentByzantineNodes: silentByzantineNodes, - badByzantineNodes: numByzantineNodes - silentByzantineNodes, - desiredHeight: desiredHeight, + setup := &propertyTestSetup{ + nodes: numNodes, + desiredHeight: desiredHeight, + events: make([][]propertyTestEvent, desiredHeight), + currentHeight: map[int]uint64{}, + currentRound: map[int]uint64{}, } + + // Go over the desired height and generate random number of rounds + // depending on the round result: success or fail. + for height := uint64(0); height < desiredHeight; height++ { + var round uint64 + + // Generate random rounds until we reach a state where to expect a successfully + // met consensus. Meaning >= 2/3 of all nodes would reach the consensus. + for { + numByzantineNodes := rapid. + Uint64Range(0, maxBadNodes). + Draw(t, fmt.Sprintf("number of byzantine nodes for height %d on round %d", height, round)) + silentByzantineNodes := rapid. + Uint64Range(0, numByzantineNodes). + Draw(t, fmt.Sprintf("number of silent byzantine nodes for height %d on round %d", height, round)) + proposerIdx := getProposer(height, round, numNodes) + + setup.events[height] = append(setup.events[height], propertyTestEvent{ + silentByzantineNodes: silentByzantineNodes, + badByzantineNodes: numByzantineNodes - silentByzantineNodes, + }) + + // If the proposer per the current round is not byzantine node, + // it is expected the consensus should be met, so the loop + // could be stopped for the running height. + if proposerIdx >= numByzantineNodes { + break + } + + round++ + } + } + + return setup } // TestProperty is a property-based test @@ -104,19 +210,21 @@ func TestProperty(t *testing.T) { var multicastFn func(message *proto.Message) var ( - testEvent = generatePropertyTestEvent(t) - currentQuorum = quorum(testEvent.nodes) - nodes = generateNodeAddresses(testEvent.nodes) - insertedProposals = newMockInsertedProposals(testEvent.nodes) + setup = generatePropertyTestEvent(t) + nodes = generateNodeAddresses(setup.nodes) + insertedProposals = newMockInsertedProposals(setup.nodes) ) // commonTransportCallback is the common method modification // required for Transport, for all nodes commonTransportCallback := func(transport *mockTransport, nodeIndex int) { transport.multicastFn = func(message *proto.Message) { + if message.Type == proto.MessageType_ROUND_CHANGE { + setup.setRound(nodeIndex, message.View.Round) + } + // If node is silent, don't send a message - if uint64(nodeIndex) >= testEvent.byzantineNodes && - uint64(nodeIndex) < testEvent.silentByzantineNodes { + if setup.getEvent(nodeIndex).isSilent(nodeIndex) { return } @@ -127,14 +235,8 @@ func TestProperty(t *testing.T) { // commonBackendCallback is the common method modification required // for the Backend, for all nodes commonBackendCallback := func(backend *mockBackend, nodeIndex int) { - // Use a bad message if the current node is a bad byzantine one - message := correctRoundMessage - if uint64(nodeIndex) < testEvent.byzantineNodes { - message = badRoundMessage - } - // Make sure the quorum function is Quorum optimal - backend.hasQuorumFn = commonHasQuorumFn(testEvent.nodes) + backend.hasQuorumFn = commonHasQuorumFn(setup.nodes) // Make sure the node ID is properly relayed backend.idFn = func() []byte { @@ -142,20 +244,24 @@ func TestProperty(t *testing.T) { } // Make sure the only proposer is picked using Round Robin - backend.isProposerFn = func(from []byte, height uint64, round uint64) bool { + backend.isProposerFn = func(from []byte, height, round uint64) bool { return bytes.Equal( from, - nodes[int(height+round)%len(nodes)], + nodes[getProposer(height, round, setup.nodes)], ) } // Make sure the proposal is valid if it matches what node 0 proposed backend.isValidBlockFn = func(newProposal []byte) bool { + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + return bytes.Equal(newProposal, message.proposal) } // Make sure the proposal hash matches backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + return bytes.Equal(p, message.proposal) && bytes.Equal(ph, message.hash) } @@ -165,6 +271,8 @@ func TestProperty(t *testing.T) { certificate *proto.RoundChangeCertificate, view *proto.View, ) *proto.Message { + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + return buildBasicPreprepareMessage( proposal, message.hash, @@ -176,11 +284,15 @@ func TestProperty(t *testing.T) { // Make sure the prepare message is built correctly backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + return buildBasicPrepareMessage(message.hash, nodes[nodeIndex], view) } // Make sure the commit message is built correctly backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + return buildBasicCommitMessage(message.hash, message.seal, nodes[nodeIndex], view) } @@ -199,64 +311,74 @@ func TestProperty(t *testing.T) { } // Make sure the proposal can be built - backend.buildProposalFn = func(_ *proto.View) []byte { + backend.buildProposalFn = func(view *proto.View) []byte { + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + return message.proposal } } // Create default cluster for rapid tests - cluster := newMockCluster(testEvent.nodes, commonBackendCallback, nil, commonTransportCallback) + cluster := newMockCluster( + setup.nodes, + commonBackendCallback, + nil, + commonTransportCallback, + ) // Set the multicast callback to relay the message // to the entire cluster multicastFn = cluster.pushMessage - // Minimum one round is required - minRounds := uint64(1) - if testEvent.byzantineNodes > minRounds { - minRounds = testEvent.byzantineNodes - } - - // Create context timeout based on the bad nodes number - ctxTimeout := getRoundTimeout(testRoundTimeout, testRoundTimeout, minRounds+1) - // Run the sequence up until a certain height - for height := uint64(0); height < testEvent.desiredHeight; height++ { + for height := uint64(0); height < setup.desiredHeight; height++ { + // Create context timeout based on the bad nodes number + rounds := uint64(len(setup.events[height])) + ctxTimeout := getRoundTimeout(testRoundTimeout, testRoundTimeout, rounds*2) + // Start the main run loops cluster.runSequence(height) - if testEvent.byzantineNodes == 0 { - // Wait until all nodes propose messages - cluster.awaitCompletion() - } else { - // Wait until Quorum nodes finish their run loop - ctx, cancelFn := context.WithTimeout(context.Background(), ctxTimeout) - err := cluster.awaitNCompletions(ctx, int64(currentQuorum)) - assert.NoError(t, err, "unable to wait for nodes to complete on height %d", height) - cancelFn() - } + ctx, cancelFn := context.WithTimeout(context.Background(), ctxTimeout) + err := cluster.awaitNCompletions(ctx, int64(quorum(setup.nodes))) + assert.NoError(t, err, "unable to wait for nodes to complete on height %d", height) + cancelFn() // Shutdown the remaining nodes that might be hanging cluster.forceShutdown() - } - // Make sure proposals map is not empty - require.Len(t, insertedProposals.proposals, int(testEvent.nodes)) - - // Make sure that the inserted proposal is valid for each height - for i, proposalMap := range insertedProposals.proposals { - if i < int(testEvent.byzantineNodes) { - // Proposals map must be empty when a byzantine node is proposer - assert.Empty(t, proposalMap) - } else { - // Make sure the node has proposals - assert.NotEmpty(t, proposalMap) - - // Check values - for _, insertedProposal := range proposalMap { - assert.Equal(t, correctRoundMessage.proposal, insertedProposal) + // Increment current height + setup.incHeight() + + // Make sure proposals map is not empty + assert.Len(t, insertedProposals.proposals, int(setup.nodes)) + + // Make sure bad nodes were out of the last round. + // Make sure we have inserted blocks >= quorum per round. + lastRound := setup.lastRound(height) + badNodes := lastRound.badNodes() + var proposalsNumber int + for nodeID, proposalMap := range insertedProposals.proposals { + if nodeID >= int(badNodes) { + // Only one inserted block per valid round + assert.LessOrEqual(t, len(proposalMap), 1) + proposalsNumber++ + + // Make sure inserted block value is correct + for _, val := range proposalMap { + assert.Equal(t, correctRoundMessage.proposal, val) + } + } else { + // There should not be inserted blocks in bad nodes + assert.Empty(t, proposalMap) } } + + // Make sure the total number of inserted blocks >= quorum + assert.GreaterOrEqual(t, proposalsNumber, int(quorum(setup.nodes))) + + // Reset proposals map for the next height + insertedProposals = newMockInsertedProposals(setup.nodes) } }) } From e2b5fd3bd5037e0aac1fdd9f38871224f4115910 Mon Sep 17 00:00:00 2001 From: Vuk Gavrilovic <114920311+trimixlover@users.noreply.github.com> Date: Fri, 2 Dec 2022 12:08:32 +0100 Subject: [PATCH 06/17] Remove redundant changeState (#49) --- core/ibft.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/core/ibft.go b/core/ibft.go index 10fb96c..cc009b7 100644 --- a/core/ibft.go +++ b/core/ibft.go @@ -569,10 +569,8 @@ func (i *IBFT) runNewRound(ctx context.Context) error { continue } - // Accept the proposal since it's valid - i.acceptProposal(proposalMessage) - // Multicast the PREPARE message + i.state.setProposalMessage(proposalMessage) i.sendPrepareMessage(view) i.log.Debug("prepare message multicasted") From eea78df818aecc3bb38cb5aacc9c2b740df15035 Mon Sep 17 00:00:00 2001 From: kourin Date: Fri, 16 Dec 2022 20:22:45 +0900 Subject: [PATCH 07/17] Fix Wrong Round Value in Validation of roundsAndPreparedBlockHashes (#51) Fix round in roundsAndP reparedBlockHashes --- core/ibft.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/ibft.go b/core/ibft.go index cc009b7..2c8ff79 100644 --- a/core/ibft.go +++ b/core/ibft.go @@ -689,7 +689,7 @@ func (i *IBFT) validateProposal(msg *proto.Message, view *proto.View) bool { hash := messages.ExtractProposalHash(cert.ProposalMessage) roundsAndPreparedBlockHashes = append(roundsAndPreparedBlockHashes, roundHashTuple{ - round: rcMessage.View.Round, + round: cert.ProposalMessage.View.Round, hash: hash, }) } From 730c09c2a8f2629924b24aed7da6aebfe83a08b2 Mon Sep 17 00:00:00 2001 From: Vuk Gavrilovic <114920311+trimixlover@users.noreply.github.com> Date: Fri, 16 Dec 2022 12:53:15 +0100 Subject: [PATCH 08/17] Audit improvements (#50) * Audit improvements --- core/ibft.go | 39 ++++++++- core/ibft_test.go | 207 +++++++++++++++++++++++++++++++++++++++++--- messages/helpers.go | 24 ++++- 3 files changed, 253 insertions(+), 17 deletions(-) diff --git a/core/ibft.go b/core/ibft.go index 2c8ff79..368573b 100644 --- a/core/ibft.go +++ b/core/ibft.go @@ -328,6 +328,7 @@ func (i *IBFT) RunSequence(ctx context.Context, h uint64) { i.moveToNewRound(ev.round) i.acceptProposal(ev.proposalMessage) i.state.setRoundStarted(true) + i.sendPrepareMessage(view) case round := <-i.roundCertificate: teardown() i.log.Info("received future RCC", "round", round) @@ -664,12 +665,31 @@ func (i *IBFT) validateProposal(msg *proto.Message, view *proto.View) bool { return false } + if !messages.HasUniqueSenders(certificate.RoundChangeMessages) { + return false + } + // Make sure all messages in the RCC are valid Round Change messages for _, rc := range certificate.RoundChangeMessages { // Make sure the message is a Round Change message if rc.Type != proto.MessageType_ROUND_CHANGE { return false } + + // Height of the message matches height of the proposal + if rc.View.Height != height { + return false + } + + // Round of the message matches round of the proposal + if rc.View.Round != round { + return false + } + + // Sender of RCC is valid + if !i.backend.IsValidSender(rc) { + return false + } } // Extract possible rounds and their corresponding @@ -706,7 +726,7 @@ func (i *IBFT) validateProposal(msg *proto.Message, view *proto.View) bool { ) for _, tuple := range roundsAndPreparedBlockHashes { - if tuple.round > maxRound { + if tuple.round >= maxRound { maxRound = tuple.round expectedHash = tuple.hash } @@ -1054,7 +1074,7 @@ func (i *IBFT) validPC( } // Order of messages is important! - // Mesage with type of MessageType_PREPREPARE must be the first element of allMessages slice + // Message with type of MessageType_PREPREPARE must be the first element of allMessages slice allMessages := append( []*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages..., @@ -1097,6 +1117,11 @@ func (i *IBFT) validPC( return false } + // Make sure all have the same round + if !messages.AllHaveSameRound(allMessages) { + return false + } + // Make sure the proposal message is sent by the proposer // for the round proposal := certificate.ProposalMessage @@ -1104,12 +1129,22 @@ func (i *IBFT) validPC( return false } + // Make sure that the proposal sender is valid + if !i.backend.IsValidSender(proposal) { + return false + } + // Make sure the Prepare messages are validators, apart from the proposer for _, message := range certificate.PrepareMessages { // Make sure the sender is part of the validator set if !i.backend.IsValidSender(message) { return false } + + // Make sure the current node is not the proposer + if i.backend.IsProposer(message.From, message.View.Height, message.View.Round) { + return false + } } return true diff --git a/core/ibft_test.go b/core/ibft_test.go index 87387f0..142eb59 100644 --- a/core/ibft_test.go +++ b/core/ibft_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "math/rand" "sync" "testing" "time" @@ -155,7 +156,7 @@ func generateFilledRCMessages( proposal, proposalHash []byte) []*proto.Message { // Generate random RC messages - roundChangeMessages := generateMessages(quorum, proto.MessageType_ROUND_CHANGE) + roundChangeMessages := generateMessagesWithUniqueSender(quorum, proto.MessageType_ROUND_CHANGE) prepareMessages := generateMessages(quorum-1, proto.MessageType_PREPARE) // Fill up the prepare message hashes @@ -456,6 +457,7 @@ func TestRunNewRound_Proposer(t *testing.T) { } var ( + proposerID = []byte("unique node") multicastedPreprepare *proto.Message = nil multicastedPrepare *proto.Message = nil proposal = []byte("proposal") @@ -472,9 +474,9 @@ func TestRunNewRound_Proposer(t *testing.T) { } }} backend = mockBackend{ - idFn: func() []byte { return nil }, - isProposerFn: func(_ []byte, _ uint64, _ uint64) bool { - return true + idFn: func() []byte { return proposerID }, + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return bytes.Equal(proposerID, proposer) }, hasQuorumFn: defaultHasQuorumFn(quorum), buildProposalFn: func(_ *proto.View) []byte { @@ -1266,9 +1268,9 @@ func TestIBFT_FutureProposal(t *testing.T) { proposer := []byte("proposer") quorum := uint64(4) - generateEmptyRCMessages := func(count uint64) []*proto.Message { + generateEmptyRCMessages := func(count uint64, round uint64) []*proto.Message { // Generate random RC messages - roundChangeMessages := generateMessages(count, proto.MessageType_ROUND_CHANGE) + roundChangeMessages := generateMessagesWithUniqueSender(count, proto.MessageType_ROUND_CHANGE) // Fill up their certificates for _, message := range roundChangeMessages { @@ -1278,11 +1280,20 @@ func TestIBFT_FutureProposal(t *testing.T) { LatestPreparedCertificate: nil, }, } + + message.View.Round = round } return roundChangeMessages } + generateFilledRCMessagesWithRound := func(quorum, round uint64) []*proto.Message { + messages := generateFilledRCMessages(quorum, correctRoundMessage.proposal, correctRoundMessage.hash) + setRoundForMessages(messages, round) + + return messages + } + testTable := []struct { name string proposalView *proto.View @@ -1295,7 +1306,7 @@ func TestIBFT_FutureProposal(t *testing.T) { Height: 0, Round: 1, }, - generateEmptyRCMessages(quorum), + generateEmptyRCMessages(quorum, 1), 1, }, { @@ -1304,11 +1315,7 @@ func TestIBFT_FutureProposal(t *testing.T) { Height: 0, Round: 2, }, - generateFilledRCMessages( - quorum, - correctRoundMessage.proposal, - correctRoundMessage.hash, - ), + generateFilledRCMessagesWithRound(quorum, 2), 2, }, } @@ -1672,6 +1679,49 @@ func TestIBFT_ValidPC(t *testing.T) { assert.False(t, i.validPC(certificate, rLimit, 0)) }) + t.Run("rounds are not the same", func(t *testing.T) { + t.Parallel() + + var ( + quorum = uint64(4) + rLimit = uint64(2) + sender = []byte("unique node") + + log = mockLogger{} + transport = mockTransport{} + backend = mockBackend{ + hasQuorumFn: defaultHasQuorumFn(quorum), + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return !bytes.Equal(proposer, sender) + }, + } + ) + + i := NewIBFT(log, backend, transport) + + proposal := generateMessagesWithSender(1, proto.MessageType_PREPREPARE, sender)[0] + + certificate := &proto.PreparedCertificate{ + ProposalMessage: proposal, + PrepareMessages: generateMessagesWithUniqueSender(quorum-1, proto.MessageType_PREPARE), + } + + // Make sure they all have the same proposal hash + allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) + appendProposalHash( + allMessages, + correctRoundMessage.hash, + ) + + setRoundForMessages(allMessages, rLimit-1) + // Make sure the round is invalid for some random message + randomIndex := rand.Intn(len(certificate.PrepareMessages)) + randomPrepareMessage := certificate.PrepareMessages[randomIndex] + randomPrepareMessage.View.Round = 0 + + assert.False(t, i.validPC(certificate, rLimit, 0)) + }) + t.Run("proposal not from proposer", func(t *testing.T) { t.Parallel() @@ -1754,6 +1804,88 @@ func TestIBFT_ValidPC(t *testing.T) { assert.False(t, i.validPC(certificate, rLimit, 0)) }) + t.Run("proposal is from an invalid sender", func(t *testing.T) { + t.Parallel() + + var ( + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") + + log = mockLogger{} + transport = mockTransport{} + backend = mockBackend{ + hasQuorumFn: defaultHasQuorumFn(quorum), + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return bytes.Equal(proposer, sender) + }, + isValidSenderFn: func(message *proto.Message) bool { + // Proposer is invalid + return !bytes.Equal(message.From, sender) + }, + } + ) + + i := NewIBFT(log, backend, transport) + + proposal := generateMessagesWithSender(1, proto.MessageType_PREPREPARE, sender)[0] + + certificate := &proto.PreparedCertificate{ + ProposalMessage: proposal, + PrepareMessages: generateMessagesWithUniqueSender(quorum-1, proto.MessageType_PREPARE), + } + + // Make sure they all have the same proposal hash + allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) + appendProposalHash( + allMessages, + correctRoundMessage.hash, + ) + + setRoundForMessages(allMessages, rLimit-1) + + assert.False(t, i.validPC(certificate, rLimit, 0)) + }) + + t.Run("prepare from proposer", func(t *testing.T) { + t.Parallel() + + var ( + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") + + log = mockLogger{} + transport = mockTransport{} + backend = mockBackend{ + hasQuorumFn: defaultHasQuorumFn(quorum), + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return true + }, + } + ) + + i := NewIBFT(log, backend, transport) + + proposal := generateMessagesWithSender(1, proto.MessageType_PREPREPARE, sender)[0] + + certificate := &proto.PreparedCertificate{ + ProposalMessage: proposal, + PrepareMessages: generateMessagesWithUniqueSender(quorum-1, proto.MessageType_PREPARE), + } + + // Make sure they all have the same proposal hash + allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) + appendProposalHash( + allMessages, + correctRoundMessage.hash, + ) + + setRoundForMessages(allMessages, rLimit-1) + + assert.False(t, i.validPC(certificate, rLimit, 0)) + }) + t.Run("completely valid PC", func(t *testing.T) { t.Parallel() @@ -1928,6 +2060,53 @@ func TestIBFT_ValidateProposal(t *testing.T) { assert.False(t, i.validateProposal(proposal, baseView)) }) + t.Run("non unique senders", func(t *testing.T) { + t.Parallel() + + var ( + quorum = uint64(4) + self = []byte("node id") + + log = mockLogger{} + transport = mockTransport{} + backend = mockBackend{ + idFn: func() []byte { + return self + }, + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return !bytes.Equal(proposer, self) + }, + } + ) + + i := NewIBFT(log, backend, transport) + + baseView := &proto.View{ + Height: 0, + Round: 0, + } + + // Make sure all rcc are from same node + messages := generateMessages(quorum, proto.MessageType_ROUND_CHANGE) + for _, msg := range messages { + msg.From = []byte("non unique node id") + } + + proposal := &proto.Message{ + View: baseView, + Type: proto.MessageType_PREPREPARE, + Payload: &proto.Message_PreprepareData{ + PreprepareData: &proto.PrePrepareMessage{ + Certificate: &proto.RoundChangeCertificate{ + RoundChangeMessages: messages, + }, + }, + }, + } + + assert.False(t, i.validateProposal(proposal, baseView)) + }) + t.Run("there are < quorum RC messages in the certificate", func(t *testing.T) { t.Parallel() @@ -2073,8 +2252,8 @@ func TestIBFT_WatchForFutureRCC(t *testing.T) { transport = mockTransport{} backend = mockBackend{ hasQuorumFn: defaultHasQuorumFn(quorum), - isProposerFn: func(_ []byte, _ uint64, _ uint64) bool { - return true + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return bytes.Equal(proposer, []byte("unique node")) }, } messages = mockMessages{ diff --git a/messages/helpers.go b/messages/helpers.go index cc11693..d30d621 100644 --- a/messages/helpers.go +++ b/messages/helpers.go @@ -6,6 +6,7 @@ import ( "github.com/0xPolygon/go-ibft/messages/proto" ) +// CommittedSeal Validator proof of signing a committed block type CommittedSeal struct { Signer []byte Signature []byte @@ -139,7 +140,7 @@ func HaveSameProposalHash(messages []*proto.Message) bool { return false } - var hash []byte = nil + var hash []byte for _, message := range messages { var extractedHash []byte @@ -149,6 +150,10 @@ func HaveSameProposalHash(messages []*proto.Message) bool { extractedHash = ExtractProposalHash(message) case proto.MessageType_PREPARE: extractedHash = ExtractPrepareHash(message) + case proto.MessageType_COMMIT: + return false + case proto.MessageType_ROUND_CHANGE: + return false default: return false } @@ -183,6 +188,23 @@ func AllHaveLowerRound(messages []*proto.Message, round uint64) bool { return true } +// AllHaveSameRound checks if all messages have the same round +func AllHaveSameRound(messages []*proto.Message) bool { + if len(messages) < 1 { + return false + } + + var round = messages[0].View.Round + + for _, message := range messages { + if message.View.Round != round { + return false + } + } + + return true +} + // AllHaveSameHeight checks if all messages have the same height func AllHaveSameHeight(messages []*proto.Message, height uint64) bool { if len(messages) < 1 { From cd560c306a0461b35d98c75f8fb722ebcc619bdb Mon Sep 17 00:00:00 2001 From: kourin Date: Mon, 9 Jan 2023 17:41:07 +0900 Subject: [PATCH 09/17] Add unit tests for EventManager to improve MSI --- messages/event_manager.go | 3 +- messages/event_manager_test.go | 206 ++++++++++++++++++++++++++++++--- 2 files changed, 194 insertions(+), 15 deletions(-) diff --git a/messages/event_manager.go b/messages/event_manager.go index 9edba60..d7b2423 100644 --- a/messages/event_manager.go +++ b/messages/event_manager.go @@ -99,8 +99,9 @@ func (em *eventManager) close() { em.subscriptionsLock.Lock() defer em.subscriptionsLock.Unlock() - for _, subscription := range em.subscriptions { + for id, subscription := range em.subscriptions { subscription.close() + delete(em.subscriptions, id) } atomic.StoreInt64(&em.numSubscriptions, 0) diff --git a/messages/event_manager_test.go b/messages/event_manager_test.go index ce0d12a..f7ce2dc 100644 --- a/messages/event_manager_test.go +++ b/messages/event_manager_test.go @@ -1,19 +1,17 @@ package messages import ( + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/0xPolygon/go-ibft/messages/proto" ) -func TestEventManager_SubscribeCancel(t *testing.T) { - t.Parallel() - - numSubscriptions := 10 - subscriptions := make([]*Subscription, numSubscriptions) - baseDetails := SubscriptionDetails{ +var ( + baseDetails = SubscriptionDetails{ MessageType: proto.MessageType_PREPARE, View: &proto.View{ Height: 0, @@ -21,6 +19,194 @@ func TestEventManager_SubscribeCancel(t *testing.T) { }, MinNumMessages: 1, } +) + +func TestEventManager_signalEvent(t *testing.T) { + t.Parallel() + + var ( + baseEventType = baseDetails.MessageType + baseEventView = &proto.View{ + Height: baseDetails.View.Height, + Round: baseDetails.View.Round, + } + ) + + // setupEventManagerAndSubscription creates new eventManager and a subscription + setupEventManagerAndSubscription := func(t *testing.T) (*eventManager, *Subscription) { + t.Helper() + + em := newEventManager() + t.Cleanup(func() { + em.close() + }) + + subscription := em.subscribe(baseDetails) + t.Cleanup(func() { + em.cancelSubscription(subscription.ID) + }) + + return em, subscription + } + + // emitEvent sends a event to eventManager and close doneCh after signalEvent completes + emitEvent := func( + t *testing.T, + em *eventManager, + eventType proto.MessageType, + eventView *proto.View, + ) <-chan struct{} { + doneCh := make(chan struct{}) + + go func() { + t.Helper() + defer close(doneCh) + + em.signalEvent( + eventType, + eventView, + ) + }() + + return doneCh + } + + // testSubscriptionData checks the data sent to subscription + testSubscriptionData := func( + t *testing.T, + sub *Subscription, + expectedSignals []uint64, + ) <-chan struct{} { + doneCh := make(chan struct{}) + + go func() { + t.Helper() + defer close(doneCh) + + actualSignals := make([]uint64, 0) + for sig := range sub.SubCh { + actualSignals = append(actualSignals, sig) + } + + assert.Equal(t, expectedSignals, actualSignals) + }() + + return doneCh + } + + // closeSubscription closes subscription manually + // because cancelSubscription might be unable + // due to mutex locking during tests + closeSubscription := func( + t *testing.T, + em *eventManager, + sub *Subscription, + ) { + t.Helper() + + close(em.subscriptions[sub.ID].doneCh) + delete(em.subscriptions, sub.ID) + } + + t.Run("should exit before locking subscriptionsLock if numSubscriptions is zero", func(t *testing.T) { + t.Parallel() + + em, sub := setupEventManagerAndSubscription(t) + + // overwrite numSubscription + atomic.StoreInt64(&em.numSubscriptions, 0) + // shouldn't be locked by mutex thanks to early return + em.subscriptionsLock.Lock() + t.Cleanup(func() { + em.subscriptionsLock.Unlock() + }) + + doneEmitCh := emitEvent(t, em, baseEventType, baseEventView) + doneTestSubCh := testSubscriptionData(t, sub, []uint64{}) + + // should exit by early return + select { + case <-doneEmitCh: + case <-time.After(5 * time.Second): + t.Errorf("signalEvent shouldn't be lock, but it was locked") + } + + closeSubscription(t, em, sub) + <-doneTestSubCh + }) + + t.Run("should be locked by other write lock", func(t *testing.T) { + t.Parallel() + + em, sub := setupEventManagerAndSubscription(t) + + // should be locked by other write lock + em.subscriptionsLock.Lock() + t.Cleanup(func() { + em.subscriptionsLock.Unlock() + }) + + doneCh := emitEvent(t, em, baseEventType, baseEventView) + doneTestSubCh := testSubscriptionData(t, sub, []uint64{}) + + select { + case <-doneCh: + t.Errorf("signalEvent is not locked") + case <-time.After(5 * time.Second): + } + + closeSubscription(t, em, sub) + <-doneTestSubCh + }) + + t.Run("should not be locked by other read lock", func(t *testing.T) { + t.Parallel() + + em, sub := setupEventManagerAndSubscription(t) + + // shouldn't be locked by mutex of read-lock + em.subscriptionsLock.RLock() + t.Cleanup(func() { + em.subscriptionsLock.RUnlock() + }) + + doneCh := emitEvent(t, em, baseEventType, baseEventView) + doneTestSubCh := testSubscriptionData(t, sub, []uint64{0}) + + select { + case <-doneCh: + return + case <-time.After(5 * time.Second): + t.Errorf("signalEvent is locked") + } + + <-doneTestSubCh + }) + + t.Run("should not notify if the event is different the one expected by subscription", func(t *testing.T) { + t.Parallel() + + em, sub := setupEventManagerAndSubscription(t) + + doneCh := emitEvent(t, em, proto.MessageType_COMMIT, baseEventView) + doneTestSubCh := testSubscriptionData(t, sub, []uint64{}) + + select { + case <-doneCh: + return + case <-time.After(5 * time.Second): + t.Errorf("signalEvent is locked") + } + + <-doneTestSubCh + }) +} + +func TestEventManager_SubscribeCancel(t *testing.T) { + t.Parallel() + + numSubscriptions := 10 + subscriptions := make([]*Subscription, numSubscriptions) IDMap := make(map[SubscriptionID]bool) @@ -73,14 +259,6 @@ func TestEventManager_SubscribeClose(t *testing.T) { numSubscriptions := 10 subscriptions := make([]*Subscription, numSubscriptions) - baseDetails := SubscriptionDetails{ - MessageType: proto.MessageType_PREPARE, - View: &proto.View{ - Height: 0, - Round: 0, - }, - MinNumMessages: 1, - } em := newEventManager() From 1b12c7f278f75f2cc9329af77676948c551c9a28 Mon Sep 17 00:00:00 2001 From: kourin Date: Mon, 9 Jan 2023 17:41:20 +0900 Subject: [PATCH 10/17] Add unit tests for message helper to improve MSI --- messages/helpers_test.go | 203 +++++++++++++++++++++++++++++++++------ 1 file changed, 174 insertions(+), 29 deletions(-) diff --git a/messages/helpers_test.go b/messages/helpers_test.go index 979cc1c..527df3e 100644 --- a/messages/helpers_test.go +++ b/messages/helpers_test.go @@ -11,37 +11,52 @@ import ( func TestMessages_ExtractCommittedSeals(t *testing.T) { t.Parallel() - signer := []byte("signer") - committedSeal := []byte("committed seal") - - commitMessage := &proto.Message{ - Type: proto.MessageType_COMMIT, - Payload: &proto.Message_CommitData{ - CommitData: &proto.CommitMessage{ - CommittedSeal: committedSeal, + newCommitMessage := func(from, committedSeal []byte) *proto.Message { + return &proto.Message{ + Type: proto.MessageType_COMMIT, + Payload: &proto.Message_CommitData{ + CommitData: &proto.CommitMessage{ + CommittedSeal: committedSeal, + }, }, - }, - From: signer, + From: from, + } } - invalidMessage := &proto.Message{ - Type: proto.MessageType_PREPARE, + + newInvalidMessage := func() *proto.Message { + return &proto.Message{ + Type: proto.MessageType_PREPARE, + } } - seals := ExtractCommittedSeals([]*proto.Message{ - commitMessage, - invalidMessage, - }) + var ( + signer1 = []byte("signer 1") + committedSeal1 = []byte("committed seal 1") - if len(seals) != 1 { - t.Fatalf("Seals not extracted") - } + signer2 = []byte("signer 2") + committedSeal2 = []byte("committed seal 2") + ) - expected := &CommittedSeal{ - Signer: signer, - Signature: committedSeal, - } + seals := ExtractCommittedSeals([]*proto.Message{ + newCommitMessage(signer1, committedSeal1), + newInvalidMessage(), + newCommitMessage(signer2, committedSeal2), + }) - assert.Equal(t, expected, seals[0]) + assert.Equal( + t, + []*CommittedSeal{ + { + Signer: signer1, + Signature: committedSeal1, + }, + { + Signer: signer2, + Signature: committedSeal2, + }, + }, + seals, + ) } func TestMessages_ExtractCommitHash(t *testing.T) { @@ -384,6 +399,15 @@ func TestMessages_HasUniqueSenders(t *testing.T) { nil, false, }, + { + "only one message", + []*proto.Message{ + { + From: []byte("node 1"), + }, + }, + true, + }, { "non unique senders", []*proto.Message{ @@ -440,9 +464,31 @@ func TestMessages_HaveSameProposalHash(t *testing.T) { nil, false, }, + { + "only one message", + []*proto.Message{ + { + Type: proto.MessageType_PREPARE, + Payload: &proto.Message_PrepareData{ + PrepareData: &proto.PrepareMessage{ + ProposalHash: []byte("hash"), + }, + }, + }, + }, + true, + }, { "invalid message type", []*proto.Message{ + { + Type: proto.MessageType_PREPARE, + Payload: &proto.Message_PrepareData{ + PrepareData: &proto.PrepareMessage{ + ProposalHash: []byte("differing hash"), + }, + }, + }, { Type: proto.MessageType_ROUND_CHANGE, }, @@ -471,6 +517,20 @@ func TestMessages_HaveSameProposalHash(t *testing.T) { }, false, }, + { + "only one message", + []*proto.Message{ + { + Type: proto.MessageType_PREPREPARE, + Payload: &proto.Message_PreprepareData{ + PreprepareData: &proto.PrePrepareMessage{ + ProposalHash: proposalHash, + }, + }, + }, + }, + true, + }, { "hash match", []*proto.Message{ @@ -528,7 +588,20 @@ func TestMessages_AllHaveLowerRond(t *testing.T) { false, }, { - "not same lower round", + "true if message's round is less than threshold", + []*proto.Message{ + { + View: &proto.View{ + Height: 0, + Round: round - 1, + }, + }, + }, + round, + true, + }, + { + "false if message's round equals to threshold", []*proto.Message{ { View: &proto.View{ @@ -536,6 +609,13 @@ func TestMessages_AllHaveLowerRond(t *testing.T) { Round: round, }, }, + }, + round, + false, + }, + { + "false if message's round is bigger than threshold", + []*proto.Message{ { View: &proto.View{ Height: 0, @@ -546,6 +626,25 @@ func TestMessages_AllHaveLowerRond(t *testing.T) { round, false, }, + { + "some of messages are not higher round", + []*proto.Message{ + { + View: &proto.View{ + Height: 0, + Round: round + 1, + }, + }, + { + View: &proto.View{ + Height: 0, + Round: round, + }, + }, + }, + round, + false, + }, { "same higher round", []*proto.Message{ @@ -566,20 +665,33 @@ func TestMessages_AllHaveLowerRond(t *testing.T) { false, }, { - "lower round match", + "1 message is lower round", []*proto.Message{ { View: &proto.View{ - Height: 0, + Height: 1, Round: round, }, }, + }, + 2, + true, + }, + { + "all of messages is lower round", + []*proto.Message{ { View: &proto.View{ Height: 0, Round: round, }, }, + { + View: &proto.View{ + Height: 1, + Round: round, + }, + }, }, 2, true, @@ -620,7 +732,40 @@ func TestMessages_AllHaveSameHeight(t *testing.T) { false, }, { - "not same height", + "false if message's height is less than the given height", + []*proto.Message{ + { + View: &proto.View{ + Height: height - 1, + }, + }, + }, + false, + }, + { + "true if message's height equals to the given height", + []*proto.Message{ + { + View: &proto.View{ + Height: height, + }, + }, + }, + true, + }, + { + "false if message's height is bigger than the given height", + []*proto.Message{ + { + View: &proto.View{ + Height: height + 1, + }, + }, + }, + false, + }, + { + "some of messages' heights are not same to the given height", []*proto.Message{ { View: &proto.View{ @@ -636,7 +781,7 @@ func TestMessages_AllHaveSameHeight(t *testing.T) { false, }, { - "same height", + "all of messages' heights is same to the given height", []*proto.Message{ { View: &proto.View{ From d6431af3acc5a83ab81f51624d02781cf3ce0ac0 Mon Sep 17 00:00:00 2001 From: Vuk Gavrilovic <114920311+trimixlover@users.noreply.github.com> Date: Mon, 9 Jan 2023 09:41:33 +0100 Subject: [PATCH 11/17] Byzantine tests (#56) * Byzantine tests --- core/byzantine_test.go | 459 +++++++++++++++++++++++++++++++++++++++++ core/helpers_test.go | 13 +- 2 files changed, 469 insertions(+), 3 deletions(-) create mode 100644 core/byzantine_test.go diff --git a/core/byzantine_test.go b/core/byzantine_test.go new file mode 100644 index 0000000..179b23f --- /dev/null +++ b/core/byzantine_test.go @@ -0,0 +1,459 @@ +package core + +import ( + "bytes" + "github.com/0xPolygon/go-ibft/messages/proto" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestByzantineBehaviour(t *testing.T) { + t.Parallel() + + t.Run("malicious hash in proposal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrePrepareMessageFn(createBadHashPrePrepareMessageFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(20*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious hash in prepare", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(c.isProposer) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrepareMessageFn(createBadHashPrepareMessageFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(10*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(10*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in proposal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrePrepareMessageFn(createBadRoundPrePrepareMessageFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + // Max tolerant byzantine + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(40*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc and in proposal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrePrepareMessageFn(createBadRoundPrePrepareMessageFn(currentNode)) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc and bad hash in proposal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrePrepareMessageFn(createBadHashPrePrepareMessageFn(currentNode)) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc and bad hash in prepare", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrepareMessageFn(createBadHashPrepareMessageFn(currentNode)) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc and bad commit seal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildCommitMessageFn(createBadCommitMessageFn(currentNode)) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) +} + +func createBadRoundRoundChangeFn(node *node) buildRoundChangeMessageDelegate { + return func(proposal []byte, + rcc *proto.PreparedCertificate, + view *proto.View) *proto.Message { + if node.byzantine { + view.Round++ + } + + return buildBasicRoundChangeMessage( + proposal, + rcc, + view, + node.address, + ) + } +} + +func createBadRoundPrePrepareMessageFn(node *node) buildPrePrepareMessageDelegate { + return func( + proposal []byte, + certificate *proto.RoundChangeCertificate, + view *proto.View, + ) *proto.Message { + if node.byzantine { + view.Round++ + } + + return buildBasicPreprepareMessage( + proposal, + validProposalHash, + certificate, + node.address, + view, + ) + } +} + +func createBadHashPrePrepareMessageFn(node *node) buildPrePrepareMessageDelegate { + return func(proposal []byte, + rcc *proto.RoundChangeCertificate, + view *proto.View) *proto.Message { + proposalHash := validProposalHash + if node.byzantine { + proposalHash = []byte("invalid proposal hash") + } + + return buildBasicPreprepareMessage( + proposal, + proposalHash, + rcc, + node.address, + view, + ) + } +} + +func createBadHashPrepareMessageFn(node *node) buildPrepareMessageDelegate { + return func(_ []byte, view *proto.View) *proto.Message { + proposalHash := validProposalHash + if node.byzantine { + proposalHash = []byte("invalid proposal hash") + } + + return buildBasicPrepareMessage( + proposalHash, + node.address, + view, + ) + } +} + +func createForcedRCProposerFn(c *cluster) isProposerDelegate { + return func(from []byte, height uint64, round uint64) bool { + if round == 0 { + return false + } + + return bytes.Equal( + from, + c.addresses()[int(round)%len(c.addresses())], + ) + } +} + +func createBadCommitMessageFn(node *node) buildCommitMessageDelegate { + return func(_ []byte, view *proto.View) *proto.Message { + committedSeal := validCommittedSeal + if node.byzantine { + committedSeal = []byte("invalid committed seal") + } + + return buildBasicCommitMessage( + validProposalHash, + committedSeal, + node.address, + view, + ) + } +} + +type mockBackendBuilder struct { + isProposerFn isProposerDelegate + + idFn idDelegate + + buildPrePrepareMessageFn buildPrePrepareMessageDelegate + buildPrepareMessageFn buildPrepareMessageDelegate + buildCommitMessageFn buildCommitMessageDelegate + buildRoundChangeMessageFn buildRoundChangeMessageDelegate + + hasQuorumFn hasQuorumDelegate +} + +func (b *mockBackendBuilder) withProposerFn(f isProposerDelegate) { + b.isProposerFn = f +} + +func (b *mockBackendBuilder) withBuildPrePrepareMessageFn(f buildPrePrepareMessageDelegate) { + b.buildPrePrepareMessageFn = f +} + +func (b *mockBackendBuilder) withBuildPrepareMessageFn(f buildPrepareMessageDelegate) { + b.buildPrepareMessageFn = f +} + +func (b *mockBackendBuilder) withBuildCommitMessageFn(f buildCommitMessageDelegate) { + b.buildCommitMessageFn = f +} + +func (b *mockBackendBuilder) withBuildRoundChangeMessageFn(f buildRoundChangeMessageDelegate) { + b.buildRoundChangeMessageFn = f +} + +func (b *mockBackendBuilder) withIDFn(f idDelegate) { + b.idFn = f +} + +func (b *mockBackendBuilder) withHasQuorumFn(f hasQuorumDelegate) { + b.hasQuorumFn = f +} + +func (b *mockBackendBuilder) build(node *node) *mockBackend { + if b.buildPrePrepareMessageFn == nil { + b.buildPrePrepareMessageFn = node.buildPrePrepare + } + + if b.buildPrepareMessageFn == nil { + b.buildPrepareMessageFn = node.buildPrepare + } + + if b.buildCommitMessageFn == nil { + b.buildCommitMessageFn = node.buildCommit + } + + if b.buildRoundChangeMessageFn == nil { + b.buildRoundChangeMessageFn = node.buildRoundChange + } + + return &mockBackend{ + isValidBlockFn: isValidProposal, + isValidProposalHashFn: isValidProposalHash, + isValidSenderFn: nil, + isValidCommittedSealFn: nil, + isProposerFn: b.isProposerFn, + idFn: b.idFn, + + buildProposalFn: buildValidProposal, + buildPrePrepareMessageFn: b.buildPrePrepareMessageFn, + buildPrepareMessageFn: b.buildPrepareMessageFn, + buildCommitMessageFn: b.buildCommitMessageFn, + buildRoundChangeMessageFn: b.buildRoundChangeMessageFn, + insertBlockFn: nil, + hasQuorumFn: b.hasQuorumFn, + } +} diff --git a/core/helpers_test.go b/core/helpers_test.go index 14de1b3..1d4c6a0 100644 --- a/core/helpers_test.go +++ b/core/helpers_test.go @@ -37,9 +37,10 @@ func isValidProposalHash(proposal, proposalHash []byte) bool { } type node struct { - core *IBFT - address []byte - offline bool + core *IBFT + address []byte + offline bool + byzantine bool } func (n *node) addr() []byte { @@ -217,6 +218,12 @@ func (c *cluster) maxFaulty() uint64 { return (uint64(len(c.nodes)) - 1) / 3 } +func (c *cluster) makeNByzantine(num int) { + for i := 0; i < num; i++ { + c.nodes[i].byzantine = true + } +} + func (c *cluster) stopN(num int) { for i := 0; i < num; i++ { c.nodes[i].offline = true From 21dab925111b129fc9f9f6bc887202a913748f38 Mon Sep 17 00:00:00 2001 From: kourin Date: Mon, 9 Jan 2023 17:41:38 +0900 Subject: [PATCH 12/17] Add unit tests for Messages to improve MSI --- messages/messages_test.go | 493 +++++++++++++++++++++++++++++++++----- 1 file changed, 430 insertions(+), 63 deletions(-) diff --git a/messages/messages_test.go b/messages/messages_test.go index 54b6e85..b9408aa 100644 --- a/messages/messages_test.go +++ b/messages/messages_test.go @@ -131,20 +131,39 @@ func TestMessages_AddDuplicates(t *testing.T) { func TestMessages_Prune(t *testing.T) { t.Parallel() - numMessages := 5 - messageType := proto.MessageType_PREPARE + var ( + numMessages = 5 + messageType = proto.MessageType_PREPARE + + height uint64 = 2 + ) + messages := NewMessages() t.Cleanup(func() { messages.Close() }) - views := make([]*proto.View, 0) - for index := uint64(1); index <= 3; index++ { - views = append(views, &proto.View{ - Height: 1, - Round: index, - }) + views := []*proto.View{ + { + Height: height - 1, + Round: 1, + }, + { + Height: height, + Round: 2, + }, + { + Height: height + 1, + Round: 3, + }, + } + + // expected number of message for each view after pruning + expectedNumMessages := []int{ + 0, + numMessages, + numMessages, } // Append random message types @@ -165,19 +184,22 @@ func TestMessages_Prune(t *testing.T) { } // Prune out the messages from this view - messages.PruneByHeight(views[1].Height + 1) - - // Make sure the round 1 messages are pruned out - assert.Equal(t, 0, messages.numMessages(views[0], messageType)) - - // Make sure the round 2 messages are pruned out - assert.Equal(t, 0, messages.numMessages(views[1], messageType)) - - // Make sure the round 3 messages are pruned out - assert.Equal(t, 0, messages.numMessages(views[2], messageType)) + messages.PruneByHeight(height) + + // check numbers of messages + for idx, expected := range expectedNumMessages { + assert.Equal( + t, + expected, + messages.numMessages( + views[idx], + messageType, + ), + ) + } } -// TestMessages_GetMessage makes sure +// TestMessages_GetValidMessagesMessage_InvalidMessages makes sure // that messages are fetched correctly for the // corresponding message type func TestMessages_GetValidMessagesMessage(t *testing.T) { @@ -188,7 +210,9 @@ func TestMessages_GetValidMessagesMessage(t *testing.T) { Height: 1, Round: 0, } - numMessages = 5 + + numMessages = 10 + numValidMessages = 5 ) testTable := []struct { @@ -213,8 +237,14 @@ func TestMessages_GetValidMessagesMessage(t *testing.T) { }, } - alwaysInvalidFn := func(_ *proto.Message) bool { - return false + newIsValid := func(numValidMessages int) func(_ *proto.Message) bool { + calls := 0 + + return func(_ *proto.Message) bool { + calls++ + + return calls <= numValidMessages + } } for _, testCase := range testTable { @@ -247,20 +277,23 @@ func TestMessages_GetValidMessagesMessage(t *testing.T) { ) // Start fetching messages and making sure they're not cleared - switch testCase.messageType { - case proto.MessageType_PREPREPARE: - messages.GetValidMessages(defaultView, proto.MessageType_PREPREPARE, alwaysInvalidFn) - case proto.MessageType_PREPARE: - messages.GetValidMessages(defaultView, proto.MessageType_PREPARE, alwaysInvalidFn) - case proto.MessageType_COMMIT: - messages.GetValidMessages(defaultView, proto.MessageType_COMMIT, alwaysInvalidFn) - case proto.MessageType_ROUND_CHANGE: - messages.GetValidMessages(defaultView, proto.MessageType_ROUND_CHANGE, alwaysInvalidFn) - } + validMessages := messages.GetValidMessages( + defaultView, + testCase.messageType, + newIsValid(numValidMessages), + ) + // make sure only valid messages are returned + assert.Len( + t, + validMessages, + numValidMessages, + ) + + // make sure invalid messages are pruned assert.Equal( t, - 0, + numMessages-numValidMessages, messages.numMessages(defaultView, testCase.messageType), ) }) @@ -273,42 +306,129 @@ func TestMessages_GetValidMessagesMessage(t *testing.T) { func TestMessages_GetMostRoundChangeMessages(t *testing.T) { t.Parallel() - messages := NewMessages() - defer messages.Close() + tests := []struct { + name string + messages [][]*proto.Message + minRound uint64 + height uint64 + expectedNum int + expectedRound uint64 + }{ + { + name: "should return nil if not found", + messages: [][]*proto.Message{ + generateRandomMessages(3, &proto.View{ + Height: 0, + Round: 1, // smaller than minRound + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 2, + height: 0, + expectedNum: 0, + }, + { + name: "should return round change messages if messages' round is greater than/equal to minRound", + messages: [][]*proto.Message{ + generateRandomMessages(1, &proto.View{ + Height: 0, + Round: 2, + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 1, + height: 0, + expectedNum: 1, + expectedRound: 2, + }, + { + name: "should return most round change messages (the round is equals to minRound)", + messages: [][]*proto.Message{ + generateRandomMessages(1, &proto.View{ + Height: 0, + Round: 4, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(2, &proto.View{ + Height: 0, + Round: 2, + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 2, + height: 0, + expectedNum: 2, + expectedRound: 2, + }, + { + name: "should return most round change messages (the round is bigger than minRound)", + messages: [][]*proto.Message{ + generateRandomMessages(3, &proto.View{ + Height: 0, + Round: 1, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(2, &proto.View{ + Height: 0, + Round: 3, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(1, &proto.View{ + Height: 0, + Round: 4, + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 2, + height: 0, + expectedNum: 2, + expectedRound: 3, + }, + { + name: "should return the first of most round change messages", + messages: [][]*proto.Message{ + generateRandomMessages(3, &proto.View{ + Height: 0, + Round: 1, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(2, &proto.View{ + Height: 0, + Round: 4, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(2, &proto.View{ + Height: 0, + Round: 3, + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 2, + height: 0, + expectedNum: 2, + expectedRound: 4, + }, + } - mostMessageCount := 3 - mostMessagesRound := uint64(2) + for _, test := range tests { + test := test - // Generate round messages - randomMessages := map[uint64][]*proto.Message{ - 0: generateRandomMessages(mostMessageCount-2, &proto.View{ - Height: 0, - Round: 0, - }, proto.MessageType_ROUND_CHANGE), - 1: generateRandomMessages(mostMessageCount-1, &proto.View{ - Height: 0, - Round: 1, - }, proto.MessageType_ROUND_CHANGE), - mostMessagesRound: generateRandomMessages(mostMessageCount, &proto.View{ - Height: 0, - Round: mostMessagesRound, - }, proto.MessageType_ROUND_CHANGE), - } + t.Run(test.name, func(t *testing.T) { + t.Parallel() - // Add the messages - for _, roundMessages := range randomMessages { - for _, message := range roundMessages { - messages.AddMessage(message) - } - } + messages := NewMessages() + defer messages.Close() - roundChangeMessages := messages.GetMostRoundChangeMessages(0, 0) + // Add the messages + for _, roundMessages := range test.messages { + for _, message := range roundMessages { + messages.AddMessage(message) + } + } - if len(roundChangeMessages) != mostMessageCount { - t.Fatalf("Invalid number of round change messages, %d", len(roundChangeMessages)) - } + roundChangeMessages := messages.GetMostRoundChangeMessages(test.minRound, test.height) - assert.Equal(t, mostMessagesRound, roundChangeMessages[0].View.Round) + if test.expectedNum == 0 { + assert.Nil(t, roundChangeMessages, "should be nil but not nil") + } else { + assert.Len(t, roundChangeMessages, test.expectedNum, "invalid number of round change messages") + } + + for _, msg := range roundChangeMessages { + assert.Equal(t, test.expectedRound, msg.View.Round) + } + }) + } } // TestMessages_EventManager checks that the event manager @@ -352,3 +472,250 @@ func TestMessages_EventManager(t *testing.T) { // Make sure the number of messages is actually accurate assert.Equal(t, numMessages, messages.numMessages(baseView, messageType)) } + +// TestMessages_Unsubscribe checks Messages calls eventManager.cancelSubscription +// in Unsubscribe method +func TestMessages_Unsubscribe(t *testing.T) { + t.Parallel() + + messages := NewMessages() + defer messages.Close() + + numMessages := 10 + messageType := proto.MessageType_PREPARE + baseView := &proto.View{ + Height: 0, + Round: 0, + } + + // Create the subscription + subscription := messages.Subscribe(SubscriptionDetails{ + MessageType: messageType, + View: baseView, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + assert.Equal(t, int64(1), messages.eventManager.numSubscriptions) + + messages.Unsubscribe(subscription.ID) + + assert.Equal(t, int64(0), messages.eventManager.numSubscriptions) +} + +// TestMessages_Unsubscribe checks Messages calls eventManager.close +// in Close method +func TestMessages_Close(t *testing.T) { + t.Parallel() + + messages := NewMessages() + defer messages.Close() + + numMessages := 10 + baseView := &proto.View{ + Height: 0, + Round: 0, + } + + // Create 2 subscriptions + _ = messages.Subscribe(SubscriptionDetails{ + MessageType: proto.MessageType_PREPARE, + View: baseView, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + _ = messages.Subscribe(SubscriptionDetails{ + MessageType: proto.MessageType_COMMIT, + View: baseView, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + assert.Equal(t, int64(2), messages.eventManager.numSubscriptions) + + messages.Close() + + assert.Equal(t, int64(0), messages.eventManager.numSubscriptions) +} + +func TestMessages_getProtoMessage(t *testing.T) { + t.Parallel() + + messages := NewMessages() + defer messages.Close() + + var ( + numMessages = 10 + messageType = proto.MessageType_COMMIT + view = &proto.View{ + Height: 0, + Round: 0, + } + ) + + // Create the subscription + subscription := messages.Subscribe(SubscriptionDetails{ + MessageType: messageType, + View: view, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + defer messages.Unsubscribe(subscription.ID) + + // Push random messages + generatedMessages := generateRandomMessages(numMessages, view, messageType) + messageMap := map[string]*proto.Message{} + + for _, message := range generatedMessages { + messages.AddMessage(message) + messageMap[string(message.From)] = message + } + + // Wait for the subscription event to happen + select { + case <-subscription.SubCh: + case <-time.After(5 * time.Second): + } + + tests := []struct { + name string + view *proto.View + messageType proto.MessageType + expected protoMessages + }{ + { + name: "should return messages for same view and type", + view: view, + messageType: messageType, + expected: messageMap, + }, + { + name: "should return nil for different type", + view: view, + messageType: proto.MessageType_PREPARE, + expected: nil, + }, + { + name: "should return nil for same type and round but different height", + view: &proto.View{ + Height: view.Height + 1, + Round: view.Round, + }, + messageType: messageType, + expected: nil, + }, + { + name: "should return nil for same type and height but different round", + view: &proto.View{ + Height: view.Height, + Round: view.Round + 1, + }, + messageType: messageType, + expected: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + assert.Equal( + t, + test.expected, + messages.getProtoMessages(test.view, test.messageType), + ) + }) + } +} + +func TestMessages_numMessages(t *testing.T) { + t.Parallel() + + messages := NewMessages() + defer messages.Close() + + var ( + numMessages = 10 + messageType = proto.MessageType_COMMIT + view = &proto.View{ + Height: 3, + Round: 5, + } + ) + + // Create the subscription + subscription := messages.Subscribe(SubscriptionDetails{ + MessageType: messageType, + View: view, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + defer messages.Unsubscribe(subscription.ID) + + // Push random messages + for _, message := range generateRandomMessages(numMessages, view, messageType) { + messages.AddMessage(message) + } + + // Wait for the subscription event to happen + select { + case <-subscription.SubCh: + case <-time.After(5 * time.Second): + } + + tests := []struct { + name string + view *proto.View + messageType proto.MessageType + expected int + }{ + { + name: "should return number of messages", + view: view, + messageType: messageType, + expected: numMessages, + }, + { + name: "should return zero if message type is different", + view: view, + messageType: proto.MessageType_PREPARE, + expected: 0, + }, + { + name: "should return zero if height is different", + view: &proto.View{ + Height: 1, + Round: view.Round, + }, + messageType: messageType, + expected: 0, + }, + { + name: "should return zero if round is different", + view: &proto.View{ + Height: view.Height, + Round: 1, + }, + messageType: messageType, + expected: 0, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expected, messages.numMessages(test.view, test.messageType)) + }) + } +} From 3fdb385d65d2389ed3fc30ec06838bdd1584a754 Mon Sep 17 00:00:00 2001 From: kourin Date: Tue, 10 Jan 2023 00:23:23 +0900 Subject: [PATCH 13/17] Fix lint error --- messages/event_manager_test.go | 46 ++++++++++++++++++++++++++-------- messages/messages_test.go | 2 ++ 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/messages/event_manager_test.go b/messages/event_manager_test.go index f7ce2dc..6695100 100644 --- a/messages/event_manager_test.go +++ b/messages/event_manager_test.go @@ -10,21 +10,19 @@ import ( "github.com/0xPolygon/go-ibft/messages/proto" ) -var ( - baseDetails = SubscriptionDetails{ - MessageType: proto.MessageType_PREPARE, - View: &proto.View{ - Height: 0, - Round: 0, - }, - MinNumMessages: 1, - } -) - func TestEventManager_signalEvent(t *testing.T) { t.Parallel() var ( + baseDetails = SubscriptionDetails{ + MessageType: proto.MessageType_PREPARE, + View: &proto.View{ + Height: 0, + Round: 0, + }, + MinNumMessages: 1, + } + baseEventType = baseDetails.MessageType baseEventView = &proto.View{ Height: baseDetails.View.Height, @@ -37,11 +35,13 @@ func TestEventManager_signalEvent(t *testing.T) { t.Helper() em := newEventManager() + t.Cleanup(func() { em.close() }) subscription := em.subscribe(baseDetails) + t.Cleanup(func() { em.cancelSubscription(subscription.ID) }) @@ -56,10 +56,13 @@ func TestEventManager_signalEvent(t *testing.T) { eventType proto.MessageType, eventView *proto.View, ) <-chan struct{} { + t.Helper() + doneCh := make(chan struct{}) go func() { t.Helper() + defer close(doneCh) em.signalEvent( @@ -77,10 +80,13 @@ func TestEventManager_signalEvent(t *testing.T) { sub *Subscription, expectedSignals []uint64, ) <-chan struct{} { + t.Helper() + doneCh := make(chan struct{}) go func() { t.Helper() + defer close(doneCh) actualSignals := make([]uint64, 0) @@ -205,6 +211,15 @@ func TestEventManager_signalEvent(t *testing.T) { func TestEventManager_SubscribeCancel(t *testing.T) { t.Parallel() + baseDetails := SubscriptionDetails{ + MessageType: proto.MessageType_PREPARE, + View: &proto.View{ + Height: 0, + Round: 0, + }, + MinNumMessages: 1, + } + numSubscriptions := 10 subscriptions := make([]*Subscription, numSubscriptions) @@ -257,6 +272,15 @@ func TestEventManager_SubscribeCancel(t *testing.T) { func TestEventManager_SubscribeClose(t *testing.T) { t.Parallel() + baseDetails := SubscriptionDetails{ + MessageType: proto.MessageType_PREPARE, + View: &proto.View{ + Height: 0, + Round: 0, + }, + MinNumMessages: 1, + } + numSubscriptions := 10 subscriptions := make([]*Subscription, numSubscriptions) diff --git a/messages/messages_test.go b/messages/messages_test.go index b9408aa..d10c386 100644 --- a/messages/messages_test.go +++ b/messages/messages_test.go @@ -715,6 +715,8 @@ func TestMessages_numMessages(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, test.expected, messages.numMessages(test.view, test.messageType)) }) } From bf899c83fa8e3f81ef4cfca35694f79a7399cf07 Mon Sep 17 00:00:00 2001 From: kourin Date: Tue, 10 Jan 2023 17:49:19 +0900 Subject: [PATCH 14/17] fix lint errors only for the codes that changed in the PR --- messages/event_manager.go | 3 ++- messages/event_manager_test.go | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/messages/event_manager.go b/messages/event_manager.go index d7b2423..7107c0f 100644 --- a/messages/event_manager.go +++ b/messages/event_manager.go @@ -4,8 +4,9 @@ import ( "sync" "sync/atomic" - "github.com/0xPolygon/go-ibft/messages/proto" "github.com/google/uuid" + + "github.com/0xPolygon/go-ibft/messages/proto" ) type eventManager struct { diff --git a/messages/event_manager_test.go b/messages/event_manager_test.go index 6695100..424c9ff 100644 --- a/messages/event_manager_test.go +++ b/messages/event_manager_test.go @@ -223,7 +223,7 @@ func TestEventManager_SubscribeCancel(t *testing.T) { numSubscriptions := 10 subscriptions := make([]*Subscription, numSubscriptions) - IDMap := make(map[SubscriptionID]bool) + idMap := make(map[SubscriptionID]bool) em := newEventManager() defer em.close() @@ -236,10 +236,10 @@ func TestEventManager_SubscribeCancel(t *testing.T) { assert.Equal(t, int64(i+1), em.numSubscriptions) // Check if a duplicate ID has been issued - if _, ok := IDMap[subscriptions[i].ID]; ok { + if _, ok := idMap[subscriptions[i].ID]; ok { t.Fatalf("Duplicate ID entry") } else { - IDMap[subscriptions[i].ID] = true + idMap[subscriptions[i].ID] = true } } From dda3271ae14a2b76aba3ed5d96d849dd4f5260d0 Mon Sep 17 00:00:00 2001 From: kourin Date: Tue, 10 Jan 2023 21:39:50 +0900 Subject: [PATCH 15/17] fixed some stuck test --- messages/event_manager_test.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/messages/event_manager_test.go b/messages/event_manager_test.go index 424c9ff..909c326 100644 --- a/messages/event_manager_test.go +++ b/messages/event_manager_test.go @@ -135,6 +135,8 @@ func TestEventManager_signalEvent(t *testing.T) { case <-doneEmitCh: case <-time.After(5 * time.Second): t.Errorf("signalEvent shouldn't be lock, but it was locked") + + return } closeSubscription(t, em, sub) @@ -158,6 +160,8 @@ func TestEventManager_signalEvent(t *testing.T) { select { case <-doneCh: t.Errorf("signalEvent is not locked") + + return case <-time.After(5 * time.Second): } @@ -181,11 +185,13 @@ func TestEventManager_signalEvent(t *testing.T) { select { case <-doneCh: - return case <-time.After(5 * time.Second): t.Errorf("signalEvent is locked") + + return } + closeSubscription(t, em, sub) <-doneTestSubCh }) @@ -199,11 +205,13 @@ func TestEventManager_signalEvent(t *testing.T) { select { case <-doneCh: - return case <-time.After(5 * time.Second): t.Errorf("signalEvent is locked") + + return } + closeSubscription(t, em, sub) <-doneTestSubCh }) } From 415633e8c011bf5481448ecdf9f6b31ee1118a62 Mon Sep 17 00:00:00 2001 From: kourin Date: Tue, 17 Jan 2023 14:08:36 +0200 Subject: [PATCH 16/17] Revert disabling function-length check in golangci --- .golangci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index ad47552..15614c9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -308,7 +308,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#function-length - name: function-length severity: warning - disabled: true + disabled: false arguments: [ 10, 0 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#get-return - name: get-return From 97bd746ae578d6d132ad00ffa9a508bdf9bb0acd Mon Sep 17 00:00:00 2001 From: kourin Date: Tue, 17 Jan 2023 14:12:46 +0200 Subject: [PATCH 17/17] Revert "Revert disabling function-length check in golangci" This reverts commit 415633e8c011bf5481448ecdf9f6b31ee1118a62. --- .golangci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index 15614c9..ad47552 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -308,7 +308,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#function-length - name: function-length severity: warning - disabled: false + disabled: true arguments: [ 10, 0 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#get-return - name: get-return