Merge pull request #7687 from heyitsanthony/deny-tls-ipsan
transport: deny incoming peer certs with wrong IP SANrelease-3.2
commit
1153e1e7d9
|
@ -201,7 +201,6 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for i, u := range cfg.LPUrls {
|
for i, u := range cfg.LPUrls {
|
||||||
var tlscfg *tls.Config
|
|
||||||
if u.Scheme == "http" {
|
if u.Scheme == "http" {
|
||||||
if !cfg.PeerTLSInfo.Empty() {
|
if !cfg.PeerTLSInfo.Empty() {
|
||||||
plog.Warningf("The scheme of peer url %s is HTTP while peer key/cert files are presented. Ignored peer key/cert files.", u.String())
|
plog.Warningf("The scheme of peer url %s is HTTP while peer key/cert files are presented. Ignored peer key/cert files.", u.String())
|
||||||
|
@ -210,12 +209,7 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) {
|
||||||
plog.Warningf("The scheme of peer url %s is HTTP while client cert auth (--peer-client-cert-auth) is enabled. Ignored client cert auth for this url.", u.String())
|
plog.Warningf("The scheme of peer url %s is HTTP while client cert auth (--peer-client-cert-auth) is enabled. Ignored client cert auth for this url.", u.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !cfg.PeerTLSInfo.Empty() {
|
if plns[i], err = rafthttp.NewListener(u, &cfg.PeerTLSInfo); err != nil {
|
||||||
if tlscfg, err = cfg.PeerTLSInfo.ServerConfig(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if plns[i], err = rafthttp.NewListener(u, tlscfg); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
plog.Info("listening for peers on ", u.String())
|
plog.Info("listening for peers on ", u.String())
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -305,18 +304,7 @@ func startProxy(cfg *config) error {
|
||||||
}
|
}
|
||||||
// Start a proxy server goroutine for each listen address
|
// Start a proxy server goroutine for each listen address
|
||||||
for _, u := range cfg.LCUrls {
|
for _, u := range cfg.LCUrls {
|
||||||
var (
|
l, err := transport.NewListener(u.Host, u.Scheme, &cfg.ClientTLSInfo)
|
||||||
l net.Listener
|
|
||||||
tlscfg *tls.Config
|
|
||||||
)
|
|
||||||
if !cfg.ClientTLSInfo.Empty() {
|
|
||||||
tlscfg, err = cfg.ClientTLSInfo.ServerConfig()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
l, err := transport.NewListener(u.Host, u.Scheme, tlscfg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -369,6 +357,11 @@ func identifyDataDirOrDie(dir string) dirType {
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupLogging(cfg *config) {
|
func setupLogging(cfg *config) {
|
||||||
|
cfg.ClientTLSInfo.HandshakeFailure = func(conn *tls.Conn, err error) {
|
||||||
|
plog.Infof("rejected connection from %q (%v)", conn.RemoteAddr().String(), err)
|
||||||
|
}
|
||||||
|
cfg.PeerTLSInfo.HandshakeFailure = cfg.ClientTLSInfo.HandshakeFailure
|
||||||
|
|
||||||
capnslog.SetGlobalLogLevel(capnslog.INFO)
|
capnslog.SetGlobalLogLevel(capnslog.INFO)
|
||||||
if cfg.Debug {
|
if cfg.Debug {
|
||||||
capnslog.SetGlobalLogLevel(capnslog.DEBUG)
|
capnslog.SetGlobalLogLevel(capnslog.DEBUG)
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -50,12 +49,12 @@ func TestNewKeepAliveListener(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// tls
|
// tls
|
||||||
tmp, err := createTempFile([]byte("XXX"))
|
tlsinfo, del, err := createSelfCert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create tmpfile: %v", err)
|
t.Fatalf("unable to create tmpfile: %v", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tmp)
|
defer del()
|
||||||
tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
|
tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile}
|
||||||
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
|
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
|
||||||
tlscfg, err := tlsInfo.ServerConfig()
|
tlscfg, err := tlsInfo.ServerConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -33,11 +33,11 @@ import (
|
||||||
"github.com/coreos/etcd/pkg/tlsutil"
|
"github.com/coreos/etcd/pkg/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewListener(addr, scheme string, tlscfg *tls.Config) (l net.Listener, err error) {
|
func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
|
||||||
if l, err = newListener(addr, scheme); err != nil {
|
if l, err = newListener(addr, scheme); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return wrapTLS(addr, scheme, tlscfg, l)
|
return wrapTLS(addr, scheme, tlsinfo, l)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newListener(addr string, scheme string) (net.Listener, error) {
|
func newListener(addr string, scheme string) (net.Listener, error) {
|
||||||
|
@ -48,15 +48,11 @@ func newListener(addr string, scheme string) (net.Listener, error) {
|
||||||
return net.Listen("tcp", addr)
|
return net.Listen("tcp", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func wrapTLS(addr, scheme string, tlscfg *tls.Config, l net.Listener) (net.Listener, error) {
|
func wrapTLS(addr, scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
|
||||||
if scheme != "https" && scheme != "unixs" {
|
if scheme != "https" && scheme != "unixs" {
|
||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
if tlscfg == nil {
|
return newTLSListener(l, tlsinfo)
|
||||||
l.Close()
|
|
||||||
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
|
|
||||||
}
|
|
||||||
return tls.NewListener(l, tlscfg), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TLSInfo struct {
|
type TLSInfo struct {
|
||||||
|
@ -69,6 +65,10 @@ type TLSInfo struct {
|
||||||
// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
|
// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
|
||||||
ServerName string
|
ServerName string
|
||||||
|
|
||||||
|
// HandshakeFailure is optinally called when a connection fails to handshake. The
|
||||||
|
// connection will be closed immediately afterwards.
|
||||||
|
HandshakeFailure func(*tls.Conn, error)
|
||||||
|
|
||||||
selfCert bool
|
selfCert bool
|
||||||
|
|
||||||
// parseFunc exists to simplify testing. Typically, parseFunc
|
// parseFunc exists to simplify testing. Typically, parseFunc
|
||||||
|
|
|
@ -24,18 +24,16 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createTempFile(b []byte) (string, error) {
|
func createSelfCert() (*TLSInfo, func(), error) {
|
||||||
f, err := ioutil.TempFile("", "etcd-test-tls-")
|
d, terr := ioutil.TempDir("", "etcd-test-tls-")
|
||||||
|
if terr != nil {
|
||||||
|
return nil, nil, terr
|
||||||
|
}
|
||||||
|
info, err := SelfCert(d, []string{"127.0.0.1"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
return &info, func() { os.RemoveAll(d) }, nil
|
||||||
|
|
||||||
if _, err = f.Write(b); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Name(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) {
|
func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) {
|
||||||
|
@ -47,28 +45,25 @@ func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBloc
|
||||||
// TestNewListenerTLSInfo tests that NewListener with valid TLSInfo returns
|
// TestNewListenerTLSInfo tests that NewListener with valid TLSInfo returns
|
||||||
// a TLS listener that accepts TLS connections.
|
// a TLS listener that accepts TLS connections.
|
||||||
func TestNewListenerTLSInfo(t *testing.T) {
|
func TestNewListenerTLSInfo(t *testing.T) {
|
||||||
tmp, err := createTempFile([]byte("XXX"))
|
tlsInfo, del, err := createSelfCert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create tmpfile: %v", err)
|
t.Fatalf("unable to create cert: %v", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tmp)
|
defer del()
|
||||||
tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
|
testNewListenerTLSInfoAccept(t, *tlsInfo)
|
||||||
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
|
|
||||||
testNewListenerTLSInfoAccept(t, tlsInfo)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
|
func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
|
||||||
tlscfg, err := tlsInfo.ServerConfig()
|
ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo)
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected serverConfig error: %v", err)
|
|
||||||
}
|
|
||||||
ln, err := NewListener("127.0.0.1:0", "https", tlscfg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected NewListener error: %v", err)
|
t.Fatalf("unexpected NewListener error: %v", err)
|
||||||
}
|
}
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
go http.Get("https://" + ln.Addr().String())
|
tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||||
|
cli := &http.Client{Transport: tr}
|
||||||
|
go cli.Get("https://" + ln.Addr().String())
|
||||||
|
|
||||||
conn, err := ln.Accept()
|
conn, err := ln.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected Accept error: %v", err)
|
t.Fatalf("unexpected Accept error: %v", err)
|
||||||
|
@ -87,25 +82,25 @@ func TestNewListenerTLSEmptyInfo(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewTransportTLSInfo(t *testing.T) {
|
func TestNewTransportTLSInfo(t *testing.T) {
|
||||||
tmp, err := createTempFile([]byte("XXX"))
|
tlsinfo, del, err := createSelfCert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to prepare tmpfile: %v", err)
|
t.Fatalf("unable to create cert: %v", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tmp)
|
defer del()
|
||||||
|
|
||||||
tests := []TLSInfo{
|
tests := []TLSInfo{
|
||||||
{},
|
{},
|
||||||
{
|
{
|
||||||
CertFile: tmp,
|
CertFile: tlsinfo.CertFile,
|
||||||
KeyFile: tmp,
|
KeyFile: tlsinfo.KeyFile,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
CertFile: tmp,
|
CertFile: tlsinfo.CertFile,
|
||||||
KeyFile: tmp,
|
KeyFile: tlsinfo.KeyFile,
|
||||||
CAFile: tmp,
|
CAFile: tlsinfo.CAFile,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
CAFile: tmp,
|
CAFile: tlsinfo.CAFile,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,17 +154,17 @@ func TestTLSInfoEmpty(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSInfoMissingFields(t *testing.T) {
|
func TestTLSInfoMissingFields(t *testing.T) {
|
||||||
tmp, err := createTempFile([]byte("XXX"))
|
tlsinfo, del, err := createSelfCert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to prepare tmpfile: %v", err)
|
t.Fatalf("unable to create cert: %v", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tmp)
|
defer del()
|
||||||
|
|
||||||
tests := []TLSInfo{
|
tests := []TLSInfo{
|
||||||
{CertFile: tmp},
|
{CertFile: tlsinfo.CertFile},
|
||||||
{KeyFile: tmp},
|
{KeyFile: tlsinfo.KeyFile},
|
||||||
{CertFile: tmp, CAFile: tmp},
|
{CertFile: tlsinfo.CertFile, CAFile: tlsinfo.CAFile},
|
||||||
{KeyFile: tmp, CAFile: tmp},
|
{KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CAFile},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, info := range tests {
|
for i, info := range tests {
|
||||||
|
@ -184,30 +179,29 @@ func TestTLSInfoMissingFields(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSInfoParseFuncError(t *testing.T) {
|
func TestTLSInfoParseFuncError(t *testing.T) {
|
||||||
tmp, err := createTempFile([]byte("XXX"))
|
tlsinfo, del, err := createSelfCert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to prepare tmpfile: %v", err)
|
t.Fatalf("unable to create cert: %v", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tmp)
|
defer del()
|
||||||
|
|
||||||
info := TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp}
|
tlsinfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake"))
|
||||||
info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake"))
|
|
||||||
|
|
||||||
if _, err = info.ServerConfig(); err == nil {
|
if _, err = tlsinfo.ServerConfig(); err == nil {
|
||||||
t.Errorf("expected non-nil error from ServerConfig()")
|
t.Errorf("expected non-nil error from ServerConfig()")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = info.ClientConfig(); err == nil {
|
if _, err = tlsinfo.ClientConfig(); err == nil {
|
||||||
t.Errorf("expected non-nil error from ClientConfig()")
|
t.Errorf("expected non-nil error from ClientConfig()")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSInfoConfigFuncs(t *testing.T) {
|
func TestTLSInfoConfigFuncs(t *testing.T) {
|
||||||
tmp, err := createTempFile([]byte("XXX"))
|
tlsinfo, del, err := createSelfCert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unable to prepare tmpfile: %v", err)
|
t.Fatalf("unable to create cert: %v", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tmp)
|
defer del()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
info TLSInfo
|
info TLSInfo
|
||||||
|
@ -215,13 +209,13 @@ func TestTLSInfoConfigFuncs(t *testing.T) {
|
||||||
wantCAs bool
|
wantCAs bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
info: TLSInfo{CertFile: tmp, KeyFile: tmp},
|
info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile},
|
||||||
clientAuth: tls.NoClientCert,
|
clientAuth: tls.NoClientCert,
|
||||||
wantCAs: false,
|
wantCAs: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
info: TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp},
|
info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CertFile},
|
||||||
clientAuth: tls.RequireAndVerifyClientCert,
|
clientAuth: tls.RequireAndVerifyClientCert,
|
||||||
wantCAs: true,
|
wantCAs: true,
|
||||||
},
|
},
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
// Copyright 2017 The etcd Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// tlsListener overrides a TLS listener so it will reject client
|
||||||
|
// certificates with insufficient SAN credentials.
|
||||||
|
type tlsListener struct {
|
||||||
|
net.Listener
|
||||||
|
connc chan net.Conn
|
||||||
|
donec chan struct{}
|
||||||
|
err error
|
||||||
|
handshakeFailure func(*tls.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
|
||||||
|
if tlsinfo == nil || tlsinfo.Empty() {
|
||||||
|
l.Close()
|
||||||
|
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
|
||||||
|
}
|
||||||
|
tlscfg, err := tlsinfo.ServerConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsl := &tlsListener{
|
||||||
|
Listener: tls.NewListener(l, tlscfg),
|
||||||
|
connc: make(chan net.Conn),
|
||||||
|
donec: make(chan struct{}),
|
||||||
|
handshakeFailure: tlsinfo.HandshakeFailure,
|
||||||
|
}
|
||||||
|
go tlsl.acceptLoop()
|
||||||
|
return tlsl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *tlsListener) Accept() (net.Conn, error) {
|
||||||
|
select {
|
||||||
|
case conn := <-l.connc:
|
||||||
|
return conn, nil
|
||||||
|
case <-l.donec:
|
||||||
|
return nil, l.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptLoop launches each TLS handshake in a separate goroutine
|
||||||
|
// to prevent a hanging TLS connection from blocking other connections.
|
||||||
|
func (l *tlsListener) acceptLoop() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var pendingMu sync.Mutex
|
||||||
|
|
||||||
|
pending := make(map[net.Conn]struct{})
|
||||||
|
stopc := make(chan struct{})
|
||||||
|
defer func() {
|
||||||
|
close(stopc)
|
||||||
|
pendingMu.Lock()
|
||||||
|
for c := range pending {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
pendingMu.Unlock()
|
||||||
|
wg.Wait()
|
||||||
|
close(l.donec)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := l.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
l.err = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingMu.Lock()
|
||||||
|
pending[conn] = struct{}{}
|
||||||
|
pendingMu.Unlock()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if conn != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
tlsConn := conn.(*tls.Conn)
|
||||||
|
herr := tlsConn.Handshake()
|
||||||
|
pendingMu.Lock()
|
||||||
|
delete(pending, conn)
|
||||||
|
pendingMu.Unlock()
|
||||||
|
if herr != nil {
|
||||||
|
if l.handshakeFailure != nil {
|
||||||
|
l.handshakeFailure(tlsConn, herr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
st := tlsConn.ConnectionState()
|
||||||
|
if len(st.PeerCertificates) > 0 {
|
||||||
|
cert := st.PeerCertificates[0]
|
||||||
|
if len(cert.IPAddresses) > 0 || len(cert.DNSNames) > 0 {
|
||||||
|
addr := tlsConn.RemoteAddr().String()
|
||||||
|
h, _, herr := net.SplitHostPort(addr)
|
||||||
|
if herr != nil || cert.VerifyHostname(h) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case l.connc <- tlsConn:
|
||||||
|
conn = nil
|
||||||
|
case <-stopc:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *tlsListener) Close() error {
|
||||||
|
err := l.Listener.Close()
|
||||||
|
<-l.donec
|
||||||
|
return err
|
||||||
|
}
|
|
@ -15,7 +15,6 @@
|
||||||
package transport
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -23,7 +22,7 @@ import (
|
||||||
// NewTimeoutListener returns a listener that listens on the given address.
|
// NewTimeoutListener returns a listener that listens on the given address.
|
||||||
// If read/write on the accepted connection blocks longer than its time limit,
|
// If read/write on the accepted connection blocks longer than its time limit,
|
||||||
// it will return timeout error.
|
// it will return timeout error.
|
||||||
func NewTimeoutListener(addr string, scheme string, tlscfg *tls.Config, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
|
func NewTimeoutListener(addr string, scheme string, tlsinfo *TLSInfo, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
|
||||||
ln, err := newListener(addr, scheme)
|
ln, err := newListener(addr, scheme)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -33,7 +32,7 @@ func NewTimeoutListener(addr string, scheme string, tlscfg *tls.Config, rdtimeou
|
||||||
rdtimeoutd: rdtimeoutd,
|
rdtimeoutd: rdtimeoutd,
|
||||||
wtimeoutd: wtimeoutd,
|
wtimeoutd: wtimeoutd,
|
||||||
}
|
}
|
||||||
if ln, err = wrapTLS(addr, scheme, tlscfg, ln); err != nil {
|
if ln, err = wrapTLS(addr, scheme, tlsinfo, ln); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return ln, nil
|
return ln, nil
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
package rafthttp
|
package rafthttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -37,8 +36,8 @@ var (
|
||||||
|
|
||||||
// NewListener returns a listener for raft message transfer between peers.
|
// NewListener returns a listener for raft message transfer between peers.
|
||||||
// It uses timeout listener to identify broken streams promptly.
|
// It uses timeout listener to identify broken streams promptly.
|
||||||
func NewListener(u url.URL, tlscfg *tls.Config) (net.Listener, error) {
|
func NewListener(u url.URL, tlsinfo *transport.TLSInfo) (net.Listener, error) {
|
||||||
return transport.NewTimeoutListener(u.Host, u.Scheme, tlscfg, ConnReadTimeout, ConnWriteTimeout)
|
return transport.NewTimeoutListener(u.Host, u.Scheme, tlsinfo, ConnReadTimeout, ConnWriteTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRoundTripper returns a roundTripper used to send requests
|
// NewRoundTripper returns a roundTripper used to send requests
|
||||||
|
|
Loading…
Reference in New Issue