transport: wrap net.Listener with TLSInfo
parent
a782a1a7d1
commit
17459c7bfc
6
main.go
6
main.go
|
@ -168,7 +168,7 @@ func startEtcd() {
|
||||||
Info: cors,
|
Info: cors,
|
||||||
}
|
}
|
||||||
|
|
||||||
l, err := transport.NewListener(*paddr)
|
l, err := transport.NewListener(*paddr, transport.TLSInfo{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -182,7 +182,7 @@ func startEtcd() {
|
||||||
// Start a client server goroutine for each listen address
|
// Start a client server goroutine for each listen address
|
||||||
for _, addr := range *addrs {
|
for _, addr := range *addrs {
|
||||||
addr := addr
|
addr := addr
|
||||||
l, err := transport.NewListener(addr)
|
l, err := transport.NewListener(addr, transport.TLSInfo{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -212,7 +212,7 @@ func startProxy() {
|
||||||
// Start a proxy server goroutine for each listen address
|
// Start a proxy server goroutine for each listen address
|
||||||
for _, addr := range *addrs {
|
for _, addr := range *addrs {
|
||||||
addr := addr
|
addr := addr
|
||||||
l, err := transport.NewListener(addr)
|
l, err := transport.NewListener(addr, transport.TLSInfo{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
2
test
2
test
|
@ -15,7 +15,7 @@ COVER=${COVER:-"-cover"}
|
||||||
source ./build
|
source ./build
|
||||||
|
|
||||||
# Hack: gofmt ./ will recursively check the .git directory. So use *.go for gofmt.
|
# Hack: gofmt ./ will recursively check the .git directory. So use *.go for gofmt.
|
||||||
TESTABLE_AND_FORMATTABLE="client etcdserver etcdserver/etcdhttp etcdserver/etcdserverpb functional proxy raft snap store wait wal"
|
TESTABLE_AND_FORMATTABLE="client etcdserver etcdserver/etcdhttp etcdserver/etcdserverpb functional proxy raft snap store wait wal transport"
|
||||||
TESTABLE="$TESTABLE_AND_FORMATTABLE ./"
|
TESTABLE="$TESTABLE_AND_FORMATTABLE ./"
|
||||||
FORMATTABLE="$TESTABLE_AND_FORMATTABLE *.go"
|
FORMATTABLE="$TESTABLE_AND_FORMATTABLE *.go"
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,92 @@
|
||||||
package transport
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewListener(addr string) (net.Listener, error) {
|
func NewListener(addr string, info TLSInfo) (net.Listener, error) {
|
||||||
return net.Listen("tcp", addr)
|
l, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !info.Empty() {
|
||||||
|
cfg, err := info.ServerConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
l = tls.NewListener(l, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TLSInfo struct {
|
||||||
|
CertFile string
|
||||||
|
KeyFile string
|
||||||
|
CAFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info TLSInfo) Empty() bool {
|
||||||
|
return info.CertFile == "" && info.KeyFile == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates a tls.Config object for a server from the given files.
|
||||||
|
func (info TLSInfo) ServerConfig() (*tls.Config, error) {
|
||||||
|
// Both the key and cert must be present.
|
||||||
|
if info.KeyFile == "" || info.CertFile == "" {
|
||||||
|
return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg tls.Config
|
||||||
|
|
||||||
|
tlsCert, err := tls.LoadX509KeyPair(info.CertFile, info.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Certificates = []tls.Certificate{tlsCert}
|
||||||
|
|
||||||
|
if info.CAFile != "" {
|
||||||
|
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
cp, err := newCertPool(info.CAFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.RootCAs = cp
|
||||||
|
cfg.ClientCAs = cp
|
||||||
|
} else {
|
||||||
|
cfg.ClientAuth = tls.NoClientCert
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCertPool creates x509 certPool with provided CA file
|
||||||
|
func newCertPool(CAFile string) (*x509.CertPool, error) {
|
||||||
|
certPool := x509.NewCertPool()
|
||||||
|
pemByte, err := ioutil.ReadFile(CAFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
var block *pem.Block
|
||||||
|
block, pemByte = pem.Decode(pemByte)
|
||||||
|
if block == nil {
|
||||||
|
return certPool, nil
|
||||||
|
}
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
certPool.AddCert(cert)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue