From 8569b9c782f9598455b349649ec8b7c9facdc966 Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Tue, 1 May 2018 17:07:37 -0700 Subject: [PATCH] clientv3: Fix endpoint resolver to create a new resolver for each grpc client connection --- clientv3/balancer/balancer_test.go | 42 ++-- .../balancer/resolver/endpoint/endpoint.go | 180 ++++++++++-------- clientv3/client.go | 44 ++--- clientv3/integration/server_shutdown_test.go | 15 ++ clientv3/integration/watch_test.go | 10 +- 5 files changed, 165 insertions(+), 126 deletions(-) diff --git a/clientv3/balancer/balancer_test.go b/clientv3/balancer/balancer_test.go index ac0ff303f..be9d5ed24 100644 --- a/clientv3/balancer/balancer_test.go +++ b/clientv3/balancer/balancer_test.go @@ -30,7 +30,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" - "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" ) @@ -58,14 +57,17 @@ func TestRoundRobinBalancedResolvableNoFailover(t *testing.T) { } defer ms.Stop() - var resolvedAddrs []resolver.Address + var eps []string for _, svr := range ms.Servers { - resolvedAddrs = append(resolvedAddrs, svr.ResolverAddress()) + eps = append(eps, svr.ResolverAddress().Addr) } - rsv := endpoint.EndpointResolver("nofailover") + rsv, err := endpoint.NewResolverGroup("nofailover") + if err != nil { + t.Fatal(err) + } defer rsv.Close() - rsv.InitialAddrs(resolvedAddrs) + rsv.SetEndpoints(eps) name := genName() cfg := Config{ @@ -121,14 +123,17 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) { t.Fatalf("failed to start mock servers: %s", err) } defer ms.Stop() - var resolvedAddrs []resolver.Address + var eps []string for _, svr := range ms.Servers { - resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: svr.Address}) + eps = append(eps, svr.ResolverAddress().Addr) } - rsv := endpoint.EndpointResolver("serverfail") + rsv, err := endpoint.NewResolverGroup("serverfail") + if err != nil { + t.Fatal(err) + } defer rsv.Close() - rsv.InitialAddrs(resolvedAddrs) + rsv.SetEndpoints(eps) name := genName() cfg := Config{ @@ -158,7 +163,7 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) { ms.StopAt(0) available := make(map[string]struct{}) for i := 1; i < serverCount; i++ { - available[resolvedAddrs[i].Addr] = struct{}{} + available[eps[i]] = struct{}{} } reqN := 10 @@ -169,8 +174,8 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) { continue } if prev == "" { // first failover - if resolvedAddrs[0].Addr == picked { - t.Fatalf("expected failover from %q, picked %q", resolvedAddrs[0].Addr, picked) + if eps[0] == picked { + t.Fatalf("expected failover from %q, picked %q", eps[0], picked) } prev = picked continue @@ -194,7 +199,7 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) { time.Sleep(time.Second) prev, switches = "", 0 - recoveredAddr, recovered := resolvedAddrs[0].Addr, 0 + recoveredAddr, recovered := eps[0], 0 available[recoveredAddr] = struct{}{} for i := 0; i < 2*reqN; i++ { @@ -234,15 +239,18 @@ func TestRoundRobinBalancedResolvableFailoverFromRequestFail(t *testing.T) { t.Fatalf("failed to start mock servers: %s", err) } defer ms.Stop() - var resolvedAddrs []resolver.Address + var eps []string available := make(map[string]struct{}) for _, svr := range ms.Servers { - resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: svr.Address}) + eps = append(eps, svr.ResolverAddress().Addr) available[svr.Address] = struct{}{} } - rsv := endpoint.EndpointResolver("requestfail") + rsv, err := endpoint.NewResolverGroup("requestfail") + if err != nil { + t.Fatal(err) + } defer rsv.Close() - rsv.InitialAddrs(resolvedAddrs) + rsv.SetEndpoints(eps) name := genName() cfg := Config{ diff --git a/clientv3/balancer/resolver/endpoint/endpoint.go b/clientv3/balancer/resolver/endpoint/endpoint.go index 120a0b9c5..679f92e9a 100644 --- a/clientv3/balancer/resolver/endpoint/endpoint.go +++ b/clientv3/balancer/resolver/endpoint/endpoint.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint:///'. +// Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint:///'. package endpoint import ( @@ -36,91 +36,140 @@ var ( func init() { bldr = &builder{ - clientResolvers: make(map[string]*Resolver), + resolverGroups: make(map[string]*ResolverGroup), } resolver.Register(bldr) } type builder struct { - clientResolvers map[string]*Resolver + resolverGroups map[string]*ResolverGroup sync.RWMutex } +// NewResolverGroup creates a new ResolverGroup with the given id. +func NewResolverGroup(id string) (*ResolverGroup, error) { + return bldr.newResolverGroup(id) +} + +// ResolverGroup keeps all endpoints of resolvers using a common endpoint:/// target +// up-to-date. +type ResolverGroup struct { + id string + endpoints []string + resolvers []*Resolver + sync.RWMutex +} + +func (e *ResolverGroup) addResolver(r *Resolver) { + e.Lock() + addrs := epsToAddrs(e.endpoints...) + e.resolvers = append(e.resolvers, r) + e.Unlock() + r.cc.NewAddress(addrs) +} + +func (e *ResolverGroup) removeResolver(r *Resolver) { + e.Lock() + for i, er := range e.resolvers { + if er == r { + e.resolvers = append(e.resolvers[:i], e.resolvers[i+1:]...) + break + } + } + e.Unlock() +} + +// SetEndpoints updates the endpoints for ResolverGroup. All registered resolver are updated +// immediately with the new endpoints. +func (e *ResolverGroup) SetEndpoints(endpoints []string) { + addrs := epsToAddrs(endpoints...) + e.Lock() + e.endpoints = endpoints + for _, r := range e.resolvers { + r.cc.NewAddress(addrs) + } + e.Unlock() +} + +// Target constructs a endpoint target using the endpoint id of the ResolverGroup. +func (e *ResolverGroup) Target(endpoint string) string { + return Target(e.id, endpoint) +} + +// Target constructs a endpoint resolver target. +func Target(id, endpoint string) string { + return fmt.Sprintf("%s://%s/%s", scheme, id, endpoint) +} + +// IsTarget checks if a given target string in an endpoint resolver target. +func IsTarget(target string) bool { + return strings.HasPrefix(target, "endpoint://") +} + +func (e *ResolverGroup) Close() { + bldr.close(e.id) +} + // Build creates or reuses an etcd resolver for the etcd cluster name identified by the authority part of the target. func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { if len(target.Authority) < 1 { return nil, fmt.Errorf("'etcd' target scheme requires non-empty authority identifying etcd cluster being routed to") } - r := b.getResolver(target.Authority) - r.cc = cc - if r.addrs != nil { - r.NewAddress(r.addrs) + id := target.Authority + es, err := b.getResolverGroup(id) + if err != nil { + return nil, fmt.Errorf("failed to build resolver: %v", err) } + r := &Resolver{ + endpointId: id, + cc: cc, + } + es.addResolver(r) return r, nil } -func (b *builder) getResolver(clientId string) *Resolver { +func (b *builder) newResolverGroup(id string) (*ResolverGroup, error) { b.RLock() - r, ok := b.clientResolvers[clientId] + es, ok := b.resolverGroups[id] b.RUnlock() if !ok { - r = &Resolver{ - clientId: clientId, - } + es = &ResolverGroup{id: id} b.Lock() - b.clientResolvers[clientId] = r + b.resolverGroups[id] = es b.Unlock() + } else { + return nil, fmt.Errorf("Endpoint already exists for id: %s", id) } - return r + return es, nil } -func (b *builder) addResolver(r *Resolver) { - bldr.Lock() - bldr.clientResolvers[r.clientId] = r - bldr.Unlock() +func (b *builder) getResolverGroup(id string) (*ResolverGroup, error) { + b.RLock() + es, ok := b.resolverGroups[id] + b.RUnlock() + if !ok { + return nil, fmt.Errorf("ResolverGroup not found for id: %s", id) + } + return es, nil } -func (b *builder) removeResolver(r *Resolver) { - bldr.Lock() - delete(bldr.clientResolvers, r.clientId) - bldr.Unlock() +func (b *builder) close(id string) { + b.Lock() + delete(b.resolverGroups, id) + b.Unlock() } func (r *builder) Scheme() string { return scheme } -// EndpointResolver gets the resolver for given etcd cluster name. -func EndpointResolver(clientId string) *Resolver { - return bldr.getResolver(clientId) -} - // Resolver provides a resolver for a single etcd cluster, identified by name. type Resolver struct { - clientId string - cc resolver.ClientConn - addrs []resolver.Address + endpointId string + cc resolver.ClientConn sync.RWMutex } -// InitialAddrs sets the initial endpoint addresses for the resolver. -func (r *Resolver) InitialAddrs(addrs []resolver.Address) { - r.Lock() - r.addrs = addrs - r.Unlock() -} - -// InitialEndpoints sets the initial endpoints to for the resolver. -// This should be called before dialing. The endpoints may be updated after the dial using NewAddress. -// At least one endpoint is required. -func (r *Resolver) InitialEndpoints(eps []string) error { - if len(eps) < 1 { - return fmt.Errorf("At least one endpoint is required, but got: %v", eps) - } - r.InitialAddrs(epsToAddrs(eps...)) - return nil -} - // TODO: use balancer.epsToAddrs func epsToAddrs(eps ...string) (addrs []resolver.Address) { addrs = make([]resolver.Address, 0, len(eps)) @@ -130,35 +179,14 @@ func epsToAddrs(eps ...string) (addrs []resolver.Address) { return addrs } -// NewAddress updates the addresses of the resolver. -func (r *Resolver) NewAddress(addrs []resolver.Address) { - r.Lock() - r.addrs = addrs - r.Unlock() - if r.cc != nil { - r.cc.NewAddress(addrs) - } -} - func (*Resolver) ResolveNow(o resolver.ResolveNowOption) {} func (r *Resolver) Close() { - bldr.removeResolver(r) -} - -// Target constructs a endpoint target with current resolver's clientId. -func (r *Resolver) Target(endpoint string) string { - return Target(r.clientId, endpoint) -} - -// Target constructs a endpoint resolver target. -func Target(clientId, endpoint string) string { - return fmt.Sprintf("%s://%s/%s", scheme, clientId, endpoint) -} - -// IsTarget checks if a given target string in an endpoint resolver target. -func IsTarget(target string) bool { - return strings.HasPrefix(target, "endpoint://") + es, err := bldr.getResolverGroup(r.endpointId) + if err != nil { + return + } + es.removeResolver(r) } // Parse endpoint parses a endpoint of the form (http|https)://*|(unix|unixs)://) and returns a @@ -185,7 +213,7 @@ func ParseEndpoint(endpoint string) (proto string, host string, scheme string) { return proto, host, scheme } -// ParseTarget parses a endpoint:/// string and returns the parsed clientId and endpoint. +// ParseTarget parses a endpoint:/// string and returns the parsed id and endpoint. // If the target is malformed, an error is returned. func ParseTarget(target string) (string, string, error) { noPrefix := strings.TrimPrefix(target, targetPrefix) @@ -194,7 +222,7 @@ func ParseTarget(target string) (string, string, error) { } parts := strings.SplitN(noPrefix, "/", 2) if len(parts) != 2 { - return "", "", fmt.Errorf("malformed target, expected %s:///, but got %s", scheme, target) + return "", "", fmt.Errorf("malformed target, expected %s:///, but got %s", scheme, target) } return parts[0], parts[1], nil } diff --git a/clientv3/client.go b/clientv3/client.go index b2870f4ee..d861ef27b 100644 --- a/clientv3/client.go +++ b/clientv3/client.go @@ -37,7 +37,6 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" - "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" ) @@ -68,11 +67,11 @@ type Client struct { conn *grpc.ClientConn dialerrc chan error - cfg Config - creds *credentials.TransportCredentials - balancer balancer.Balancer - resolver *endpoint.Resolver - mu *sync.Mutex + cfg Config + creds *credentials.TransportCredentials + balancer balancer.Balancer + resolverGroup *endpoint.ResolverGroup + mu *sync.Mutex ctx context.Context cancel context.CancelFunc @@ -119,12 +118,12 @@ func (c *Client) Close() error { c.cancel() c.Watcher.Close() c.Lease.Close() + if c.resolverGroup != nil { + c.resolverGroup.Close() + } if c.conn != nil { return toErr(c.ctx, c.conn.Close()) } - if c.resolver != nil { - c.resolver.Close() - } return c.ctx.Err() } @@ -143,22 +142,10 @@ func (c *Client) Endpoints() (eps []string) { // SetEndpoints updates client's endpoints. func (c *Client) SetEndpoints(eps ...string) { - var addrs []resolver.Address - for _, ep := range eps { - addrs = append(addrs, resolver.Address{Addr: ep}) - } - c.mu.Lock() defer c.mu.Unlock() c.cfg.Endpoints = eps - c.resolver.NewAddress(addrs) - // TODO: Does the new grpc balancer provide a way to block until the endpoint changes are propagated? - /*if c.balancer.NeedUpdate() { - select { - case c.balancer.UpdateAddrsC() <- balancer.NotifyNext: - case <-c.balancer.StopC(): - } - }*/ + c.resolverGroup.SetEndpoints(eps) } // Sync synchronizes client's endpoints with the known endpoints from the etcd membership. @@ -301,12 +288,13 @@ func (c *Client) getToken(ctx context.Context) error { // use dial options without dopts to avoid reusing the client balancer var dOpts []grpc.DialOption _, host, _ := endpoint.ParseEndpoint(ep) - target := c.resolver.Target(host) + target := c.resolverGroup.Target(host) dOpts, err = c.dialSetupOpts(target, c.cfg.DialOptions...) if err != nil { err = fmt.Errorf("failed to configure auth dialer: %v", err) continue } + dOpts = append(dOpts, grpc.WithBalancerName(roundRobinBalancerName)) auth, err = newAuthenticator(ctx, target, dOpts, c) if err != nil { continue @@ -333,7 +321,7 @@ func (c *Client) dial(ep string, dopts ...grpc.DialOption) (*grpc.ClientConn, er // We pass a target to DialContext of the form: endpoint:/// that // does not include scheme (http/https/unix/unixs) or path parts. _, host, _ := endpoint.ParseEndpoint(ep) - target := c.resolver.Target(host) + target := c.resolverGroup.Target(host) opts, err := c.dialSetupOpts(target, dopts...) if err != nil { @@ -439,13 +427,13 @@ func newClient(cfg *Config) (*Client, error) { // Prepare a 'endpoint:///' resolver for the client and create a endpoint target to pass // to dial so the client knows to use this resolver. - client.resolver = endpoint.EndpointResolver(fmt.Sprintf("client-%s", strconv.FormatInt(time.Now().UnixNano(), 36))) - err := client.resolver.InitialEndpoints(cfg.Endpoints) + var err error + client.resolverGroup, err = endpoint.NewResolverGroup(fmt.Sprintf("client-%s", strconv.FormatInt(time.Now().UnixNano(), 36))) if err != nil { client.cancel() - client.resolver.Close() return nil, err } + client.resolverGroup.SetEndpoints(cfg.Endpoints) if len(cfg.Endpoints) < 1 { return nil, fmt.Errorf("at least one Endpoint must is required in client config") @@ -457,7 +445,7 @@ func newClient(cfg *Config) (*Client, error) { conn, err := client.dial(dialEndpoint, grpc.WithBalancerName(roundRobinBalancerName)) if err != nil { client.cancel() - client.resolver.Close() + client.resolverGroup.Close() return nil, err } // TODO: With the old grpc balancer interface, we waited until the dial timeout diff --git a/clientv3/integration/server_shutdown_test.go b/clientv3/integration/server_shutdown_test.go index fdca6e3e6..1d20f25d1 100644 --- a/clientv3/integration/server_shutdown_test.go +++ b/clientv3/integration/server_shutdown_test.go @@ -403,3 +403,18 @@ func isServerUnavailable(err error) bool { code := ev.Code() return code == codes.Unavailable } + +func isCanceled(err error) bool { + if err == nil { + return false + } + if err == context.Canceled { + return true + } + ev, ok := status.FromError(err) + if !ok { + return false + } + code := ev.Code() + return code == codes.Canceled +} diff --git a/clientv3/integration/watch_test.go b/clientv3/integration/watch_test.go index f4879149f..efe4387ab 100644 --- a/clientv3/integration/watch_test.go +++ b/clientv3/integration/watch_test.go @@ -30,7 +30,6 @@ import ( mvccpb "github.com/coreos/etcd/mvcc/mvccpb" "github.com/coreos/etcd/pkg/testutil" - "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -667,8 +666,9 @@ func TestWatchErrConnClosed(t *testing.T) { go func() { defer close(donec) ch := cli.Watch(context.TODO(), "foo") - if wr := <-ch; wr.Err() != grpc.ErrClientConnClosing { - t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, wr.Err()) + + if wr := <-ch; !isCanceled(wr.Err()) { + t.Fatalf("expected context canceled, got %v", wr.Err()) } }() @@ -699,8 +699,8 @@ func TestWatchAfterClose(t *testing.T) { donec := make(chan struct{}) go func() { cli.Watch(context.TODO(), "foo") - if err := cli.Close(); err != nil && err != grpc.ErrClientConnClosing { - t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, err) + if err := cli.Close(); err != nil && err != context.Canceled { + t.Fatalf("expected %v, got %v", context.Canceled, err) } close(donec) }()