diff --git a/internal/raft/raft.go b/internal/raft/raft.go index 5ec64e895..a1fec3e99 100644 --- a/internal/raft/raft.go +++ b/internal/raft/raft.go @@ -408,8 +408,7 @@ func (r *raft) restore(ss pb.Snapshot) bool { func (r *raft) restoreRemotes(ss pb.Snapshot) { r.remotes = make(map[uint64]*remote) for id := range ss.Membership.Addresses { - _, ok := r.observers[id] - if ok { + if id == r.nodeID && r.state == observer { r.becomeFollower(r.term, r.leaderID) } match := uint64(0) diff --git a/internal/raft/raft_etcd_test.go b/internal/raft/raft_etcd_test.go index 3ba68ff94..9cfd5e24c 100644 --- a/internal/raft/raft_etcd_test.go +++ b/internal/raft/raft_etcd_test.go @@ -3012,6 +3012,15 @@ func newRateLimitedTestRaft(id uint64, peers []uint64, election, heartbeat int, } func newTestObserver(id uint64, peers []uint64, observers []uint64, election, heartbeat int, logdb ILogDB) *raft { + found := false + for _, p := range observers { + if p == id { + found = true + } + } + if !found { + panic("observer node id not included in the observers list") + } cfg := newTestConfig(id, election, heartbeat, logdb) cfg.IsObserver = true r := newRaft(cfg, logdb) diff --git a/internal/raft/raft_test.go b/internal/raft/raft_test.go index c0f79f023..9444069a1 100644 --- a/internal/raft/raft_test.go +++ b/internal/raft/raft_test.go @@ -432,7 +432,7 @@ func TestObserverCanReceiveSnapshot(t *testing.T) { Term: 20, Membership: members, } - p1 := newTestObserver(3, []uint64{1}, []uint64{2}, 10, 1, NewTestLogDB()) + p1 := newTestObserver(3, []uint64{1}, []uint64{2, 3}, 10, 1, NewTestLogDB()) if !p1.isObserver() { t.Errorf("not an observer") } @@ -508,7 +508,7 @@ func TestObserverCanBePromotedBySnapshot(t *testing.T) { Term: 20, Membership: members, } - p1 := newTestObserver(1, []uint64{1}, []uint64{2}, 10, 1, NewTestLogDB()) + p1 := newTestObserver(1, nil, []uint64{1, 2}, 10, 1, NewTestLogDB()) if !p1.isObserver() { t.Errorf("not an observer") } @@ -521,6 +521,38 @@ func TestObserverCanBePromotedBySnapshot(t *testing.T) { } } +func TestCorrectObserverCanBePromotedBySnapshot(t *testing.T) { + members := pb.Membership{ + Addresses: make(map[uint64]string), + Observers: make(map[uint64]string), + Removed: make(map[uint64]bool), + } + members.Observers[1] = "a1" + members.Addresses[2] = "a2" + members.Addresses[3] = "a3" + ss := pb.Snapshot{ + Index: 20, + Term: 20, + Membership: members, + } + p1 := newTestObserver(1, []uint64{2}, []uint64{1, 3}, 10, 1, NewTestLogDB()) + if !p1.isObserver() { + t.Errorf("not an observer") + } + _, ok := p1.observers[1] + if !ok { + t.Errorf("not an observer") + } + _, ok = p1.observers[3] + if !ok { + t.Errorf("not an observer") + } + p1.restoreRemotes(ss) + if !p1.isObserver() { + t.Errorf("observer p1 unexpectedly promoted") + } +} + func TestObserverCanNotMoveNodeBackToObserverBySnapshot(t *testing.T) { members := pb.Membership{ Addresses: make(map[uint64]string), @@ -561,14 +593,18 @@ func TestObserverCanBeAdded(t *testing.T) { } func TestObserverCanBeRemoved(t *testing.T) { - p1 := newTestObserver(1, []uint64{1}, []uint64{2}, 10, 1, NewTestLogDB()) - if len(p1.observers) != 1 { + p1 := newTestObserver(1, nil, []uint64{1, 2}, 10, 1, NewTestLogDB()) + if len(p1.observers) != 2 { t.Errorf("unexpected observer count") } p1.removeNode(2) - if len(p1.observers) != 0 { + if len(p1.observers) != 1 { t.Errorf("observer not removed") } + _, ok := p1.observers[2] + if ok { + t.Errorf("observer node 2 not removed") + } } func TestFollowerTick(t *testing.T) {