diff --git a/rafthttp/peer.go b/rafthttp/peer.go index 9dd0207d0..b978f1ad3 100644 --- a/rafthttp/peer.go +++ b/rafthttp/peer.go @@ -177,8 +177,26 @@ func startPeer(transport *Transport, urls types.URLs, local, to, cid types.ID, r } }() - p.msgAppV2Reader = startStreamReader(transport, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc) - p.msgAppReader = startStreamReader(transport, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc) + p.msgAppV2Reader = &streamReader{ + typ: streamTypeMsgAppV2, + tr: transport, + picker: picker, + to: to, + status: status, + recvc: p.recvc, + propc: p.propc, + } + p.msgAppReader = &streamReader{ + typ: streamTypeMessage, + tr: transport, + picker: picker, + to: to, + status: status, + recvc: p.recvc, + propc: p.propc, + } + p.msgAppV2Reader.start() + p.msgAppReader.start() return p } diff --git a/rafthttp/stream.go b/rafthttp/stream.go index a5bb32e33..5ed4a2fc8 100644 --- a/rafthttp/stream.go +++ b/rafthttp/stream.go @@ -244,46 +244,39 @@ func (cw *streamWriter) stop() { // streamReader is a long-running go-routine that dials to the remote stream // endpoint and reads messages from the response body returned. type streamReader struct { - tr *Transport - picker *urlPicker - t streamType - local, remote types.ID - cid types.ID - status *peerStatus - recvc chan<- raftpb.Message - propc chan<- raftpb.Message - errorc chan<- error + typ streamType + + tr *Transport + picker *urlPicker + to types.ID + status *peerStatus + recvc chan<- raftpb.Message + propc chan<- raftpb.Message + + errorc chan<- error mu sync.Mutex paused bool cancel func() closer io.Closer - stopc chan struct{} - done chan struct{} + + stopc chan struct{} + done chan struct{} } -func startStreamReader(tr *Transport, picker *urlPicker, t streamType, local, remote, cid types.ID, status *peerStatus, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader { - r := &streamReader{ - tr: tr, - picker: picker, - t: t, - local: local, - remote: remote, - cid: cid, - status: status, - recvc: recvc, - propc: propc, - errorc: errorc, - stopc: make(chan struct{}), - done: make(chan struct{}), +func (r *streamReader) start() { + r.stopc = make(chan struct{}) + r.done = make(chan struct{}) + if r.errorc != nil { + r.errorc = r.tr.ErrorC } + go r.run() - return r } func (cr *streamReader) run() { for { - t := cr.t + t := cr.typ rc, err := cr.dial(t) if err != nil { if err != errUnsupportedStreamType { @@ -317,7 +310,7 @@ func (cr *streamReader) decodeLoop(rc io.ReadCloser, t streamType) error { cr.mu.Lock() switch t { case streamTypeMsgAppV2: - dec = newMsgAppV2Decoder(rc, cr.local, cr.remote) + dec = newMsgAppV2Decoder(rc, cr.tr.ID, cr.to) case streamTypeMessage: dec = &messageDecoder{r: rc} default: @@ -382,18 +375,18 @@ func (cr *streamReader) stop() { func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) { u := cr.picker.pick() uu := u - uu.Path = path.Join(t.endpoint(), cr.local.String()) + uu.Path = path.Join(t.endpoint(), cr.tr.ID.String()) req, err := http.NewRequest("GET", uu.String(), nil) if err != nil { cr.picker.unreachable(u) return nil, fmt.Errorf("failed to make http request to %v (%v)", u, err) } - req.Header.Set("X-Server-From", cr.local.String()) + req.Header.Set("X-Server-From", cr.tr.ID.String()) req.Header.Set("X-Server-Version", version.Version) req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion) - req.Header.Set("X-Etcd-Cluster-ID", cr.cid.String()) - req.Header.Set("X-Raft-To", cr.remote.String()) + req.Header.Set("X-Etcd-Cluster-ID", cr.tr.ClusterID.String()) + req.Header.Set("X-Raft-To", cr.to.String()) setPeerURLsHeader(req, cr.tr.URLs) @@ -436,7 +429,7 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) { case http.StatusNotFound: httputil.GracefulClose(resp) cr.picker.unreachable(u) - return nil, fmt.Errorf("remote member %s could not recognize local member", cr.remote) + return nil, fmt.Errorf("remote member %s could not recognize local member", cr.to) case http.StatusPreconditionFailed: b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -448,11 +441,11 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) { switch strings.TrimSuffix(string(b), "\n") { case errIncompatibleVersion.Error(): - plog.Errorf("request sent was ignored by peer %s (server version incompatible)", cr.remote) + plog.Errorf("request sent was ignored by peer %s (server version incompatible)", cr.to) return nil, errIncompatibleVersion case errClusterIDMismatch.Error(): plog.Errorf("request sent was ignored (cluster ID mismatch: remote[%s]=%s, local=%s)", - cr.remote, resp.Header.Get("X-Etcd-Cluster-ID"), cr.cid) + cr.to, resp.Header.Get("X-Etcd-Cluster-ID"), cr.tr.ClusterID) return nil, errClusterIDMismatch default: return nil, fmt.Errorf("unhandled error %q when precondition failed", string(b)) diff --git a/rafthttp/stream_test.go b/rafthttp/stream_test.go index 12c109383..54fe86f09 100644 --- a/rafthttp/stream_test.go +++ b/rafthttp/stream_test.go @@ -116,11 +116,9 @@ func TestStreamReaderDialRequest(t *testing.T) { for i, tt := range []streamType{streamTypeMessage, streamTypeMsgAppV2} { tr := &roundTripperRecorder{} sr := &streamReader{ - tr: &Transport{streamRt: tr}, + tr: &Transport{streamRt: tr, ClusterID: types.ID(1), ID: types.ID(1)}, picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), - local: types.ID(1), - remote: types.ID(2), - cid: types.ID(1), + to: types.ID(2), } sr.dial(tt) @@ -166,11 +164,9 @@ func TestStreamReaderDialResult(t *testing.T) { err: tt.err, } sr := &streamReader{ - tr: &Transport{streamRt: tr}, + tr: &Transport{streamRt: tr, ClusterID: types.ID(1)}, picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), - local: types.ID(1), - remote: types.ID(2), - cid: types.ID(1), + to: types.ID(2), errorc: make(chan error, 1), } @@ -194,11 +190,9 @@ func TestStreamReaderDialDetectUnsupport(t *testing.T) { header: http.Header{}, } sr := &streamReader{ - tr: &Transport{streamRt: tr}, + tr: &Transport{streamRt: tr, ClusterID: types.ID(1)}, picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), - local: types.ID(1), - remote: types.ID(2), - cid: types.ID(1), + to: types.ID(2), } _, err := sr.dial(typ) @@ -254,9 +248,19 @@ func TestStream(t *testing.T) { h.sw = sw picker := mustNewURLPicker(t, []string{srv.URL}) - tr := &Transport{streamRt: &http.Transport{}} - sr := startStreamReader(tr, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), newPeerStatus(types.ID(1)), recvc, propc, nil) - defer sr.stop() + tr := &Transport{streamRt: &http.Transport{}, ClusterID: types.ID(1)} + + sr := &streamReader{ + typ: tt.t, + tr: tr, + picker: picker, + to: types.ID(2), + status: newPeerStatus(types.ID(1)), + recvc: recvc, + propc: propc, + } + sr.start() + // wait for stream to work var writec chan<- raftpb.Message for { @@ -277,6 +281,8 @@ func TestStream(t *testing.T) { if !reflect.DeepEqual(m, tt.m) { t.Fatalf("#%d: message = %+v, want %+v", i, m, tt.m) } + + sr.stop() } }