clientv3: Fix endpoint resolver to create a new resolver for each grpc client connection

release-3.4
Joe Betz 2018-05-01 17:07:37 -07:00 committed by Gyuho Lee
parent 9304d1abd1
commit 8569b9c782
5 changed files with 165 additions and 126 deletions

View File

@ -30,7 +30,6 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
@ -58,14 +57,17 @@ func TestRoundRobinBalancedResolvableNoFailover(t *testing.T) {
} }
defer ms.Stop() defer ms.Stop()
var resolvedAddrs []resolver.Address var eps []string
for _, svr := range ms.Servers { for _, svr := range ms.Servers {
resolvedAddrs = append(resolvedAddrs, svr.ResolverAddress()) eps = append(eps, svr.ResolverAddress().Addr)
} }
rsv := endpoint.EndpointResolver("nofailover") rsv, err := endpoint.NewResolverGroup("nofailover")
if err != nil {
t.Fatal(err)
}
defer rsv.Close() defer rsv.Close()
rsv.InitialAddrs(resolvedAddrs) rsv.SetEndpoints(eps)
name := genName() name := genName()
cfg := Config{ cfg := Config{
@ -121,14 +123,17 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
t.Fatalf("failed to start mock servers: %s", err) t.Fatalf("failed to start mock servers: %s", err)
} }
defer ms.Stop() defer ms.Stop()
var resolvedAddrs []resolver.Address var eps []string
for _, svr := range ms.Servers { for _, svr := range ms.Servers {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: svr.Address}) eps = append(eps, svr.ResolverAddress().Addr)
} }
rsv := endpoint.EndpointResolver("serverfail") rsv, err := endpoint.NewResolverGroup("serverfail")
if err != nil {
t.Fatal(err)
}
defer rsv.Close() defer rsv.Close()
rsv.InitialAddrs(resolvedAddrs) rsv.SetEndpoints(eps)
name := genName() name := genName()
cfg := Config{ cfg := Config{
@ -158,7 +163,7 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
ms.StopAt(0) ms.StopAt(0)
available := make(map[string]struct{}) available := make(map[string]struct{})
for i := 1; i < serverCount; i++ { for i := 1; i < serverCount; i++ {
available[resolvedAddrs[i].Addr] = struct{}{} available[eps[i]] = struct{}{}
} }
reqN := 10 reqN := 10
@ -169,8 +174,8 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
continue continue
} }
if prev == "" { // first failover if prev == "" { // first failover
if resolvedAddrs[0].Addr == picked { if eps[0] == picked {
t.Fatalf("expected failover from %q, picked %q", resolvedAddrs[0].Addr, picked) t.Fatalf("expected failover from %q, picked %q", eps[0], picked)
} }
prev = picked prev = picked
continue continue
@ -194,7 +199,7 @@ func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
prev, switches = "", 0 prev, switches = "", 0
recoveredAddr, recovered := resolvedAddrs[0].Addr, 0 recoveredAddr, recovered := eps[0], 0
available[recoveredAddr] = struct{}{} available[recoveredAddr] = struct{}{}
for i := 0; i < 2*reqN; i++ { for i := 0; i < 2*reqN; i++ {
@ -234,15 +239,18 @@ func TestRoundRobinBalancedResolvableFailoverFromRequestFail(t *testing.T) {
t.Fatalf("failed to start mock servers: %s", err) t.Fatalf("failed to start mock servers: %s", err)
} }
defer ms.Stop() defer ms.Stop()
var resolvedAddrs []resolver.Address var eps []string
available := make(map[string]struct{}) available := make(map[string]struct{})
for _, svr := range ms.Servers { for _, svr := range ms.Servers {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: svr.Address}) eps = append(eps, svr.ResolverAddress().Addr)
available[svr.Address] = struct{}{} available[svr.Address] = struct{}{}
} }
rsv := endpoint.EndpointResolver("requestfail") rsv, err := endpoint.NewResolverGroup("requestfail")
if err != nil {
t.Fatal(err)
}
defer rsv.Close() defer rsv.Close()
rsv.InitialAddrs(resolvedAddrs) rsv.SetEndpoints(eps)
name := genName() name := genName()
cfg := Config{ cfg := Config{

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint://<clientId>/<endpoint>'. // Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint://<id>/<endpoint>'.
package endpoint package endpoint
import ( import (
@ -36,91 +36,140 @@ var (
func init() { func init() {
bldr = &builder{ bldr = &builder{
clientResolvers: make(map[string]*Resolver), resolverGroups: make(map[string]*ResolverGroup),
} }
resolver.Register(bldr) resolver.Register(bldr)
} }
type builder struct { type builder struct {
clientResolvers map[string]*Resolver resolverGroups map[string]*ResolverGroup
sync.RWMutex sync.RWMutex
} }
// NewResolverGroup creates a new ResolverGroup with the given id.
func NewResolverGroup(id string) (*ResolverGroup, error) {
return bldr.newResolverGroup(id)
}
// ResolverGroup keeps all endpoints of resolvers using a common endpoint://<id>/ target
// up-to-date.
type ResolverGroup struct {
id string
endpoints []string
resolvers []*Resolver
sync.RWMutex
}
func (e *ResolverGroup) addResolver(r *Resolver) {
e.Lock()
addrs := epsToAddrs(e.endpoints...)
e.resolvers = append(e.resolvers, r)
e.Unlock()
r.cc.NewAddress(addrs)
}
func (e *ResolverGroup) removeResolver(r *Resolver) {
e.Lock()
for i, er := range e.resolvers {
if er == r {
e.resolvers = append(e.resolvers[:i], e.resolvers[i+1:]...)
break
}
}
e.Unlock()
}
// SetEndpoints updates the endpoints for ResolverGroup. All registered resolver are updated
// immediately with the new endpoints.
func (e *ResolverGroup) SetEndpoints(endpoints []string) {
addrs := epsToAddrs(endpoints...)
e.Lock()
e.endpoints = endpoints
for _, r := range e.resolvers {
r.cc.NewAddress(addrs)
}
e.Unlock()
}
// Target constructs a endpoint target using the endpoint id of the ResolverGroup.
func (e *ResolverGroup) Target(endpoint string) string {
return Target(e.id, endpoint)
}
// Target constructs a endpoint resolver target.
func Target(id, endpoint string) string {
return fmt.Sprintf("%s://%s/%s", scheme, id, endpoint)
}
// IsTarget checks if a given target string in an endpoint resolver target.
func IsTarget(target string) bool {
return strings.HasPrefix(target, "endpoint://")
}
func (e *ResolverGroup) Close() {
bldr.close(e.id)
}
// Build creates or reuses an etcd resolver for the etcd cluster name identified by the authority part of the target. // Build creates or reuses an etcd resolver for the etcd cluster name identified by the authority part of the target.
func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
if len(target.Authority) < 1 { if len(target.Authority) < 1 {
return nil, fmt.Errorf("'etcd' target scheme requires non-empty authority identifying etcd cluster being routed to") return nil, fmt.Errorf("'etcd' target scheme requires non-empty authority identifying etcd cluster being routed to")
} }
r := b.getResolver(target.Authority) id := target.Authority
r.cc = cc es, err := b.getResolverGroup(id)
if r.addrs != nil { if err != nil {
r.NewAddress(r.addrs) return nil, fmt.Errorf("failed to build resolver: %v", err)
} }
r := &Resolver{
endpointId: id,
cc: cc,
}
es.addResolver(r)
return r, nil return r, nil
} }
func (b *builder) getResolver(clientId string) *Resolver { func (b *builder) newResolverGroup(id string) (*ResolverGroup, error) {
b.RLock() b.RLock()
r, ok := b.clientResolvers[clientId] es, ok := b.resolverGroups[id]
b.RUnlock() b.RUnlock()
if !ok { if !ok {
r = &Resolver{ es = &ResolverGroup{id: id}
clientId: clientId,
}
b.Lock() b.Lock()
b.clientResolvers[clientId] = r b.resolverGroups[id] = es
b.Unlock() b.Unlock()
} else {
return nil, fmt.Errorf("Endpoint already exists for id: %s", id)
} }
return r return es, nil
} }
func (b *builder) addResolver(r *Resolver) { func (b *builder) getResolverGroup(id string) (*ResolverGroup, error) {
bldr.Lock() b.RLock()
bldr.clientResolvers[r.clientId] = r es, ok := b.resolverGroups[id]
bldr.Unlock() b.RUnlock()
if !ok {
return nil, fmt.Errorf("ResolverGroup not found for id: %s", id)
}
return es, nil
} }
func (b *builder) removeResolver(r *Resolver) { func (b *builder) close(id string) {
bldr.Lock() b.Lock()
delete(bldr.clientResolvers, r.clientId) delete(b.resolverGroups, id)
bldr.Unlock() b.Unlock()
} }
func (r *builder) Scheme() string { func (r *builder) Scheme() string {
return scheme return scheme
} }
// EndpointResolver gets the resolver for given etcd cluster name.
func EndpointResolver(clientId string) *Resolver {
return bldr.getResolver(clientId)
}
// Resolver provides a resolver for a single etcd cluster, identified by name. // Resolver provides a resolver for a single etcd cluster, identified by name.
type Resolver struct { type Resolver struct {
clientId string endpointId string
cc resolver.ClientConn cc resolver.ClientConn
addrs []resolver.Address
sync.RWMutex sync.RWMutex
} }
// InitialAddrs sets the initial endpoint addresses for the resolver.
func (r *Resolver) InitialAddrs(addrs []resolver.Address) {
r.Lock()
r.addrs = addrs
r.Unlock()
}
// InitialEndpoints sets the initial endpoints to for the resolver.
// This should be called before dialing. The endpoints may be updated after the dial using NewAddress.
// At least one endpoint is required.
func (r *Resolver) InitialEndpoints(eps []string) error {
if len(eps) < 1 {
return fmt.Errorf("At least one endpoint is required, but got: %v", eps)
}
r.InitialAddrs(epsToAddrs(eps...))
return nil
}
// TODO: use balancer.epsToAddrs // TODO: use balancer.epsToAddrs
func epsToAddrs(eps ...string) (addrs []resolver.Address) { func epsToAddrs(eps ...string) (addrs []resolver.Address) {
addrs = make([]resolver.Address, 0, len(eps)) addrs = make([]resolver.Address, 0, len(eps))
@ -130,35 +179,14 @@ func epsToAddrs(eps ...string) (addrs []resolver.Address) {
return addrs return addrs
} }
// NewAddress updates the addresses of the resolver.
func (r *Resolver) NewAddress(addrs []resolver.Address) {
r.Lock()
r.addrs = addrs
r.Unlock()
if r.cc != nil {
r.cc.NewAddress(addrs)
}
}
func (*Resolver) ResolveNow(o resolver.ResolveNowOption) {} func (*Resolver) ResolveNow(o resolver.ResolveNowOption) {}
func (r *Resolver) Close() { func (r *Resolver) Close() {
bldr.removeResolver(r) es, err := bldr.getResolverGroup(r.endpointId)
} if err != nil {
return
// Target constructs a endpoint target with current resolver's clientId. }
func (r *Resolver) Target(endpoint string) string { es.removeResolver(r)
return Target(r.clientId, endpoint)
}
// Target constructs a endpoint resolver target.
func Target(clientId, endpoint string) string {
return fmt.Sprintf("%s://%s/%s", scheme, clientId, endpoint)
}
// IsTarget checks if a given target string in an endpoint resolver target.
func IsTarget(target string) bool {
return strings.HasPrefix(target, "endpoint://")
} }
// Parse endpoint parses a endpoint of the form (http|https)://<host>*|(unix|unixs)://<path>) and returns a // Parse endpoint parses a endpoint of the form (http|https)://<host>*|(unix|unixs)://<path>) and returns a
@ -185,7 +213,7 @@ func ParseEndpoint(endpoint string) (proto string, host string, scheme string) {
return proto, host, scheme return proto, host, scheme
} }
// ParseTarget parses a endpoint://<clientId>/<endpoint> string and returns the parsed clientId and endpoint. // ParseTarget parses a endpoint://<id>/<endpoint> string and returns the parsed id and endpoint.
// If the target is malformed, an error is returned. // If the target is malformed, an error is returned.
func ParseTarget(target string) (string, string, error) { func ParseTarget(target string) (string, string, error) {
noPrefix := strings.TrimPrefix(target, targetPrefix) noPrefix := strings.TrimPrefix(target, targetPrefix)
@ -194,7 +222,7 @@ func ParseTarget(target string) (string, string, error) {
} }
parts := strings.SplitN(noPrefix, "/", 2) parts := strings.SplitN(noPrefix, "/", 2)
if len(parts) != 2 { if len(parts) != 2 {
return "", "", fmt.Errorf("malformed target, expected %s://<clientId>/<endpoint>, but got %s", scheme, target) return "", "", fmt.Errorf("malformed target, expected %s://<id>/<endpoint>, but got %s", scheme, target)
} }
return parts[0], parts[1], nil return parts[0], parts[1], nil
} }

View File

@ -37,7 +37,6 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
@ -68,11 +67,11 @@ type Client struct {
conn *grpc.ClientConn conn *grpc.ClientConn
dialerrc chan error dialerrc chan error
cfg Config cfg Config
creds *credentials.TransportCredentials creds *credentials.TransportCredentials
balancer balancer.Balancer balancer balancer.Balancer
resolver *endpoint.Resolver resolverGroup *endpoint.ResolverGroup
mu *sync.Mutex mu *sync.Mutex
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@ -119,12 +118,12 @@ func (c *Client) Close() error {
c.cancel() c.cancel()
c.Watcher.Close() c.Watcher.Close()
c.Lease.Close() c.Lease.Close()
if c.resolverGroup != nil {
c.resolverGroup.Close()
}
if c.conn != nil { if c.conn != nil {
return toErr(c.ctx, c.conn.Close()) return toErr(c.ctx, c.conn.Close())
} }
if c.resolver != nil {
c.resolver.Close()
}
return c.ctx.Err() return c.ctx.Err()
} }
@ -143,22 +142,10 @@ func (c *Client) Endpoints() (eps []string) {
// SetEndpoints updates client's endpoints. // SetEndpoints updates client's endpoints.
func (c *Client) SetEndpoints(eps ...string) { func (c *Client) SetEndpoints(eps ...string) {
var addrs []resolver.Address
for _, ep := range eps {
addrs = append(addrs, resolver.Address{Addr: ep})
}
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.cfg.Endpoints = eps c.cfg.Endpoints = eps
c.resolver.NewAddress(addrs) c.resolverGroup.SetEndpoints(eps)
// TODO: Does the new grpc balancer provide a way to block until the endpoint changes are propagated?
/*if c.balancer.NeedUpdate() {
select {
case c.balancer.UpdateAddrsC() <- balancer.NotifyNext:
case <-c.balancer.StopC():
}
}*/
} }
// Sync synchronizes client's endpoints with the known endpoints from the etcd membership. // Sync synchronizes client's endpoints with the known endpoints from the etcd membership.
@ -301,12 +288,13 @@ func (c *Client) getToken(ctx context.Context) error {
// use dial options without dopts to avoid reusing the client balancer // use dial options without dopts to avoid reusing the client balancer
var dOpts []grpc.DialOption var dOpts []grpc.DialOption
_, host, _ := endpoint.ParseEndpoint(ep) _, host, _ := endpoint.ParseEndpoint(ep)
target := c.resolver.Target(host) target := c.resolverGroup.Target(host)
dOpts, err = c.dialSetupOpts(target, c.cfg.DialOptions...) dOpts, err = c.dialSetupOpts(target, c.cfg.DialOptions...)
if err != nil { if err != nil {
err = fmt.Errorf("failed to configure auth dialer: %v", err) err = fmt.Errorf("failed to configure auth dialer: %v", err)
continue continue
} }
dOpts = append(dOpts, grpc.WithBalancerName(roundRobinBalancerName))
auth, err = newAuthenticator(ctx, target, dOpts, c) auth, err = newAuthenticator(ctx, target, dOpts, c)
if err != nil { if err != nil {
continue continue
@ -333,7 +321,7 @@ func (c *Client) dial(ep string, dopts ...grpc.DialOption) (*grpc.ClientConn, er
// We pass a target to DialContext of the form: endpoint://<clusterName>/<host-part> that // We pass a target to DialContext of the form: endpoint://<clusterName>/<host-part> that
// does not include scheme (http/https/unix/unixs) or path parts. // does not include scheme (http/https/unix/unixs) or path parts.
_, host, _ := endpoint.ParseEndpoint(ep) _, host, _ := endpoint.ParseEndpoint(ep)
target := c.resolver.Target(host) target := c.resolverGroup.Target(host)
opts, err := c.dialSetupOpts(target, dopts...) opts, err := c.dialSetupOpts(target, dopts...)
if err != nil { if err != nil {
@ -439,13 +427,13 @@ func newClient(cfg *Config) (*Client, error) {
// Prepare a 'endpoint://<unique-client-id>/' resolver for the client and create a endpoint target to pass // Prepare a 'endpoint://<unique-client-id>/' resolver for the client and create a endpoint target to pass
// to dial so the client knows to use this resolver. // to dial so the client knows to use this resolver.
client.resolver = endpoint.EndpointResolver(fmt.Sprintf("client-%s", strconv.FormatInt(time.Now().UnixNano(), 36))) var err error
err := client.resolver.InitialEndpoints(cfg.Endpoints) client.resolverGroup, err = endpoint.NewResolverGroup(fmt.Sprintf("client-%s", strconv.FormatInt(time.Now().UnixNano(), 36)))
if err != nil { if err != nil {
client.cancel() client.cancel()
client.resolver.Close()
return nil, err return nil, err
} }
client.resolverGroup.SetEndpoints(cfg.Endpoints)
if len(cfg.Endpoints) < 1 { if len(cfg.Endpoints) < 1 {
return nil, fmt.Errorf("at least one Endpoint must is required in client config") return nil, fmt.Errorf("at least one Endpoint must is required in client config")
@ -457,7 +445,7 @@ func newClient(cfg *Config) (*Client, error) {
conn, err := client.dial(dialEndpoint, grpc.WithBalancerName(roundRobinBalancerName)) conn, err := client.dial(dialEndpoint, grpc.WithBalancerName(roundRobinBalancerName))
if err != nil { if err != nil {
client.cancel() client.cancel()
client.resolver.Close() client.resolverGroup.Close()
return nil, err return nil, err
} }
// TODO: With the old grpc balancer interface, we waited until the dial timeout // TODO: With the old grpc balancer interface, we waited until the dial timeout

View File

@ -403,3 +403,18 @@ func isServerUnavailable(err error) bool {
code := ev.Code() code := ev.Code()
return code == codes.Unavailable return code == codes.Unavailable
} }
func isCanceled(err error) bool {
if err == nil {
return false
}
if err == context.Canceled {
return true
}
ev, ok := status.FromError(err)
if !ok {
return false
}
code := ev.Code()
return code == codes.Canceled
}

View File

@ -30,7 +30,6 @@ import (
mvccpb "github.com/coreos/etcd/mvcc/mvccpb" mvccpb "github.com/coreos/etcd/mvcc/mvccpb"
"github.com/coreos/etcd/pkg/testutil" "github.com/coreos/etcd/pkg/testutil"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
@ -667,8 +666,9 @@ func TestWatchErrConnClosed(t *testing.T) {
go func() { go func() {
defer close(donec) defer close(donec)
ch := cli.Watch(context.TODO(), "foo") ch := cli.Watch(context.TODO(), "foo")
if wr := <-ch; wr.Err() != grpc.ErrClientConnClosing {
t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, wr.Err()) if wr := <-ch; !isCanceled(wr.Err()) {
t.Fatalf("expected context canceled, got %v", wr.Err())
} }
}() }()
@ -699,8 +699,8 @@ func TestWatchAfterClose(t *testing.T) {
donec := make(chan struct{}) donec := make(chan struct{})
go func() { go func() {
cli.Watch(context.TODO(), "foo") cli.Watch(context.TODO(), "foo")
if err := cli.Close(); err != nil && err != grpc.ErrClientConnClosing { if err := cli.Close(); err != nil && err != context.Canceled {
t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, err) t.Fatalf("expected %v, got %v", context.Canceled, err)
} }
close(donec) close(donec)
}() }()