diff --git a/etcdserver/server_test.go b/etcdserver/server_test.go index 542aae80c..7b6fadd00 100644 --- a/etcdserver/server_test.go +++ b/etcdserver/server_test.go @@ -36,7 +36,7 @@ func testServer(t *testing.T, ns int64) { } for i := int64(0); i < ns; i++ { - n := raft.Start(i, peers, 1, 10) + n := raft.Start(i, peers, 10, 1) tk := time.NewTicker(10 * time.Millisecond) defer tk.Stop() srv := &Server{ @@ -47,16 +47,12 @@ func testServer(t *testing.T, ns int64) { Ticker: tk.C, } Start(srv) - + // TODO(xiangli): randomize election timeout + // then remove this sleep. + time.Sleep(1 * time.Millisecond) ss[i] = srv } - for i := int64(0); i < ns; i++ { - if err := ss[i].Node.Campaign(ctx); err != nil { - t.Fatal(err) - } - } - for i := 1; i <= 10; i++ { r := pb.Request{ Method: "PUT", diff --git a/raft/raft.go b/raft/raft.go index cea3de8b3..b835f27f7 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -56,12 +56,6 @@ var stmap = [...]string{ stateLeader: "stateLeader", } -var stepmap = [...]stepFunc{ - stateFollower: stepFollower, - stateCandidate: stepCandidate, - stateLeader: stepLeader, -} - func (st stateType) String() string { return stmap[int64(st)] } @@ -126,6 +120,7 @@ type raft struct { heartbeatTimeout int electionTimeout int tick func() + step stepFunc } func newRaft(id int64, peers []int64, election, heartbeat int) *raft { @@ -249,6 +244,7 @@ func (r *raft) reset(term int64) { r.Term = term r.lead = none r.Vote = none + r.elapsed = 0 r.votes = make(map[int64]bool) for i := range r.prs { r.prs[i] = &progress{next: r.raftLog.lastIndex() + 1} @@ -272,9 +268,10 @@ func (r *raft) appendEntry(e pb.Entry) { func (r *raft) tickElection() { r.elapsed++ + // TODO (xiangli): elctionTimeout should be randomized. if r.elapsed > r.electionTimeout { r.elapsed = 0 - r.campaign() + r.Step(pb.Message{From: r.id, Type: msgHup}) } } @@ -282,41 +279,39 @@ func (r *raft) tickHeartbeat() { r.elapsed++ if r.elapsed > r.heartbeatTimeout { r.elapsed = 0 - r.bcastHeartbeat() + r.Step(pb.Message{From: r.id, Type: msgBeat}) } } -func (r *raft) setTick(f func()) { - r.elapsed = 0 - r.tick = f -} - func (r *raft) becomeFollower(term int64, lead int64) { - r.setTick(r.tickElection) + r.step = stepFollower r.reset(term) + r.tick = r.tickElection r.lead = lead r.state = stateFollower r.configuring = false } func (r *raft) becomeCandidate() { - r.setTick(r.tickElection) // TODO(xiangli) remove the panic when the raft implementation is stable if r.state == stateLeader { panic("invalid transition [leader -> candidate]") } + r.step = stepCandidate r.reset(r.Term + 1) + r.tick = r.tickElection r.Vote = r.id r.state = stateCandidate } func (r *raft) becomeLeader() { - r.setTick(r.tickHeartbeat) // TODO(xiangli) remove the panic when the raft implementation is stable if r.state == stateFollower { panic("invalid transition [follower -> leader]") } + r.step = stepLeader r.reset(r.Term) + r.tick = r.tickElection r.lead = r.id r.state = stateLeader @@ -370,8 +365,7 @@ func (r *raft) Step(m pb.Message) error { case m.Term < r.Term: // ignore } - - stepmap[r.state](r, m) + r.step(r, m) return nil } diff --git a/raft/raft_test.go b/raft/raft_test.go index d3e549a24..960fd06ca 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -549,11 +549,18 @@ func TestRecvMsgVote(t *testing.T) { } for i, tt := range tests { - sm := &raft{ - state: tt.state, - State: pb.State{Vote: tt.voteFor}, - raftLog: &raftLog{ents: []pb.Entry{{}, {Term: 2}, {Term: 2}}}, + sm := newRaft(0, []int64{0}, 0, 0) + sm.state = tt.state + switch tt.state { + case stateFollower: + sm.step = stepFollower + case stateCandidate: + sm.step = stepCandidate + case stateLeader: + sm.step = stepLeader } + sm.State = pb.State{Vote: tt.voteFor} + sm.raftLog = &raftLog{ents: []pb.Entry{{}, {Term: 2}, {Term: 2}}} sm.Step(pb.Message{Type: msgVote, From: 1, Index: tt.i, LogTerm: tt.term}) @@ -778,6 +785,14 @@ func TestRecvMsgBeat(t *testing.T) { sm.raftLog = &raftLog{ents: []pb.Entry{{}, {Term: 0}, {Term: 1}}} sm.Term = 1 sm.state = tt.state + switch tt.state { + case stateFollower: + sm.step = stepFollower + case stateCandidate: + sm.step = stepCandidate + case stateLeader: + sm.step = stepLeader + } sm.Step(pb.Message{From: 0, To: 0, Type: msgBeat}) msgs := sm.ReadMessages()