*: add client support for discovery-srv-name

Signed-off-by: Sam Batschelet <sbatsche@redhat.com>
release-3.4
Sam Batschelet 2018-11-09 09:48:35 -05:00
parent 83304cfc80
commit fa35126ef8
No known key found for this signature in database
GPG Key ID: B3F3B946C6C228FC
9 changed files with 113 additions and 44 deletions

View File

@ -21,7 +21,7 @@ import (
// Discoverer is an interface that wraps the Discover method. // Discoverer is an interface that wraps the Discover method.
type Discoverer interface { type Discoverer interface {
// Discover looks up the etcd servers for the domain. // Discover looks up the etcd servers for the domain.
Discover(domain string) ([]string, error) Discover(domain string, serviceName string) ([]string, error)
} }
type srvDiscover struct{} type srvDiscover struct{}
@ -31,8 +31,8 @@ func NewSRVDiscover() Discoverer {
return &srvDiscover{} return &srvDiscover{}
} }
func (d *srvDiscover) Discover(domain string) ([]string, error) { func (d *srvDiscover) Discover(domain string, serviceName string) ([]string, error) {
srvs, err := srv.GetClient("etcd-client", domain) srvs, err := srv.GetClient("etcd-client", domain, serviceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -86,7 +86,7 @@ func getPeersFlagValue(c *cli.Context) []string {
} }
func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) { func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
domainstr, insecure := getDiscoveryDomain(c) domainstr, insecure, serviceName := getDiscoveryDomain(c)
// If we still don't have domain discovery, return nothing // If we still don't have domain discovery, return nothing
if domainstr == "" { if domainstr == "" {
@ -94,7 +94,7 @@ func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
} }
discoverer := client.NewSRVDiscover() discoverer := client.NewSRVDiscover()
eps, err := discoverer.Discover(domainstr) eps, err := discoverer.Discover(domainstr, serviceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -113,7 +113,7 @@ func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
return ret, err return ret, err
} }
func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool) { func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool, serviceName string) {
domainstr = c.GlobalString("discovery-srv") domainstr = c.GlobalString("discovery-srv")
// Use an environment variable if nothing was supplied on the // Use an environment variable if nothing was supplied on the
// command line // command line
@ -121,7 +121,11 @@ func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool) {
domainstr = os.Getenv("ETCDCTL_DISCOVERY_SRV") domainstr = os.Getenv("ETCDCTL_DISCOVERY_SRV")
} }
insecure = c.GlobalBool("insecure-discovery") || (os.Getenv("ETCDCTL_INSECURE_DISCOVERY") != "") insecure = c.GlobalBool("insecure-discovery") || (os.Getenv("ETCDCTL_INSECURE_DISCOVERY") != "")
return domainstr, insecure serviceName = c.GlobalString("discovery-srv-name")
if serviceName == "" {
serviceName = os.Getenv("ETCDCTL_DISCOVERY_SRV_NAME")
}
return domainstr, insecure, serviceName
} }
func getEndpoints(c *cli.Context) ([]string, error) { func getEndpoints(c *cli.Context) ([]string, error) {
@ -168,7 +172,7 @@ func getTransport(c *cli.Context) (*http.Transport, error) {
keyfile = os.Getenv("ETCDCTL_KEY_FILE") keyfile = os.Getenv("ETCDCTL_KEY_FILE")
} }
discoveryDomain, insecure := getDiscoveryDomain(c) discoveryDomain, insecure, _ := getDiscoveryDomain(c)
if insecure { if insecure {
discoveryDomain = "" discoveryDomain = ""
} }

View File

@ -39,14 +39,15 @@ import (
// GlobalFlags are flags that defined globally // GlobalFlags are flags that defined globally
// and are inherited to all sub-commands. // and are inherited to all sub-commands.
type GlobalFlags struct { type GlobalFlags struct {
Insecure bool Insecure bool
InsecureSkipVerify bool InsecureSkipVerify bool
InsecureDiscovery bool InsecureDiscovery bool
Endpoints []string Endpoints []string
DialTimeout time.Duration DialTimeout time.Duration
CommandTimeOut time.Duration CommandTimeOut time.Duration
KeepAliveTime time.Duration KeepAliveTime time.Duration
KeepAliveTimeout time.Duration KeepAliveTimeout time.Duration
DNSClusterServiceName string
TLS transport.TLSInfo TLS transport.TLSInfo
@ -75,8 +76,9 @@ type authCfg struct {
} }
type discoveryCfg struct { type discoveryCfg struct {
domain string domain string
insecure bool insecure bool
serviceName string
} }
var display printer = &simplePrinter{} var display printer = &simplePrinter{}
@ -390,10 +392,19 @@ func discoverySrvFromCmd(cmd *cobra.Command) string {
return domainStr return domainStr
} }
func discoveryDNSClusterServiceNameFromCmd(cmd *cobra.Command) string {
serviceNameStr, err := cmd.Flags().GetString("discovery-srv-name")
if err != nil {
ExitWithError(ExitBadArgs, err)
}
return serviceNameStr
}
func discoveryCfgFromCmd(cmd *cobra.Command) *discoveryCfg { func discoveryCfgFromCmd(cmd *cobra.Command) *discoveryCfg {
return &discoveryCfg{ return &discoveryCfg{
domain: discoverySrvFromCmd(cmd), domain: discoverySrvFromCmd(cmd),
insecure: insecureDiscoveryFromCmd(cmd), insecure: insecureDiscoveryFromCmd(cmd),
serviceName: discoveryDNSClusterServiceNameFromCmd(cmd),
} }
} }
@ -422,7 +433,7 @@ func endpointsFromFlagValue(cmd *cobra.Command) ([]string, error) {
return []string{}, nil return []string{}, nil
} }
srvs, err := srv.GetClient("etcd-client", discoveryCfg.domain) srvs, err := srv.GetClient("etcd-client", discoveryCfg.domain, discoveryCfg.serviceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -67,6 +67,7 @@ func init() {
rootCmd.PersistentFlags().StringVar(&globalFlags.User, "user", "", "username[:password] for authentication (prompt if password is not supplied)") rootCmd.PersistentFlags().StringVar(&globalFlags.User, "user", "", "username[:password] for authentication (prompt if password is not supplied)")
rootCmd.PersistentFlags().StringVar(&globalFlags.Password, "password", "", "password for authentication (if this option is used, --user option shouldn't include password)") rootCmd.PersistentFlags().StringVar(&globalFlags.Password, "password", "", "password for authentication (if this option is used, --user option shouldn't include password)")
rootCmd.PersistentFlags().StringVarP(&globalFlags.TLS.ServerName, "discovery-srv", "d", "", "domain name to query for SRV records describing cluster endpoints") rootCmd.PersistentFlags().StringVarP(&globalFlags.TLS.ServerName, "discovery-srv", "d", "", "domain name to query for SRV records describing cluster endpoints")
rootCmd.PersistentFlags().StringVarP(&globalFlags.DNSClusterServiceName, "discovery-srv-name", "", "", "service name to query when using DNS discovery")
rootCmd.AddCommand( rootCmd.AddCommand(
command.NewGetCommand(), command.NewGetCommand(),

View File

@ -28,12 +28,13 @@ import (
) )
var ( var (
gatewayListenAddr string gatewayListenAddr string
gatewayEndpoints []string gatewayEndpoints []string
gatewayDNSCluster string gatewayDNSCluster string
gatewayInsecureDiscovery bool gatewayDNSClusterServiceName string
getewayRetryDelay time.Duration gatewayInsecureDiscovery bool
gatewayCA string getewayRetryDelay time.Duration
gatewayCA string
) )
var ( var (
@ -68,6 +69,7 @@ func newGatewayStartCommand() *cobra.Command {
cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address") cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address")
cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster") cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster")
cmd.Flags().StringVar(&gatewayDNSClusterServiceName, "discovery-srv-name", "", "service name to query when using DNS discovery")
cmd.Flags().BoolVar(&gatewayInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records") cmd.Flags().BoolVar(&gatewayInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records")
cmd.Flags().StringVar(&gatewayCA, "trusted-ca-file", "", "path to the client server TLS CA file.") cmd.Flags().StringVar(&gatewayCA, "trusted-ca-file", "", "path to the client server TLS CA file.")
@ -97,7 +99,7 @@ func startGateway(cmd *cobra.Command, args []string) {
os.Exit(1) os.Exit(1)
} }
srvs := discoverEndpoints(lg, gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery) srvs := discoverEndpoints(lg, gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery, gatewayDNSClusterServiceName)
if len(srvs.Endpoints) == 0 { if len(srvs.Endpoints) == 0 {
// no endpoints discovered, fall back to provided endpoints // no endpoints discovered, fall back to provided endpoints
srvs.Endpoints = gatewayEndpoints srvs.Endpoints = gatewayEndpoints

View File

@ -49,14 +49,15 @@ import (
) )
var ( var (
grpcProxyListenAddr string grpcProxyListenAddr string
grpcProxyMetricsListenAddr string grpcProxyMetricsListenAddr string
grpcProxyEndpoints []string grpcProxyEndpoints []string
grpcProxyDNSCluster string grpcProxyDNSCluster string
grpcProxyInsecureDiscovery bool grpcProxyDNSClusterServiceName string
grpcProxyDataDir string grpcProxyInsecureDiscovery bool
grpcMaxCallSendMsgSize int grpcProxyDataDir string
grpcMaxCallRecvMsgSize int grpcMaxCallSendMsgSize int
grpcMaxCallRecvMsgSize int
// tls for connecting to etcd // tls for connecting to etcd
@ -111,7 +112,8 @@ func newGRPCProxyStartCommand() *cobra.Command {
} }
cmd.Flags().StringVar(&grpcProxyListenAddr, "listen-addr", "127.0.0.1:23790", "listen address") cmd.Flags().StringVar(&grpcProxyListenAddr, "listen-addr", "127.0.0.1:23790", "listen address")
cmd.Flags().StringVar(&grpcProxyDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster") cmd.Flags().StringVar(&grpcProxyDNSCluster, "discovery-srv", "", "domain name to query for SRV records describing cluster endpoints")
cmd.Flags().StringVar(&grpcProxyDNSClusterServiceName, "discovery-srv-name", "", "service name to query when using DNS discovery")
cmd.Flags().StringVar(&grpcProxyMetricsListenAddr, "metrics-addr", "", "listen for /metrics requests on an additional interface") cmd.Flags().StringVar(&grpcProxyMetricsListenAddr, "metrics-addr", "", "listen for /metrics requests on an additional interface")
cmd.Flags().BoolVar(&grpcProxyInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records") cmd.Flags().BoolVar(&grpcProxyInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records")
cmd.Flags().StringSliceVar(&grpcProxyEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints") cmd.Flags().StringSliceVar(&grpcProxyEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints")
@ -249,7 +251,7 @@ func checkArgs() {
} }
func mustNewClient(lg *zap.Logger) *clientv3.Client { func mustNewClient(lg *zap.Logger) *clientv3.Client {
srvs := discoverEndpoints(lg, grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery) srvs := discoverEndpoints(lg, grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery, grpcProxyDNSClusterServiceName)
eps := srvs.Endpoints eps := srvs.Endpoints
if len(eps) == 0 { if len(eps) == 0 {
eps = grpcProxyEndpoints eps = grpcProxyEndpoints

View File

@ -24,11 +24,11 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
func discoverEndpoints(lg *zap.Logger, dns string, ca string, insecure bool) (s srv.SRVClients) { func discoverEndpoints(lg *zap.Logger, dns string, ca string, insecure bool, serviceName string) (s srv.SRVClients) {
if dns == "" { if dns == "" {
return s return s
} }
srvs, err := srv.GetClient("etcd-client", dns) srvs, err := srv.GetClient("etcd-client", dns, serviceName)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
os.Exit(1) os.Exit(1)

View File

@ -96,7 +96,7 @@ type SRVClients struct {
} }
// GetClient looks up the client endpoints for a service and domain. // GetClient looks up the client endpoints for a service and domain.
func GetClient(service, domain string) (*SRVClients, error) { func GetClient(service, domain string, serviceName string) (*SRVClients, error) {
var urls []*url.URL var urls []*url.URL
var srvs []*net.SRV var srvs []*net.SRV
@ -115,8 +115,8 @@ func GetClient(service, domain string) (*SRVClients, error) {
return nil return nil
} }
errHTTPS := updateURLs(service+"-ssl", "https") errHTTPS := updateURLs(GetSRVService(service, serviceName, "https"), "https")
errHTTP := updateURLs(service, "http") errHTTP := updateURLs(GetSRVService(service, serviceName, "http"), "http")
if errHTTPS != nil && errHTTP != nil { if errHTTPS != nil && errHTTP != nil {
return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP) return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
@ -128,3 +128,15 @@ func GetClient(service, domain string) (*SRVClients, error) {
} }
return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil
} }
// GetSRVService generates a SRV service including an optional suffix.
func GetSRVService(service, serviceName string, scheme string) (SRVService string) {
if scheme == "https" {
service = fmt.Sprintf("%s-ssl", service)
}
if serviceName != "" {
return fmt.Sprintf("%s-%s", service, serviceName)
}
return service
}

View File

@ -188,7 +188,7 @@ func TestSRVDiscover(t *testing.T) {
return "", nil, errors.New("Unknown service in mock") return "", nil, errors.New("Unknown service in mock")
} }
srvs, err := GetClient("etcd-client", "example.com") srvs, err := GetClient("etcd-client", "example.com", "")
if err != nil { if err != nil {
t.Fatalf("%d: err: %#v", i, err) t.Fatalf("%d: err: %#v", i, err)
} }
@ -199,3 +199,40 @@ func TestSRVDiscover(t *testing.T) {
} }
} }
func TestGetSRVService(t *testing.T) {
tests := []struct {
scheme string
serviceName string
expected string
}{
{
"https",
"",
"etcd-client-ssl",
},
{
"http",
"",
"etcd-client",
},
{
"https",
"foo",
"etcd-client-ssl-foo",
},
{
"http",
"bar",
"etcd-client-bar",
},
}
for i, tt := range tests {
service := GetSRVService("etcd-client", tt.serviceName, tt.scheme)
if strings.Compare(service, tt.expected) != 0 {
t.Errorf("#%d: service = %s, want %s", i, service, tt.expected)
}
}
}