Merge pull request #626 from xiangli-cmu/refactor_listener

refactor(listener) refactor listener related code
release-0.4
Xiang Li 2014-03-13 20:05:54 -07:00
commit 79e4c838f4
4 changed files with 61 additions and 75 deletions

View File

@ -390,8 +390,8 @@ func (c *Config) Sanitize() error {
}
// EtcdTLSInfo retrieves a TLSInfo object for the etcd server
func (c *Config) EtcdTLSInfo() server.TLSInfo {
return server.TLSInfo{
func (c *Config) EtcdTLSInfo() *server.TLSInfo {
return &server.TLSInfo{
CAFile: c.CAFile,
CertFile: c.CertFile,
KeyFile: c.KeyFile,
@ -399,8 +399,8 @@ func (c *Config) EtcdTLSInfo() server.TLSInfo {
}
// PeerRaftInfo retrieves a TLSInfo object for the peer server.
func (c *Config) PeerTLSInfo() server.TLSInfo {
return server.TLSInfo{
func (c *Config) PeerTLSInfo() *server.TLSInfo {
return &server.TLSInfo{
CAFile: c.Peer.CAFile,
CertFile: c.Peer.CertFile,
KeyFile: c.Peer.KeyFile,

48
etcd.go
View File

@ -18,7 +18,6 @@ package main
import (
"fmt"
"net"
"net/http"
"os"
"path/filepath"
@ -126,24 +125,6 @@ func main() {
}
ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
var psListener net.Listener
if psConfig.Scheme == "https" {
peerServerTLSConfig, err := config.PeerTLSInfo().ServerConfig()
if err != nil {
log.Fatal("peer server TLS error: ", err)
}
psListener, err = server.NewTLSListener(config.Peer.BindAddr, peerServerTLSConfig)
if err != nil {
log.Fatal("Failed to create peer listener: ", err)
}
} else {
psListener, err = server.NewListener(config.Peer.BindAddr)
if err != nil {
log.Fatal("Failed to create peer listener: ", err)
}
}
// Create raft transporter and server
raftTransporter := server.NewTransporter(followersStats, serverStats, registry, heartbeatInterval, dialTimeout, responseHeaderTimeout)
if psConfig.Scheme == "https" {
@ -168,34 +149,19 @@ func main() {
s.EnableTracing()
}
var sListener net.Listener
if config.EtcdTLSInfo().Scheme() == "https" {
etcdServerTLSConfig, err := config.EtcdTLSInfo().ServerConfig()
if err != nil {
log.Fatal("etcd TLS error: ", err)
}
sListener, err = server.NewTLSListener(config.BindAddr, etcdServerTLSConfig)
if err != nil {
log.Fatal("Failed to create TLS etcd listener: ", err)
}
} else {
sListener, err = server.NewListener(config.BindAddr)
if err != nil {
log.Fatal("Failed to create etcd listener: ", err)
}
}
ps.SetServer(s)
ps.Start(config.Snapshot, config.Discovery, config.Peers)
go func() {
log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL)
log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, config.Peer.BindAddr, ps.Config.URL)
l := server.NewListener(psConfig.Scheme, config.Peer.BindAddr, config.PeerTLSInfo())
sHTTP := &ehttp.CORSHandler{ps.HTTPHandler(), corsInfo}
log.Fatal(http.Serve(psListener, sHTTP))
log.Fatal(http.Serve(l, sHTTP))
}()
log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, sListener.Addr(), s.URL())
log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, config.BindAddr, s.URL())
l := server.NewListener(config.EtcdTLSInfo().Scheme(), config.BindAddr, config.EtcdTLSInfo())
sHTTP := &ehttp.CORSHandler{s.HTTPHandler(), corsInfo}
log.Fatal(http.Serve(sListener, sHTTP))
log.Fatal(http.Serve(l, sHTTP))
}

View File

@ -3,9 +3,35 @@ package server
import (
"crypto/tls"
"net"
"github.com/coreos/etcd/log"
)
func NewListener(addr string) (net.Listener, error) {
// NewListener creates a net.Listener
// If the given scheme is "https", it will generate TLS configuration based on TLSInfo.
// If any error happens, this function will call log.Fatal
func NewListener(scheme, addr string, tlsInfo *TLSInfo) net.Listener {
if scheme == "https" {
cfg, err := tlsInfo.ServerConfig()
if err != nil {
log.Fatal("TLS info error: ", err)
}
l, err := newTLSListener(addr, cfg)
if err != nil {
log.Fatal("Failed to create TLS listener: ", err)
}
return l
}
l, err := newListener(addr)
if err != nil {
log.Fatal("Failed to create listener: ", err)
}
return l
}
func newListener(addr string) (net.Listener, error) {
if addr == "" {
addr = ":http"
}
@ -16,7 +42,7 @@ func NewListener(addr string) (net.Listener, error) {
return l, nil
}
func NewTLSListener(addr string, cfg *tls.Config) (net.Listener, error) {
func newTLSListener(addr string, cfg *tls.Config) (net.Listener, error) {
if addr == "" {
addr = ":https"
}

View File

@ -15,12 +15,12 @@ import (
)
const (
testName = "ETCDTEST"
testClientURL = "localhost:4401"
testRaftURL = "localhost:7701"
testSnapshotCount = 10000
testHeartbeatInterval = time.Duration(50) * time.Millisecond
testElectionTimeout = time.Duration(200) * time.Millisecond
testName = "ETCDTEST"
testClientURL = "localhost:4401"
testRaftURL = "localhost:7701"
testSnapshotCount = 10000
testHeartbeatInterval = time.Duration(50) * time.Millisecond
testElectionTimeout = time.Duration(200) * time.Millisecond
)
// Starts a server in a temporary directory.
@ -35,20 +35,17 @@ func RunServer(f func(*server.Server)) {
followersStats := server.NewRaftFollowersStats(testName)
psConfig := server.PeerServerConfig{
Name: testName,
URL: "http://" + testRaftURL,
Scheme: "http",
SnapshotCount: testSnapshotCount,
MaxClusterSize: 9,
Name: testName,
URL: "http://" + testRaftURL,
Scheme: "http",
SnapshotCount: testSnapshotCount,
MaxClusterSize: 9,
}
mb := metrics.NewBucket("")
ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
psListener, err := server.NewListener(testRaftURL)
if err != nil {
panic(err)
}
psListener := server.NewListener("http", testRaftURL, nil)
// Create Raft transporter and server
dialTimeout := (3 * testHeartbeatInterval) + testElectionTimeout
@ -63,10 +60,7 @@ func RunServer(f func(*server.Server)) {
ps.SetRaftServer(raftServer)
s := server.New(testName, "http://"+testClientURL, ps, registry, store, nil)
sListener, err := server.NewListener(testClientURL)
if err != nil {
panic(err)
}
sListener := server.NewListener("http", testClientURL, nil)
ps.SetServer(s)
@ -104,16 +98,16 @@ func RunServer(f func(*server.Server)) {
}
type waitHandler struct {
wg *sync.WaitGroup
handler http.Handler
wg *sync.WaitGroup
handler http.Handler
}
func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){
h.wg.Add(1)
defer h.wg.Done()
h.handler.ServeHTTP(w, r)
func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.wg.Add(1)
defer h.wg.Done()
h.handler.ServeHTTP(w, r)
//important to flush before decrementing the wait group.
//we won't get a chance to once main() ends.
w.(http.Flusher).Flush()
//important to flush before decrementing the wait group.
//we won't get a chance to once main() ends.
w.(http.Flusher).Flush()
}