etcd/raft_server.go

342 lines
7.9 KiB
Go

/*
Copyright 2013 CoreOS Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"bytes"
"crypto/tls"
"encoding/binary"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
etcdErr "github.com/coreos/etcd/error"
"github.com/coreos/go-raft"
)
type raftServer struct {
*raft.Server
version string
joinIndex uint64
name string
url string
listenHost string
tlsConf *TLSConfig
tlsInfo *TLSInfo
followersStats *raftFollowersStats
serverStats *raftServerStats
}
var r *raftServer
func newRaftServer(name string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer {
// Create transporter for raft
raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client)
// Create raft server
server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil, "")
check(err)
return &raftServer{
Server: server,
version: raftVersion,
name: name,
url: url,
listenHost: listenHost,
tlsConf: tlsConf,
tlsInfo: tlsInfo,
followersStats: &raftFollowersStats{
Leader: name,
Followers: make(map[string]*raftFollowerStats),
},
serverStats: &raftServerStats{
StartTime: time.Now(),
sendRateQueue: &statsQueue{
back: -1,
},
recvRateQueue: &statsQueue{
back: -1,
},
},
}
}
// Start the raft server
func (r *raftServer) ListenAndServe() {
// Setup commands.
registerCommands()
// LoadSnapshot
if snapshot {
err := r.LoadSnapshot()
if err == nil {
debugf("%s finished load snapshot", r.name)
} else {
debug(err)
}
}
r.SetElectionTimeout(ElectionTimeout)
r.SetHeartbeatTimeout(HeartbeatTimeout)
r.Start()
if r.IsLogEmpty() {
// start as a leader in a new cluster
if len(cluster) == 0 {
startAsLeader()
} else {
startAsFollower()
}
} else {
// rejoin the previous cluster
cluster = getMachines(nameToRaftURL)
for i := 0; i < len(cluster); i++ {
u, err := url.Parse(cluster[i])
if err != nil {
debug("rejoin cannot parse url: ", err)
}
cluster[i] = u.Host
}
ok := joinCluster(cluster)
if !ok {
warn("the entire cluster is down! this machine will restart the cluster.")
}
debugf("%s restart as a follower", r.name)
}
// open the snapshot
if snapshot {
go monitorSnapshot()
}
// start to response to raft requests
go r.startTransport(r.tlsConf.Scheme, r.tlsConf.Server)
}
func startAsLeader() {
// leader need to join self as a peer
for {
_, err := r.Do(newJoinCommand())
if err == nil {
break
}
}
debugf("%s start as a leader", r.name)
}
func startAsFollower() {
// start as a follower in a existing cluster
for i := 0; i < retryTimes; i++ {
ok := joinCluster(cluster)
if ok {
return
}
warnf("cannot join to cluster via given machines, retry in %d seconds", RetryInterval)
time.Sleep(time.Second * RetryInterval)
}
fatalf("Cannot join the cluster via given machines after %x retries", retryTimes)
}
// Start to listen and response raft command
func (r *raftServer) startTransport(scheme string, tlsConf tls.Config) {
infof("raft server [name %s, listen on %s, advertised url %s]", r.name, r.listenHost, r.url)
raftMux := http.NewServeMux()
server := &http.Server{
Handler: raftMux,
TLSConfig: &tlsConf,
Addr: r.listenHost,
}
// internal commands
raftMux.HandleFunc("/name", NameHttpHandler)
raftMux.HandleFunc("/version", RaftVersionHttpHandler)
raftMux.Handle("/join", errorHandler(JoinHttpHandler))
raftMux.HandleFunc("/remove/", RemoveHttpHandler)
raftMux.HandleFunc("/vote", VoteHttpHandler)
raftMux.HandleFunc("/log", GetLogHttpHandler)
raftMux.HandleFunc("/log/append", AppendEntriesHttpHandler)
raftMux.HandleFunc("/snapshot", SnapshotHttpHandler)
raftMux.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler)
raftMux.HandleFunc("/etcdURL", EtcdURLHttpHandler)
if scheme == "http" {
fatal(server.ListenAndServe())
} else {
fatal(server.ListenAndServeTLS(r.tlsInfo.CertFile, r.tlsInfo.KeyFile))
}
}
// getVersion fetches the raft version of a peer. This works for now but we
// will need to do something more sophisticated later when we allow mixed
// version clusters.
func getVersion(t *transporter, versionURL url.URL) (string, error) {
resp, req, err := t.Get(versionURL.String())
if err != nil {
return "", err
}
defer resp.Body.Close()
t.CancelWhenTimeout(req)
body, err := ioutil.ReadAll(resp.Body)
return string(body), nil
}
func joinCluster(cluster []string) bool {
for _, machine := range cluster {
if len(machine) == 0 {
continue
}
err := joinByMachine(r.Server, machine, r.tlsConf.Scheme)
if err == nil {
debugf("%s success join to the cluster via machine %s", r.name, machine)
return true
} else {
if _, ok := err.(etcdErr.Error); ok {
fatal(err)
}
debugf("cannot join to cluster via machine %s %s", machine, err)
}
}
return false
}
// Send join requests to machine.
func joinByMachine(s *raft.Server, machine string, scheme string) error {
var b bytes.Buffer
// t must be ok
t, _ := r.Transporter().(*transporter)
// Our version must match the leaders version
versionURL := url.URL{Host: machine, Scheme: scheme, Path: "/version"}
version, err := getVersion(t, versionURL)
if err != nil {
return fmt.Errorf("Unable to join: %v", err)
}
// TODO: versioning of the internal protocol. See:
// Documentation/internatl-protocol-versioning.md
if version != r.version {
return fmt.Errorf("Unable to join: internal version mismatch, entire cluster must be running identical versions of etcd")
}
json.NewEncoder(&b).Encode(newJoinCommand())
joinURL := url.URL{Host: machine, Scheme: scheme, Path: "/join"}
debugf("Send Join Request to %s", joinURL.String())
resp, req, err := t.Post(joinURL.String(), &b)
for {
if err != nil {
return fmt.Errorf("Unable to join: %v", err)
}
if resp != nil {
defer resp.Body.Close()
t.CancelWhenTimeout(req)
if resp.StatusCode == http.StatusOK {
b, _ := ioutil.ReadAll(resp.Body)
r.joinIndex, _ = binary.Uvarint(b)
return nil
}
if resp.StatusCode == http.StatusTemporaryRedirect {
address := resp.Header.Get("Location")
debugf("Send Join Request to %s", address)
json.NewEncoder(&b).Encode(newJoinCommand())
resp, req, err = t.Post(address, &b)
} else if resp.StatusCode == http.StatusBadRequest {
debug("Reach max number machines in the cluster")
decoder := json.NewDecoder(resp.Body)
err := &etcdErr.Error{}
decoder.Decode(err)
return *err
} else {
return fmt.Errorf("Unable to join")
}
}
}
return fmt.Errorf("Unable to join: %v", err)
}
func (r *raftServer) Stats() []byte {
r.serverStats.LeaderInfo.Uptime = time.Now().Sub(r.serverStats.LeaderInfo.startTime).String()
queue := r.serverStats.sendRateQueue
r.serverStats.SendingPkgRate, r.serverStats.SendingBandwidthRate = queue.Rate()
queue = r.serverStats.recvRateQueue
r.serverStats.RecvingPkgRate, r.serverStats.RecvingBandwidthRate = queue.Rate()
b, _ := json.Marshal(r.serverStats)
return b
}
func (r *raftServer) PeerStats() []byte {
if r.State() == raft.Leader {
b, _ := json.Marshal(r.followersStats)
return b
}
return nil
}
// Register commands to raft server
func registerCommands() {
raft.RegisterCommand(&JoinCommand{})
raft.RegisterCommand(&RemoveCommand{})
raft.RegisterCommand(&SetCommand{})
raft.RegisterCommand(&GetCommand{})
raft.RegisterCommand(&DeleteCommand{})
raft.RegisterCommand(&WatchCommand{})
raft.RegisterCommand(&TestAndSetCommand{})
}