diff --git a/discovery/discovery.go b/discovery/discovery.go index fc3ea46fb..752ba02c9 100644 --- a/discovery/discovery.go +++ b/discovery/discovery.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "net/url" + "os" "path" "sort" "strconv" @@ -26,6 +27,8 @@ var ( ) const ( + // Environment variable used to configure an HTTP proxy for discovery + DiscoveryProxyEnv = "ETCD_DISCOVERY_PROXY" // Number of retries discovery will attempt before giving up and erroring out. nRetries = uint(3) ) @@ -46,6 +49,35 @@ type discovery struct { timeoutTimescale time.Duration } +// proxyFuncFromEnv builds a proxy function if the appropriate environment +// variable is set. It performs basic sanitization of the environment variable +// and returns any error encountered. +func proxyFuncFromEnv() (func(*http.Request) (*url.URL, error), error) { + proxy := os.Getenv(DiscoveryProxyEnv) + if proxy == "" { + return nil, nil + } + // Do a small amount of URL sanitization to help the user + // Derived from net/http.ProxyFromEnvironment + proxyURL, err := url.Parse(proxy) + if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") { + // proxy was bogus. Try prepending "http://" to it and + // see if that parses correctly. If not, we ignore the + // error and complain about the original one + var err2 error + proxyURL, err2 = url.Parse("http://" + proxy) + if err2 == nil { + err = nil + } + } + if err != nil { + return nil, fmt.Errorf("invalid proxy address %q: %v", proxy, err) + } + + log.Printf("discovery: using proxy %q", proxyURL.String()) + return http.ProxyURL(proxyURL), nil +} + func New(durl string, id uint64, config string) (Discoverer, error) { u, err := url.Parse(durl) if err != nil { @@ -53,7 +85,11 @@ func New(durl string, id uint64, config string) (Discoverer, error) { } token := u.Path u.Path = "" - c, err := client.NewHTTPClient(&http.Transport{}, u.String(), time.Second*5) + pf, err := proxyFuncFromEnv() + if err != nil { + return nil, err + } + c, err := client.NewHTTPClient(&http.Transport{Proxy: pf}, u.String(), time.Second*5) if err != nil { return nil, err } diff --git a/discovery/discovery_test.go b/discovery/discovery_test.go index bbcc84137..7912672e4 100644 --- a/discovery/discovery_test.go +++ b/discovery/discovery_test.go @@ -3,6 +3,8 @@ package discovery import ( "errors" "math/rand" + "net/http" + "os" "sort" "strconv" @@ -13,6 +15,61 @@ import ( "github.com/coreos/etcd/client" ) +func TestProxyFuncFromEnvUnset(t *testing.T) { + os.Setenv(DiscoveryProxyEnv, "") + pf, err := proxyFuncFromEnv() + if pf != nil { + t.Fatal("unexpected non-nil proxyFunc") + } + if err != nil { + t.Fatalf("unexpected non-nil err: %v", err) + } +} + +func TestProxyFuncFromEnvBad(t *testing.T) { + tests := []string{ + "%%", + "http://foo.com/%1", + } + for i, in := range tests { + os.Setenv(DiscoveryProxyEnv, in) + pf, err := proxyFuncFromEnv() + if pf != nil { + t.Errorf("#%d: unexpected non-nil proxyFunc", i) + } + if err == nil { + t.Errorf("#%d: unexpected nil err", i) + } + } +} + +func TestProxyFuncFromEnv(t *testing.T) { + tests := map[string]string{ + "bar.com": "http://bar.com", + "http://disco.foo.bar": "http://disco.foo.bar", + } + for in, w := range tests { + os.Setenv(DiscoveryProxyEnv, in) + pf, err := proxyFuncFromEnv() + if pf == nil { + t.Errorf("%s: unexpected nil proxyFunc", in) + continue + } + if err != nil { + t.Errorf("%s: unexpected non-nil err: %v", in, err) + continue + } + g, err := pf(&http.Request{}) + if err != nil { + t.Errorf("%s: unexpected non-nil err: %v", in, err) + } + if g.String() != w { + t.Errorf("%s: proxyURL=%q, want %q", g, w) + } + + } +} + func TestCheckCluster(t *testing.T) { cluster := "1000" self := "/1000/1"