// Copyright 2017 The etcd Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package transport import ( "context" "crypto/tls" "crypto/x509" "fmt" "io/ioutil" "net" "strings" "sync" ) // tlsListener overrides a TLS listener so it will reject client // certificates with insufficient SAN credentials or CRL revoked // certificates. type tlsListener struct { net.Listener connc chan net.Conn donec chan struct{} err error handshakeFailure func(*tls.Conn, error) check tlsCheckFunc } type tlsCheckFunc func(context.Context, *tls.Conn) error // NewTLSListener handshakes TLS connections and performs optional CRL checking. func NewTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) { check := func(context.Context, *tls.Conn) error { return nil } return newTLSListener(l, tlsinfo, check) } func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.Listener, error) { if tlsinfo == nil || tlsinfo.Empty() { l.Close() return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String()) } tlscfg, err := tlsinfo.ServerConfig() if err != nil { return nil, err } hf := tlsinfo.HandshakeFailure if hf == nil { hf = func(*tls.Conn, error) {} } if len(tlsinfo.CRLFile) > 0 { prevCheck := check check = func(ctx context.Context, tlsConn *tls.Conn) error { if err := prevCheck(ctx, tlsConn); err != nil { return err } st := tlsConn.ConnectionState() if certs := st.PeerCertificates; len(certs) > 0 { return checkCRL(tlsinfo.CRLFile, certs) } return nil } } tlsl := &tlsListener{ Listener: tls.NewListener(l, tlscfg), connc: make(chan net.Conn), donec: make(chan struct{}), handshakeFailure: hf, check: check, } go tlsl.acceptLoop() return tlsl, nil } func (l *tlsListener) Accept() (net.Conn, error) { select { case conn := <-l.connc: return conn, nil case <-l.donec: return nil, l.err } } func checkSAN(ctx context.Context, tlsConn *tls.Conn) error { st := tlsConn.ConnectionState() if certs := st.PeerCertificates; len(certs) > 0 { addr := tlsConn.RemoteAddr().String() return checkCertSAN(ctx, certs[0], addr) } return nil } // acceptLoop launches each TLS handshake in a separate goroutine // to prevent a hanging TLS connection from blocking other connections. func (l *tlsListener) acceptLoop() { var wg sync.WaitGroup var pendingMu sync.Mutex pending := make(map[net.Conn]struct{}) ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() pendingMu.Lock() for c := range pending { c.Close() } pendingMu.Unlock() wg.Wait() close(l.donec) }() for { conn, err := l.Listener.Accept() if err != nil { l.err = err return } pendingMu.Lock() pending[conn] = struct{}{} pendingMu.Unlock() wg.Add(1) go func() { defer func() { if conn != nil { conn.Close() } wg.Done() }() tlsConn := conn.(*tls.Conn) herr := tlsConn.Handshake() pendingMu.Lock() delete(pending, conn) pendingMu.Unlock() if herr != nil { l.handshakeFailure(tlsConn, herr) return } if err := l.check(ctx, tlsConn); err != nil { l.handshakeFailure(tlsConn, err) return } select { case l.connc <- tlsConn: conn = nil case <-ctx.Done(): } }() } } func checkCRL(crlPath string, cert []*x509.Certificate) error { // TODO: cache crlBytes, err := ioutil.ReadFile(crlPath) if err != nil { return err } certList, err := x509.ParseCRL(crlBytes) if err != nil { return err } revokedSerials := make(map[string]struct{}) for _, rc := range certList.TBSCertList.RevokedCertificates { revokedSerials[string(rc.SerialNumber.Bytes())] = struct{}{} } for _, c := range cert { serial := string(c.SerialNumber.Bytes()) if _, ok := revokedSerials[serial]; ok { return fmt.Errorf("transport: certificate serial %x revoked", serial) } } return nil } func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string) error { if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 { return nil } h, _, herr := net.SplitHostPort(remoteAddr) if herr != nil { return herr } if len(cert.IPAddresses) > 0 { cerr := cert.VerifyHostname(h) if cerr == nil { return nil } if len(cert.DNSNames) == 0 { return cerr } } if len(cert.DNSNames) > 0 { ok, err := isHostInDNS(ctx, h, cert.DNSNames) if ok { return nil } errStr := "" if err != nil { errStr = " (" + err.Error() + ")" } return fmt.Errorf("tls: %q does not match any of DNSNames %q"+errStr, h, cert.DNSNames) } return nil } func isHostInDNS(ctx context.Context, host string, dnsNames []string) (ok bool, err error) { // reverse lookup wildcards, names := []string{}, []string{} for _, dns := range dnsNames { if strings.HasPrefix(dns, "*.") { wildcards = append(wildcards, dns[1:]) } else { names = append(names, dns) } } lnames, lerr := net.DefaultResolver.LookupAddr(ctx, host) for _, name := range lnames { // strip trailing '.' from PTR record if name[len(name)-1] == '.' { name = name[:len(name)-1] } for _, wc := range wildcards { if strings.HasSuffix(name, wc) { return true, nil } } for _, n := range names { if n == name { return true, nil } } } err = lerr // forward lookup for _, dns := range names { addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns) if lerr != nil { err = lerr continue } for _, addr := range addrs { if addr == host { return true, nil } } } return false, err } func (l *tlsListener) Close() error { err := l.Listener.Close() <-l.donec return err }