diff --git a/raft/raft.go b/raft/raft.go index f6077227b..658d2d02e 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -547,11 +547,11 @@ func (r *raft) campaign(t CampaignType) { r.logger.Infof("%x [logterm: %d, index: %d] sent vote request to %x at term %d", r.id, r.raftLog.lastTerm(), r.raftLog.lastIndex(), id, r.Term) - var entries []pb.Entry + var ctx []byte if t == campaignTransfer { - entries = []pb.Entry{{Data: []byte(t)}} + ctx = []byte(t) } - r.send(pb.Message{To: id, Type: pb.MsgVote, Index: r.raftLog.lastIndex(), LogTerm: r.raftLog.lastTerm(), Entries: entries}) + r.send(pb.Message{To: id, Type: pb.MsgVote, Index: r.raftLog.lastIndex(), LogTerm: r.raftLog.lastTerm(), Context: ctx}) } } @@ -594,7 +594,7 @@ func (r *raft) Step(m pb.Message) error { case m.Term > r.Term: lead := m.From if m.Type == pb.MsgVote { - force := len(m.Entries) == 1 && bytes.Equal(m.Entries[0].Data, []byte(campaignTransfer)) + force := bytes.Equal(m.Context, []byte(campaignTransfer)) inLease := r.checkQuorum && r.state != StateCandidate && r.electionElapsed < r.electionTimeout if !force && inLease { // If a server receives a RequestVote request within the minimum election timeout diff --git a/raft/raftpb/raft.pb.go b/raft/raftpb/raft.pb.go index 479a1c683..6176c3d2d 100644 --- a/raft/raftpb/raft.pb.go +++ b/raft/raftpb/raft.pb.go @@ -236,6 +236,7 @@ type Message struct { Snapshot Snapshot `protobuf:"bytes,9,opt,name=snapshot" json:"snapshot"` Reject bool `protobuf:"varint,10,opt,name=reject" json:"reject"` RejectHint uint64 `protobuf:"varint,11,opt,name=rejectHint" json:"rejectHint"` + Context []byte `protobuf:"bytes,12,opt,name=context" json:"context,omitempty"` XXX_unrecognized []byte `json:"-"` } @@ -464,6 +465,12 @@ func (m *Message) MarshalTo(data []byte) (int, error) { data[i] = 0x58 i++ i = encodeVarintRaft(data, i, uint64(m.RejectHint)) + if m.Context != nil { + data[i] = 0x62 + i++ + i = encodeVarintRaft(data, i, uint64(len(m.Context))) + i += copy(data[i:], m.Context) + } if m.XXX_unrecognized != nil { i += copy(data[i:], m.XXX_unrecognized) } @@ -655,6 +662,10 @@ func (m *Message) Size() (n int) { n += 1 + l + sovRaft(uint64(l)) n += 2 n += 1 + sovRaft(uint64(m.RejectHint)) + if m.Context != nil { + l = len(m.Context) + n += 1 + l + sovRaft(uint64(l)) + } if m.XXX_unrecognized != nil { n += len(m.XXX_unrecognized) } @@ -1348,6 +1359,37 @@ func (m *Message) Unmarshal(data []byte) error { break } } + case 12: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Context", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRaft + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthRaft + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Context = append(m.Context[:0], data[iNdEx:postIndex]...) + if m.Context == nil { + m.Context = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipRaft(data[iNdEx:]) diff --git a/raft/raftpb/raft.proto b/raft/raftpb/raft.proto index 1948fc1e4..18f4cefae 100644 --- a/raft/raftpb/raft.proto +++ b/raft/raftpb/raft.proto @@ -64,6 +64,7 @@ message Message { optional Snapshot snapshot = 9 [(gogoproto.nullable) = false]; optional bool reject = 10 [(gogoproto.nullable) = false]; optional uint64 rejectHint = 11 [(gogoproto.nullable) = false]; + optional bytes context = 12; } message HardState {