diff --git a/client/discover.go b/client/discover.go index 30473148b..580c25626 100644 --- a/client/discover.go +++ b/client/discover.go @@ -21,7 +21,7 @@ import ( // Discoverer is an interface that wraps the Discover method. type Discoverer interface { // Discover looks up the etcd servers for the domain. - Discover(domain string) ([]string, error) + Discover(domain string, serviceName string) ([]string, error) } type srvDiscover struct{} @@ -31,8 +31,8 @@ func NewSRVDiscover() Discoverer { return &srvDiscover{} } -func (d *srvDiscover) Discover(domain string) ([]string, error) { - srvs, err := srv.GetClient("etcd-client", domain) +func (d *srvDiscover) Discover(domain string, serviceName string) ([]string, error) { + srvs, err := srv.GetClient("etcd-client", domain, serviceName) if err != nil { return nil, err } diff --git a/etcdctl/ctlv2/command/util.go b/etcdctl/ctlv2/command/util.go index 4ac8481e8..d19cd40e3 100644 --- a/etcdctl/ctlv2/command/util.go +++ b/etcdctl/ctlv2/command/util.go @@ -86,7 +86,7 @@ func getPeersFlagValue(c *cli.Context) []string { } func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) { - domainstr, insecure := getDiscoveryDomain(c) + domainstr, insecure, serviceName := getDiscoveryDomain(c) // If we still don't have domain discovery, return nothing if domainstr == "" { @@ -94,7 +94,7 @@ func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) { } discoverer := client.NewSRVDiscover() - eps, err := discoverer.Discover(domainstr) + eps, err := discoverer.Discover(domainstr, serviceName) if err != nil { return nil, err } @@ -113,7 +113,7 @@ func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) { return ret, err } -func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool) { +func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool, serviceName string) { domainstr = c.GlobalString("discovery-srv") // Use an environment variable if nothing was supplied on the // command line @@ -121,7 +121,11 @@ func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool) { domainstr = os.Getenv("ETCDCTL_DISCOVERY_SRV") } insecure = c.GlobalBool("insecure-discovery") || (os.Getenv("ETCDCTL_INSECURE_DISCOVERY") != "") - return domainstr, insecure + serviceName = c.GlobalString("discovery-srv-name") + if serviceName == "" { + serviceName = os.Getenv("ETCDCTL_DISCOVERY_SRV_NAME") + } + return domainstr, insecure, serviceName } func getEndpoints(c *cli.Context) ([]string, error) { @@ -168,7 +172,7 @@ func getTransport(c *cli.Context) (*http.Transport, error) { keyfile = os.Getenv("ETCDCTL_KEY_FILE") } - discoveryDomain, insecure := getDiscoveryDomain(c) + discoveryDomain, insecure, _ := getDiscoveryDomain(c) if insecure { discoveryDomain = "" } diff --git a/etcdctl/ctlv3/command/global.go b/etcdctl/ctlv3/command/global.go index 4400420c1..101a2d209 100644 --- a/etcdctl/ctlv3/command/global.go +++ b/etcdctl/ctlv3/command/global.go @@ -39,14 +39,15 @@ import ( // GlobalFlags are flags that defined globally // and are inherited to all sub-commands. type GlobalFlags struct { - Insecure bool - InsecureSkipVerify bool - InsecureDiscovery bool - Endpoints []string - DialTimeout time.Duration - CommandTimeOut time.Duration - KeepAliveTime time.Duration - KeepAliveTimeout time.Duration + Insecure bool + InsecureSkipVerify bool + InsecureDiscovery bool + Endpoints []string + DialTimeout time.Duration + CommandTimeOut time.Duration + KeepAliveTime time.Duration + KeepAliveTimeout time.Duration + DNSClusterServiceName string TLS transport.TLSInfo @@ -75,8 +76,9 @@ type authCfg struct { } type discoveryCfg struct { - domain string - insecure bool + domain string + insecure bool + serviceName string } var display printer = &simplePrinter{} @@ -390,10 +392,19 @@ func discoverySrvFromCmd(cmd *cobra.Command) string { return domainStr } +func discoveryDNSClusterServiceNameFromCmd(cmd *cobra.Command) string { + serviceNameStr, err := cmd.Flags().GetString("discovery-srv-name") + if err != nil { + ExitWithError(ExitBadArgs, err) + } + return serviceNameStr +} + func discoveryCfgFromCmd(cmd *cobra.Command) *discoveryCfg { return &discoveryCfg{ - domain: discoverySrvFromCmd(cmd), - insecure: insecureDiscoveryFromCmd(cmd), + domain: discoverySrvFromCmd(cmd), + insecure: insecureDiscoveryFromCmd(cmd), + serviceName: discoveryDNSClusterServiceNameFromCmd(cmd), } } @@ -422,7 +433,7 @@ func endpointsFromFlagValue(cmd *cobra.Command) ([]string, error) { return []string{}, nil } - srvs, err := srv.GetClient("etcd-client", discoveryCfg.domain) + srvs, err := srv.GetClient("etcd-client", discoveryCfg.domain, discoveryCfg.serviceName) if err != nil { return nil, err } diff --git a/etcdctl/ctlv3/ctl.go b/etcdctl/ctlv3/ctl.go index bf64f8b9e..40154a5da 100644 --- a/etcdctl/ctlv3/ctl.go +++ b/etcdctl/ctlv3/ctl.go @@ -67,6 +67,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&globalFlags.User, "user", "", "username[:password] for authentication (prompt if password is not supplied)") rootCmd.PersistentFlags().StringVar(&globalFlags.Password, "password", "", "password for authentication (if this option is used, --user option shouldn't include password)") rootCmd.PersistentFlags().StringVarP(&globalFlags.TLS.ServerName, "discovery-srv", "d", "", "domain name to query for SRV records describing cluster endpoints") + rootCmd.PersistentFlags().StringVarP(&globalFlags.DNSClusterServiceName, "discovery-srv-name", "", "", "service name to query when using DNS discovery") rootCmd.AddCommand( command.NewGetCommand(), diff --git a/etcdmain/gateway.go b/etcdmain/gateway.go index 201878081..fece05f1c 100644 --- a/etcdmain/gateway.go +++ b/etcdmain/gateway.go @@ -28,12 +28,13 @@ import ( ) var ( - gatewayListenAddr string - gatewayEndpoints []string - gatewayDNSCluster string - gatewayInsecureDiscovery bool - getewayRetryDelay time.Duration - gatewayCA string + gatewayListenAddr string + gatewayEndpoints []string + gatewayDNSCluster string + gatewayDNSClusterServiceName string + gatewayInsecureDiscovery bool + getewayRetryDelay time.Duration + gatewayCA string ) var ( @@ -68,6 +69,7 @@ func newGatewayStartCommand() *cobra.Command { cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address") cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster") + cmd.Flags().StringVar(&gatewayDNSClusterServiceName, "discovery-srv-name", "", "service name to query when using DNS discovery") cmd.Flags().BoolVar(&gatewayInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records") cmd.Flags().StringVar(&gatewayCA, "trusted-ca-file", "", "path to the client server TLS CA file.") @@ -97,7 +99,7 @@ func startGateway(cmd *cobra.Command, args []string) { os.Exit(1) } - srvs := discoverEndpoints(lg, gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery) + srvs := discoverEndpoints(lg, gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery, gatewayDNSClusterServiceName) if len(srvs.Endpoints) == 0 { // no endpoints discovered, fall back to provided endpoints srvs.Endpoints = gatewayEndpoints diff --git a/etcdmain/grpc_proxy.go b/etcdmain/grpc_proxy.go index 588094bfe..425ba8fc8 100644 --- a/etcdmain/grpc_proxy.go +++ b/etcdmain/grpc_proxy.go @@ -49,14 +49,15 @@ import ( ) var ( - grpcProxyListenAddr string - grpcProxyMetricsListenAddr string - grpcProxyEndpoints []string - grpcProxyDNSCluster string - grpcProxyInsecureDiscovery bool - grpcProxyDataDir string - grpcMaxCallSendMsgSize int - grpcMaxCallRecvMsgSize int + grpcProxyListenAddr string + grpcProxyMetricsListenAddr string + grpcProxyEndpoints []string + grpcProxyDNSCluster string + grpcProxyDNSClusterServiceName string + grpcProxyInsecureDiscovery bool + grpcProxyDataDir string + grpcMaxCallSendMsgSize int + grpcMaxCallRecvMsgSize int // tls for connecting to etcd @@ -111,7 +112,8 @@ func newGRPCProxyStartCommand() *cobra.Command { } cmd.Flags().StringVar(&grpcProxyListenAddr, "listen-addr", "127.0.0.1:23790", "listen address") - cmd.Flags().StringVar(&grpcProxyDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster") + cmd.Flags().StringVar(&grpcProxyDNSCluster, "discovery-srv", "", "domain name to query for SRV records describing cluster endpoints") + cmd.Flags().StringVar(&grpcProxyDNSClusterServiceName, "discovery-srv-name", "", "service name to query when using DNS discovery") cmd.Flags().StringVar(&grpcProxyMetricsListenAddr, "metrics-addr", "", "listen for /metrics requests on an additional interface") cmd.Flags().BoolVar(&grpcProxyInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records") cmd.Flags().StringSliceVar(&grpcProxyEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints") @@ -249,7 +251,7 @@ func checkArgs() { } func mustNewClient(lg *zap.Logger) *clientv3.Client { - srvs := discoverEndpoints(lg, grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery) + srvs := discoverEndpoints(lg, grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery, grpcProxyDNSClusterServiceName) eps := srvs.Endpoints if len(eps) == 0 { eps = grpcProxyEndpoints diff --git a/etcdmain/util.go b/etcdmain/util.go index 15c686ab6..463ada65b 100644 --- a/etcdmain/util.go +++ b/etcdmain/util.go @@ -24,11 +24,11 @@ import ( "go.uber.org/zap" ) -func discoverEndpoints(lg *zap.Logger, dns string, ca string, insecure bool) (s srv.SRVClients) { +func discoverEndpoints(lg *zap.Logger, dns string, ca string, insecure bool, serviceName string) (s srv.SRVClients) { if dns == "" { return s } - srvs, err := srv.GetClient("etcd-client", dns) + srvs, err := srv.GetClient("etcd-client", dns, serviceName) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) diff --git a/pkg/srv/srv.go b/pkg/srv/srv.go index 9914104a7..c3560026d 100644 --- a/pkg/srv/srv.go +++ b/pkg/srv/srv.go @@ -96,7 +96,7 @@ type SRVClients struct { } // GetClient looks up the client endpoints for a service and domain. -func GetClient(service, domain string) (*SRVClients, error) { +func GetClient(service, domain string, serviceName string) (*SRVClients, error) { var urls []*url.URL var srvs []*net.SRV @@ -115,8 +115,8 @@ func GetClient(service, domain string) (*SRVClients, error) { return nil } - errHTTPS := updateURLs(service+"-ssl", "https") - errHTTP := updateURLs(service, "http") + errHTTPS := updateURLs(GetSRVService(service, serviceName, "https"), "https") + errHTTP := updateURLs(GetSRVService(service, serviceName, "http"), "http") if errHTTPS != nil && errHTTP != nil { return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP) @@ -128,3 +128,15 @@ func GetClient(service, domain string) (*SRVClients, error) { } return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil } + +// GetSRVService generates a SRV service including an optional suffix. +func GetSRVService(service, serviceName string, scheme string) (SRVService string) { + if scheme == "https" { + service = fmt.Sprintf("%s-ssl", service) + } + + if serviceName != "" { + return fmt.Sprintf("%s-%s", service, serviceName) + } + return service +} diff --git a/pkg/srv/srv_test.go b/pkg/srv/srv_test.go index 4ac2744ce..24a7cf22d 100644 --- a/pkg/srv/srv_test.go +++ b/pkg/srv/srv_test.go @@ -188,7 +188,7 @@ func TestSRVDiscover(t *testing.T) { return "", nil, errors.New("Unknown service in mock") } - srvs, err := GetClient("etcd-client", "example.com") + srvs, err := GetClient("etcd-client", "example.com", "") if err != nil { t.Fatalf("%d: err: %#v", i, err) } @@ -199,3 +199,40 @@ func TestSRVDiscover(t *testing.T) { } } + +func TestGetSRVService(t *testing.T) { + tests := []struct { + scheme string + serviceName string + + expected string + }{ + { + "https", + "", + "etcd-client-ssl", + }, + { + "http", + "", + "etcd-client", + }, + { + "https", + "foo", + "etcd-client-ssl-foo", + }, + { + "http", + "bar", + "etcd-client-bar", + }, + } + + for i, tt := range tests { + service := GetSRVService("etcd-client", tt.serviceName, tt.scheme) + if strings.Compare(service, tt.expected) != 0 { + t.Errorf("#%d: service = %s, want %s", i, service, tt.expected) + } + } +}