diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index 8a23abab6..b4172dde7 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -46,6 +46,11 @@ func NewListener(addr string, info TLSInfo) (net.Listener, error) { } func NewTransport(info TLSInfo) (*http.Transport, error) { + cfg, err := info.ClientConfig() + if err != nil { + return nil, err + } + t := &http.Transport{ // timeouts taken from http.DefaultTransport Dial: (&net.Dialer{ @@ -53,14 +58,7 @@ func NewTransport(info TLSInfo) (*http.Transport, error) { KeepAlive: 30 * time.Second, }).Dial, TLSHandshakeTimeout: 10 * time.Second, - } - - if !info.Empty() { - tlsCfg, err := info.ClientConfig() - if err != nil { - return nil, err - } - t.TLSClientConfig = tlsCfg + TLSClientConfig: cfg, } return t, nil @@ -134,22 +132,24 @@ func (info TLSInfo) ServerConfig() (*tls.Config, error) { } // ClientConfig generates a tls.Config object for use by an HTTP client -func (info TLSInfo) ClientConfig() (*tls.Config, error) { - cfg, err := info.baseConfig() - if err != nil { - return nil, err - } - - if info.CAFile != "" { - cp, err := newCertPool(info.CAFile) +func (info TLSInfo) ClientConfig() (cfg *tls.Config, err error) { + if !info.Empty() { + cfg, err = info.baseConfig() if err != nil { return nil, err } - - cfg.RootCAs = cp + } else { + cfg = &tls.Config{} } - return cfg, nil + if info.CAFile != "" { + cfg.RootCAs, err = newCertPool(info.CAFile) + if err != nil { + return + } + } + + return } // newCertPool creates x509 certPool with provided CA file diff --git a/pkg/transport/listener_test.go b/pkg/transport/listener_test.go index 9745b0900..8d18460b1 100644 --- a/pkg/transport/listener_test.go +++ b/pkg/transport/listener_test.go @@ -51,41 +51,31 @@ func TestNewTransportTLSInfo(t *testing.T) { } defer os.Remove(tmp) - tests := []struct { - info TLSInfo - wantTLSClientConfig bool - }{ - { - info: TLSInfo{}, - wantTLSClientConfig: false, + tests := []TLSInfo{ + TLSInfo{}, + TLSInfo{ + CertFile: tmp, + KeyFile: tmp, }, - { - info: TLSInfo{ - CertFile: tmp, - KeyFile: tmp, - }, - wantTLSClientConfig: true, + TLSInfo{ + CertFile: tmp, + KeyFile: tmp, + CAFile: tmp, }, - { - info: TLSInfo{ - CertFile: tmp, - KeyFile: tmp, - CAFile: tmp, - }, - wantTLSClientConfig: true, + TLSInfo{ + CAFile: tmp, }, } for i, tt := range tests { - tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) - trans, err := NewTransport(tt.info) + tt.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) + trans, err := NewTransport(tt) if err != nil { t.Fatalf("Received unexpected error from NewTransport: %v", err) } - gotTLSClientConfig := trans.TLSClientConfig != nil - if tt.wantTLSClientConfig != gotTLSClientConfig { - t.Fatalf("#%d: wantTLSClientConfig=%t but gotTLSClientConfig=%t", i, tt.wantTLSClientConfig, gotTLSClientConfig) + if trans.TLSClientConfig == nil { + t.Fatalf("#%d: want non-nil TLSClientConfig", i) } } } @@ -121,8 +111,6 @@ func TestTLSInfoMissingFields(t *testing.T) { defer os.Remove(tmp) tests := []TLSInfo{ - TLSInfo{}, - TLSInfo{CAFile: tmp}, TLSInfo{CertFile: tmp}, TLSInfo{KeyFile: tmp}, TLSInfo{CertFile: tmp, CAFile: tmp},