raft: handle snapshot message
parent
2a11c1487c
commit
5651272ec8
38
raft/raft.go
38
raft/raft.go
|
@ -17,6 +17,7 @@ const (
|
|||
msgAppResp
|
||||
msgVote
|
||||
msgVoteResp
|
||||
msgSnap
|
||||
)
|
||||
|
||||
var mtmap = [...]string{
|
||||
|
@ -27,6 +28,7 @@ var mtmap = [...]string{
|
|||
msgAppResp: "msgAppResp",
|
||||
msgVote: "msgVote",
|
||||
msgVoteResp: "msgVoteResp",
|
||||
msgSnap: "msgSnap",
|
||||
}
|
||||
|
||||
func (mt messageType) String() string {
|
||||
|
@ -69,6 +71,7 @@ type Message struct {
|
|||
PrevTerm int
|
||||
Entries []Entry
|
||||
Commit int
|
||||
Snapshot Snapshot
|
||||
}
|
||||
|
||||
type index struct {
|
||||
|
@ -151,12 +154,17 @@ func (sm *stateMachine) send(m Message) {
|
|||
func (sm *stateMachine) sendAppend(to int) {
|
||||
in := sm.ins[to]
|
||||
m := Message{}
|
||||
m.Type = msgApp
|
||||
m.To = to
|
||||
m.Index = in.next - 1
|
||||
m.LogTerm = sm.log.term(in.next - 1)
|
||||
m.Entries = sm.log.entries(in.next)
|
||||
m.Commit = sm.log.committed
|
||||
if sm.needSnapshot(m.Index) {
|
||||
m.Type = msgSnap
|
||||
m.Snapshot = sm.snapshoter.GetSnap()
|
||||
} else {
|
||||
m.Type = msgApp
|
||||
m.LogTerm = sm.log.term(in.next - 1)
|
||||
m.Entries = sm.log.entries(in.next)
|
||||
m.Commit = sm.log.committed
|
||||
}
|
||||
sm.send(m)
|
||||
}
|
||||
|
||||
|
@ -244,7 +252,7 @@ func (sm *stateMachine) becomeLeader() {
|
|||
sm.lead = sm.id
|
||||
sm.state = stateLeader
|
||||
|
||||
for _, e := range sm.log.ents[sm.log.committed:] {
|
||||
for _, e := range sm.log.entries(sm.log.committed + 1) {
|
||||
if e.isConfig() {
|
||||
sm.pendingConf = true
|
||||
}
|
||||
|
@ -298,6 +306,11 @@ func (sm *stateMachine) handleAppendEntries(m Message) {
|
|||
}
|
||||
}
|
||||
|
||||
func (sm *stateMachine) handleSnapshot(m Message) {
|
||||
sm.restore(m.Snapshot)
|
||||
sm.send(Message{To: m.From, Type: msgAppResp, Index: sm.log.lastIndex()})
|
||||
}
|
||||
|
||||
func (sm *stateMachine) addNode(id int) {
|
||||
sm.ins[id] = &index{next: sm.log.lastIndex() + 1}
|
||||
sm.pendingConf = false
|
||||
|
@ -350,6 +363,9 @@ func stepCandidate(sm *stateMachine, m Message) bool {
|
|||
case msgApp:
|
||||
sm.becomeFollower(sm.term, m.From)
|
||||
sm.handleAppendEntries(m)
|
||||
case msgSnap:
|
||||
sm.becomeFollower(m.Term, m.From)
|
||||
sm.handleSnapshot(m)
|
||||
case msgVote:
|
||||
sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1})
|
||||
case msgVoteResp:
|
||||
|
@ -375,6 +391,8 @@ func stepFollower(sm *stateMachine, m Message) bool {
|
|||
sm.send(m)
|
||||
case msgApp:
|
||||
sm.handleAppendEntries(m)
|
||||
case msgSnap:
|
||||
sm.handleSnapshot(m)
|
||||
case msgVote:
|
||||
if (sm.vote == none || sm.vote == m.From) && sm.log.isUpToDate(m.Index, m.LogTerm) {
|
||||
sm.vote = m.From
|
||||
|
@ -417,6 +435,16 @@ func (sm *stateMachine) restore(s Snapshot) {
|
|||
sm.snapshoter.Restore(s)
|
||||
}
|
||||
|
||||
func (sm *stateMachine) needSnapshot(i int) bool {
|
||||
if i < sm.log.offset {
|
||||
if sm.snapshoter == nil {
|
||||
panic("need snapshot but snapshoter is nil")
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (sm *stateMachine) nodes() []int {
|
||||
nodes := make([]int, 0, len(sm.ins))
|
||||
for k := range sm.ins {
|
||||
|
|
|
@ -802,6 +802,92 @@ func TestRestore(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestProvideSnap(t *testing.T) {
|
||||
s := Snapshot{
|
||||
Index: defaultCompactThreshold + 1,
|
||||
Term: defaultCompactThreshold + 1,
|
||||
Nodes: []int{0, 1},
|
||||
}
|
||||
sm := newStateMachine(0, []int{0})
|
||||
sm.setSnapshoter(new(logSnapshoter))
|
||||
// restore the statemachin from a snapshot
|
||||
// so it has a compacted log and a snapshot
|
||||
sm.restore(s)
|
||||
|
||||
sm.becomeCandidate()
|
||||
sm.becomeLeader()
|
||||
|
||||
sm.Step(Message{Type: msgBeat})
|
||||
msgs := sm.Msgs()
|
||||
if len(msgs) != 1 {
|
||||
t.Errorf("len(msgs) = %d, want 1", len(msgs))
|
||||
}
|
||||
m := msgs[0]
|
||||
if m.Type != msgApp {
|
||||
t.Errorf("m.Type = %v, want %v", m.Type, msgApp)
|
||||
}
|
||||
|
||||
// force set the next of node 1, so that
|
||||
// node 1 needs a snapshot
|
||||
sm.ins[1].next = sm.log.offset
|
||||
|
||||
sm.Step(Message{Type: msgBeat})
|
||||
msgs = sm.Msgs()
|
||||
if len(msgs) != 1 {
|
||||
t.Errorf("len(msgs) = %d, want 1", len(msgs))
|
||||
}
|
||||
m = msgs[0]
|
||||
if m.Type != msgSnap {
|
||||
t.Errorf("m.Type = %v, want %v", m.Type, msgSnap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreFromSnapMsg(t *testing.T) {
|
||||
s := Snapshot{
|
||||
Index: defaultCompactThreshold + 1,
|
||||
Term: defaultCompactThreshold + 1,
|
||||
Nodes: []int{0, 1},
|
||||
}
|
||||
m := Message{Type: msgSnap, From: 0, Term: 1, Snapshot: s}
|
||||
|
||||
sm := newStateMachine(1, []int{0, 1})
|
||||
sm.setSnapshoter(new(logSnapshoter))
|
||||
sm.Step(m)
|
||||
|
||||
if !reflect.DeepEqual(sm.snapshoter.GetSnap(), s) {
|
||||
t.Errorf("snapshot = %+v, want %+v", sm.snapshoter.GetSnap(), s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlowNodeRestore(t *testing.T) {
|
||||
nt := newNetwork(nil, nil, nil)
|
||||
nt.send(Message{To: 0, Type: msgHup})
|
||||
|
||||
nt.isolate(2)
|
||||
for j := 0; j < defaultCompactThreshold+1; j++ {
|
||||
nt.send(Message{To: 0, Type: msgProp, Entries: []Entry{{}}})
|
||||
}
|
||||
lead := nt.peers[0].(*stateMachine)
|
||||
lead.nextEnts()
|
||||
if !lead.maybeCompact() {
|
||||
t.Errorf("compacted = false, want true")
|
||||
}
|
||||
|
||||
nt.recover()
|
||||
nt.send(Message{To: 0, Type: msgBeat})
|
||||
|
||||
follower := nt.peers[2].(*stateMachine)
|
||||
if !reflect.DeepEqual(follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap()) {
|
||||
t.Errorf("follower.snap = %+v, want %+v", follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap())
|
||||
}
|
||||
|
||||
committed := follower.log.lastIndex()
|
||||
nt.send(Message{To: 0, Type: msgProp, Entries: []Entry{{}}})
|
||||
if follower.log.committed != committed+1 {
|
||||
t.Errorf("follower.comitted = %d, want %d", follower.log.committed, committed+1)
|
||||
}
|
||||
}
|
||||
|
||||
func ents(terms ...int) *stateMachine {
|
||||
ents := []Entry{{}}
|
||||
for _, term := range terms {
|
||||
|
@ -836,6 +922,7 @@ func newNetwork(peers ...Interface) *network {
|
|||
switch v := p.(type) {
|
||||
case nil:
|
||||
sm := newStateMachine(id, defaultPeerAddrs)
|
||||
sm.setSnapshoter(new(logSnapshoter))
|
||||
npeers[id] = sm
|
||||
case *stateMachine:
|
||||
v.id = id
|
||||
|
|
Loading…
Reference in New Issue