diff --git a/etcdserver/api/v3rpc/watch.go b/etcdserver/api/v3rpc/watch.go index de613dad3..dde1ba0f3 100644 --- a/etcdserver/api/v3rpc/watch.go +++ b/etcdserver/api/v3rpc/watch.go @@ -37,6 +37,8 @@ type watchServer struct { clusterID int64 memberID int64 + maxRequestBytes int + sg etcdserver.RaftStatusGetter watchable mvcc.WatchableKV ag AuthGetter @@ -50,6 +52,8 @@ func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer { clusterID: int64(s.Cluster().ID()), memberID: int64(s.ID()), + maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes), + sg: s, watchable: s.Watchable(), ag: s, @@ -102,6 +106,8 @@ type serverWatchStream struct { clusterID int64 memberID int64 + maxRequestBytes int + sg etcdserver.RaftStatusGetter watchable mvcc.WatchableKV ag AuthGetter @@ -110,13 +116,15 @@ type serverWatchStream struct { watchStream mvcc.WatchStream ctrlStream chan *pb.WatchResponse - // mu protects progress, prevKV + // mu protects progress, prevKV, fragment mu sync.RWMutex // tracks the watchID that stream might need to send progress to // TODO: combine progress and prevKV into a single struct? progress map[mvcc.WatchID]bool // record watch IDs that need return previous key-value pair prevKV map[mvcc.WatchID]bool + // records fragmented watch IDs + fragment map[mvcc.WatchID]bool // closec indicates the stream is closed. closec chan struct{} @@ -132,6 +140,8 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) { clusterID: ws.clusterID, memberID: ws.memberID, + maxRequestBytes: ws.maxRequestBytes, + sg: ws.sg, watchable: ws.watchable, ag: ws.ag, @@ -143,6 +153,7 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) { progress: make(map[mvcc.WatchID]bool), prevKV: make(map[mvcc.WatchID]bool), + fragment: make(map[mvcc.WatchID]bool), closec: make(chan struct{}), } @@ -268,6 +279,9 @@ func (sws *serverWatchStream) recvLoop() error { if creq.PrevKv { sws.prevKV[id] = true } + if creq.Fragment { + sws.fragment[id] = true + } sws.mu.Unlock() } wr := &pb.WatchResponse{ @@ -298,6 +312,7 @@ func (sws *serverWatchStream) recvLoop() error { sws.mu.Lock() delete(sws.progress, mvcc.WatchID(id)) delete(sws.prevKV, mvcc.WatchID(id)) + delete(sws.fragment, mvcc.WatchID(id)) sws.mu.Unlock() } } @@ -376,18 +391,30 @@ func (sws *serverWatchStream) sendLoop() { } mvcc.ReportEventReceived(len(evs)) - if err := sws.gRPCStream.Send(wr); err != nil { - if isClientCtxErr(sws.gRPCStream.Context().Err(), err) { + + sws.mu.RLock() + fragmented, ok := sws.fragment[wresp.WatchID] + sws.mu.RUnlock() + + var serr error + if !fragmented && !ok { + serr = sws.gRPCStream.Send(wr) + } else { + serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send) + } + + if serr != nil { + if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) { if sws.lg != nil { - sws.lg.Debug("failed to send watch response to gRPC stream", zap.Error(err)) + sws.lg.Debug("failed to send watch response to gRPC stream", zap.Error(serr)) } else { - plog.Debugf("failed to send watch response to gRPC stream (%q)", err.Error()) + plog.Debugf("failed to send watch response to gRPC stream (%q)", serr.Error()) } } else { if sws.lg != nil { - sws.lg.Warn("failed to send watch response to gRPC stream", zap.Error(err)) + sws.lg.Warn("failed to send watch response to gRPC stream", zap.Error(serr)) } else { - plog.Warningf("failed to send watch response to gRPC stream (%q)", err.Error()) + plog.Warningf("failed to send watch response to gRPC stream (%q)", serr.Error()) } } return @@ -469,6 +496,45 @@ func (sws *serverWatchStream) sendLoop() { } } +func sendFragments( + wr *pb.WatchResponse, + maxRequestBytes int, + sendFunc func(*pb.WatchResponse) error) error { + // no need to fragment if total request size is smaller + // than max request limit or response contains only one event + if wr.Size() < maxRequestBytes || len(wr.Events) < 2 { + return sendFunc(wr) + } + + ow := *wr + ow.Events = make([]*mvccpb.Event, 0) + ow.Fragment = true + + var idx int + for { + cur := ow + for _, ev := range wr.Events[idx:] { + cur.Events = append(cur.Events, ev) + if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes { + cur.Events = cur.Events[:len(cur.Events)-1] + break + } + idx++ + } + if idx == len(wr.Events) { + // last response has no more fragment + cur.Fragment = false + } + if err := sendFunc(&cur); err != nil { + return err + } + if !cur.Fragment { + break + } + } + return nil +} + func (sws *serverWatchStream) close() { sws.watchStream.Close() close(sws.closec) diff --git a/etcdserver/api/v3rpc/watch_test.go b/etcdserver/api/v3rpc/watch_test.go new file mode 100644 index 000000000..15850ab41 --- /dev/null +++ b/etcdserver/api/v3rpc/watch_test.go @@ -0,0 +1,95 @@ +// Copyright 2018 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v3rpc + +import ( + "bytes" + "math" + "testing" + + pb "github.com/coreos/etcd/etcdserver/etcdserverpb" + "github.com/coreos/etcd/mvcc/mvccpb" +) + +func TestSendFragment(t *testing.T) { + tt := []struct { + wr *pb.WatchResponse + maxRequestBytes int + fragments int + werr error + }{ + { // large limit should not fragment + wr: createResponse(100, 1), + maxRequestBytes: math.MaxInt32, + fragments: 1, + }, + { // large limit for two messages, expect no fragment + wr: createResponse(10, 2), + maxRequestBytes: 50, + fragments: 1, + }, + { // limit is small but only one message, expect no fragment + wr: createResponse(1024, 1), + maxRequestBytes: 1, + fragments: 1, + }, + { // exceed limit only when combined, expect fragments + wr: createResponse(11, 5), + maxRequestBytes: 20, + fragments: 5, + }, + { // 5 events with each event exceeding limits, expect fragments + wr: createResponse(15, 5), + maxRequestBytes: 10, + fragments: 5, + }, + { // 4 events with some combined events exceeding limits + wr: createResponse(10, 4), + maxRequestBytes: 35, + fragments: 2, + }, + } + + for i := range tt { + fragmentedResp := make([]*pb.WatchResponse, 0) + testSend := func(wr *pb.WatchResponse) error { + fragmentedResp = append(fragmentedResp, wr) + return nil + } + err := sendFragments(tt[i].wr, tt[i].maxRequestBytes, testSend) + if err != tt[i].werr { + t.Errorf("#%d: expected error %v, got %v", i, tt[i].werr, err) + } + got := len(fragmentedResp) + if got != tt[i].fragments { + t.Errorf("#%d: expected response number %d, got %d", i, tt[i].fragments, got) + } + if got > 0 && fragmentedResp[got-1].Fragment { + t.Errorf("#%d: expected fragment=false in last response, got %+v", i, fragmentedResp[got-1]) + } + } +} + +func createResponse(dataSize, events int) (resp *pb.WatchResponse) { + resp = &pb.WatchResponse{Events: make([]*mvccpb.Event, events)} + for i := range resp.Events { + resp.Events[i] = &mvccpb.Event{ + Kv: &mvccpb.KeyValue{ + Key: bytes.Repeat([]byte("a"), dataSize), + }, + } + } + return resp +}