Merge pull request #626 from xiangli-cmu/refactor_listener

refactor(listener) refactor listener related code
release-0.4
Xiang Li 2014-03-13 20:05:54 -07:00
commit 79e4c838f4
4 changed files with 61 additions and 75 deletions

View File

@ -390,8 +390,8 @@ func (c *Config) Sanitize() error {
} }
// EtcdTLSInfo retrieves a TLSInfo object for the etcd server // EtcdTLSInfo retrieves a TLSInfo object for the etcd server
func (c *Config) EtcdTLSInfo() server.TLSInfo { func (c *Config) EtcdTLSInfo() *server.TLSInfo {
return server.TLSInfo{ return &server.TLSInfo{
CAFile: c.CAFile, CAFile: c.CAFile,
CertFile: c.CertFile, CertFile: c.CertFile,
KeyFile: c.KeyFile, KeyFile: c.KeyFile,
@ -399,8 +399,8 @@ func (c *Config) EtcdTLSInfo() server.TLSInfo {
} }
// PeerRaftInfo retrieves a TLSInfo object for the peer server. // PeerRaftInfo retrieves a TLSInfo object for the peer server.
func (c *Config) PeerTLSInfo() server.TLSInfo { func (c *Config) PeerTLSInfo() *server.TLSInfo {
return server.TLSInfo{ return &server.TLSInfo{
CAFile: c.Peer.CAFile, CAFile: c.Peer.CAFile,
CertFile: c.Peer.CertFile, CertFile: c.Peer.CertFile,
KeyFile: c.Peer.KeyFile, KeyFile: c.Peer.KeyFile,

48
etcd.go
View File

@ -18,7 +18,6 @@ package main
import ( import (
"fmt" "fmt"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -126,24 +125,6 @@ func main() {
} }
ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats) ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
var psListener net.Listener
if psConfig.Scheme == "https" {
peerServerTLSConfig, err := config.PeerTLSInfo().ServerConfig()
if err != nil {
log.Fatal("peer server TLS error: ", err)
}
psListener, err = server.NewTLSListener(config.Peer.BindAddr, peerServerTLSConfig)
if err != nil {
log.Fatal("Failed to create peer listener: ", err)
}
} else {
psListener, err = server.NewListener(config.Peer.BindAddr)
if err != nil {
log.Fatal("Failed to create peer listener: ", err)
}
}
// Create raft transporter and server // Create raft transporter and server
raftTransporter := server.NewTransporter(followersStats, serverStats, registry, heartbeatInterval, dialTimeout, responseHeaderTimeout) raftTransporter := server.NewTransporter(followersStats, serverStats, registry, heartbeatInterval, dialTimeout, responseHeaderTimeout)
if psConfig.Scheme == "https" { if psConfig.Scheme == "https" {
@ -168,34 +149,19 @@ func main() {
s.EnableTracing() s.EnableTracing()
} }
var sListener net.Listener
if config.EtcdTLSInfo().Scheme() == "https" {
etcdServerTLSConfig, err := config.EtcdTLSInfo().ServerConfig()
if err != nil {
log.Fatal("etcd TLS error: ", err)
}
sListener, err = server.NewTLSListener(config.BindAddr, etcdServerTLSConfig)
if err != nil {
log.Fatal("Failed to create TLS etcd listener: ", err)
}
} else {
sListener, err = server.NewListener(config.BindAddr)
if err != nil {
log.Fatal("Failed to create etcd listener: ", err)
}
}
ps.SetServer(s) ps.SetServer(s)
ps.Start(config.Snapshot, config.Discovery, config.Peers) ps.Start(config.Snapshot, config.Discovery, config.Peers)
go func() { go func() {
log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL) log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, config.Peer.BindAddr, ps.Config.URL)
l := server.NewListener(psConfig.Scheme, config.Peer.BindAddr, config.PeerTLSInfo())
sHTTP := &ehttp.CORSHandler{ps.HTTPHandler(), corsInfo} sHTTP := &ehttp.CORSHandler{ps.HTTPHandler(), corsInfo}
log.Fatal(http.Serve(psListener, sHTTP)) log.Fatal(http.Serve(l, sHTTP))
}() }()
log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, sListener.Addr(), s.URL()) log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, config.BindAddr, s.URL())
l := server.NewListener(config.EtcdTLSInfo().Scheme(), config.BindAddr, config.EtcdTLSInfo())
sHTTP := &ehttp.CORSHandler{s.HTTPHandler(), corsInfo} sHTTP := &ehttp.CORSHandler{s.HTTPHandler(), corsInfo}
log.Fatal(http.Serve(sListener, sHTTP)) log.Fatal(http.Serve(l, sHTTP))
} }

View File

@ -3,9 +3,35 @@ package server
import ( import (
"crypto/tls" "crypto/tls"
"net" "net"
"github.com/coreos/etcd/log"
) )
func NewListener(addr string) (net.Listener, error) { // NewListener creates a net.Listener
// If the given scheme is "https", it will generate TLS configuration based on TLSInfo.
// If any error happens, this function will call log.Fatal
func NewListener(scheme, addr string, tlsInfo *TLSInfo) net.Listener {
if scheme == "https" {
cfg, err := tlsInfo.ServerConfig()
if err != nil {
log.Fatal("TLS info error: ", err)
}
l, err := newTLSListener(addr, cfg)
if err != nil {
log.Fatal("Failed to create TLS listener: ", err)
}
return l
}
l, err := newListener(addr)
if err != nil {
log.Fatal("Failed to create listener: ", err)
}
return l
}
func newListener(addr string) (net.Listener, error) {
if addr == "" { if addr == "" {
addr = ":http" addr = ":http"
} }
@ -16,7 +42,7 @@ func NewListener(addr string) (net.Listener, error) {
return l, nil return l, nil
} }
func NewTLSListener(addr string, cfg *tls.Config) (net.Listener, error) { func newTLSListener(addr string, cfg *tls.Config) (net.Listener, error) {
if addr == "" { if addr == "" {
addr = ":https" addr = ":https"
} }

View File

@ -15,12 +15,12 @@ import (
) )
const ( const (
testName = "ETCDTEST" testName = "ETCDTEST"
testClientURL = "localhost:4401" testClientURL = "localhost:4401"
testRaftURL = "localhost:7701" testRaftURL = "localhost:7701"
testSnapshotCount = 10000 testSnapshotCount = 10000
testHeartbeatInterval = time.Duration(50) * time.Millisecond testHeartbeatInterval = time.Duration(50) * time.Millisecond
testElectionTimeout = time.Duration(200) * time.Millisecond testElectionTimeout = time.Duration(200) * time.Millisecond
) )
// Starts a server in a temporary directory. // Starts a server in a temporary directory.
@ -35,20 +35,17 @@ func RunServer(f func(*server.Server)) {
followersStats := server.NewRaftFollowersStats(testName) followersStats := server.NewRaftFollowersStats(testName)
psConfig := server.PeerServerConfig{ psConfig := server.PeerServerConfig{
Name: testName, Name: testName,
URL: "http://" + testRaftURL, URL: "http://" + testRaftURL,
Scheme: "http", Scheme: "http",
SnapshotCount: testSnapshotCount, SnapshotCount: testSnapshotCount,
MaxClusterSize: 9, MaxClusterSize: 9,
} }
mb := metrics.NewBucket("") mb := metrics.NewBucket("")
ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats) ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
psListener, err := server.NewListener(testRaftURL) psListener := server.NewListener("http", testRaftURL, nil)
if err != nil {
panic(err)
}
// Create Raft transporter and server // Create Raft transporter and server
dialTimeout := (3 * testHeartbeatInterval) + testElectionTimeout dialTimeout := (3 * testHeartbeatInterval) + testElectionTimeout
@ -63,10 +60,7 @@ func RunServer(f func(*server.Server)) {
ps.SetRaftServer(raftServer) ps.SetRaftServer(raftServer)
s := server.New(testName, "http://"+testClientURL, ps, registry, store, nil) s := server.New(testName, "http://"+testClientURL, ps, registry, store, nil)
sListener, err := server.NewListener(testClientURL) sListener := server.NewListener("http", testClientURL, nil)
if err != nil {
panic(err)
}
ps.SetServer(s) ps.SetServer(s)
@ -104,16 +98,16 @@ func RunServer(f func(*server.Server)) {
} }
type waitHandler struct { type waitHandler struct {
wg *sync.WaitGroup wg *sync.WaitGroup
handler http.Handler handler http.Handler
} }
func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){ func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.wg.Add(1) h.wg.Add(1)
defer h.wg.Done() defer h.wg.Done()
h.handler.ServeHTTP(w, r) h.handler.ServeHTTP(w, r)
//important to flush before decrementing the wait group. //important to flush before decrementing the wait group.
//we won't get a chance to once main() ends. //we won't get a chance to once main() ends.
w.(http.Flusher).Flush() w.(http.Flusher).Flush()
} }