etcd/raft_server.go

321 lines
7.3 KiB
Go

package main
import (
"bytes"
"crypto/tls"
"encoding/binary"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
etcdErr "github.com/coreos/etcd/error"
"github.com/coreos/go-raft"
)
type raftServer struct {
*raft.Server
version string
joinIndex uint64
name string
url string
listenHost string
tlsConf *TLSConfig
tlsInfo *TLSInfo
peersStats map[string]*raftPeerStats
serverStats *raftServerStats
}
var r *raftServer
func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer {
// Create transporter for raft
raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client, ElectionTimeout)
// Create raft server
server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil)
check(err)
return &raftServer{
Server: server,
version: raftVersion,
name: name,
url: url,
listenHost: listenHost,
tlsConf: tlsConf,
tlsInfo: tlsInfo,
peersStats: make(map[string]*raftPeerStats),
serverStats: &raftServerStats{
StartTime: time.Now(),
sendRateQueue: &statsQueue{
back: -1,
},
recvRateQueue: &statsQueue{
back: -1,
},
},
}
}
// Start the raft server
func (r *raftServer) ListenAndServe() {
// Setup commands.
registerCommands()
// LoadSnapshot
if snapshot {
err := r.LoadSnapshot()
if err == nil {
debugf("%s finished load snapshot", r.name)
} else {
debug(err)
}
}
r.SetElectionTimeout(ElectionTimeout)
r.SetHeartbeatTimeout(HeartbeatTimeout)
r.Start()
if r.IsLogEmpty() {
// start as a leader in a new cluster
if len(cluster) == 0 {
startAsLeader()
} else {
startAsFollower()
}
} else {
// rejoin the previous cluster
cluster = getMachines(nameToRaftURL)
for i := 0; i < len(cluster); i++ {
u, err := url.Parse(cluster[i])
if err != nil {
debug("rejoin cannot parse url: ", err)
}
cluster[i] = u.Host
}
ok := joinCluster(cluster)
if !ok {
warn("the entire cluster is down! this machine will restart the cluster.")
}
debugf("%s restart as a follower", r.name)
}
// open the snapshot
if snapshot {
go monitorSnapshot()
}
// start to response to raft requests
go r.startTransport(r.tlsConf.Scheme, r.tlsConf.Server)
}
func startAsLeader() {
// leader need to join self as a peer
for {
_, err := r.Do(newJoinCommand())
if err == nil {
break
}
}
debugf("%s start as a leader", r.name)
}
func startAsFollower() {
// start as a follower in a existing cluster
for i := 0; i < retryTimes; i++ {
ok := joinCluster(cluster)
if ok {
return
}
warnf("cannot join to cluster via given machines, retry in %d seconds", RetryInterval)
time.Sleep(time.Second * RetryInterval)
}
fatalf("Cannot join the cluster via given machines after %x retries", retryTimes)
}
// Start to listen and response raft command
func (r *raftServer) startTransport(scheme string, tlsConf tls.Config) {
infof("raft server [name %s, listen on %s, advertised url %s]", r.name, r.listenHost, r.url)
raftMux := http.NewServeMux()
server := &http.Server{
Handler: raftMux,
TLSConfig: &tlsConf,
Addr: r.listenHost,
}
// internal commands
raftMux.HandleFunc("/name", NameHttpHandler)
raftMux.HandleFunc("/version", RaftVersionHttpHandler)
raftMux.Handle("/join", errorHandler(JoinHttpHandler))
raftMux.HandleFunc("/remove/", RemoveHttpHandler)
raftMux.HandleFunc("/vote", VoteHttpHandler)
raftMux.HandleFunc("/log", GetLogHttpHandler)
raftMux.HandleFunc("/log/append", AppendEntriesHttpHandler)
raftMux.HandleFunc("/snapshot", SnapshotHttpHandler)
raftMux.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler)
raftMux.HandleFunc("/etcdURL", EtcdURLHttpHandler)
if scheme == "http" {
fatal(server.ListenAndServe())
} else {
fatal(server.ListenAndServeTLS(r.tlsInfo.CertFile, r.tlsInfo.KeyFile))
}
}
// getVersion fetches the raft version of a peer. This works for now but we
// will need to do something more sophisticated later when we allow mixed
// version clusters.
func getVersion(t *transporter, versionURL url.URL) (string, error) {
resp, err := t.Get(versionURL.String())
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
return string(body), nil
}
func joinCluster(cluster []string) bool {
for _, machine := range cluster {
if len(machine) == 0 {
continue
}
err := joinByMachine(r.Server, machine, r.tlsConf.Scheme)
if err == nil {
debugf("%s success join to the cluster via machine %s", r.name, machine)
return true
} else {
if _, ok := err.(etcdErr.Error); ok {
fatal(err)
}
debugf("cannot join to cluster via machine %s %s", machine, err)
}
}
return false
}
// Send join requests to machine.
func joinByMachine(s *raft.Server, machine string, scheme string) error {
var b bytes.Buffer
// t must be ok
t, _ := r.Transporter().(*transporter)
// Our version must match the leaders version
versionURL := url.URL{Host: machine, Scheme: scheme, Path: "/version"}
version, err := getVersion(t, versionURL)
if err != nil {
return fmt.Errorf("Unable to join: %v", err)
}
// TODO: versioning of the internal protocol. See:
// Documentation/internatl-protocol-versioning.md
if version != r.version {
return fmt.Errorf("Unable to join: internal version mismatch, entire cluster must be running identical versions of etcd")
}
json.NewEncoder(&b).Encode(newJoinCommand())
joinURL := url.URL{Host: machine, Scheme: scheme, Path: "/join"}
debugf("Send Join Request to %s", joinURL.String())
resp, err := t.Post(joinURL.String(), &b)
for {
if err != nil {
return fmt.Errorf("Unable to join: %v", err)
}
if resp != nil {
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
b, _ := ioutil.ReadAll(resp.Body)
r.joinIndex, _ = binary.Uvarint(b)
return nil
}
if resp.StatusCode == http.StatusTemporaryRedirect {
address := resp.Header.Get("Location")
debugf("Send Join Request to %s", address)
json.NewEncoder(&b).Encode(newJoinCommand())
resp, err = t.Post(address, &b)
} else if resp.StatusCode == http.StatusBadRequest {
debug("Reach max number machines in the cluster")
decoder := json.NewDecoder(resp.Body)
err := &etcdErr.Error{}
decoder.Decode(err)
return *err
} else {
return fmt.Errorf("Unable to join")
}
}
}
return fmt.Errorf("Unable to join: %v", err)
}
func (r *raftServer) Stats() []byte {
r.serverStats.LeaderUptime = time.Now().Sub(r.serverStats.leaderStartTime).String()
queue := r.serverStats.sendRateQueue
r.serverStats.SendingPkgRate, r.serverStats.SendingBandwidthRate = queue.Rate()
queue = r.serverStats.recvRateQueue
r.serverStats.RecvingPkgRate, r.serverStats.RecvingBandwidthRate = queue.Rate()
sBytes, err := json.Marshal(r.serverStats)
if err != nil {
warn(err)
}
if r.State() == raft.Leader {
pBytes, _ := json.Marshal(r.peersStats)
b := append(sBytes, pBytes...)
return b
}
return sBytes
}
// Register commands to raft server
func registerCommands() {
raft.RegisterCommand(&JoinCommand{})
raft.RegisterCommand(&RemoveCommand{})
raft.RegisterCommand(&SetCommand{})
raft.RegisterCommand(&GetCommand{})
raft.RegisterCommand(&DeleteCommand{})
raft.RegisterCommand(&WatchCommand{})
raft.RegisterCommand(&TestAndSetCommand{})
}