diff --git a/etcd.go b/etcd.go index 96ab64957..5f6222149 100644 --- a/etcd.go +++ b/etcd.go @@ -114,8 +114,8 @@ func main() { psConfig := server.PeerServerConfig{ Name: info.Name, Path: config.DataDir, + Scheme: peerTLSConfig.Scheme, URL: info.RaftURL, - BindAddr: info.RaftListenHost, SnapshotCount: config.SnapshotCount, HeartbeatTimeout: time.Duration(config.Peer.HeartbeatTimeout) * time.Millisecond, ElectionTimeout: time.Duration(config.Peer.ElectionTimeout) * time.Millisecond, @@ -125,6 +125,16 @@ func main() { } ps := server.NewPeerServer(psConfig, &peerTLSConfig, &info.RaftTLS, registry, store, &mb) + var psListener net.Listener + if psConfig.Scheme == "https" { + psListener, err = server.NewTLSListener(info.RaftListenHost, info.RaftTLS.CertFile, info.RaftTLS.KeyFile) + } else { + psListener, err = server.NewListener(info.RaftListenHost) + } + if err != nil { + panic(err) + } + // Create client server. sConfig := server.ServerConfig{ Name: info.Name, @@ -151,7 +161,7 @@ func main() { // Run peer server in separate thread while the client server blocks. go func() { - log.Fatal(ps.ListenAndServe(config.Snapshot, config.Peers)) + log.Fatal(ps.Serve(psListener, config.Snapshot, config.Peers)) }() log.Fatal(s.Serve(sListener)) } diff --git a/server/peer_server.go b/server/peer_server.go index b5d8fde38..5eea6b363 100644 --- a/server/peer_server.go +++ b/server/peer_server.go @@ -2,7 +2,6 @@ package server import ( "bytes" - "crypto/tls" "encoding/binary" "encoding/json" "fmt" @@ -29,8 +28,8 @@ const ThresholdMonitorTimeout = 5 * time.Second type PeerServerConfig struct { Name string Path string + Scheme string URL string - BindAddr string SnapshotCount int HeartbeatTimeout time.Duration ElectionTimeout time.Duration @@ -43,8 +42,6 @@ type PeerServer struct { Config PeerServerConfig raftServer raft.Server server *Server - httpServer *http.Server - listener net.Listener joinIndex uint64 tlsConf *TLSConfig tlsInfo *TLSInfo @@ -54,6 +51,8 @@ type PeerServer struct { store store.Store snapConf *snapshotConf + listener net.Listener + closeChan chan bool timeoutThresholdChan chan interface{} @@ -77,8 +76,6 @@ func NewPeerServer(psConfig PeerServerConfig, tlsConf *TLSConfig, tlsInfo *TLSIn s := &PeerServer{ Config: psConfig, - tlsConf: tlsConf, - tlsInfo: tlsInfo, registry: registry, store: store, followersStats: &raftFollowersStats{ @@ -132,7 +129,7 @@ func NewPeerServer(psConfig PeerServerConfig, tlsConf *TLSConfig, tlsInfo *TLSIn } // Start the raft server -func (s *PeerServer) ListenAndServe(snapshot bool, cluster []string) error { +func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []string) error { // LoadSnapshot if snapshot { err := s.raftServer.LoadSnapshot() @@ -185,56 +182,29 @@ func (s *PeerServer) ListenAndServe(snapshot bool, cluster []string) error { go s.monitorSnapshot() } - // start to response to raft requests - return s.startTransport(s.tlsConf.Scheme, s.tlsConf.Server) + router := mux.NewRouter() + httpServer := &http.Server{Handler: router} + + // internal commands + router.HandleFunc("/name", s.NameHttpHandler) + router.HandleFunc("/version", s.VersionHttpHandler) + router.HandleFunc("/version/{version:[0-9]+}/check", s.VersionCheckHttpHandler) + router.HandleFunc("/upgrade", s.UpgradeHttpHandler) + router.HandleFunc("/join", s.JoinHttpHandler) + router.HandleFunc("/remove/{name:.+}", s.RemoveHttpHandler) + router.HandleFunc("/vote", s.VoteHttpHandler) + router.HandleFunc("/log", s.GetLogHttpHandler) + router.HandleFunc("/log/append", s.AppendEntriesHttpHandler) + router.HandleFunc("/snapshot", s.SnapshotHttpHandler) + router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler) + router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler) + + s.listener = listener + log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, listener.Addr(), s.Config.URL) + httpServer.Serve(listener) + return nil } -// Overridden version of net/http added so we can manage the listener. -func (s *PeerServer) listenAndServe() error { - addr := s.httpServer.Addr - if addr == "" { - addr = ":http" - } - l, e := net.Listen("tcp", addr) - if e != nil { - return e - } - s.listener = l - return s.httpServer.Serve(l) -} - -// Overridden version of net/http added so we can manage the listener. -func (s *PeerServer) listenAndServeTLS(certFile, keyFile string) error { - addr := s.httpServer.Addr - if addr == "" { - addr = ":https" - } - config := &tls.Config{} - if s.httpServer.TLSConfig != nil { - *config = *s.httpServer.TLSConfig - } - if config.NextProtos == nil { - config.NextProtos = []string{"http/1.1"} - } - - var err error - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return err - } - - conn, err := net.Listen("tcp", addr) - if err != nil { - return err - } - - tlsListener := tls.NewListener(conn, config) - s.listener = tlsListener - return s.httpServer.Serve(tlsListener) -} - -// Stops the server. func (s *PeerServer) Close() { if s.closeChan != nil { close(s.closeChan) @@ -281,40 +251,6 @@ func (s *PeerServer) startAsFollower(cluster []string) { log.Fatalf("Cannot join the cluster via given peers after %x retries", s.Config.RetryTimes) } -// Start to listen and response raft command -func (s *PeerServer) startTransport(scheme string, tlsConf tls.Config) error { - log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, s.Config.BindAddr, s.Config.URL) - - router := mux.NewRouter() - - s.httpServer = &http.Server{ - Handler: router, - TLSConfig: &tlsConf, - Addr: s.Config.BindAddr, - } - - // internal commands - router.HandleFunc("/name", s.NameHttpHandler) - router.HandleFunc("/version", s.VersionHttpHandler) - router.HandleFunc("/version/{version:[0-9]+}/check", s.VersionCheckHttpHandler) - router.HandleFunc("/upgrade", s.UpgradeHttpHandler) - router.HandleFunc("/join", s.JoinHttpHandler) - router.HandleFunc("/remove/{name:.+}", s.RemoveHttpHandler) - router.HandleFunc("/vote", s.VoteHttpHandler) - router.HandleFunc("/log", s.GetLogHttpHandler) - router.HandleFunc("/log/append", s.AppendEntriesHttpHandler) - router.HandleFunc("/snapshot", s.SnapshotHttpHandler) - router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler) - router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler) - - if scheme == "http" { - return s.listenAndServe() - } else { - return s.listenAndServeTLS(s.tlsInfo.CertFile, s.tlsInfo.KeyFile) - } - -} - // getVersion fetches the peer version of a cluster. func getVersion(t *transporter, versionURL url.URL) (int, error) { resp, req, err := t.Get(versionURL.String()) @@ -344,7 +280,7 @@ func (s *PeerServer) Upgradable() error { } t, _ := s.raftServer.Transporter().(*transporter) - checkURL := (&url.URL{Host: u.Host, Scheme: s.tlsConf.Scheme, Path: fmt.Sprintf("/version/%d/check", nextVersion)}).String() + checkURL := (&url.URL{Host: u.Host, Scheme: s.Config.Scheme, Path: fmt.Sprintf("/version/%d/check", nextVersion)}).String() resp, _, err := t.Get(checkURL) if err != nil { return fmt.Errorf("PeerServer: Cannot check version compatibility: %s", u.Host) @@ -363,7 +299,7 @@ func (s *PeerServer) joinCluster(cluster []string) bool { continue } - err := s.joinByPeer(s.raftServer, peer, s.tlsConf.Scheme) + err := s.joinByPeer(s.raftServer, peer, s.Config.Scheme) if err == nil { log.Debugf("%s success join to the cluster via peer %s", s.Config.Name, peer) return true diff --git a/tests/server_utils.go b/tests/server_utils.go index 2977224e9..596f960a0 100644 --- a/tests/server_utils.go +++ b/tests/server_utils.go @@ -31,7 +31,7 @@ func RunServer(f func(*server.Server)) { Name: testName, Path: path, URL: "http://"+testRaftURL, - BindAddr: testRaftURL, + Scheme: "http", SnapshotCount: testSnapshotCount, HeartbeatTimeout: testHeartbeatTimeout, ElectionTimeout: testElectionTimeout, @@ -39,6 +39,10 @@ func RunServer(f func(*server.Server)) { CORS: corsInfo, } ps := server.NewPeerServer(psConfig, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, registry, store, nil) + psListener, err := server.NewListener(testRaftURL) + if err != nil { + panic(err) + } sConfig := server.ServerConfig{ Name: testName, @@ -57,7 +61,7 @@ func RunServer(f func(*server.Server)) { c := make(chan bool) go func() { c <- true - ps.ListenAndServe(false, []string{}) + ps.Serve(psListener, false, []string{}) }() <-c