From 16bf0f6641001461638e4f410dae219e17d735da Mon Sep 17 00:00:00 2001 From: Marek Siarkowicz Date: Thu, 25 May 2023 22:11:11 +0200 Subject: [PATCH] tests/robustness: Use traffic.RecordingClient in watch Signed-off-by: Marek Siarkowicz --- tests/robustness/traffic/client.go | 15 ++++++++-- tests/robustness/traffic/kubernetes.go | 2 +- tests/robustness/watch.go | 41 ++++++++------------------ 3 files changed, 27 insertions(+), 31 deletions(-) diff --git a/tests/robustness/traffic/client.go b/tests/robustness/traffic/client.go index af5c99b8b..52ffb0445 100644 --- a/tests/robustness/traffic/client.go +++ b/tests/robustness/traffic/client.go @@ -216,7 +216,7 @@ func (c *RecordingClient) Defragment(ctx context.Context) error { return err } -func (c *RecordingClient) Watch(ctx context.Context, key string, rev int64, withPrefix bool) clientv3.WatchChan { +func (c *RecordingClient) Watch(ctx context.Context, key string, rev int64, withPrefix bool, withProgressNotify bool) clientv3.WatchChan { ops := []clientv3.OpOption{clientv3.WithProgressNotify()} if withPrefix { ops = append(ops, clientv3.WithPrefix()) @@ -224,6 +224,9 @@ func (c *RecordingClient) Watch(ctx context.Context, key string, rev int64, with if rev != 0 { ops = append(ops, clientv3.WithRev(rev)) } + if withProgressNotify { + ops = append(ops, clientv3.WithProgressNotify()) + } respCh := make(chan clientv3.WatchResponse) go func() { defer close(respCh) @@ -231,12 +234,20 @@ func (c *RecordingClient) Watch(ctx context.Context, key string, rev int64, with c.watchMux.Lock() c.watchResponses = append(c.watchResponses, ToWatchResponse(r, c.baseTime)) c.watchMux.Unlock() - respCh <- r + select { + case respCh <- r: + case <-ctx.Done(): + return + } } }() return respCh } +func (c *RecordingClient) RequestProgress(ctx context.Context) error { + return c.client.RequestProgress(ctx) +} + func ToWatchResponse(r clientv3.WatchResponse, baseTime time.Time) WatchResponse { // using time.Since time-measuring operation to get monotonic clock reading // see https://github.com/golang/go/blob/master/src/time/time.go#L17 diff --git a/tests/robustness/traffic/kubernetes.go b/tests/robustness/traffic/kubernetes.go index c3df6db07..d3d26d91d 100644 --- a/tests/robustness/traffic/kubernetes.go +++ b/tests/robustness/traffic/kubernetes.go @@ -84,7 +84,7 @@ func (t kubernetesTraffic) Run(ctx context.Context, c *RecordingClient, limiter s.Reset(resp) limiter.Wait(ctx) watchCtx, cancel := context.WithTimeout(ctx, WatchTimeout) - for e := range c.Watch(watchCtx, keyPrefix, resp.Header.Revision+1, true) { + for e := range c.Watch(watchCtx, keyPrefix, resp.Header.Revision+1, true, true) { s.Update(e) } cancel() diff --git a/tests/robustness/watch.go b/tests/robustness/watch.go index 844d14384..e2e865e9e 100644 --- a/tests/robustness/watch.go +++ b/tests/robustness/watch.go @@ -22,9 +22,7 @@ import ( "github.com/anishathalye/porcupine" "github.com/google/go-cmp/cmp" - "go.uber.org/zap" - clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/tests/v3/framework/e2e" "go.etcd.io/etcd/tests/v3/robustness/identity" "go.etcd.io/etcd/tests/v3/robustness/model" @@ -37,27 +35,19 @@ func collectClusterWatchEvents(ctx context.Context, t *testing.T, clus *e2e.Etcd reports := make([]traffic.ClientReport, len(clus.Procs)) memberMaxRevisionChans := make([]chan int64, len(clus.Procs)) for i, member := range clus.Procs { - c, err := clientv3.New(clientv3.Config{ - Endpoints: member.EndpointsGRPC(), - Logger: zap.NewNop(), - DialKeepAliveTime: 10 * time.Second, - DialKeepAliveTimeout: 100 * time.Millisecond, - }) + c, err := traffic.NewClient(member.EndpointsGRPC(), ids, baseTime) if err != nil { t.Fatal(err) } - memberChan := make(chan int64, 1) - memberMaxRevisionChans[i] = memberChan + memberMaxRevisionChan := make(chan int64, 1) + memberMaxRevisionChans[i] = memberMaxRevisionChan wg.Add(1) - go func(i int, c *clientv3.Client) { + go func(i int, c *traffic.RecordingClient) { defer wg.Done() defer c.Close() - responses := watchMember(ctx, t, c, memberChan, cfg, baseTime) + watchUntilRevision(ctx, t, c, memberMaxRevisionChan, cfg) mux.Lock() - reports[i] = traffic.ClientReport{ - ClientId: ids.NewClientId(), - Watch: responses, - } + reports[i] = c.Report() mux.Unlock() }(i, c) } @@ -78,25 +68,23 @@ type watchConfig struct { expectUniqueRevision bool } -// watchMember collects all responses until context is cancelled, it has observed revision provided via maxRevisionChan or maxRevisionChan was closed. -// TODO: Use traffic.RecordingClient instead of clientv3.Client -func watchMember(ctx context.Context, t *testing.T, c *clientv3.Client, maxRevisionChan <-chan int64, cfg watchConfig, baseTime time.Time) (resps []traffic.WatchResponse) { +// watchUntilRevision watches all changes until context is cancelled, it has observed revision provided via maxRevisionChan or maxRevisionChan was closed. +func watchUntilRevision(ctx context.Context, t *testing.T, c *traffic.RecordingClient, maxRevisionChan <-chan int64, cfg watchConfig) { var maxRevision int64 = 0 var lastRevision int64 = 0 ctx, cancel := context.WithCancel(ctx) defer cancel() - watch := c.Watch(ctx, "", clientv3.WithPrefix(), clientv3.WithRev(1), clientv3.WithProgressNotify()) + watch := c.Watch(ctx, "", 1, true, true) for { select { case <-ctx.Done(): - revision := watchResponsesMaxRevision(resps) if maxRevision == 0 { t.Errorf("Client didn't collect all events, max revision not set") } - if revision < maxRevision { - t.Errorf("Client didn't collect all events, revision got %d, expected: %d", revision, maxRevision) + if lastRevision < maxRevision { + t.Errorf("Client didn't collect all events, revision got %d, expected: %d", lastRevision, maxRevision) } - return resps + return case revision, ok := <-maxRevisionChan: if ok { maxRevision = revision @@ -113,12 +101,9 @@ func watchMember(ctx context.Context, t *testing.T, c *clientv3.Client, maxRevis if cfg.requestProgress { c.RequestProgress(ctx) } - if resp.Err() == nil { - resps = append(resps, traffic.ToWatchResponse(resp, baseTime)) - } else if !resp.Canceled { + if resp.Err() != nil && !resp.Canceled { t.Errorf("Watch stream received error, err %v", resp.Err()) } - // Assumes that we track all events as we watch all keys. if len(resp.Events) > 0 { lastRevision = resp.Events[len(resp.Events)-1].Kv.ModRevision }