rafthttp: simplify streamReader initilization

release-3.0
Xiang Li 2016-05-31 11:54:13 -07:00
parent 310ebdd3e1
commit 86269ab5bf
3 changed files with 69 additions and 52 deletions

View File

@ -177,8 +177,26 @@ func startPeer(transport *Transport, urls types.URLs, local, to, cid types.ID, r
} }
}() }()
p.msgAppV2Reader = startStreamReader(transport, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc) p.msgAppV2Reader = &streamReader{
p.msgAppReader = startStreamReader(transport, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc) typ: streamTypeMsgAppV2,
tr: transport,
picker: picker,
to: to,
status: status,
recvc: p.recvc,
propc: p.propc,
}
p.msgAppReader = &streamReader{
typ: streamTypeMessage,
tr: transport,
picker: picker,
to: to,
status: status,
recvc: p.recvc,
propc: p.propc,
}
p.msgAppV2Reader.start()
p.msgAppReader.start()
return p return p
} }

View File

@ -244,46 +244,39 @@ func (cw *streamWriter) stop() {
// streamReader is a long-running go-routine that dials to the remote stream // streamReader is a long-running go-routine that dials to the remote stream
// endpoint and reads messages from the response body returned. // endpoint and reads messages from the response body returned.
type streamReader struct { type streamReader struct {
tr *Transport typ streamType
picker *urlPicker
t streamType tr *Transport
local, remote types.ID picker *urlPicker
cid types.ID to types.ID
status *peerStatus status *peerStatus
recvc chan<- raftpb.Message recvc chan<- raftpb.Message
propc chan<- raftpb.Message propc chan<- raftpb.Message
errorc chan<- error
errorc chan<- error
mu sync.Mutex mu sync.Mutex
paused bool paused bool
cancel func() cancel func()
closer io.Closer closer io.Closer
stopc chan struct{}
done chan struct{} stopc chan struct{}
done chan struct{}
} }
func startStreamReader(tr *Transport, picker *urlPicker, t streamType, local, remote, cid types.ID, status *peerStatus, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader { func (r *streamReader) start() {
r := &streamReader{ r.stopc = make(chan struct{})
tr: tr, r.done = make(chan struct{})
picker: picker, if r.errorc != nil {
t: t, r.errorc = r.tr.ErrorC
local: local,
remote: remote,
cid: cid,
status: status,
recvc: recvc,
propc: propc,
errorc: errorc,
stopc: make(chan struct{}),
done: make(chan struct{}),
} }
go r.run() go r.run()
return r
} }
func (cr *streamReader) run() { func (cr *streamReader) run() {
for { for {
t := cr.t t := cr.typ
rc, err := cr.dial(t) rc, err := cr.dial(t)
if err != nil { if err != nil {
if err != errUnsupportedStreamType { if err != errUnsupportedStreamType {
@ -317,7 +310,7 @@ func (cr *streamReader) decodeLoop(rc io.ReadCloser, t streamType) error {
cr.mu.Lock() cr.mu.Lock()
switch t { switch t {
case streamTypeMsgAppV2: case streamTypeMsgAppV2:
dec = newMsgAppV2Decoder(rc, cr.local, cr.remote) dec = newMsgAppV2Decoder(rc, cr.tr.ID, cr.to)
case streamTypeMessage: case streamTypeMessage:
dec = &messageDecoder{r: rc} dec = &messageDecoder{r: rc}
default: default:
@ -382,18 +375,18 @@ func (cr *streamReader) stop() {
func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) { func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
u := cr.picker.pick() u := cr.picker.pick()
uu := u uu := u
uu.Path = path.Join(t.endpoint(), cr.local.String()) uu.Path = path.Join(t.endpoint(), cr.tr.ID.String())
req, err := http.NewRequest("GET", uu.String(), nil) req, err := http.NewRequest("GET", uu.String(), nil)
if err != nil { if err != nil {
cr.picker.unreachable(u) cr.picker.unreachable(u)
return nil, fmt.Errorf("failed to make http request to %v (%v)", u, err) return nil, fmt.Errorf("failed to make http request to %v (%v)", u, err)
} }
req.Header.Set("X-Server-From", cr.local.String()) req.Header.Set("X-Server-From", cr.tr.ID.String())
req.Header.Set("X-Server-Version", version.Version) req.Header.Set("X-Server-Version", version.Version)
req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion) req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion)
req.Header.Set("X-Etcd-Cluster-ID", cr.cid.String()) req.Header.Set("X-Etcd-Cluster-ID", cr.tr.ClusterID.String())
req.Header.Set("X-Raft-To", cr.remote.String()) req.Header.Set("X-Raft-To", cr.to.String())
setPeerURLsHeader(req, cr.tr.URLs) setPeerURLsHeader(req, cr.tr.URLs)
@ -436,7 +429,7 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
case http.StatusNotFound: case http.StatusNotFound:
httputil.GracefulClose(resp) httputil.GracefulClose(resp)
cr.picker.unreachable(u) cr.picker.unreachable(u)
return nil, fmt.Errorf("remote member %s could not recognize local member", cr.remote) return nil, fmt.Errorf("remote member %s could not recognize local member", cr.to)
case http.StatusPreconditionFailed: case http.StatusPreconditionFailed:
b, err := ioutil.ReadAll(resp.Body) b, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
@ -448,11 +441,11 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
switch strings.TrimSuffix(string(b), "\n") { switch strings.TrimSuffix(string(b), "\n") {
case errIncompatibleVersion.Error(): case errIncompatibleVersion.Error():
plog.Errorf("request sent was ignored by peer %s (server version incompatible)", cr.remote) plog.Errorf("request sent was ignored by peer %s (server version incompatible)", cr.to)
return nil, errIncompatibleVersion return nil, errIncompatibleVersion
case errClusterIDMismatch.Error(): case errClusterIDMismatch.Error():
plog.Errorf("request sent was ignored (cluster ID mismatch: remote[%s]=%s, local=%s)", plog.Errorf("request sent was ignored (cluster ID mismatch: remote[%s]=%s, local=%s)",
cr.remote, resp.Header.Get("X-Etcd-Cluster-ID"), cr.cid) cr.to, resp.Header.Get("X-Etcd-Cluster-ID"), cr.tr.ClusterID)
return nil, errClusterIDMismatch return nil, errClusterIDMismatch
default: default:
return nil, fmt.Errorf("unhandled error %q when precondition failed", string(b)) return nil, fmt.Errorf("unhandled error %q when precondition failed", string(b))

View File

@ -116,11 +116,9 @@ func TestStreamReaderDialRequest(t *testing.T) {
for i, tt := range []streamType{streamTypeMessage, streamTypeMsgAppV2} { for i, tt := range []streamType{streamTypeMessage, streamTypeMsgAppV2} {
tr := &roundTripperRecorder{} tr := &roundTripperRecorder{}
sr := &streamReader{ sr := &streamReader{
tr: &Transport{streamRt: tr}, tr: &Transport{streamRt: tr, ClusterID: types.ID(1), ID: types.ID(1)},
picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
local: types.ID(1), to: types.ID(2),
remote: types.ID(2),
cid: types.ID(1),
} }
sr.dial(tt) sr.dial(tt)
@ -166,11 +164,9 @@ func TestStreamReaderDialResult(t *testing.T) {
err: tt.err, err: tt.err,
} }
sr := &streamReader{ sr := &streamReader{
tr: &Transport{streamRt: tr}, tr: &Transport{streamRt: tr, ClusterID: types.ID(1)},
picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
local: types.ID(1), to: types.ID(2),
remote: types.ID(2),
cid: types.ID(1),
errorc: make(chan error, 1), errorc: make(chan error, 1),
} }
@ -194,11 +190,9 @@ func TestStreamReaderDialDetectUnsupport(t *testing.T) {
header: http.Header{}, header: http.Header{},
} }
sr := &streamReader{ sr := &streamReader{
tr: &Transport{streamRt: tr}, tr: &Transport{streamRt: tr, ClusterID: types.ID(1)},
picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), picker: mustNewURLPicker(t, []string{"http://localhost:2380"}),
local: types.ID(1), to: types.ID(2),
remote: types.ID(2),
cid: types.ID(1),
} }
_, err := sr.dial(typ) _, err := sr.dial(typ)
@ -254,9 +248,19 @@ func TestStream(t *testing.T) {
h.sw = sw h.sw = sw
picker := mustNewURLPicker(t, []string{srv.URL}) picker := mustNewURLPicker(t, []string{srv.URL})
tr := &Transport{streamRt: &http.Transport{}} tr := &Transport{streamRt: &http.Transport{}, ClusterID: types.ID(1)}
sr := startStreamReader(tr, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), newPeerStatus(types.ID(1)), recvc, propc, nil)
defer sr.stop() sr := &streamReader{
typ: tt.t,
tr: tr,
picker: picker,
to: types.ID(2),
status: newPeerStatus(types.ID(1)),
recvc: recvc,
propc: propc,
}
sr.start()
// wait for stream to work // wait for stream to work
var writec chan<- raftpb.Message var writec chan<- raftpb.Message
for { for {
@ -277,6 +281,8 @@ func TestStream(t *testing.T) {
if !reflect.DeepEqual(m, tt.m) { if !reflect.DeepEqual(m, tt.m) {
t.Fatalf("#%d: message = %+v, want %+v", i, m, tt.m) t.Fatalf("#%d: message = %+v, want %+v", i, m, tt.m)
} }
sr.stop()
} }
} }