Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add raft label filtering capabilities at the RPC check layer #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type RPCHeader struct {
ID []byte
// Addr is the ServerAddr of the node sending the RPC Request or Response
Addr []byte
// Label is the label the sender is configured with
Label string
}

// WithRPCHeader is an interface that exposes the RPC header.
Expand Down
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ type Config struct {
// PreVoteDisabled deactivate the pre-vote feature when set to true
PreVoteDisabled bool

Label string
SkipLabelCheck bool

// skipStartup allows NewRaft() to bypass all background work goroutines
skipStartup bool
}
Expand Down
9 changes: 7 additions & 2 deletions raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func (r *Raft) getRPCHeader() RPCHeader {
ProtocolVersion: r.config().ProtocolVersion,
ID: []byte(r.config().LocalID),
Addr: r.trans.EncodePeer(r.config().LocalID, r.localAddr),
Label: r.config().Label,
}
}

Expand Down Expand Up @@ -67,6 +68,12 @@ func (r *Raft) checkRPCHeader(rpc RPC) error {
return ErrUnsupportedProtocol
}

if !r.config().SkipLabelCheck {
if header.Label != r.config().Label {
return fmt.Errorf("RPC has wrong label: got: %s expected %s", header.Label, r.config().Label)
}
}

return nil
}

Expand Down Expand Up @@ -2048,7 +2055,6 @@ func (r *Raft) electSelf() <-chan *voteResult {
// vote for ourself).
// This must only be called from the main thread.
func (r *Raft) preElectSelf() <-chan *preVoteResult {

// At this point transport should support pre-vote
// but check just in case
prevoteTrans, prevoteTransSupported := r.trans.(WithPreVote)
Expand Down Expand Up @@ -2097,7 +2103,6 @@ func (r *Raft) preElectSelf() <-chan *preVoteResult {
resp.Granted = false
}
respCh <- resp

})
}

Expand Down
116 changes: 111 additions & 5 deletions raft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2066,7 +2066,6 @@ func TestRaft_AppendEntry(t *testing.T) {
// Once the cluster is created, we force an election by partioning the leader
// and verify that the cluster regain stability.
func TestRaft_PreVoteMixedCluster(t *testing.T) {

tcs := []struct {
name string
prevoteNum int
Expand All @@ -2081,7 +2080,6 @@ func TestRaft_PreVoteMixedCluster(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {

// Make majority cluster.
majority := tc.prevoteNum
minority := tc.noprevoteNum
Expand Down Expand Up @@ -2120,7 +2118,6 @@ func TestRaft_PreVoteMixedCluster(t *testing.T) {
require.NotEqual(t, leader.leaderID, leaderOld.leaderID)
})
}

}

func TestRaft_PreVoteAvoidElectionWithPartition(t *testing.T) {
Expand All @@ -2134,7 +2131,7 @@ func TestRaft_PreVoteAvoidElectionWithPartition(t *testing.T) {
followers := c.Followers()
require.Len(t, followers, 4)

//Partition a node and wait enough for it to increase its term
// Partition a node and wait enough for it to increase its term
c.Partition([]ServerAddress{followers[0].localAddr})
time.Sleep(10 * c.propagateTimeout)

Expand All @@ -2151,7 +2148,6 @@ func TestRaft_PreVoteAvoidElectionWithPartition(t *testing.T) {
require.Len(t, c.Followers(), 4)
leaderTerm = c.Leader().getCurrentTerm()
require.Equal(t, leaderTerm, oldLeaderTerm)

}

func TestRaft_VotingGrant_WhenLeaderAvailable(t *testing.T) {
Expand Down Expand Up @@ -2201,6 +2197,116 @@ func TestRaft_VotingGrant_WhenLeaderAvailable(t *testing.T) {
}
}

func TestRaft_LabelFiltering(t *testing.T) {
tests := []struct {
Name string
ReceiverLabel string
ReceiverSkipLabelCheck bool
SenderLabel string
ExpectFailure bool
}{
{
Name: "Receiver skip label check, labeled receiver, incorrectly labeled sender",
ReceiverLabel: "cluster-0",
ReceiverSkipLabelCheck: true,
SenderLabel: "cluster-1",
ExpectFailure: false,
},
{
Name: "Receiver skip label check, labeled receiver, correctly labeled sender",
ReceiverLabel: "cluster-0",
ReceiverSkipLabelCheck: true,
SenderLabel: "cluster-0",
ExpectFailure: false,
},
{
Name: "Receiver skip label check, unlabeled receiver, labeled sender",
ReceiverLabel: "",
ReceiverSkipLabelCheck: true,
SenderLabel: "cluster-1",
ExpectFailure: false,
},
{
Name: "Receiver skip label check, unlabeled receiver, unlabeled sender",
ReceiverLabel: "",
ReceiverSkipLabelCheck: true,
SenderLabel: "",
ExpectFailure: false,
},
{
Name: "Receiver do label check, unlabeled receiver, unlabeled sender",
ReceiverLabel: "",
ReceiverSkipLabelCheck: false,
SenderLabel: "",
ExpectFailure: false,
},
{
Name: "Receiver do label check, labeled receiver, unlabeled sender",
ReceiverLabel: "cluster-0",
ReceiverSkipLabelCheck: false,
SenderLabel: "",
ExpectFailure: true,
},
{
Name: "Receiver do label check, unlabeled receiver, labeled sender",
ReceiverLabel: "",
ReceiverSkipLabelCheck: false,
SenderLabel: "cluster-0",
ExpectFailure: true,
},
{
Name: "Receiver do label check, labeled receiver, incorreclty labeled sender",
ReceiverLabel: "cluster-0",
ReceiverSkipLabelCheck: false,
SenderLabel: "cluster-1",
ExpectFailure: true,
},
{
Name: "Receiver do label check, labeled receiver, correclty labeled sender",
ReceiverLabel: "cluster-0",
ReceiverSkipLabelCheck: false,
SenderLabel: "cluster-0",
ExpectFailure: false,
},
}
for _, test := range tests {
test := test
t.Run(test.Name, func(t *testing.T) {
config := inmemConfig(t)
config.Label = test.ReceiverLabel
config.SkipLabelCheck = test.ReceiverSkipLabelCheck

c := MakeCluster(3, t, config)
defer c.Close()
followers := c.Followers()
ldr := c.Leader()
ldrT := c.trans[c.IndexOf(ldr)]

reqVote := RequestVoteRequest{
RPCHeader: RPCHeader{
ProtocolVersion: ProtocolVersionMax,
Addr: ldrT.EncodePeer(ldr.localID, ldr.localAddr),
Label: test.SenderLabel,
},
Term: ldr.getCurrentTerm() + 10,
LastLogIndex: ldr.LastIndex(),
LastLogTerm: ldr.getCurrentTerm(),
}
var resp RequestVoteResponse
err := ldrT.RequestVote(followers[0].localID, followers[0].localAddr, &reqVote, &resp)
if test.ExpectFailure {
if err == nil || !strings.Contains(err.Error(), "RPC has wrong label") {
t.Fatalf("unexpected RPC did not get rejected: %v", err)
}
} else {
if err != nil {
t.Fatalf("unexpected RPC got rejected rejected: %v", err)
}
}
})
}
}

func TestRaft_ProtocolVersion_RejectRPC(t *testing.T) {
c := MakeCluster(3, t, nil)
defer c.Close()
Expand Down