diff --git a/config/config.go b/config/config.go index b34e56017..dcba45037 100644 --- a/config/config.go +++ b/config/config.go @@ -390,8 +390,8 @@ func (c *Config) Sanitize() error { } // EtcdTLSInfo retrieves a TLSInfo object for the etcd server -func (c *Config) EtcdTLSInfo() server.TLSInfo { - return server.TLSInfo{ +func (c *Config) EtcdTLSInfo() *server.TLSInfo { + return &server.TLSInfo{ CAFile: c.CAFile, CertFile: c.CertFile, KeyFile: c.KeyFile, @@ -399,8 +399,8 @@ func (c *Config) EtcdTLSInfo() server.TLSInfo { } // PeerRaftInfo retrieves a TLSInfo object for the peer server. -func (c *Config) PeerTLSInfo() server.TLSInfo { - return server.TLSInfo{ +func (c *Config) PeerTLSInfo() *server.TLSInfo { + return &server.TLSInfo{ CAFile: c.Peer.CAFile, CertFile: c.Peer.CertFile, KeyFile: c.Peer.KeyFile, diff --git a/etcd.go b/etcd.go index 4d074fc5d..d6eafda4a 100644 --- a/etcd.go +++ b/etcd.go @@ -18,7 +18,6 @@ package main import ( "fmt" - "net" "net/http" "os" "path/filepath" @@ -126,24 +125,6 @@ func main() { } ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats) - var psListener net.Listener - if psConfig.Scheme == "https" { - peerServerTLSConfig, err := config.PeerTLSInfo().ServerConfig() - if err != nil { - log.Fatal("peer server TLS error: ", err) - } - - psListener, err = server.NewTLSListener(config.Peer.BindAddr, peerServerTLSConfig) - if err != nil { - log.Fatal("Failed to create peer listener: ", err) - } - } else { - psListener, err = server.NewListener(config.Peer.BindAddr) - if err != nil { - log.Fatal("Failed to create peer listener: ", err) - } - } - // Create raft transporter and server raftTransporter := server.NewTransporter(followersStats, serverStats, registry, heartbeatInterval, dialTimeout, responseHeaderTimeout) if psConfig.Scheme == "https" { @@ -168,34 +149,19 @@ func main() { s.EnableTracing() } - var sListener net.Listener - if config.EtcdTLSInfo().Scheme() == "https" { - etcdServerTLSConfig, err := config.EtcdTLSInfo().ServerConfig() - if err != nil { - log.Fatal("etcd TLS error: ", err) - } - - sListener, err = server.NewTLSListener(config.BindAddr, etcdServerTLSConfig) - if err != nil { - log.Fatal("Failed to create TLS etcd listener: ", err) - } - } else { - sListener, err = server.NewListener(config.BindAddr) - if err != nil { - log.Fatal("Failed to create etcd listener: ", err) - } - } - ps.SetServer(s) ps.Start(config.Snapshot, config.Discovery, config.Peers) go func() { - log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL) + log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, config.Peer.BindAddr, ps.Config.URL) + l := server.NewListener(psConfig.Scheme, config.Peer.BindAddr, config.PeerTLSInfo()) + sHTTP := &ehttp.CORSHandler{ps.HTTPHandler(), corsInfo} - log.Fatal(http.Serve(psListener, sHTTP)) + log.Fatal(http.Serve(l, sHTTP)) }() - log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, sListener.Addr(), s.URL()) + log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, config.BindAddr, s.URL()) + l := server.NewListener(config.EtcdTLSInfo().Scheme(), config.BindAddr, config.EtcdTLSInfo()) sHTTP := &ehttp.CORSHandler{s.HTTPHandler(), corsInfo} - log.Fatal(http.Serve(sListener, sHTTP)) + log.Fatal(http.Serve(l, sHTTP)) } diff --git a/server/listener.go b/server/listener.go index 93527d66c..343677881 100644 --- a/server/listener.go +++ b/server/listener.go @@ -3,9 +3,35 @@ package server import ( "crypto/tls" "net" + + "github.com/coreos/etcd/log" ) -func NewListener(addr string) (net.Listener, error) { +// NewListener creates a net.Listener +// If the given scheme is "https", it will generate TLS configuration based on TLSInfo. +// If any error happens, this function will call log.Fatal +func NewListener(scheme, addr string, tlsInfo *TLSInfo) net.Listener { + if scheme == "https" { + cfg, err := tlsInfo.ServerConfig() + if err != nil { + log.Fatal("TLS info error: ", err) + } + + l, err := newTLSListener(addr, cfg) + if err != nil { + log.Fatal("Failed to create TLS listener: ", err) + } + return l + } + + l, err := newListener(addr) + if err != nil { + log.Fatal("Failed to create listener: ", err) + } + return l +} + +func newListener(addr string) (net.Listener, error) { if addr == "" { addr = ":http" } @@ -16,7 +42,7 @@ func NewListener(addr string) (net.Listener, error) { return l, nil } -func NewTLSListener(addr string, cfg *tls.Config) (net.Listener, error) { +func newTLSListener(addr string, cfg *tls.Config) (net.Listener, error) { if addr == "" { addr = ":https" } diff --git a/tests/server_utils.go b/tests/server_utils.go index 70e784398..7fd383be1 100644 --- a/tests/server_utils.go +++ b/tests/server_utils.go @@ -15,12 +15,12 @@ import ( ) const ( - testName = "ETCDTEST" - testClientURL = "localhost:4401" - testRaftURL = "localhost:7701" - testSnapshotCount = 10000 - testHeartbeatInterval = time.Duration(50) * time.Millisecond - testElectionTimeout = time.Duration(200) * time.Millisecond + testName = "ETCDTEST" + testClientURL = "localhost:4401" + testRaftURL = "localhost:7701" + testSnapshotCount = 10000 + testHeartbeatInterval = time.Duration(50) * time.Millisecond + testElectionTimeout = time.Duration(200) * time.Millisecond ) // Starts a server in a temporary directory. @@ -35,20 +35,17 @@ func RunServer(f func(*server.Server)) { followersStats := server.NewRaftFollowersStats(testName) psConfig := server.PeerServerConfig{ - Name: testName, - URL: "http://" + testRaftURL, - Scheme: "http", - SnapshotCount: testSnapshotCount, - MaxClusterSize: 9, + Name: testName, + URL: "http://" + testRaftURL, + Scheme: "http", + SnapshotCount: testSnapshotCount, + MaxClusterSize: 9, } mb := metrics.NewBucket("") ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats) - psListener, err := server.NewListener(testRaftURL) - if err != nil { - panic(err) - } + psListener := server.NewListener("http", testRaftURL, nil) // Create Raft transporter and server dialTimeout := (3 * testHeartbeatInterval) + testElectionTimeout @@ -63,10 +60,7 @@ func RunServer(f func(*server.Server)) { ps.SetRaftServer(raftServer) s := server.New(testName, "http://"+testClientURL, ps, registry, store, nil) - sListener, err := server.NewListener(testClientURL) - if err != nil { - panic(err) - } + sListener := server.NewListener("http", testClientURL, nil) ps.SetServer(s) @@ -104,16 +98,16 @@ func RunServer(f func(*server.Server)) { } type waitHandler struct { - wg *sync.WaitGroup - handler http.Handler + wg *sync.WaitGroup + handler http.Handler } -func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){ - h.wg.Add(1) - defer h.wg.Done() - h.handler.ServeHTTP(w, r) +func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.wg.Add(1) + defer h.wg.Done() + h.handler.ServeHTTP(w, r) - //important to flush before decrementing the wait group. - //we won't get a chance to once main() ends. - w.(http.Flusher).Flush() + //important to flush before decrementing the wait group. + //we won't get a chance to once main() ends. + w.(http.Flusher).Flush() }