pkg/transport: generate TLS client config w/ only CAFile

release-2.0
Brian Waldon 2014-11-06 12:10:04 -08:00
parent f4ea274555
commit 902f06c5c4
2 changed files with 34 additions and 46 deletions

View File

@ -46,6 +46,11 @@ func NewListener(addr string, info TLSInfo) (net.Listener, error) {
}
func NewTransport(info TLSInfo) (*http.Transport, error) {
cfg, err := info.ClientConfig()
if err != nil {
return nil, err
}
t := &http.Transport{
// timeouts taken from http.DefaultTransport
Dial: (&net.Dialer{
@ -53,14 +58,7 @@ func NewTransport(info TLSInfo) (*http.Transport, error) {
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
}
if !info.Empty() {
tlsCfg, err := info.ClientConfig()
if err != nil {
return nil, err
}
t.TLSClientConfig = tlsCfg
TLSClientConfig: cfg,
}
return t, nil
@ -134,22 +132,24 @@ func (info TLSInfo) ServerConfig() (*tls.Config, error) {
}
// ClientConfig generates a tls.Config object for use by an HTTP client
func (info TLSInfo) ClientConfig() (*tls.Config, error) {
cfg, err := info.baseConfig()
if err != nil {
return nil, err
}
if info.CAFile != "" {
cp, err := newCertPool(info.CAFile)
func (info TLSInfo) ClientConfig() (cfg *tls.Config, err error) {
if !info.Empty() {
cfg, err = info.baseConfig()
if err != nil {
return nil, err
}
cfg.RootCAs = cp
} else {
cfg = &tls.Config{}
}
return cfg, nil
if info.CAFile != "" {
cfg.RootCAs, err = newCertPool(info.CAFile)
if err != nil {
return
}
}
return
}
// newCertPool creates x509 certPool with provided CA file

View File

@ -51,41 +51,31 @@ func TestNewTransportTLSInfo(t *testing.T) {
}
defer os.Remove(tmp)
tests := []struct {
info TLSInfo
wantTLSClientConfig bool
}{
{
info: TLSInfo{},
wantTLSClientConfig: false,
tests := []TLSInfo{
TLSInfo{},
TLSInfo{
CertFile: tmp,
KeyFile: tmp,
},
{
info: TLSInfo{
CertFile: tmp,
KeyFile: tmp,
},
wantTLSClientConfig: true,
TLSInfo{
CertFile: tmp,
KeyFile: tmp,
CAFile: tmp,
},
{
info: TLSInfo{
CertFile: tmp,
KeyFile: tmp,
CAFile: tmp,
},
wantTLSClientConfig: true,
TLSInfo{
CAFile: tmp,
},
}
for i, tt := range tests {
tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
trans, err := NewTransport(tt.info)
tt.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
trans, err := NewTransport(tt)
if err != nil {
t.Fatalf("Received unexpected error from NewTransport: %v", err)
}
gotTLSClientConfig := trans.TLSClientConfig != nil
if tt.wantTLSClientConfig != gotTLSClientConfig {
t.Fatalf("#%d: wantTLSClientConfig=%t but gotTLSClientConfig=%t", i, tt.wantTLSClientConfig, gotTLSClientConfig)
if trans.TLSClientConfig == nil {
t.Fatalf("#%d: want non-nil TLSClientConfig", i)
}
}
}
@ -121,8 +111,6 @@ func TestTLSInfoMissingFields(t *testing.T) {
defer os.Remove(tmp)
tests := []TLSInfo{
TLSInfo{},
TLSInfo{CAFile: tmp},
TLSInfo{CertFile: tmp},
TLSInfo{KeyFile: tmp},
TLSInfo{CertFile: tmp, CAFile: tmp},