package main import ( "bytes" "crypto/tls" "encoding/binary" "encoding/json" "fmt" "io/ioutil" "net/http" "net/url" "time" etcdErr "github.com/coreos/etcd/error" "github.com/coreos/go-raft" ) type raftServer struct { *raft.Server version string joinIndex uint64 name string url string listenHost string tlsConf *TLSConfig tlsInfo *TLSInfo peersStats map[string]*raftPeerStats serverStats *raftServerStats } var r *raftServer func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer { // Create transporter for raft raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client, ElectionTimeout) // Create raft server server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil) check(err) return &raftServer{ Server: server, version: raftVersion, name: name, url: url, listenHost: listenHost, tlsConf: tlsConf, tlsInfo: tlsInfo, peersStats: make(map[string]*raftPeerStats), serverStats: &raftServerStats{ StartTime: time.Now(), sendRateQueue: &statsQueue{ back: -1, }, recvRateQueue: &statsQueue{ back: -1, }, }, } } // Start the raft server func (r *raftServer) ListenAndServe() { // Setup commands. registerCommands() // LoadSnapshot if snapshot { err := r.LoadSnapshot() if err == nil { debugf("%s finished load snapshot", r.name) } else { debug(err) } } r.SetElectionTimeout(ElectionTimeout) r.SetHeartbeatTimeout(HeartbeatTimeout) r.Start() if r.IsLogEmpty() { // start as a leader in a new cluster if len(cluster) == 0 { startAsLeader() } else { startAsFollower() } } else { // rejoin the previous cluster cluster = getMachines(nameToRaftURL) for i := 0; i < len(cluster); i++ { u, err := url.Parse(cluster[i]) if err != nil { debug("rejoin cannot parse url: ", err) } cluster[i] = u.Host } ok := joinCluster(cluster) if !ok { warn("the entire cluster is down! this machine will restart the cluster.") } debugf("%s restart as a follower", r.name) } // open the snapshot if snapshot { go monitorSnapshot() } // start to response to raft requests go r.startTransport(r.tlsConf.Scheme, r.tlsConf.Server) } func startAsLeader() { // leader need to join self as a peer for { _, err := r.Do(newJoinCommand()) if err == nil { break } } debugf("%s start as a leader", r.name) } func startAsFollower() { // start as a follower in a existing cluster for i := 0; i < retryTimes; i++ { ok := joinCluster(cluster) if ok { return } warnf("cannot join to cluster via given machines, retry in %d seconds", RetryInterval) time.Sleep(time.Second * RetryInterval) } fatalf("Cannot join the cluster via given machines after %x retries", retryTimes) } // Start to listen and response raft command func (r *raftServer) startTransport(scheme string, tlsConf tls.Config) { infof("raft server [name %s, listen on %s, advertised url %s]", r.name, r.listenHost, r.url) raftMux := http.NewServeMux() server := &http.Server{ Handler: raftMux, TLSConfig: &tlsConf, Addr: r.listenHost, } // internal commands raftMux.HandleFunc("/name", NameHttpHandler) raftMux.HandleFunc("/version", RaftVersionHttpHandler) raftMux.Handle("/join", errorHandler(JoinHttpHandler)) raftMux.HandleFunc("/remove/", RemoveHttpHandler) raftMux.HandleFunc("/vote", VoteHttpHandler) raftMux.HandleFunc("/log", GetLogHttpHandler) raftMux.HandleFunc("/log/append", AppendEntriesHttpHandler) raftMux.HandleFunc("/snapshot", SnapshotHttpHandler) raftMux.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler) raftMux.HandleFunc("/etcdURL", EtcdURLHttpHandler) if scheme == "http" { fatal(server.ListenAndServe()) } else { fatal(server.ListenAndServeTLS(r.tlsInfo.CertFile, r.tlsInfo.KeyFile)) } } // getVersion fetches the raft version of a peer. This works for now but we // will need to do something more sophisticated later when we allow mixed // version clusters. func getVersion(t *transporter, versionURL url.URL) (string, error) { resp, err := t.Get(versionURL.String()) if err != nil { return "", err } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) return string(body), nil } func joinCluster(cluster []string) bool { for _, machine := range cluster { if len(machine) == 0 { continue } err := joinByMachine(r.Server, machine, r.tlsConf.Scheme) if err == nil { debugf("%s success join to the cluster via machine %s", r.name, machine) return true } else { if _, ok := err.(etcdErr.Error); ok { fatal(err) } debugf("cannot join to cluster via machine %s %s", machine, err) } } return false } // Send join requests to machine. func joinByMachine(s *raft.Server, machine string, scheme string) error { var b bytes.Buffer // t must be ok t, _ := r.Transporter().(*transporter) // Our version must match the leaders version versionURL := url.URL{Host: machine, Scheme: scheme, Path: "/version"} version, err := getVersion(t, versionURL) if err != nil { return fmt.Errorf("Unable to join: %v", err) } // TODO: versioning of the internal protocol. See: // Documentation/internatl-protocol-versioning.md if version != r.version { return fmt.Errorf("Unable to join: internal version mismatch, entire cluster must be running identical versions of etcd") } json.NewEncoder(&b).Encode(newJoinCommand()) joinURL := url.URL{Host: machine, Scheme: scheme, Path: "/join"} debugf("Send Join Request to %s", joinURL.String()) resp, err := t.Post(joinURL.String(), &b) for { if err != nil { return fmt.Errorf("Unable to join: %v", err) } if resp != nil { defer resp.Body.Close() if resp.StatusCode == http.StatusOK { b, _ := ioutil.ReadAll(resp.Body) r.joinIndex, _ = binary.Uvarint(b) return nil } if resp.StatusCode == http.StatusTemporaryRedirect { address := resp.Header.Get("Location") debugf("Send Join Request to %s", address) json.NewEncoder(&b).Encode(newJoinCommand()) resp, err = t.Post(address, &b) } else if resp.StatusCode == http.StatusBadRequest { debug("Reach max number machines in the cluster") decoder := json.NewDecoder(resp.Body) err := &etcdErr.Error{} decoder.Decode(err) return *err } else { return fmt.Errorf("Unable to join") } } } return fmt.Errorf("Unable to join: %v", err) } func (r *raftServer) Stats() []byte { r.serverStats.LeaderUptime = time.Now().Sub(r.serverStats.leaderStartTime).String() queue := r.serverStats.sendRateQueue r.serverStats.SendingPkgRate, r.serverStats.SendingBandwidthRate = queue.Rate() queue = r.serverStats.recvRateQueue r.serverStats.RecvingPkgRate, r.serverStats.RecvingBandwidthRate = queue.Rate() sBytes, err := json.Marshal(r.serverStats) if err != nil { warn(err) } if r.State() == raft.Leader { pBytes, _ := json.Marshal(r.peersStats) b := append(sBytes, pBytes...) return b } return sBytes } // Register commands to raft server func registerCommands() { raft.RegisterCommand(&JoinCommand{}) raft.RegisterCommand(&RemoveCommand{}) raft.RegisterCommand(&SetCommand{}) raft.RegisterCommand(&GetCommand{}) raft.RegisterCommand(&DeleteCommand{}) raft.RegisterCommand(&WatchCommand{}) raft.RegisterCommand(&TestAndSetCommand{}) }