diff --git a/commands.go b/commands.go index 1ec76cb2..f964c822 100644 --- a/commands.go +++ b/commands.go @@ -15,6 +15,8 @@ type RPCHeader struct { ID []byte // Addr is the ServerAddr of the node sending the RPC Request or Response Addr []byte + // Label is the label the sender is configured with + Label string } // WithRPCHeader is an interface that exposes the RPC header. diff --git a/config.go b/config.go index d14392fc..75b405ca 100644 --- a/config.go +++ b/config.go @@ -235,6 +235,9 @@ type Config struct { // PreVoteDisabled deactivate the pre-vote feature when set to true PreVoteDisabled bool + Label string + SkipLabelCheck bool + // skipStartup allows NewRaft() to bypass all background work goroutines skipStartup bool } diff --git a/raft.go b/raft.go index cbc9a59a..c43f2489 100644 --- a/raft.go +++ b/raft.go @@ -37,6 +37,7 @@ func (r *Raft) getRPCHeader() RPCHeader { ProtocolVersion: r.config().ProtocolVersion, ID: []byte(r.config().LocalID), Addr: r.trans.EncodePeer(r.config().LocalID, r.localAddr), + Label: r.config().Label, } } @@ -67,6 +68,12 @@ func (r *Raft) checkRPCHeader(rpc RPC) error { return ErrUnsupportedProtocol } + if !r.config().SkipLabelCheck { + if header.Label != r.config().Label { + return fmt.Errorf("RPC has wrong label: got: %s expected %s", header.Label, r.config().Label) + } + } + return nil } @@ -2048,7 +2055,6 @@ func (r *Raft) electSelf() <-chan *voteResult { // vote for ourself). // This must only be called from the main thread. func (r *Raft) preElectSelf() <-chan *preVoteResult { - // At this point transport should support pre-vote // but check just in case prevoteTrans, prevoteTransSupported := r.trans.(WithPreVote) @@ -2097,7 +2103,6 @@ func (r *Raft) preElectSelf() <-chan *preVoteResult { resp.Granted = false } respCh <- resp - }) } diff --git a/raft_test.go b/raft_test.go index 2db115b6..5fe19660 100644 --- a/raft_test.go +++ b/raft_test.go @@ -2066,7 +2066,6 @@ func TestRaft_AppendEntry(t *testing.T) { // Once the cluster is created, we force an election by partioning the leader // and verify that the cluster regain stability. func TestRaft_PreVoteMixedCluster(t *testing.T) { - tcs := []struct { name string prevoteNum int @@ -2081,7 +2080,6 @@ func TestRaft_PreVoteMixedCluster(t *testing.T) { } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - // Make majority cluster. majority := tc.prevoteNum minority := tc.noprevoteNum @@ -2120,7 +2118,6 @@ func TestRaft_PreVoteMixedCluster(t *testing.T) { require.NotEqual(t, leader.leaderID, leaderOld.leaderID) }) } - } func TestRaft_PreVoteAvoidElectionWithPartition(t *testing.T) { @@ -2134,7 +2131,7 @@ func TestRaft_PreVoteAvoidElectionWithPartition(t *testing.T) { followers := c.Followers() require.Len(t, followers, 4) - //Partition a node and wait enough for it to increase its term + // Partition a node and wait enough for it to increase its term c.Partition([]ServerAddress{followers[0].localAddr}) time.Sleep(10 * c.propagateTimeout) @@ -2151,7 +2148,6 @@ func TestRaft_PreVoteAvoidElectionWithPartition(t *testing.T) { require.Len(t, c.Followers(), 4) leaderTerm = c.Leader().getCurrentTerm() require.Equal(t, leaderTerm, oldLeaderTerm) - } func TestRaft_VotingGrant_WhenLeaderAvailable(t *testing.T) { @@ -2201,6 +2197,116 @@ func TestRaft_VotingGrant_WhenLeaderAvailable(t *testing.T) { } } +func TestRaft_LabelFiltering(t *testing.T) { + tests := []struct { + Name string + ReceiverLabel string + ReceiverSkipLabelCheck bool + SenderLabel string + ExpectFailure bool + }{ + { + Name: "Receiver skip label check, labeled receiver, incorrectly labeled sender", + ReceiverLabel: "cluster-0", + ReceiverSkipLabelCheck: true, + SenderLabel: "cluster-1", + ExpectFailure: false, + }, + { + Name: "Receiver skip label check, labeled receiver, correctly labeled sender", + ReceiverLabel: "cluster-0", + ReceiverSkipLabelCheck: true, + SenderLabel: "cluster-0", + ExpectFailure: false, + }, + { + Name: "Receiver skip label check, unlabeled receiver, labeled sender", + ReceiverLabel: "", + ReceiverSkipLabelCheck: true, + SenderLabel: "cluster-1", + ExpectFailure: false, + }, + { + Name: "Receiver skip label check, unlabeled receiver, unlabeled sender", + ReceiverLabel: "", + ReceiverSkipLabelCheck: true, + SenderLabel: "", + ExpectFailure: false, + }, + { + Name: "Receiver do label check, unlabeled receiver, unlabeled sender", + ReceiverLabel: "", + ReceiverSkipLabelCheck: false, + SenderLabel: "", + ExpectFailure: false, + }, + { + Name: "Receiver do label check, labeled receiver, unlabeled sender", + ReceiverLabel: "cluster-0", + ReceiverSkipLabelCheck: false, + SenderLabel: "", + ExpectFailure: true, + }, + { + Name: "Receiver do label check, unlabeled receiver, labeled sender", + ReceiverLabel: "", + ReceiverSkipLabelCheck: false, + SenderLabel: "cluster-0", + ExpectFailure: true, + }, + { + Name: "Receiver do label check, labeled receiver, incorreclty labeled sender", + ReceiverLabel: "cluster-0", + ReceiverSkipLabelCheck: false, + SenderLabel: "cluster-1", + ExpectFailure: true, + }, + { + Name: "Receiver do label check, labeled receiver, correclty labeled sender", + ReceiverLabel: "cluster-0", + ReceiverSkipLabelCheck: false, + SenderLabel: "cluster-0", + ExpectFailure: false, + }, + } + for _, test := range tests { + test := test + t.Run(test.Name, func(t *testing.T) { + config := inmemConfig(t) + config.Label = test.ReceiverLabel + config.SkipLabelCheck = test.ReceiverSkipLabelCheck + + c := MakeCluster(3, t, config) + defer c.Close() + followers := c.Followers() + ldr := c.Leader() + ldrT := c.trans[c.IndexOf(ldr)] + + reqVote := RequestVoteRequest{ + RPCHeader: RPCHeader{ + ProtocolVersion: ProtocolVersionMax, + Addr: ldrT.EncodePeer(ldr.localID, ldr.localAddr), + Label: test.SenderLabel, + }, + Term: ldr.getCurrentTerm() + 10, + LastLogIndex: ldr.LastIndex(), + LastLogTerm: ldr.getCurrentTerm(), + } + var resp RequestVoteResponse + err := ldrT.RequestVote(followers[0].localID, followers[0].localAddr, &reqVote, &resp) + if test.ExpectFailure { + if err == nil || !strings.Contains(err.Error(), "RPC has wrong label") { + t.Fatalf("unexpected RPC did not get rejected: %v", err) + } + } else { + if err != nil { + t.Fatalf("unexpected RPC got rejected rejected: %v", err) + } + } + }) + } +} + func TestRaft_ProtocolVersion_RejectRPC(t *testing.T) { c := MakeCluster(3, t, nil) defer c.Close()