diff --git a/raft/log.go b/raft/log.go index e6bd2eb3e..7fca9c842 100644 --- a/raft/log.go +++ b/raft/log.go @@ -19,12 +19,13 @@ func (e *Entry) isConfig() bool { } type raftLog struct { - ents []Entry - unstable int64 - committed int64 - applied int64 - offset int64 - snapshot Snapshot + ents []Entry + unstable int64 + committed int64 + applied int64 + offset int64 + snapshot Snapshot + unstableSnapshot Snapshot // want a compact after the number of entries exceeds the threshold // TODO(xiangli) size might be a better criteria @@ -163,12 +164,14 @@ func (l *raftLog) shouldCompact() bool { return (l.applied - l.offset) > l.compactThreshold } -func (l *raftLog) restore(index, term int64) { - l.ents = []Entry{{Term: term}} - l.unstable = index + 1 - l.committed = index - l.applied = index - l.offset = index +func (l *raftLog) restore(s Snapshot) { + l.ents = []Entry{{Term: s.Term}} + l.unstable = s.Index + 1 + l.committed = s.Index + l.applied = s.Index + l.offset = s.Index + l.snapshot = s + l.unstableSnapshot = s } func (l *raftLog) at(i int64) *Entry { diff --git a/raft/log_test.go b/raft/log_test.go index 229956022..46cd8dcaf 100644 --- a/raft/log_test.go +++ b/raft/log_test.go @@ -192,7 +192,7 @@ func TestLogRestore(t *testing.T) { index := int64(1000) term := int64(1000) - raftLog.restore(index, term) + raftLog.restore(Snapshot{Index: index, Term: term}) // only has the guard entry if len(raftLog.ents) != 1 { diff --git a/raft/node.go b/raft/node.go index cee1bbfbc..cca9b87e8 100644 --- a/raft/node.go +++ b/raft/node.go @@ -5,6 +5,7 @@ import ( "encoding/json" "log" "math/rand" + "sort" "time" ) @@ -76,6 +77,15 @@ func (n *Node) Leader() int64 { return n.sm.lead.Get() } func (n *Node) IsRemoved() bool { return n.removed } +func (n *Node) Nodes() []int64 { + nodes := make(int64Slice, 0, len(n.sm.ins)) + for k := range n.sm.ins { + nodes = append(nodes, k) + } + sort.Sort(nodes) + return nodes +} + // Propose asynchronously proposes data be applied to the underlying state machine. func (n *Node) Propose(data []byte) { n.propose(Normal, data) } @@ -232,6 +242,15 @@ func (n *Node) UnstableState() State { return s } +func (n *Node) UnstableSnapshot() Snapshot { + if n.sm.raftLog.unstableSnapshot.IsEmpty() { + return emptySnapshot + } + s := n.sm.raftLog.unstableSnapshot + n.sm.raftLog.unstableSnapshot = emptySnapshot + return s +} + func (n *Node) GetSnap() Snapshot { return n.sm.raftLog.snapshot } diff --git a/raft/raft.go b/raft/raft.go index a84caadc0..128ad3b28 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -157,8 +157,6 @@ type stateMachine struct { // pending reconfiguration pendingConf bool - snapshoter Snapshoter - unstableState State } @@ -187,10 +185,6 @@ func (sm *stateMachine) String() string { return s } -func (sm *stateMachine) setSnapshoter(snapshoter Snapshoter) { - sm.snapshoter = snapshoter -} - func (sm *stateMachine) poll(id int64, v bool) (granted int) { if _, ok := sm.votes[id]; !ok { sm.votes[id] = v @@ -220,7 +214,7 @@ func (sm *stateMachine) sendAppend(to int64) { m.Index = in.next - 1 if sm.needSnapshot(m.Index) { m.Type = msgSnap - m.Snapshot = sm.snapshoter.GetSnap() + m.Snapshot = sm.raftLog.snapshot } else { m.Type = msgApp m.LogTerm = sm.raftLog.term(in.next - 1) @@ -502,31 +496,15 @@ func stepFollower(sm *stateMachine, m Message) bool { return true } -// maybeCompact tries to compact the log. It calls the snapshoter to take a snapshot and -// then compact the log up-to the index at which the snapshot was taken. -func (sm *stateMachine) maybeCompact() bool { - if sm.snapshoter == nil || !sm.raftLog.shouldCompact() { - return false - } - sm.snapshoter.Snap(sm.raftLog.applied, sm.raftLog.term(sm.raftLog.applied), sm.nodes()) - sm.raftLog.compact(sm.raftLog.applied) - return true -} - func (sm *stateMachine) compact(d []byte) { sm.raftLog.snap(d, sm.raftLog.applied, sm.raftLog.term(sm.raftLog.applied), sm.nodes()) sm.raftLog.compact(sm.raftLog.applied) } // restore recovers the statemachine from a snapshot. It restores the log and the -// configuration of statemachine. It calls the snapshoter to restore from the given -// snapshot. +// configuration of statemachine. func (sm *stateMachine) restore(s Snapshot) { - if sm.snapshoter == nil { - panic("try to restore from snapshot, but snapshoter is nil") - } - - sm.raftLog.restore(s.Index, s.Term) + sm.raftLog.restore(s) sm.index.Set(sm.raftLog.lastIndex()) sm.ins = make(map[int64]*index) for _, n := range s.Nodes { @@ -537,13 +515,12 @@ func (sm *stateMachine) restore(s Snapshot) { } } sm.pendingConf = false - sm.snapshoter.Restore(s) } func (sm *stateMachine) needSnapshot(i int64) bool { if i < sm.raftLog.offset { - if sm.snapshoter == nil { - panic("need snapshot but snapshoter is nil") + if sm.raftLog.snapshot.IsEmpty() { + panic("need non-empty snapshot") } return true } diff --git a/raft/raft_test.go b/raft/raft_test.go index 795cc0a05..bcc396f58 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -2,7 +2,6 @@ package raft import ( "bytes" - "fmt" "math/rand" "reflect" "sort" @@ -781,45 +780,24 @@ func TestRestore(t *testing.T) { Nodes: []int64{0, 1, 2}, } - tests := []struct { - snapshoter Snapshoter - wallow bool - }{ - {nil, false}, - {new(logSnapshoter), true}, + sm := newStateMachine(0, []int64{0, 1}) + sm.restore(s) + + if sm.raftLog.lastIndex() != s.Index { + t.Errorf("log.lastIndex = %d, want %d", sm.raftLog.lastIndex(), s.Index) } - - for i, tt := range tests { - func() { - defer func() { - if r := recover(); r != nil { - if tt.wallow == true { - t.Errorf("%d: allow = %v, want %v", i, false, true) - } - } - }() - - sm := newStateMachine(0, []int64{0, 1}) - sm.setSnapshoter(tt.snapshoter) - sm.restore(s) - - if sm.raftLog.lastIndex() != s.Index { - t.Errorf("#%d: log.lastIndex = %d, want %d", i, sm.raftLog.lastIndex(), s.Index) - } - if sm.raftLog.term(s.Index) != s.Term { - t.Errorf("#%d: log.lastTerm = %d, want %d", i, sm.raftLog.term(s.Index), s.Term) - } - sg := int64Slice(sm.nodes()) - sw := int64Slice(s.Nodes) - sort.Sort(sg) - sort.Sort(sw) - if !reflect.DeepEqual(sg, sw) { - t.Errorf("#%d: sm.Nodes = %+v, want %+v", i, sg, sw) - } - if !reflect.DeepEqual(sm.snapshoter.GetSnap(), s) { - t.Errorf("%d: snapshoter.getSnap = %+v, want %+v", sm.snapshoter.GetSnap(), s) - } - }() + if sm.raftLog.term(s.Index) != s.Term { + t.Errorf("log.lastTerm = %d, want %d", sm.raftLog.term(s.Index), s.Term) + } + sg := int64Slice(sm.nodes()) + sw := int64Slice(s.Nodes) + sort.Sort(sg) + sort.Sort(sw) + if !reflect.DeepEqual(sg, sw) { + t.Errorf("sm.Nodes = %+v, want %+v", sg, sw) + } + if !reflect.DeepEqual(sm.raftLog.snapshot, s) { + t.Errorf("snapshot = %+v, want %+v", sm.raftLog.snapshot, s) } } @@ -830,7 +808,6 @@ func TestProvideSnap(t *testing.T) { Nodes: []int64{0, 1}, } sm := newStateMachine(0, []int64{0}) - sm.setSnapshoter(new(logSnapshoter)) // restore the statemachin from a snapshot // so it has a compacted log and a snapshot sm.restore(s) @@ -872,11 +849,10 @@ func TestRestoreFromSnapMsg(t *testing.T) { m := Message{Type: msgSnap, From: 0, Term: 1, Snapshot: s} sm := newStateMachine(1, []int64{0, 1}) - sm.setSnapshoter(new(logSnapshoter)) sm.Step(m) - if !reflect.DeepEqual(sm.snapshoter.GetSnap(), s) { - t.Errorf("snapshot = %+v, want %+v", sm.snapshoter.GetSnap(), s) + if !reflect.DeepEqual(sm.raftLog.snapshot, s) { + t.Errorf("snapshot = %+v, want %+v", sm.raftLog.snapshot, s) } } @@ -890,16 +866,14 @@ func TestSlowNodeRestore(t *testing.T) { } lead := nt.peers[0].(*stateMachine) lead.nextEnts() - if !lead.maybeCompact() { - t.Errorf("compacted = false, want true") - } + lead.compact(nil) nt.recover() nt.send(Message{From: 0, To: 0, Type: msgBeat}) follower := nt.peers[2].(*stateMachine) - if !reflect.DeepEqual(follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap()) { - t.Errorf("follower.snap = %+v, want %+v", follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap()) + if !reflect.DeepEqual(follower.raftLog.snapshot, lead.raftLog.snapshot) { + t.Errorf("follower.snap = %+v, want %+v", follower.raftLog.snapshot, lead.raftLog.snapshot) } committed := follower.raftLog.lastIndex() @@ -979,7 +953,6 @@ func newNetwork(peers ...Interface) *network { switch v := p.(type) { case nil: sm := newStateMachine(nid, defaultPeerAddrs) - sm.setSnapshoter(new(logSnapshoter)) npeers[nid] = sm case *stateMachine: v.id = nid @@ -1070,22 +1043,3 @@ func (blackHole) Step(Message) bool { return true } func (blackHole) Msgs() []Message { return nil } var nopStepper = &blackHole{} - -type logSnapshoter struct { - snapshot Snapshot -} - -func (s *logSnapshoter) Snap(index, term int64, nodes []int64) { - s.snapshot = Snapshot{ - Index: index, - Term: term, - Nodes: nodes, - Data: []byte(fmt.Sprintf("%d:%d", term, index)), - } -} -func (s *logSnapshoter) Restore(ss Snapshot) { - s.snapshot = ss -} -func (s *logSnapshoter) GetSnap() Snapshot { - return s.snapshot -} diff --git a/raft/snapshot.go b/raft/snapshot.go index 4ba3c0f92..f56a5fed8 100644 --- a/raft/snapshot.go +++ b/raft/snapshot.go @@ -1,5 +1,7 @@ package raft +var emptySnapshot = Snapshot{} + type Snapshot struct { Data []byte @@ -11,10 +13,6 @@ type Snapshot struct { Term int64 } -// A snapshoter can make a snapshot of its current state atomically. -// It can restore from a snapshot and get the latest snapshot it took. -type Snapshoter interface { - Snap(index, term int64, nodes []int64) - Restore(snap Snapshot) - GetSnap() Snapshot +func (s Snapshot) IsEmpty() bool { + return s.Term == 0 }