feat add snapCount parameter

release-0.4
Xiang Li 2013-10-30 17:36:15 -07:00
parent 107762e82a
commit 9d0de611a7
70 changed files with 4413 additions and 536 deletions

View File

@ -0,0 +1,104 @@
# Etcd Configuration
Configuration options can be set in three places:
1. Command line flags
2. Environment variables
3. Configuration file
Options set on the command line take precedence over all other sources.
Options set in environment variables take precedence over options set in
configuration files.
## Command Line Flags
### Required
* `-n` - The node name. Defaults to `default-name`.
### Optional
* `-c` - The advertised public hostname:port for client communication. Defaults to `127.0.0.1:4001`.
* `-cl` - The listening hostname for client communication. Defaults to advertised ip.
* `-C` - A comma separated list of machines in the cluster (i.e `"203.0.113.101:7001,203.0.113.102:7001"`).
* `-CF` - The file path containing a comma separated list of machines in the cluster.
* `-clientCAFile` - The path of the client CAFile. Enables client cert authentication when present.
* `-clientCert` - The cert file of the client.
* `-clientKey` - The key file of the client.
* `-configfile` - The path of the etcd config file. Defaults to `/etc/etcd/etcd.conf`.
* `-cors` - A comma separated white list of origins for cross-origin resource sharing.
* `-cpuprofile` - The path to a file to output cpu profile data. Enables cpu profiling when present.
* `-d` - The directory to store log and snapshot. Defaults to the current working directory.
* `-m` - The max size of result buffer. Defaults to `1024`.
* `-maxsize` - The max size of the cluster. Defaults to `9`.
* `-r` - The max retry attempts when trying to join a cluster. Defaults to `3`.
* `-s` - The advertised public hostname:port for server communication. Defaults to `127.0.0.1:7001`.
* `-sl` - The listening hostname for server communication. Defaults to advertised ip.
* `-serverCAFile` - The path of the CAFile. Enables client/peer cert authentication when present.
* `-serverCert` - The cert file of the server.
* `-serverKey` - The key file of the server.
* `-snapshot` - Open or close snapshot. Defaults to `false`.
* `-v` - Enable verbose logging. Defaults to `false`.
* `-vv` - Enable very verbose logging. Defaults to `false`.
* `-version` - Print the version and exit.
* `-w` - The hostname:port of web interface.
## Configuration File
The etcd configuration file is written in [TOML](https://github.com/mojombo/toml)
and read from `/etc/etcd/etcd.conf` by default.
```TOML
advertised_url = "127.0.0.1:4001"
ca_file = ""
cert_file = ""
cors = []
cpu_profile_file = ""
datadir = "."
key_file = ""
listen_host = "127.0.0.1:4001"
machines = []
machines_file = ""
max_cluster_size = 9
max_result_buffer = 1024
max_retry_attempts = 3
name = "default-name"
snapshot = false
verbose = false
very_verbose = false
web_url = ""
[peer]
advertised_url = "127.0.0.1:7001"
ca_file = ""
cert_file = ""
key_file = ""
listen_host = "127.0.0.1:7001"
```
## Environment Variables
* `ETCD_ADVERTISED_URL`
* `ETCD_CA_FILE`
* `ETCD_CERT_FILE`
* `ETCD_CORS`
* `ETCD_CONFIG_FILE`
* `ETCD_CPU_PROFILE_FILE`
* `ETCD_DATADIR`
* `ETCD_KEY_FILE`
* `ETCD_LISTEN_HOST`
* `ETCD_MACHINES`
* `ETCD_MACHINES_FILE`
* `ETCD_MAX_RETRY_ATTEMPTS`
* `ETCD_MAX_CLUSTER_SIZE`
* `ETCD_MAX_RESULT_BUFFER`
* `ETCD_NAME`
* `ETCD_SNAPSHOT`
* `ETCD_VERBOSE`
* `ETCD_VERY_VERBOSE`
* `ETCD_WEB_URL`
* `ETCD_PEER_ADVERTISED_URL`
* `ETCD_PEER_CA_FILE`
* `ETCD_PEER_CERT_FILE`
* `ETCD_PEER_KEY_FILE`
* `ETCD_PEER_LISTEN_HOST`

143
config.go
View File

@ -1,143 +0,0 @@
package main
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"io/ioutil"
"os"
"path/filepath"
"github.com/coreos/etcd/log"
"github.com/coreos/etcd/server"
)
//--------------------------------------
// Config
//--------------------------------------
// Get the server info from previous conf file
// or from the user
func getInfo(path string) *Info {
infoPath := filepath.Join(path, "info")
if force {
// Delete the old configuration if exist
logPath := filepath.Join(path, "log")
confPath := filepath.Join(path, "conf")
snapshotPath := filepath.Join(path, "snapshot")
os.Remove(infoPath)
os.Remove(logPath)
os.Remove(confPath)
os.RemoveAll(snapshotPath)
} else if info := readInfo(infoPath); info != nil {
log.Infof("Found node configuration in '%s'. Ignoring flags", infoPath)
return info
}
// Read info from command line
info := &argInfo
// Write to file.
content, _ := json.MarshalIndent(info, "", " ")
content = []byte(string(content) + "\n")
if err := ioutil.WriteFile(infoPath, content, 0644); err != nil {
log.Fatalf("Unable to write info to file: %v", err)
}
log.Infof("Wrote node configuration to '%s'", infoPath)
return info
}
// readInfo reads from info file and decode to Info struct
func readInfo(path string) *Info {
file, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
log.Fatal(err)
}
defer file.Close()
info := &Info{}
content, err := ioutil.ReadAll(file)
if err != nil {
log.Fatalf("Unable to read info: %v", err)
return nil
}
if err = json.Unmarshal(content, &info); err != nil {
log.Fatalf("Unable to parse info: %v", err)
return nil
}
return info
}
func tlsConfigFromInfo(info server.TLSInfo) (t server.TLSConfig, ok bool) {
var keyFile, certFile, CAFile string
var tlsCert tls.Certificate
var err error
t.Scheme = "http"
keyFile = info.KeyFile
certFile = info.CertFile
CAFile = info.CAFile
// If the user do not specify key file, cert file and
// CA file, the type will be HTTP
if keyFile == "" && certFile == "" && CAFile == "" {
return t, true
}
// both the key and cert must be present
if keyFile == "" || certFile == "" {
return t, false
}
tlsCert, err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
log.Fatal(err)
}
t.Scheme = "https"
t.Server.ClientAuth, t.Server.ClientCAs = newCertPool(CAFile)
// The client should trust the RootCA that the Server uses since
// everyone is a peer in the network.
t.Client.Certificates = []tls.Certificate{tlsCert}
t.Client.RootCAs = t.Server.ClientCAs
return t, true
}
// newCertPool creates x509 certPool and corresponding Auth Type.
// If the given CAfile is valid, add the cert into the pool and verify the clients'
// certs against the cert in the pool.
// If the given CAfile is empty, do not verify the clients' cert.
// If the given CAfile is not valid, fatal.
func newCertPool(CAFile string) (tls.ClientAuthType, *x509.CertPool) {
if CAFile == "" {
return tls.NoClientCert, nil
}
pemByte, err := ioutil.ReadFile(CAFile)
check(err)
block, pemByte := pem.Decode(pemByte)
cert, err := x509.ParseCertificate(block.Bytes)
check(err)
certPool := x509.NewCertPool()
certPool.AddCert(cert)
return tls.RequireAndVerifyClientCert, certPool
}

237
etcd.go
View File

@ -5,7 +5,8 @@ import (
"fmt"
"io/ioutil"
"os"
"strings"
"os/signal"
"runtime/pprof"
"github.com/coreos/etcd/log"
"github.com/coreos/etcd/server"
@ -13,189 +14,103 @@ import (
"github.com/coreos/go-raft"
)
//------------------------------------------------------------------------------
//
// Initialization
//
//------------------------------------------------------------------------------
var (
veryVerbose bool
machines string
machinesFile string
cluster []string
argInfo Info
dirPath string
force bool
printVersion bool
maxSize int
snapshot bool
retryTimes int
maxClusterSize int
cpuprofile string
cors string
)
func init() {
flag.BoolVar(&printVersion, "version", false, "print the version and exit")
flag.BoolVar(&log.Verbose, "v", false, "verbose logging")
flag.BoolVar(&veryVerbose, "vv", false, "very verbose logging")
flag.StringVar(&machines, "C", "", "the ip address and port of a existing machines in the cluster, sepearate by comma")
flag.StringVar(&machinesFile, "CF", "", "the file contains a list of existing machines in the cluster, seperate by comma")
flag.StringVar(&argInfo.Name, "n", "default-name", "the node name (required)")
flag.StringVar(&argInfo.EtcdURL, "c", "127.0.0.1:4001", "the advertised public hostname:port for etcd client communication")
flag.StringVar(&argInfo.RaftURL, "s", "127.0.0.1:7001", "the advertised public hostname:port for raft server communication")
flag.StringVar(&argInfo.EtcdListenHost, "cl", "", "the listening hostname for etcd client communication (defaults to advertised ip)")
flag.StringVar(&argInfo.RaftListenHost, "sl", "", "the listening hostname for raft server communication (defaults to advertised ip)")
flag.StringVar(&argInfo.WebURL, "w", "", "the hostname:port of web interface")
flag.StringVar(&argInfo.RaftTLS.CAFile, "serverCAFile", "", "the path of the CAFile")
flag.StringVar(&argInfo.RaftTLS.CertFile, "serverCert", "", "the cert file of the server")
flag.StringVar(&argInfo.RaftTLS.KeyFile, "serverKey", "", "the key file of the server")
flag.StringVar(&argInfo.EtcdTLS.CAFile, "clientCAFile", "", "the path of the client CAFile")
flag.StringVar(&argInfo.EtcdTLS.CertFile, "clientCert", "", "the cert file of the client")
flag.StringVar(&argInfo.EtcdTLS.KeyFile, "clientKey", "", "the key file of the client")
flag.StringVar(&dirPath, "d", ".", "the directory to store log and snapshot")
flag.BoolVar(&force, "f", false, "force new node configuration if existing is found (WARNING: data loss!)")
flag.BoolVar(&snapshot, "snapshot", false, "open or close snapshot")
flag.IntVar(&maxSize, "m", 1024, "the max size of result buffer")
flag.IntVar(&retryTimes, "r", 3, "the max retry attempts when trying to join a cluster")
flag.IntVar(&maxClusterSize, "maxsize", 9, "the max size of the cluster")
flag.StringVar(&cpuprofile, "cpuprofile", "", "write cpu profile to file")
flag.StringVar(&cors, "cors", "", "whitelist origins for cross-origin resource sharing (e.g. '*' or 'http://localhost:8001,etc')")
}
//------------------------------------------------------------------------------
//
// Typedefs
//
//------------------------------------------------------------------------------
type Info struct {
Name string `json:"name"`
RaftURL string `json:"raftURL"`
EtcdURL string `json:"etcdURL"`
WebURL string `json:"webURL"`
RaftListenHost string `json:"raftListenHost"`
EtcdListenHost string `json:"etcdListenHost"`
RaftTLS server.TLSInfo `json:"raftTLS"`
EtcdTLS server.TLSInfo `json:"etcdTLS"`
}
//------------------------------------------------------------------------------
//
// Functions
//
//------------------------------------------------------------------------------
//--------------------------------------
// Main
//--------------------------------------
func main() {
flag.Parse()
parseFlags()
if printVersion {
fmt.Println(server.ReleaseVersion)
os.Exit(0)
// Load configuration.
var config = server.NewConfig()
if err := config.Load(os.Args[1:]); err != nil {
log.Fatal("Configuration error:", err)
}
if cpuprofile != "" {
runCPUProfile()
}
if veryVerbose {
// Turn on logging.
if config.VeryVerbose {
log.Verbose = true
raft.SetLogLevel(raft.Debug)
} else if config.Verbose {
log.Verbose = true
}
if machines != "" {
cluster = strings.Split(machines, ",")
} else if machinesFile != "" {
b, err := ioutil.ReadFile(machinesFile)
if err != nil {
log.Fatalf("Unable to read the given machines file: %s", err)
}
cluster = strings.Split(string(b), ",")
}
// Check TLS arguments
raftTLSConfig, ok := tlsConfigFromInfo(argInfo.RaftTLS)
if !ok {
log.Fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
}
etcdTLSConfig, ok := tlsConfigFromInfo(argInfo.EtcdTLS)
if !ok {
log.Fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
}
argInfo.Name = strings.TrimSpace(argInfo.Name)
if argInfo.Name == "" {
log.Fatal("ERROR: server name required. e.g. '-n=server_name'")
}
// Check host name arguments
argInfo.RaftURL = sanitizeURL(argInfo.RaftURL, raftTLSConfig.Scheme)
argInfo.EtcdURL = sanitizeURL(argInfo.EtcdURL, etcdTLSConfig.Scheme)
argInfo.WebURL = sanitizeURL(argInfo.WebURL, "http")
argInfo.RaftListenHost = sanitizeListenHost(argInfo.RaftListenHost, argInfo.RaftURL)
argInfo.EtcdListenHost = sanitizeListenHost(argInfo.EtcdListenHost, argInfo.EtcdURL)
// Read server info from file or grab it from user.
if err := os.MkdirAll(dirPath, 0744); err != nil {
// Create data directory if it doesn't already exist.
if err := os.MkdirAll(config.DataDir, 0744); err != nil {
log.Fatalf("Unable to create path: %s", err)
}
info := getInfo(dirPath)
// Load info object.
info, err := config.Info()
if err != nil {
log.Fatal("info:", err)
}
if info.Name == "" {
log.Fatal("ERROR: server name required. e.g. '-n=server_name'")
}
// Create etcd key-value store
// Retrieve TLS configuration.
tlsConfig, err := info.EtcdTLS.Config()
if err != nil {
log.Fatal("Client TLS:", err)
}
peerTLSConfig, err := info.RaftTLS.Config()
if err != nil {
log.Fatal("Peer TLS:", err)
}
// Create etcd key-value store and registry.
store := store.New()
// Create a shared node registry.
registry := server.NewRegistry(store)
// Create peer server.
ps := server.NewPeerServer(info.Name, dirPath, info.RaftURL, info.RaftListenHost, &raftTLSConfig, &info.RaftTLS, registry, store)
ps.MaxClusterSize = maxClusterSize
ps.RetryTimes = retryTimes
ps := server.NewPeerServer(info.Name, config.DataDir, info.RaftURL, info.RaftListenHost, &peerTLSConfig, &info.RaftTLS, registry, store, config.SnapCount)
ps.MaxClusterSize = config.MaxClusterSize
ps.RetryTimes = config.MaxRetryAttempts
s := server.New(info.Name, info.EtcdURL, info.EtcdListenHost, &etcdTLSConfig, &info.EtcdTLS, ps, registry, store)
if err := s.AllowOrigins(cors); err != nil {
// Create client server.
s := server.New(info.Name, info.EtcdURL, info.EtcdListenHost, &tlsConfig, &info.EtcdTLS, ps, registry, store)
if err := s.AllowOrigins(config.Cors); err != nil {
panic(err)
}
ps.SetServer(s)
// Run peer server in separate thread while the client server blocks.
go func() {
log.Fatal(ps.ListenAndServe(snapshot, cluster))
log.Fatal(ps.ListenAndServe(config.Snapshot, config.Machines))
}()
log.Fatal(s.ListenAndServe())
}
// Parses non-configuration flags.
func parseFlags() {
var versionFlag bool
var cpuprofile string
f := flag.NewFlagSet(os.Args[0], -1)
f.SetOutput(ioutil.Discard)
f.BoolVar(&versionFlag, "version", false, "print the version and exit")
f.StringVar(&cpuprofile, "cpuprofile", "", "write cpu profile to file")
f.Parse(os.Args[1:])
// Print version if necessary.
if versionFlag {
fmt.Println(server.ReleaseVersion)
os.Exit(0)
}
// Begin CPU profiling if specified.
if cpuprofile != "" {
f, err := os.Create(cpuprofile)
if err != nil {
log.Fatal(err)
}
pprof.StartCPUProfile(f)
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
sig := <-c
log.Infof("captured %v, stopping profiler and exiting..", sig)
pprof.StopCPUProfile()
os.Exit(1)
}()
}
}

406
server/config.go Normal file
View File

@ -0,0 +1,406 @@
package server
import (
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"net"
"net/url"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"github.com/BurntSushi/toml"
)
// The default location for the etcd configuration file.
const DefaultSystemConfigPath = "/etc/etcd/etcd.conf"
// Config represents the server configuration.
type Config struct {
SystemPath string
AdvertisedUrl string `toml:"advertised_url" env:"ETCD_ADVERTISED_URL"`
CAFile string `toml:"ca_file" env:"ETCD_CA_FILE"`
CertFile string `toml:"cert_file" env:"ETCD_CERT_FILE"`
Cors []string `toml:"cors" env:"ETCD_CORS"`
DataDir string `toml:"datadir" env:"ETCD_DATADIR"`
KeyFile string `toml:"key_file" env:"ETCD_KEY_FILE"`
ListenHost string `toml:"listen_host" env:"ETCD_LISTEN_HOST"`
Machines []string `toml:"machines" env:"ETCD_MACHINES"`
MachinesFile string `toml:"machines_file" env:"ETCD_MACHINES_FILE"`
MaxClusterSize int `toml:"max_cluster_size" env:"ETCD_MAX_CLUSTER_SIZE"`
MaxResultBuffer int `toml:"max_result_buffer" env:"ETCD_MAX_RESULT_BUFFER"`
MaxRetryAttempts int `toml:"max_retry_attempts" env:"ETCD_MAX_RETRY_ATTEMPTS"`
Name string `toml:"name" env:"ETCD_NAME"`
Snapshot bool `toml:"snapshot" env:"ETCD_SNAPSHOT"`
SnapCount int `toml:"snap_count" env:"ETCD_SNAPCOUNT"`
Verbose bool `toml:"verbose" env:"ETCD_VERBOSE"`
VeryVerbose bool `toml:"very_verbose" env:"ETCD_VERY_VERBOSE"`
WebURL string `toml:"web_url" env:"ETCD_WEB_URL"`
Peer struct {
AdvertisedUrl string `toml:"advertised_url" env:"ETCD_PEER_ADVERTISED_URL"`
CAFile string `toml:"ca_file" env:"ETCD_PEER_CA_FILE"`
CertFile string `toml:"cert_file" env:"ETCD_PEER_CERT_FILE"`
KeyFile string `toml:"key_file" env:"ETCD_PEER_KEY_FILE"`
ListenHost string `toml:"listen_host" env:"ETCD_PEER_LISTEN_HOST"`
}
}
// NewConfig returns a Config initialized with default values.
func NewConfig() *Config {
c := new(Config)
c.SystemPath = DefaultSystemConfigPath
c.AdvertisedUrl = "127.0.0.1:4001"
c.AdvertisedUrl = "127.0.0.1:4001"
c.DataDir = "."
c.MaxClusterSize = 9
c.MaxResultBuffer = 1024
c.MaxRetryAttempts = 3
c.Peer.AdvertisedUrl = "127.0.0.1:7001"
c.SnapCount = 10000
return c
}
// Loads the configuration from the system config, command line config,
// environment variables, and finally command line arguments.
func (c *Config) Load(arguments []string) error {
var path string
f := flag.NewFlagSet("etcd", -1)
f.SetOutput(ioutil.Discard)
f.StringVar(&path, "config", "", "path to config file")
f.Parse(arguments)
// Load from system file.
if err := c.LoadSystemFile(); err != nil {
return err
}
// Load from config file specified in arguments.
if path != "" {
if err := c.LoadFile(path); err != nil {
return err
}
}
// Load from the environment variables next.
if err := c.LoadEnv(); err != nil {
return err
}
// Load from command line flags.
if err := c.LoadFlags(arguments); err != nil {
return err
}
// Loads machines if a machine file was specified.
if err := c.LoadMachineFile(); err != nil {
return err
}
// Sanitize all the input fields.
if err := c.Sanitize(); err != nil {
return fmt.Errorf("sanitize:", err)
}
return nil
}
// Loads from the system etcd configuration file if it exists.
func (c *Config) LoadSystemFile() error {
if _, err := os.Stat(c.SystemPath); os.IsNotExist(err) {
return nil
}
return c.LoadFile(c.SystemPath)
}
// Loads configuration from a file.
func (c *Config) LoadFile(path string) error {
_, err := toml.DecodeFile(path, &c)
return err
}
// LoadEnv loads the configuration via environment variables.
func (c *Config) LoadEnv() error {
if err := c.loadEnv(c); err != nil {
return err
}
if err := c.loadEnv(&c.Peer); err != nil {
return err
}
return nil
}
func (c *Config) loadEnv(target interface{}) error {
value := reflect.Indirect(reflect.ValueOf(target))
typ := value.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Retrieve environment variable.
v := strings.TrimSpace(os.Getenv(field.Tag.Get("env")))
if v == "" {
continue
}
// Set the appropriate type.
switch field.Type.Kind() {
case reflect.Bool:
value.Field(i).SetBool(v != "0" && v != "false")
case reflect.Int:
newValue, err := strconv.ParseInt(v, 10, 0)
if err != nil {
return fmt.Errorf("Parse error: %s: %s", field.Tag.Get("env"), err)
}
value.Field(i).SetInt(newValue)
case reflect.String:
value.Field(i).SetString(v)
case reflect.Slice:
value.Field(i).Set(reflect.ValueOf(trimsplit(v, ",")))
}
}
return nil
}
// Loads configuration from command line flags.
func (c *Config) LoadFlags(arguments []string) error {
var machines, cors string
var force bool
f := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
f.BoolVar(&force, "f", false, "force new node configuration if existing is found (WARNING: data loss!)")
f.BoolVar(&c.Verbose, "v", c.Verbose, "verbose logging")
f.BoolVar(&c.VeryVerbose, "vv", c.Verbose, "very verbose logging")
f.StringVar(&machines, "C", "", "the ip address and port of a existing machines in the cluster, sepearate by comma")
f.StringVar(&c.MachinesFile, "CF", c.MachinesFile, "the file contains a list of existing machines in the cluster, seperate by comma")
f.StringVar(&c.Name, "n", c.Name, "the node name (required)")
f.StringVar(&c.AdvertisedUrl, "c", c.AdvertisedUrl, "the advertised public hostname:port for etcd client communication")
f.StringVar(&c.Peer.AdvertisedUrl, "s", c.Peer.AdvertisedUrl, "the advertised public hostname:port for raft server communication")
f.StringVar(&c.ListenHost, "cl", c.ListenHost, "the listening hostname for etcd client communication (defaults to advertised ip)")
f.StringVar(&c.Peer.ListenHost, "sl", c.Peer.ListenHost, "the listening hostname for raft server communication (defaults to advertised ip)")
f.StringVar(&c.WebURL, "w", c.WebURL, "the hostname:port of web interface")
f.StringVar(&c.Peer.CAFile, "serverCAFile", c.Peer.CAFile, "the path of the CAFile")
f.StringVar(&c.Peer.CertFile, "serverCert", c.Peer.CertFile, "the cert file of the server")
f.StringVar(&c.Peer.KeyFile, "serverKey", c.Peer.KeyFile, "the key file of the server")
f.StringVar(&c.CAFile, "clientCAFile", c.CAFile, "the path of the client CAFile")
f.StringVar(&c.CertFile, "clientCert", c.CertFile, "the cert file of the client")
f.StringVar(&c.KeyFile, "clientKey", c.KeyFile, "the key file of the client")
f.StringVar(&c.DataDir, "d", c.DataDir, "the directory to store log and snapshot")
f.IntVar(&c.MaxResultBuffer, "m", c.MaxResultBuffer, "the max size of result buffer")
f.IntVar(&c.MaxRetryAttempts, "r", c.MaxRetryAttempts, "the max retry attempts when trying to join a cluster")
f.IntVar(&c.MaxClusterSize, "maxsize", c.MaxClusterSize, "the max size of the cluster")
f.StringVar(&cors, "cors", "", "whitelist origins for cross-origin resource sharing (e.g. '*' or 'http://localhost:8001,etc')")
f.BoolVar(&c.Snapshot, "snapshot", c.Snapshot, "open or close snapshot")
f.IntVar(&c.SnapCount, "snapCount", c.SnapCount, "save the in memory logs and states to a snapshot file after snapCount transactions")
// These flags are ignored since they were already parsed.
var path string
f.StringVar(&path, "config", "", "path to config file")
f.Parse(arguments)
// Convert some parameters to lists.
if machines != "" {
c.Machines = trimsplit(machines, ",")
}
if cors != "" {
c.Cors = trimsplit(cors, ",")
}
// Force remove server configuration if specified.
if force {
c.Reset()
}
return nil
}
// LoadMachineFile loads the machines listed in the machine file.
func (c *Config) LoadMachineFile() error {
if c.MachinesFile == "" {
return nil
}
b, err := ioutil.ReadFile(c.MachinesFile)
if err != nil {
return fmt.Errorf("Machines file error: %s", err)
}
c.Machines = trimsplit(string(b), ",")
return nil
}
// Reset removes all server configuration files.
func (c *Config) Reset() error {
if err := os.RemoveAll(filepath.Join(c.DataDir, "info")); err != nil {
return err
}
if err := os.RemoveAll(filepath.Join(c.DataDir, "log")); err != nil {
return err
}
if err := os.RemoveAll(filepath.Join(c.DataDir, "conf")); err != nil {
return err
}
if err := os.RemoveAll(filepath.Join(c.DataDir, "snapshot")); err != nil {
return err
}
return nil
}
// Reads the info file from the file system or initializes it based on the config.
func (c *Config) Info() (*Info, error) {
info := &Info{}
path := filepath.Join(c.DataDir, "info")
// Open info file and read it out.
f, err := os.Open(path)
if err != nil && !os.IsNotExist(err) {
return nil, err
} else if f != nil {
defer f.Close()
if err := json.NewDecoder(f).Decode(&info); err != nil {
return nil, err
}
return info, nil
}
// If the file doesn't exist then initialize it.
info.Name = strings.TrimSpace(c.Name)
info.EtcdURL = c.AdvertisedUrl
info.EtcdListenHost = c.ListenHost
info.RaftURL = c.Peer.AdvertisedUrl
info.RaftListenHost = c.Peer.ListenHost
info.WebURL = c.WebURL
info.EtcdTLS = c.TLSInfo()
info.RaftTLS = c.PeerTLSInfo()
// Write to file.
f, err = os.Create(path)
if err != nil {
return nil, err
}
defer f.Close()
if err := json.NewEncoder(f).Encode(info); err != nil {
return nil, err
}
return info, nil
}
// Sanitize cleans the input fields.
func (c *Config) Sanitize() error {
tlsConfig, err := c.TLSConfig()
if err != nil {
return err
}
peerTlsConfig, err := c.PeerTLSConfig()
if err != nil {
return err
}
// Sanitize the URLs first.
if c.AdvertisedUrl, err = sanitizeURL(c.AdvertisedUrl, tlsConfig.Scheme); err != nil {
return fmt.Errorf("Advertised URL: %s", err)
}
if c.ListenHost, err = sanitizeListenHost(c.ListenHost, c.AdvertisedUrl); err != nil {
return fmt.Errorf("Listen Host: %s", err)
}
if c.WebURL, err = sanitizeURL(c.WebURL, "http"); err != nil {
return fmt.Errorf("Web URL: %s", err)
}
if c.Peer.AdvertisedUrl, err = sanitizeURL(c.Peer.AdvertisedUrl, peerTlsConfig.Scheme); err != nil {
return fmt.Errorf("Peer Advertised URL: %s", err)
}
if c.Peer.ListenHost, err = sanitizeListenHost(c.Peer.ListenHost, c.Peer.AdvertisedUrl); err != nil {
return fmt.Errorf("Peer Listen Host: %s", err)
}
return nil
}
// TLSInfo retrieves a TLSInfo object for the client server.
func (c *Config) TLSInfo() TLSInfo {
return TLSInfo{
CAFile: c.CAFile,
CertFile: c.CertFile,
KeyFile: c.KeyFile,
}
}
// ClientTLSConfig generates the TLS configuration for the client server.
func (c *Config) TLSConfig() (TLSConfig, error) {
return c.TLSInfo().Config()
}
// PeerTLSInfo retrieves a TLSInfo object for the peer server.
func (c *Config) PeerTLSInfo() TLSInfo {
return TLSInfo{
CAFile: c.Peer.CAFile,
CertFile: c.Peer.CertFile,
KeyFile: c.Peer.KeyFile,
}
}
// PeerTLSConfig generates the TLS configuration for the peer server.
func (c *Config) PeerTLSConfig() (TLSConfig, error) {
return c.PeerTLSInfo().Config()
}
// sanitizeURL will cleanup a host string in the format hostname:port and
// attach a schema.
func sanitizeURL(host string, defaultScheme string) (string, error) {
// Blank URLs are fine input, just return it
if len(host) == 0 {
return host, nil
}
p, err := url.Parse(host)
if err != nil {
return "", err
}
// Make sure the host is in Host:Port format
_, _, err = net.SplitHostPort(host)
if err != nil {
return "", err
}
p = &url.URL{Host: host, Scheme: defaultScheme}
return p.String(), nil
}
// sanitizeListenHost cleans up the ListenHost parameter and appends a port
// if necessary based on the advertised port.
func sanitizeListenHost(listen string, advertised string) (string, error) {
aurl, err := url.Parse(advertised)
if err != nil {
return "", err
}
ahost, aport, err := net.SplitHostPort(aurl.Host)
if err != nil {
return "", err
}
// If the listen host isn't set use the advertised host
if listen == "" {
listen = ahost
}
return net.JoinHostPort(listen, aport), nil
}

479
server/config_test.go Normal file
View File

@ -0,0 +1,479 @@
package server
import (
"io/ioutil"
"os"
"testing"
"github.com/BurntSushi/toml"
"github.com/stretchr/testify/assert"
)
// Ensures that a configuration can be deserialized from TOML.
func TestConfigTOML(t *testing.T) {
content := `
advertised_url = "127.0.0.1:4002"
ca_file = "/tmp/file.ca"
cert_file = "/tmp/file.cert"
cors = ["*"]
cpu_profile_file = "XXX"
datadir = "/tmp/data"
key_file = "/tmp/file.key"
listen_host = "127.0.0.1:4003"
machines = ["coreos.com:4001", "coreos.com:4002"]
machines_file = "/tmp/machines"
max_cluster_size = 10
max_result_buffer = 512
max_retry_attempts = 5
name = "test-name"
snapshot = true
verbose = true
very_verbose = true
web_url = "/web"
[peer]
advertised_url = "127.0.0.1:7002"
ca_file = "/tmp/peer/file.ca"
cert_file = "/tmp/peer/file.cert"
key_file = "/tmp/peer/file.key"
listen_host = "127.0.0.1:7003"
`
c := NewConfig()
_, err := toml.Decode(content, &c)
assert.Nil(t, err, "")
assert.Equal(t, c.AdvertisedUrl, "127.0.0.1:4002", "")
assert.Equal(t, c.CAFile, "/tmp/file.ca", "")
assert.Equal(t, c.CertFile, "/tmp/file.cert", "")
assert.Equal(t, c.Cors, []string{"*"}, "")
assert.Equal(t, c.DataDir, "/tmp/data", "")
assert.Equal(t, c.KeyFile, "/tmp/file.key", "")
assert.Equal(t, c.ListenHost, "127.0.0.1:4003", "")
assert.Equal(t, c.Machines, []string{"coreos.com:4001", "coreos.com:4002"}, "")
assert.Equal(t, c.MachinesFile, "/tmp/machines", "")
assert.Equal(t, c.MaxClusterSize, 10, "")
assert.Equal(t, c.MaxResultBuffer, 512, "")
assert.Equal(t, c.MaxRetryAttempts, 5, "")
assert.Equal(t, c.Name, "test-name", "")
assert.Equal(t, c.Snapshot, true, "")
assert.Equal(t, c.Verbose, true, "")
assert.Equal(t, c.VeryVerbose, true, "")
assert.Equal(t, c.WebURL, "/web", "")
assert.Equal(t, c.Peer.AdvertisedUrl, "127.0.0.1:7002", "")
assert.Equal(t, c.Peer.CAFile, "/tmp/peer/file.ca", "")
assert.Equal(t, c.Peer.CertFile, "/tmp/peer/file.cert", "")
assert.Equal(t, c.Peer.KeyFile, "/tmp/peer/file.key", "")
assert.Equal(t, c.Peer.ListenHost, "127.0.0.1:7003", "")
}
// Ensures that a configuration can be retrieved from environment variables.
func TestConfigEnv(t *testing.T) {
os.Setenv("ETCD_CA_FILE", "/tmp/file.ca")
os.Setenv("ETCD_CERT_FILE", "/tmp/file.cert")
os.Setenv("ETCD_CPU_PROFILE_FILE", "XXX")
os.Setenv("ETCD_CORS", "localhost:4001,localhost:4002")
os.Setenv("ETCD_DATADIR", "/tmp/data")
os.Setenv("ETCD_KEY_FILE", "/tmp/file.key")
os.Setenv("ETCD_LISTEN_HOST", "127.0.0.1:4003")
os.Setenv("ETCD_MACHINES", "coreos.com:4001,coreos.com:4002")
os.Setenv("ETCD_MACHINES_FILE", "/tmp/machines")
os.Setenv("ETCD_MAX_CLUSTER_SIZE", "10")
os.Setenv("ETCD_MAX_RESULT_BUFFER", "512")
os.Setenv("ETCD_MAX_RETRY_ATTEMPTS", "5")
os.Setenv("ETCD_NAME", "test-name")
os.Setenv("ETCD_SNAPSHOT", "true")
os.Setenv("ETCD_VERBOSE", "1")
os.Setenv("ETCD_VERY_VERBOSE", "yes")
os.Setenv("ETCD_WEB_URL", "/web")
os.Setenv("ETCD_PEER_ADVERTISED_URL", "127.0.0.1:7002")
os.Setenv("ETCD_PEER_CA_FILE", "/tmp/peer/file.ca")
os.Setenv("ETCD_PEER_CERT_FILE", "/tmp/peer/file.cert")
os.Setenv("ETCD_PEER_KEY_FILE", "/tmp/peer/file.key")
os.Setenv("ETCD_PEER_LISTEN_HOST", "127.0.0.1:7003")
c := NewConfig()
c.LoadEnv()
assert.Equal(t, c.CAFile, "/tmp/file.ca", "")
assert.Equal(t, c.CertFile, "/tmp/file.cert", "")
assert.Equal(t, c.Cors, []string{"localhost:4001", "localhost:4002"}, "")
assert.Equal(t, c.DataDir, "/tmp/data", "")
assert.Equal(t, c.KeyFile, "/tmp/file.key", "")
assert.Equal(t, c.ListenHost, "127.0.0.1:4003", "")
assert.Equal(t, c.Machines, []string{"coreos.com:4001", "coreos.com:4002"}, "")
assert.Equal(t, c.MachinesFile, "/tmp/machines", "")
assert.Equal(t, c.MaxClusterSize, 10, "")
assert.Equal(t, c.MaxResultBuffer, 512, "")
assert.Equal(t, c.MaxRetryAttempts, 5, "")
assert.Equal(t, c.Name, "test-name", "")
assert.Equal(t, c.Snapshot, true, "")
assert.Equal(t, c.Verbose, true, "")
assert.Equal(t, c.VeryVerbose, true, "")
assert.Equal(t, c.WebURL, "/web", "")
assert.Equal(t, c.Peer.AdvertisedUrl, "127.0.0.1:7002", "")
assert.Equal(t, c.Peer.CAFile, "/tmp/peer/file.ca", "")
assert.Equal(t, c.Peer.CertFile, "/tmp/peer/file.cert", "")
assert.Equal(t, c.Peer.KeyFile, "/tmp/peer/file.key", "")
assert.Equal(t, c.Peer.ListenHost, "127.0.0.1:7003", "")
}
// Ensures that a the advertised url can be parsed from the environment.
func TestConfigAdvertisedUrlEnv(t *testing.T) {
withEnv("ETCD_ADVERTISED_URL", "127.0.0.1:4002", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.AdvertisedUrl, "127.0.0.1:4002", "")
})
}
// Ensures that a the advertised flag can be parsed.
func TestConfigAdvertisedUrlFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-c", "127.0.0.1:4002"}), "")
assert.Equal(t, c.AdvertisedUrl, "127.0.0.1:4002", "")
}
// Ensures that a the CA file can be parsed from the environment.
func TestConfigCAFileEnv(t *testing.T) {
withEnv("ETCD_CA_FILE", "/tmp/file.ca", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.CAFile, "/tmp/file.ca", "")
})
}
// Ensures that a the CA file flag can be parsed.
func TestConfigCAFileFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-clientCAFile", "/tmp/file.ca"}), "")
assert.Equal(t, c.CAFile, "/tmp/file.ca", "")
}
// Ensures that a the CA file can be parsed from the environment.
func TestConfigCertFileEnv(t *testing.T) {
withEnv("ETCD_CERT_FILE", "/tmp/file.cert", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.CertFile, "/tmp/file.cert", "")
})
}
// Ensures that a the Cert file flag can be parsed.
func TestConfigCertFileFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-clientCert", "/tmp/file.cert"}), "")
assert.Equal(t, c.CertFile, "/tmp/file.cert", "")
}
// Ensures that a the Key file can be parsed from the environment.
func TestConfigKeyFileEnv(t *testing.T) {
withEnv("ETCD_KEY_FILE", "/tmp/file.key", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.KeyFile, "/tmp/file.key", "")
})
}
// Ensures that a the Key file flag can be parsed.
func TestConfigKeyFileFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-clientKey", "/tmp/file.key"}), "")
assert.Equal(t, c.KeyFile, "/tmp/file.key", "")
}
// Ensures that a the Listen Host can be parsed from the environment.
func TestConfigListenHostEnv(t *testing.T) {
withEnv("ETCD_LISTEN_HOST", "127.0.0.1:4003", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.ListenHost, "127.0.0.1:4003", "")
})
}
// Ensures that a the Listen Host file flag can be parsed.
func TestConfigListenHostFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-cl", "127.0.0.1:4003"}), "")
assert.Equal(t, c.ListenHost, "127.0.0.1:4003", "")
}
// Ensures that the Machines can be parsed from the environment.
func TestConfigMachinesEnv(t *testing.T) {
withEnv("ETCD_MACHINES", "coreos.com:4001,coreos.com:4002", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.Machines, []string{"coreos.com:4001", "coreos.com:4002"}, "")
})
}
// Ensures that a the Machines flag can be parsed.
func TestConfigMachinesFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-C", "coreos.com:4001,coreos.com:4002"}), "")
assert.Equal(t, c.Machines, []string{"coreos.com:4001", "coreos.com:4002"}, "")
}
// Ensures that the Machines File can be parsed from the environment.
func TestConfigMachinesFileEnv(t *testing.T) {
withEnv("ETCD_MACHINES_FILE", "/tmp/machines", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.MachinesFile, "/tmp/machines", "")
})
}
// Ensures that a the Machines File flag can be parsed.
func TestConfigMachinesFileFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-CF", "/tmp/machines"}), "")
assert.Equal(t, c.MachinesFile, "/tmp/machines", "")
}
// Ensures that the Max Cluster Size can be parsed from the environment.
func TestConfigMaxClusterSizeEnv(t *testing.T) {
withEnv("ETCD_MAX_CLUSTER_SIZE", "5", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.MaxClusterSize, 5, "")
})
}
// Ensures that a the Max Cluster Size flag can be parsed.
func TestConfigMaxClusterSizeFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-maxsize", "5"}), "")
assert.Equal(t, c.MaxClusterSize, 5, "")
}
// Ensures that the Max Result Buffer can be parsed from the environment.
func TestConfigMaxResultBufferEnv(t *testing.T) {
withEnv("ETCD_MAX_RESULT_BUFFER", "512", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.MaxResultBuffer, 512, "")
})
}
// Ensures that a the Max Result Buffer flag can be parsed.
func TestConfigMaxResultBufferFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-m", "512"}), "")
assert.Equal(t, c.MaxResultBuffer, 512, "")
}
// Ensures that the Max Retry Attempts can be parsed from the environment.
func TestConfigMaxRetryAttemptsEnv(t *testing.T) {
withEnv("ETCD_MAX_RETRY_ATTEMPTS", "10", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.MaxRetryAttempts, 10, "")
})
}
// Ensures that a the Max Retry Attempts flag can be parsed.
func TestConfigMaxRetryAttemptsFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-r", "10"}), "")
assert.Equal(t, c.MaxRetryAttempts, 10, "")
}
// Ensures that the Name can be parsed from the environment.
func TestConfigNameEnv(t *testing.T) {
withEnv("ETCD_NAME", "test-name", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.Name, "test-name", "")
})
}
// Ensures that a the Name flag can be parsed.
func TestConfigNameFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-n", "test-name"}), "")
assert.Equal(t, c.Name, "test-name", "")
}
// Ensures that Snapshot can be parsed from the environment.
func TestConfigSnapshotEnv(t *testing.T) {
withEnv("ETCD_SNAPSHOT", "1", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.Snapshot, true, "")
})
}
// Ensures that a the Snapshot flag can be parsed.
func TestConfigSnapshotFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-snapshot"}), "")
assert.Equal(t, c.Snapshot, true, "")
}
// Ensures that Verbose can be parsed from the environment.
func TestConfigVerboseEnv(t *testing.T) {
withEnv("ETCD_VERBOSE", "true", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.Verbose, true, "")
})
}
// Ensures that a the Verbose flag can be parsed.
func TestConfigVerboseFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-v"}), "")
assert.Equal(t, c.Verbose, true, "")
}
// Ensures that Very Verbose can be parsed from the environment.
func TestConfigVeryVerboseEnv(t *testing.T) {
withEnv("ETCD_VERY_VERBOSE", "true", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.VeryVerbose, true, "")
})
}
// Ensures that a the Very Verbose flag can be parsed.
func TestConfigVeryVerboseFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-vv"}), "")
assert.Equal(t, c.VeryVerbose, true, "")
}
// Ensures that Web URL can be parsed from the environment.
func TestConfigWebURLEnv(t *testing.T) {
withEnv("ETCD_WEB_URL", "/web", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.WebURL, "/web", "")
})
}
// Ensures that a the Web URL flag can be parsed.
func TestConfigWebURLFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-w", "/web"}), "")
assert.Equal(t, c.WebURL, "/web", "")
}
// Ensures that the Peer Advertised URL can be parsed from the environment.
func TestConfigPeerAdvertisedUrlEnv(t *testing.T) {
withEnv("ETCD_PEER_ADVERTISED_URL", "localhost:7002", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.Peer.AdvertisedUrl, "localhost:7002", "")
})
}
// Ensures that a the Peer Advertised URL flag can be parsed.
func TestConfigPeerAdvertisedUrlFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-s", "localhost:7002"}), "")
assert.Equal(t, c.Peer.AdvertisedUrl, "localhost:7002", "")
}
// Ensures that the Peer CA File can be parsed from the environment.
func TestConfigPeerCAFileEnv(t *testing.T) {
withEnv("ETCD_PEER_CA_FILE", "/tmp/peer/file.ca", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.Peer.CAFile, "/tmp/peer/file.ca", "")
})
}
// Ensures that a the Peer CA file flag can be parsed.
func TestConfigPeerCAFileFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-serverCAFile", "/tmp/peer/file.ca"}), "")
assert.Equal(t, c.Peer.CAFile, "/tmp/peer/file.ca", "")
}
// Ensures that the Peer Cert File can be parsed from the environment.
func TestConfigPeerCertFileEnv(t *testing.T) {
withEnv("ETCD_PEER_CERT_FILE", "/tmp/peer/file.cert", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.Peer.CertFile, "/tmp/peer/file.cert", "")
})
}
// Ensures that a the Cert file flag can be parsed.
func TestConfigPeerCertFileFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-serverCert", "/tmp/peer/file.cert"}), "")
assert.Equal(t, c.Peer.CertFile, "/tmp/peer/file.cert", "")
}
// Ensures that the Peer Key File can be parsed from the environment.
func TestConfigPeerKeyFileEnv(t *testing.T) {
withEnv("ETCD_PEER_KEY_FILE", "/tmp/peer/file.key", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.Peer.KeyFile, "/tmp/peer/file.key", "")
})
}
// Ensures that a the Peer Key file flag can be parsed.
func TestConfigPeerKeyFileFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-serverKey", "/tmp/peer/file.key"}), "")
assert.Equal(t, c.Peer.KeyFile, "/tmp/peer/file.key", "")
}
// Ensures that the Peer Listen Host can be parsed from the environment.
func TestConfigPeerListenHostEnv(t *testing.T) {
withEnv("ETCD_PEER_LISTEN_HOST", "localhost:7004", func(c *Config) {
assert.Nil(t, c.LoadEnv(), "")
assert.Equal(t, c.Peer.ListenHost, "localhost:7004", "")
})
}
// Ensures that a the Peer Listen Host file flag can be parsed.
func TestConfigPeerListenHostFlag(t *testing.T) {
c := NewConfig()
assert.Nil(t, c.LoadFlags([]string{"-sl", "127.0.0.1:4003"}), "")
assert.Equal(t, c.Peer.ListenHost, "127.0.0.1:4003", "")
}
// Ensures that a system config field is overridden by a custom config field.
func TestConfigCustomConfigOverrideSystemConfig(t *testing.T) {
system := `advertised_url = "127.0.0.1:5000"`
custom := `advertised_url = "127.0.0.1:6000"`
withTempFile(system, func(p1 string) {
withTempFile(custom, func(p2 string) {
c := NewConfig()
c.SystemPath = p1
assert.Nil(t, c.Load([]string{"-config", p2}), "")
assert.Equal(t, c.AdvertisedUrl, "http://127.0.0.1:6000", "")
})
})
}
// Ensures that a custom config field is overridden by an environment variable.
func TestConfigEnvVarOverrideCustomConfig(t *testing.T) {
os.Setenv("ETCD_PEER_ADVERTISED_URL", "127.0.0.1:8000")
defer os.Setenv("ETCD_PEER_ADVERTISED_URL", "")
custom := `[peer]`+"\n"+`advertised_url = "127.0.0.1:9000"`
withTempFile(custom, func(path string) {
c := NewConfig()
c.SystemPath = ""
assert.Nil(t, c.Load([]string{"-config", path}), "")
assert.Equal(t, c.Peer.AdvertisedUrl, "http://127.0.0.1:8000", "")
})
}
// Ensures that an environment variable field is overridden by a command line argument.
func TestConfigCLIArgsOverrideEnvVar(t *testing.T) {
os.Setenv("ETCD_ADVERTISED_URL", "127.0.0.1:1000")
defer os.Setenv("ETCD_ADVERTISED_URL", "")
c := NewConfig()
c.SystemPath = ""
assert.Nil(t, c.Load([]string{"-c", "127.0.0.1:2000"}), "")
assert.Equal(t, c.AdvertisedUrl, "http://127.0.0.1:2000", "")
}
//--------------------------------------
// Helpers
//--------------------------------------
// Sets up the environment with a given environment variable set.
func withEnv(key, value string, f func(c *Config)) {
os.Setenv(key, value)
defer os.Setenv(key, "")
c := NewConfig()
f(c)
}
// Creates a temp file and calls a function with the context.
func withTempFile(content string, fn func(string)) {
f, _ := ioutil.TempFile("", "")
f.WriteString(content)
f.Close()
defer os.Remove(f.Name())
fn(f.Name())
}

19
server/info.go Normal file
View File

@ -0,0 +1,19 @@
package server
// Info describes the non-mutable state of the server upon initialization.
// These fields cannot be changed without deleting the server fields and
// reinitializing.
type Info struct {
Name string `json:"name"`
RaftURL string `json:"raftURL"`
EtcdURL string `json:"etcdURL"`
WebURL string `json:"webURL"`
RaftListenHost string `json:"raftListenHost"`
EtcdListenHost string `json:"etcdListenHost"`
RaftTLS TLSInfo `json:"raftTLS"`
EtcdTLS TLSInfo `json:"etcdTLS"`
}

View File

@ -14,15 +14,17 @@ func init() {
// The JoinCommand adds a node to the cluster.
type JoinCommand struct {
RaftVersion string `json:"raftVersion"`
MinVersion int `json:"minVersion"`
MaxVersion int `json:"maxVersion"`
Name string `json:"name"`
RaftURL string `json:"raftURL"`
EtcdURL string `json:"etcdURL"`
}
func NewJoinCommand(version, name, raftUrl, etcdUrl string) *JoinCommand {
func NewJoinCommand(minVersion int, maxVersion int, name, raftUrl, etcdUrl string) *JoinCommand {
return &JoinCommand{
RaftVersion: version,
MinVersion: minVersion,
MaxVersion: maxVersion,
Name: name,
RaftURL: raftUrl,
EtcdURL: etcdUrl,
@ -56,7 +58,7 @@ func (c *JoinCommand) Apply(server raft.Server) (interface{}, error) {
}
// Add to shared machine registry.
ps.registry.Register(c.Name, c.RaftVersion, c.RaftURL, c.EtcdURL, server.CommitIndex(), server.Term())
ps.registry.Register(c.Name, c.RaftURL, c.EtcdURL, server.CommitIndex(), server.Term())
// Add peer in raft
err := server.AddPeer(c.Name, "")

View File

@ -10,12 +10,14 @@ import (
"net"
"net/http"
"net/url"
"strconv"
"time"
etcdErr "github.com/coreos/etcd/error"
"github.com/coreos/etcd/log"
"github.com/coreos/etcd/store"
"github.com/coreos/go-raft"
"github.com/gorilla/mux"
)
type PeerServer struct {
@ -51,7 +53,7 @@ type snapshotConf struct {
writesThr uint64
}
func NewPeerServer(name string, path string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo, registry *Registry, store store.Store) *PeerServer {
func NewPeerServer(name string, path string, url string, listenHost string, tlsConf *TLSConfig, tlsInfo *TLSInfo, registry *Registry, store store.Store, snapCount int) *PeerServer {
s := &PeerServer{
name: name,
url: url,
@ -60,7 +62,7 @@ func NewPeerServer(name string, path string, url string, listenHost string, tlsC
tlsInfo: tlsInfo,
registry: registry,
store: store,
snapConf: &snapshotConf{time.Second * 3, 0, 2},
snapConf: &snapshotConf{time.Second * 3, 0, uint64(snapCount)},
followersStats: &raftFollowersStats{
Leader: name,
Followers: make(map[string]*raftFollowerStats),
@ -209,7 +211,7 @@ func (s *PeerServer) SetServer(server *Server) {
func (s *PeerServer) startAsLeader() {
// leader need to join self as a peer
for {
_, err := s.raftServer.Do(NewJoinCommand(PeerVersion, s.raftServer.Name(), s.url, s.server.URL()))
_, err := s.raftServer.Do(NewJoinCommand(store.MinVersion(), store.MaxVersion(), s.raftServer.Name(), s.url, s.server.URL()))
if err == nil {
break
}
@ -235,25 +237,27 @@ func (s *PeerServer) startAsFollower(cluster []string) {
func (s *PeerServer) startTransport(scheme string, tlsConf tls.Config) error {
log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.name, s.listenHost, s.url)
raftMux := http.NewServeMux()
router := mux.NewRouter()
s.httpServer = &http.Server{
Handler: raftMux,
Handler: router,
TLSConfig: &tlsConf,
Addr: s.listenHost,
}
// internal commands
raftMux.HandleFunc("/name", s.NameHttpHandler)
raftMux.HandleFunc("/version", s.RaftVersionHttpHandler)
raftMux.HandleFunc("/join", s.JoinHttpHandler)
raftMux.HandleFunc("/remove/", s.RemoveHttpHandler)
raftMux.HandleFunc("/vote", s.VoteHttpHandler)
raftMux.HandleFunc("/log", s.GetLogHttpHandler)
raftMux.HandleFunc("/log/append", s.AppendEntriesHttpHandler)
raftMux.HandleFunc("/snapshot", s.SnapshotHttpHandler)
raftMux.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
raftMux.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
router.HandleFunc("/name", s.NameHttpHandler)
router.HandleFunc("/version", s.VersionHttpHandler)
router.HandleFunc("/version/{version:[0-9]+}/check", s.VersionCheckHttpHandler)
router.HandleFunc("/upgrade", s.UpgradeHttpHandler)
router.HandleFunc("/join", s.JoinHttpHandler)
router.HandleFunc("/remove/{name:.+}", s.RemoveHttpHandler)
router.HandleFunc("/vote", s.VoteHttpHandler)
router.HandleFunc("/log", s.GetLogHttpHandler)
router.HandleFunc("/log/append", s.AppendEntriesHttpHandler)
router.HandleFunc("/snapshot", s.SnapshotHttpHandler)
router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
if scheme == "http" {
return s.listenAndServe()
@ -263,21 +267,46 @@ func (s *PeerServer) startTransport(scheme string, tlsConf tls.Config) error {
}
// getVersion fetches the raft version of a peer. This works for now but we
// will need to do something more sophisticated later when we allow mixed
// version clusters.
func getVersion(t *transporter, versionURL url.URL) (string, error) {
// getVersion fetches the peer version of a cluster.
func getVersion(t *transporter, versionURL url.URL) (int, error) {
resp, req, err := t.Get(versionURL.String())
if err != nil {
return "", err
return 0, err
}
defer resp.Body.Close()
t.CancelWhenTimeout(req)
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return 0, err
}
return string(body), nil
// Parse version number.
version, _ := strconv.Atoi(string(body))
return version, nil
}
// Upgradable checks whether all peers in a cluster support an upgrade to the next store version.
func (s *PeerServer) Upgradable() error {
nextVersion := s.store.Version() + 1
for _, peerURL := range s.registry.PeerURLs(s.raftServer.Leader(), s.name) {
u, err := url.Parse(peerURL)
if err != nil {
return fmt.Errorf("PeerServer: Cannot parse URL: '%s' (%s)", peerURL, err)
}
t, _ := s.raftServer.Transporter().(*transporter)
checkURL := (&url.URL{Host: u.Host, Scheme: s.tlsConf.Scheme, Path: fmt.Sprintf("/version/%d/check", nextVersion)}).String()
resp, _, err := t.Get(checkURL)
if err != nil {
return fmt.Errorf("PeerServer: Cannot check version compatibility: %s", u.Host)
}
if resp.StatusCode != 200 {
return fmt.Errorf("PeerServer: Version %d is not compatible with peer: %s", nextVersion, u.Host)
}
}
return nil
}
func (s *PeerServer) joinCluster(cluster []string) bool {
@ -315,14 +344,11 @@ func (s *PeerServer) joinByMachine(server raft.Server, machine string, scheme st
if err != nil {
return fmt.Errorf("Error during join version check: %v", err)
}
// TODO: versioning of the internal protocol. See:
// Documentation/internatl-protocol-versioning.md
if version != PeerVersion {
return fmt.Errorf("Unable to join: internal version mismatch, entire cluster must be running identical versions of etcd")
if version < store.MinVersion() || version > store.MaxVersion() {
return fmt.Errorf("Unable to join: cluster version is %d; version compatibility is %d - %d", version, store.MinVersion(), store.MaxVersion())
}
json.NewEncoder(&b).Encode(NewJoinCommand(PeerVersion, server.Name(), s.url, s.server.URL()))
json.NewEncoder(&b).Encode(NewJoinCommand(store.MinVersion(), store.MaxVersion(), server.Name(), s.url, s.server.URL()))
joinURL := url.URL{Host: machine, Scheme: scheme, Path: "/join"}
@ -347,7 +373,7 @@ func (s *PeerServer) joinByMachine(server raft.Server, machine string, scheme st
if resp.StatusCode == http.StatusTemporaryRedirect {
address := resp.Header.Get("Location")
log.Debugf("Send Join Request to %s", address)
json.NewEncoder(&b).Encode(NewJoinCommand(PeerVersion, server.Name(), s.url, s.server.URL()))
json.NewEncoder(&b).Encode(NewJoinCommand(store.MinVersion(), store.MaxVersion(), server.Name(), s.url, s.server.URL()))
resp, req, err = t.Post(address, &b)
} else if resp.StatusCode == http.StatusBadRequest {

View File

@ -3,10 +3,13 @@ package server
import (
"encoding/json"
"net/http"
"strconv"
etcdErr "github.com/coreos/etcd/error"
"github.com/coreos/etcd/log"
"github.com/coreos/etcd/store"
"github.com/coreos/go-raft"
"github.com/gorilla/mux"
)
// Get all the current logs
@ -133,9 +136,9 @@ func (ps *PeerServer) RemoveHttpHandler(w http.ResponseWriter, req *http.Request
return
}
nodeName := req.URL.Path[len("/remove/"):]
vars := mux.Vars(req)
command := &RemoveCommand{
Name: nodeName,
Name: vars["name"],
}
log.Debugf("[recv] Remove Request [%s]", command.Name)
@ -151,8 +154,40 @@ func (ps *PeerServer) NameHttpHandler(w http.ResponseWriter, req *http.Request)
}
// Response to the name request
func (ps *PeerServer) RaftVersionHttpHandler(w http.ResponseWriter, req *http.Request) {
func (ps *PeerServer) VersionHttpHandler(w http.ResponseWriter, req *http.Request) {
log.Debugf("[recv] Get %s/version/ ", ps.url)
w.WriteHeader(http.StatusOK)
w.Write([]byte(PeerVersion))
w.Write([]byte(strconv.Itoa(ps.store.Version())))
}
// Checks whether a given version is supported.
func (ps *PeerServer) VersionCheckHttpHandler(w http.ResponseWriter, req *http.Request) {
log.Debugf("[recv] Get %s%s ", ps.url, req.URL.Path)
vars := mux.Vars(req)
version, _ := strconv.Atoi(vars["version"])
if version >= store.MinVersion() && version <= store.MaxVersion() {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusForbidden)
}
}
// Upgrades the current store version to the next version.
func (ps *PeerServer) UpgradeHttpHandler(w http.ResponseWriter, req *http.Request) {
log.Debugf("[recv] Get %s/version", ps.url)
// Check if upgrade is possible for all nodes.
if err := ps.Upgradable(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Create an upgrade command from the current version.
c := ps.store.CommandFactory().CreateUpgradeCommand()
if err := ps.server.Dispatch(c, w, req); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}

View File

@ -38,13 +38,13 @@ func NewRegistry(s store.Store) *Registry {
}
// Adds a node to the registry.
func (r *Registry) Register(name string, peerVersion string, peerURL string, url string, commitIndex uint64, term uint64) error {
func (r *Registry) Register(name string, peerURL string, url string, commitIndex uint64, term uint64) error {
r.Lock()
defer r.Unlock()
// Write data to store.
key := path.Join(RegistryKey, name)
value := fmt.Sprintf("raft=%s&etcd=%s&raftVersion=%s", peerURL, url, peerVersion)
value := fmt.Sprintf("raft=%s&etcd=%s", peerURL, url)
_, err := r.store.Create(key, value, false, store.Permanent, commitIndex, term)
log.Debugf("Register: %s", name)
return err
@ -175,6 +175,5 @@ func (r *Registry) load(name string) {
r.nodes[name] = &node{
url: m["etcd"][0],
peerURL: m["raft"][0],
peerVersion: m["raftVersion"][0],
}
}

View File

@ -15,6 +15,7 @@ import (
"github.com/coreos/etcd/server/v1"
"github.com/coreos/etcd/server/v2"
"github.com/coreos/etcd/store"
_ "github.com/coreos/etcd/store/v2"
"github.com/coreos/go-raft"
"github.com/gorilla/mux"
)
@ -283,10 +284,10 @@ func (s *Server) Dispatch(c raft.Command, w http.ResponseWriter, req *http.Reque
}
// Sets a comma-delimited list of origins that are allowed.
func (s *Server) AllowOrigins(origins string) error {
func (s *Server) AllowOrigins(origins []string) error {
// Construct a lookup of all origins.
m := make(map[string]bool)
for _, v := range strings.Split(origins, ",") {
for _, v := range origins {
if v != "*" {
if _, err := url.Parse(v); err != nil {
return fmt.Errorf("Invalid CORS origin: %s", err)
@ -366,11 +367,7 @@ func (s *Server) SpeedTestHandler(w http.ResponseWriter, req *http.Request) erro
for i := 0; i < count; i++ {
go func() {
for j := 0; j < 10; j++ {
c := &store.SetCommand{
Key: "foo",
Value: "bar",
ExpireTime: time.Unix(0, 0),
}
c := s.Store().CommandFactory().CreateSetCommand("foo", "bar", time.Unix(0, 0))
s.peerServer.RaftServer().Do(c)
}
c <- true

View File

@ -4,6 +4,7 @@ import (
"crypto/tls"
)
// TLSConfig holds the TLS configuration.
type TLSConfig struct {
Scheme string
Server tls.Config

View File

@ -1,7 +1,76 @@
package server
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"io/ioutil"
)
// TLSInfo holds the SSL certificates paths.
type TLSInfo struct {
CertFile string `json:"CertFile"`
KeyFile string `json:"KeyFile"`
CAFile string `json:"CAFile"`
}
// Generates a TLS configuration from the given files.
func (info TLSInfo) Config() (TLSConfig, error) {
var t TLSConfig
t.Scheme = "http"
// If the user do not specify key file, cert file and CA file, the type will be HTTP
if info.KeyFile == "" && info.CertFile == "" && info.CAFile == "" {
return t, nil
}
// Both the key and cert must be present.
if info.KeyFile == "" || info.CertFile == "" {
return t, errors.New("KeyFile and CertFile must both be present")
}
tlsCert, err := tls.LoadX509KeyPair(info.CertFile, info.KeyFile)
if err != nil {
return t, err
}
t.Scheme = "https"
t.Server.ClientAuth, t.Server.ClientCAs, err = newCertPool(info.CAFile)
if err != nil {
return t, err
}
// The client should trust the RootCA that the Server uses since
// everyone is a peer in the network.
t.Client.Certificates = []tls.Certificate{tlsCert}
t.Client.RootCAs = t.Server.ClientCAs
return t, nil
}
// newCertPool creates x509 certPool and corresponding Auth Type.
// If the given CAfile is valid, add the cert into the pool and verify the clients'
// certs against the cert in the pool.
// If the given CAfile is empty, do not verify the clients' cert.
// If the given CAfile is not valid, fatal.
func newCertPool(CAFile string) (tls.ClientAuthType, *x509.CertPool, error) {
if CAFile == "" {
return tls.NoClientCert, nil, nil
}
pemByte, err := ioutil.ReadFile(CAFile)
if err != nil {
return 0, nil, err
}
block, pemByte := pem.Decode(pemByte)
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return 0, nil, err
}
certPool := x509.NewCertPool()
certPool.AddCert(cert)
return tls.RequireAndVerifyClientCert, certPool, nil
}

View File

@ -18,13 +18,15 @@ import (
// This should not exceed 3 * RTT
var dailTimeout = 3 * HeartbeatTimeout
// Timeout for setup internal raft http connection + receive response header
// This should not exceed 3 * RTT + RTT
var responseHeaderTimeout = 4 * HeartbeatTimeout
// Timeout for setup internal raft http connection + receive all post body
// The raft server will not send back response header until it received all the
// post body.
// This should not exceed dailTimeout + electionTimeout
var responseHeaderTimeout = 3*HeartbeatTimeout + ElectionTimeout
// Timeout for receiving the response body from the server
// This should not exceed election timeout
var tranTimeout = ElectionTimeout
// This should not exceed heartbeatTimeout
var tranTimeout = HeartbeatTimeout
// Transporter layer for communication between raft nodes
type transporter struct {
@ -221,7 +223,7 @@ func (t *transporter) Get(urlStr string) (*http.Response, *http.Request, error)
// Cancel the on fly HTTP transaction when timeout happens.
func (t *transporter) CancelWhenTimeout(req *http.Request) {
go func() {
time.Sleep(ElectionTimeout)
time.Sleep(tranTimeout)
t.transport.CancelRequest(req)
}()
}

View File

@ -1,5 +1,6 @@
package server
/*
import (
"crypto/tls"
"fmt"
@ -59,3 +60,4 @@ func TestTransporterTimeout(t *testing.T) {
}
}
*/

View File

@ -6,6 +6,7 @@ import (
"io"
"net/http"
"net/url"
"strings"
"github.com/coreos/etcd/log"
)
@ -31,3 +32,15 @@ func redirect(hostname string, w http.ResponseWriter, req *http.Request) {
log.Debugf("Redirect to %s", redirectURL.String())
http.Redirect(w, req, redirectURL.String(), http.StatusTemporaryRedirect)
}
// trimsplit slices s into all substrings separated by sep and returns a
// slice of the substrings between the separator with all leading and trailing
// white space removed, as defined by Unicode.
func trimsplit(s, sep string) []string {
raw := strings.Split(s, ",")
trimmed := make([]string, 0)
for _, r := range raw {
trimmed = append(trimmed, strings.TrimSpace(r))
}
return trimmed
}

View File

@ -1,7 +1,6 @@
package v1
import (
"github.com/coreos/etcd/store"
"github.com/gorilla/mux"
"net/http"
)
@ -10,6 +9,6 @@ import (
func DeleteKeyHandler(w http.ResponseWriter, req *http.Request, s Server) error {
vars := mux.Vars(req)
key := "/" + vars["key"]
c := &store.DeleteCommand{Key: key}
c := s.Store().CommandFactory().CreateDeleteCommand(key, false)
return s.Dispatch(c, w, req)
}

View File

@ -31,27 +31,16 @@ func SetKeyHandler(w http.ResponseWriter, req *http.Request, s Server) error {
// If the "prevValue" is specified then test-and-set. Otherwise create a new key.
var c raft.Command
if prevValueArr, ok := req.Form["prevValue"]; ok {
if len(prevValueArr[0]) > 0 { // test against previous value
c = &store.CompareAndSwapCommand{
Key: key,
Value: value,
PrevValue: prevValueArr[0],
ExpireTime: expireTime,
}
if len(prevValueArr[0]) > 0 {
// test against previous value
c = s.Store().CommandFactory().CreateCompareAndSwapCommand(key, value, prevValueArr[0], 0, expireTime)
} else {
c = &store.CreateCommand{ // test against existence
Key: key,
Value: value,
ExpireTime: expireTime,
}
// test against existence
c = s.Store().CommandFactory().CreateCreateCommand(key, value, expireTime, false)
}
} else {
c = &store.SetCommand{
Key: key,
Value: value,
ExpireTime: expireTime,
}
c = s.Store().CommandFactory().CreateSetCommand(key, value, expireTime)
}
return s.Dispatch(c, w, req)

View File

@ -3,18 +3,14 @@ package v2
import (
"net/http"
"github.com/coreos/etcd/store"
"github.com/gorilla/mux"
)
func DeleteHandler(w http.ResponseWriter, req *http.Request, s Server) error {
vars := mux.Vars(req)
key := "/" + vars["key"]
recursive := (req.FormValue("recursive") == "true")
c := &store.DeleteCommand{
Key: key,
Recursive: (req.FormValue("recursive") == "true"),
}
c := s.Store().CommandFactory().CreateDeleteCommand(key, recursive)
return s.Dispatch(c, w, req)
}

View File

@ -18,12 +18,6 @@ func PostHandler(w http.ResponseWriter, req *http.Request, s Server) error {
return etcdErr.NewError(etcdErr.EcodeTTLNaN, "Create", store.UndefIndex, store.UndefTerm)
}
c := &store.CreateCommand{
Key: key,
Value: value,
ExpireTime: expireTime,
Unique: true,
}
c := s.Store().CommandFactory().CreateCreateCommand(key, value, expireTime, true)
return s.Dispatch(c, w, req)
}

View File

@ -71,31 +71,17 @@ func PutHandler(w http.ResponseWriter, req *http.Request, s Server) error {
}
}
c = &store.CompareAndSwapCommand{
Key: key,
Value: value,
PrevValue: prevValue,
PrevIndex: prevIndex,
}
c = s.Store().CommandFactory().CreateCompareAndSwapCommand(key, value, prevValue, prevIndex, expireTime)
return s.Dispatch(c, w, req)
}
func SetHandler(w http.ResponseWriter, req *http.Request, s Server, key, value string, expireTime time.Time) error {
c := &store.SetCommand{
Key: key,
Value: value,
ExpireTime: expireTime,
}
c := s.Store().CommandFactory().CreateSetCommand(key, value, expireTime)
return s.Dispatch(c, w, req)
}
func CreateHandler(w http.ResponseWriter, req *http.Request, s Server, key, value string, expireTime time.Time) error {
c := &store.CreateCommand{
Key: key,
Value: value,
ExpireTime: expireTime,
}
c := s.Store().CommandFactory().CreateCreateCommand(key, value, expireTime, false)
return s.Dispatch(c, w, req)
}
@ -105,10 +91,6 @@ func UpdateHandler(w http.ResponseWriter, req *http.Request, s Server, key, valu
return etcdErr.NewError(etcdErr.EcodeValueOrTTLRequired, "Update", store.UndefIndex, store.UndefTerm)
}
c := &store.UpdateCommand{
Key: key,
Value: value,
ExpireTime: expireTime,
}
c := s.Store().CommandFactory().CreateUpdateCommand(key, value, expireTime)
return s.Dispatch(c, w, req)
}

View File

@ -0,0 +1,29 @@
package v2
import (
"fmt"
"net/url"
"testing"
"github.com/coreos/etcd/server"
"github.com/coreos/etcd/tests"
"github.com/stretchr/testify/assert"
)
// Ensures that a key is deleted.
//
// $ curl -X PUT localhost:4001/v2/keys/foo/bar -d value=XXX
// $ curl -X DELETE localhost:4001/v2/keys/foo/bar
//
func TestV2DeleteKey(t *testing.T) {
tests.RunServer(func(s *server.Server) {
v := url.Values{}
v.Set("value", "XXX")
resp, err := tests.PutForm(fmt.Sprintf("http://%s%s", s.URL(), "/v2/keys/foo/bar"), v)
tests.ReadBody(resp)
resp, err = tests.DeleteForm(fmt.Sprintf("http://%s%s", s.URL(), "/v2/keys/foo/bar"), url.Values{})
body := tests.ReadBody(resp)
assert.Nil(t, err, "")
assert.Equal(t, string(body), `{"action":"delete","key":"/foo/bar","prevValue":"XXX","index":4,"term":0}`, "")
})
}

View File

@ -42,6 +42,7 @@ func TestV2GetKeyRecursively(t *testing.T) {
tests.RunServer(func(s *server.Server) {
v := url.Values{}
v.Set("value", "XXX")
v.Set("ttl", "10")
resp, _ := tests.PutForm(fmt.Sprintf("http://%s%s", s.URL(), "/v2/keys/foo/x"), v)
tests.ReadBody(resp)
@ -60,6 +61,7 @@ func TestV2GetKeyRecursively(t *testing.T) {
kv0 := body["kvs"].([]interface{})[0].(map[string]interface{})
assert.Equal(t, kv0["key"], "/foo/x", "")
assert.Equal(t, kv0["value"], "XXX", "")
assert.Equal(t, kv0["ttl"], 10, "")
kv1 := body["kvs"].([]interface{})[1].(map[string]interface{})
assert.Equal(t, kv1["key"], "/foo/y", "")
@ -105,7 +107,6 @@ func TestV2WatchKey(t *testing.T) {
})
}
// Ensures that a watcher can wait for a value to be set after a given index.
//
// $ curl localhost:4001/v2/keys/foo/bar?wait=true&waitIndex=4
@ -115,9 +116,11 @@ func TestV2WatchKey(t *testing.T) {
func TestV2WatchKeyWithIndex(t *testing.T) {
tests.RunServer(func(s *server.Server) {
var body map[string]interface{}
c := make(chan bool)
go func() {
resp, _ := tests.Get(fmt.Sprintf("http://%s%s", s.URL(), "/v2/keys/foo/bar?wait=true&waitIndex=5"))
body = tests.ReadBodyJSON(resp)
c <- true
}()
// Make sure response didn't fire early.
@ -141,6 +144,14 @@ func TestV2WatchKeyWithIndex(t *testing.T) {
// A response should follow from the GET above.
time.Sleep(1 * time.Millisecond)
select {
case <-c:
default:
t.Fatal("cannot get watch result")
}
assert.NotNil(t, body, "")
assert.Equal(t, body["action"], "set", "")
assert.Equal(t, body["key"], "/foo/bar", "")
@ -149,4 +160,3 @@ func TestV2WatchKeyWithIndex(t *testing.T) {
assert.Equal(t, body["term"], 0, "")
})
}

View File

@ -1,8 +1,3 @@
package server
const Version = "v2"
// TODO: The release version (generated from the git tag) will be the raft
// protocol version for now. When things settle down we will fix it like the
// client API above.
const PeerVersion = ReleaseVersion

58
store/command_factory.go Normal file
View File

@ -0,0 +1,58 @@
package store
import (
"fmt"
"time"
"github.com/coreos/go-raft"
)
// A lookup of factories by version.
var factories = make(map[int]CommandFactory)
var minVersion, maxVersion int
// The CommandFactory provides a way to create different types of commands
// depending on the current version of the store.
type CommandFactory interface {
Version() int
CreateUpgradeCommand() raft.Command
CreateSetCommand(key string, value string, expireTime time.Time) raft.Command
CreateCreateCommand(key string, value string, expireTime time.Time, unique bool) raft.Command
CreateUpdateCommand(key string, value string, expireTime time.Time) raft.Command
CreateDeleteCommand(key string, recursive bool) raft.Command
CreateCompareAndSwapCommand(key string, value string, prevValue string, prevIndex uint64, expireTime time.Time) raft.Command
}
// RegisterCommandFactory adds a command factory to the global registry.
func RegisterCommandFactory(factory CommandFactory) {
version := factory.Version()
if GetCommandFactory(version) != nil {
panic(fmt.Sprintf("Command factory already registered for version: %d", factory.Version()))
}
factories[version] = factory
// Update compatibility versions.
if minVersion == 0 || version > minVersion {
minVersion = version
}
if maxVersion == 0 || version > maxVersion {
maxVersion = version
}
}
// GetCommandFactory retrieves a command factory for a given command version.
func GetCommandFactory(version int) CommandFactory {
return factories[version]
}
// MinVersion returns the minimum compatible store version.
func MinVersion() int {
return minVersion
}
// MaxVersion returns the maximum compatible store version.
func MaxVersion() int {
return maxVersion
}

View File

@ -375,8 +375,8 @@ func (n *Node) UpdateTTL(expireTime time.Time) {
}
}
n.ExpireTime = expireTime
if expireTime.Sub(Permanent) != 0 {
n.ExpireTime = expireTime
n.Expire()
}
}

View File

@ -13,7 +13,12 @@ import (
etcdErr "github.com/coreos/etcd/error"
)
// The default version to set when the store is first initialized.
const defaultVersion = 2
type Store interface {
Version() int
CommandFactory() CommandFactory
Get(nodePath string, recursive, sorted bool, index uint64, term uint64) (*Event, error)
Set(nodePath string, value string, expireTime time.Time, index uint64, term uint64) (*Event, error)
Update(nodePath string, newValue string, expireTime time.Time, index uint64, term uint64) (*Event, error)
@ -30,12 +35,13 @@ type Store interface {
}
type store struct {
Root *Node
WatcherHub *watcherHub
Index uint64
Term uint64
Stats *Stats
worldLock sync.RWMutex // stop the world lock
Root *Node
WatcherHub *watcherHub
Index uint64
Term uint64
Stats *Stats
CurrentVersion int
worldLock sync.RWMutex // stop the world lock
}
func New() Store {
@ -44,13 +50,23 @@ func New() Store {
func newStore() *store {
s := new(store)
s.CurrentVersion = defaultVersion
s.Root = newDir(s, "/", UndefIndex, UndefTerm, nil, "", Permanent)
s.Stats = newStats()
s.WatcherHub = newWatchHub(1000)
return s
}
// Version retrieves current version of the store.
func (s *store) Version() int {
return s.CurrentVersion
}
// CommandFactory retrieves the command factory for the current version of the store.
func (s *store) CommandFactory() CommandFactory {
return GetCommandFactory(s.Version())
}
// Get function returns a get event.
// If recursive is true, it will return all the content under the node path.
// If sorted is true, it will sort the content by keys.
@ -450,6 +466,7 @@ func (s *store) Save() ([]byte, error) {
clonedStore.Root = s.Root.Clone()
clonedStore.WatcherHub = s.WatcherHub.clone()
clonedStore.Stats = s.Stats.clone()
clonedStore.CurrentVersion = s.CurrentVersion
s.worldLock.Unlock()

View File

@ -0,0 +1,73 @@
package v2
import (
"time"
"github.com/coreos/etcd/store"
"github.com/coreos/go-raft"
)
func init() {
store.RegisterCommandFactory(&CommandFactory{})
}
// CommandFactory provides a pluggable way to create version 2 commands.
type CommandFactory struct {
}
// Version returns the version of this factory.
func (f *CommandFactory) Version() int {
return 2
}
// CreateUpgradeCommand is a no-op since version 2 is the first version to support store versioning.
func (f *CommandFactory) CreateUpgradeCommand() raft.Command {
return &raft.NOPCommand{}
}
// CreateSetCommand creates a version 2 command to set a key to a given value in the store.
func (f *CommandFactory) CreateSetCommand(key string, value string, expireTime time.Time) raft.Command {
return &SetCommand{
Key: key,
Value: value,
ExpireTime: expireTime,
}
}
// CreateCreateCommand creates a version 2 command to create a new key in the store.
func (f *CommandFactory) CreateCreateCommand(key string, value string, expireTime time.Time, unique bool) raft.Command {
return &CreateCommand{
Key: key,
Value: value,
ExpireTime: expireTime,
Unique: unique,
}
}
// CreateUpdateCommand creates a version 2 command to update a key to a given value in the store.
func (f *CommandFactory) CreateUpdateCommand(key string, value string, expireTime time.Time) raft.Command {
return &UpdateCommand{
Key: key,
Value: value,
ExpireTime: expireTime,
}
}
// CreateDeleteCommand creates a version 2 command to delete a key from the store.
func (f *CommandFactory) CreateDeleteCommand(key string, recursive bool) raft.Command {
return &DeleteCommand{
Key: key,
Recursive: recursive,
}
}
// CreateCompareAndSwapCommand creates a version 2 command to conditionally set a key in the store.
func (f *CommandFactory) CreateCompareAndSwapCommand(key string, value string, prevValue string, prevIndex uint64, expireTime time.Time) raft.Command {
return &CompareAndSwapCommand{
Key: key,
Value: value,
PrevValue: prevValue,
PrevIndex: prevIndex,
ExpireTime: expireTime,
}
}

View File

@ -1,9 +1,10 @@
package store
package v2
import (
"time"
"github.com/coreos/etcd/log"
"github.com/coreos/etcd/store"
"github.com/coreos/go-raft"
)
@ -27,7 +28,7 @@ func (c *CompareAndSwapCommand) CommandName() string {
// Set the key-value pair if the current value of the key equals to the given prevValue
func (c *CompareAndSwapCommand) Apply(server raft.Server) (interface{}, error) {
s, _ := server.StateMachine().(Store)
s, _ := server.StateMachine().(store.Store)
e, err := s.CompareAndSwap(c.Key, c.PrevValue, c.PrevIndex,
c.Value, c.ExpireTime, server.CommitIndex(), server.Term())

View File

@ -1,9 +1,11 @@
package store
package v2
import (
"github.com/coreos/etcd/log"
"github.com/coreos/go-raft"
"time"
"github.com/coreos/etcd/log"
"github.com/coreos/etcd/store"
"github.com/coreos/go-raft"
)
func init() {
@ -25,7 +27,7 @@ func (c *CreateCommand) CommandName() string {
// Create node
func (c *CreateCommand) Apply(server raft.Server) (interface{}, error) {
s, _ := server.StateMachine().(Store)
s, _ := server.StateMachine().(store.Store)
e, err := s.Create(c.Key, c.Value, c.Unique, c.ExpireTime, server.CommitIndex(), server.Term())

View File

@ -1,6 +1,7 @@
package store
package v2
import (
"github.com/coreos/etcd/store"
"github.com/coreos/etcd/log"
"github.com/coreos/go-raft"
)
@ -22,7 +23,7 @@ func (c *DeleteCommand) CommandName() string {
// Delete the key
func (c *DeleteCommand) Apply(server raft.Server) (interface{}, error) {
s, _ := server.StateMachine().(Store)
s, _ := server.StateMachine().(store.Store)
e, err := s.Delete(c.Key, c.Recursive, server.CommitIndex(), server.Term())

View File

@ -1,9 +1,11 @@
package store
package v2
import (
"github.com/coreos/etcd/log"
"github.com/coreos/go-raft"
"time"
"github.com/coreos/etcd/log"
"github.com/coreos/etcd/store"
"github.com/coreos/go-raft"
)
func init() {
@ -24,7 +26,7 @@ func (c *SetCommand) CommandName() string {
// Create node
func (c *SetCommand) Apply(server raft.Server) (interface{}, error) {
s, _ := server.StateMachine().(Store)
s, _ := server.StateMachine().(store.Store)
// create a new node or replace the old node.
e, err := s.Set(c.Key, c.Value, c.ExpireTime, server.CommitIndex(), server.Term())

View File

@ -1,7 +1,8 @@
package store
package v2
import (
"github.com/coreos/etcd/log"
"github.com/coreos/etcd/store"
"github.com/coreos/go-raft"
"time"
)
@ -24,7 +25,7 @@ func (c *UpdateCommand) CommandName() string {
// Create node
func (c *UpdateCommand) Apply(server raft.Server) (interface{}, error) {
s, _ := server.StateMachine().(Store)
s, _ := server.StateMachine().(store.Store)
e, err := s.Update(c.Key, c.Value, c.ExpireTime, server.CommitIndex(), server.Term())

View File

@ -8,6 +8,9 @@ set -e
export GOPATH="${PWD}"
# Unit tests
go test -i ./server
go test -v ./server
go test -i ./server/v2/tests
go test -v ./server/v2/tests

View File

@ -0,0 +1,46 @@
package test
import (
"net/http"
"os"
"testing"
"time"
)
// Ensure that a node can reply to a version check appropriately.
func TestVersionCheck(t *testing.T) {
procAttr := new(os.ProcAttr)
procAttr.Files = []*os.File{nil, os.Stdout, os.Stderr}
args := []string{"etcd", "-n=node1", "-f", "-d=/tmp/version_check"}
process, err := os.StartProcess(EtcdBinPath, args, procAttr)
if err != nil {
t.Fatal("start process failed:" + err.Error())
return
}
defer process.Kill()
time.Sleep(time.Second)
// Check a version too small.
resp, _ := http.Get("http://localhost:7001/version/1/check")
resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatal("Invalid version check: ", resp.StatusCode)
}
// Check a version too large.
resp, _ = http.Get("http://localhost:7001/version/3/check")
resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Fatal("Invalid version check: ", resp.StatusCode)
}
// Check a version that's just right.
resp, _ = http.Get("http://localhost:7001/version/2/check")
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatal("Invalid version check: ", resp.StatusCode)
}
}

View File

@ -55,6 +55,14 @@ func PutForm(url string, data url.Values) (*http.Response, error) {
return Put(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
func Delete(url string, bodyType string, body io.Reader) (*http.Response, error) {
return send("DELETE", url, bodyType, body)
}
func DeleteForm(url string, data url.Values) (*http.Response, error) {
return Delete(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
func send(method string, url string, bodyType string, body io.Reader) (*http.Response, error) {
c := NewHTTPClient()

View File

@ -5,14 +5,15 @@ import (
"os"
"time"
"github.com/coreos/etcd/store"
"github.com/coreos/etcd/server"
"github.com/coreos/etcd/store"
)
const (
testName = "ETCDTEST"
testName = "ETCDTEST"
testClientURL = "localhost:4401"
testRaftURL = "localhost:7701"
testRaftURL = "localhost:7701"
testSnapCount = 10000
)
// Starts a server in a temporary directory.
@ -22,8 +23,8 @@ func RunServer(f func(*server.Server)) {
store := store.New()
registry := server.NewRegistry(store)
ps := server.NewPeerServer(testName, path, testRaftURL, testRaftURL, &server.TLSConfig{Scheme:"http"}, &server.TLSInfo{}, registry, store)
s := server.New(testName, testClientURL, testClientURL, &server.TLSConfig{Scheme:"http"}, &server.TLSInfo{}, ps, registry, store)
ps := server.NewPeerServer(testName, path, testRaftURL, testRaftURL, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, registry, store, testSnapCount)
s := server.New(testName, testClientURL, testClientURL, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, ps, registry, store)
ps.SetServer(s)
// Start up peer server.
@ -32,17 +33,17 @@ func RunServer(f func(*server.Server)) {
c <- true
ps.ListenAndServe(false, []string{})
}()
<- c
<-c
// Start up etcd server.
go func() {
c <- true
s.ListenAndServe()
}()
<- c
<-c
// Wait to make sure servers have started.
time.Sleep(5 * time.Millisecond)
time.Sleep(50 * time.Millisecond)
// Execute the function passed in.
f(s)

1
third_party/deps vendored
View File

@ -1,4 +1,5 @@
packages="
github.com/BurntSushi/toml
github.com/coreos/go-raft
github.com/coreos/go-etcd
github.com/coreos/go-log/log

View File

@ -0,0 +1,4 @@
TAGS
tags
.*.swp
tomlcheck/tomlcheck

View File

@ -0,0 +1,3 @@
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)

View File

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@ -0,0 +1,14 @@
install:
go install
fmt:
gofmt -w *.go */*.go
colcheck *.go */*.go
tags:
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
push:
git push origin master
git push github master

View File

@ -0,0 +1,163 @@
# TOML parser for Go with reflection
TOML stands for Tom's Obvious, Minimal Language.
Spec: https://github.com/mojombo/toml
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)
Documentation: http://godoc.org/github.com/BurntSushi/toml
Installation:
```bash
go get github.com/BurntSushi/toml
```
Try the toml validator:
```bash
go get github.com/BurntSushi/toml/tomlv
tomlv some-toml-file.toml
```
## Testing
This package passes all tests in
[toml-test](https://github.com/BurntSushi/toml-test).
## Examples
This package works similarly to how the Go standard library handles `XML`
and `JSON`. Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys
and values:
```toml
Age = 25
Cats = [ "Cauchy", "Plato" ]
Pi = 3.14
Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z
```
Which could be defined in Go as:
```go
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
}
```
And then decoded with:
```go
var conf Config
if _, err := toml.Decode(tomlData, &conf); err != nil {
// handle error
}
```
You can also use struct tags if your struct field name doesn't map to a TOML
key value directly:
```toml
some_key_NAME = "wat"
```
```go
type TOML struct {
ObscureKey string `toml:"some_key_NAME"`
}
```
## More complex usage
Here's an example of how to load the example from the official spec page:
```toml
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]
```
And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
```
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_examples/example.{go,toml}`.

View File

@ -0,0 +1,59 @@
package main
import (
"fmt"
"time"
"github.com/BurntSushi/toml"
)
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
func main() {
var config tomlConfig
if _, err := toml.DecodeFile("example.toml", &config); err != nil {
fmt.Println(err)
return
}
fmt.Printf("Title: %s\n", config.Title)
fmt.Printf("Owner: %s (%s, %s), Born: %s\n",
config.Owner.Name, config.Owner.Org, config.Owner.Bio, config.Owner.DOB)
fmt.Printf("Database: %s %v (Max conn. %d), Enabled? %v\n",
config.DB.Server, config.DB.Ports, config.DB.ConnMax, config.DB.Enabled)
for serverName, server := range config.Servers {
fmt.Printf("Server: %s (%s, %s)\n", serverName, server.IP, server.DC)
}
fmt.Printf("Client data: %v\n", config.Clients.Data)
fmt.Printf("Client hosts: %v\n", config.Clients.Hosts)
}

View File

@ -0,0 +1,35 @@
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]

View File

@ -0,0 +1,22 @@
# Test file for TOML
# Only this one tries to emulate a TOML file written by a user of the kind of parser writers probably hate
# This part you'll really hate
[the]
test_string = "You'll hate me after this - #" # " Annoying, isn't it?
[the.hard]
test_array = [ "] ", " # "] # ] There you go, parse this!
test_array2 = [ "Test #11 ]proved that", "Experiment #9 was a success" ]
# You didn't think it'd as easy as chucking out the last #, did you?
another_test_string = " Same thing, but with a string #"
harder_test_string = " And when \"'s are in the string, along with # \"" # "and comments are there too"
# Things will get harder
[the.hard.bit#]
what? = "You don't think some user won't do that?"
multi_line_array = [
"]",
# ] Oh yes I did
]

View File

@ -0,0 +1,4 @@
# [x] you
# [x.y] don't
# [x.y.z] need these
[x.y.z.w] # for this to work

View File

@ -0,0 +1,6 @@
# DO NOT WANT
[fruit]
type = "apple"
[fruit.type]
apple = "yes"

View File

@ -0,0 +1,35 @@
# This is an INVALID TOML document. Boom.
# Can you spot the error without help?
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T7:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]

View File

@ -0,0 +1,5 @@
Age = 25
Cats = [ "Cauchy", "Plato" ]
Pi = 3.14
Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z

View File

@ -0,0 +1 @@
some_key_NAME = "wat"

View File

@ -0,0 +1,429 @@
package toml
import (
"fmt"
"io"
"io/ioutil"
"reflect"
"strings"
"time"
)
var e = fmt.Errorf
// Primitive is a TOML value that hasn't been decoded into a Go value.
// When using the various `Decode*` functions, the type `Primitive` may
// be given to any value, and its decoding will be delayed.
//
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
//
// The underlying representation of a `Primitive` value is subject to change.
// Do not rely on it.
//
// N.B. Primitive values are still parsed, so using them will only avoid
// the overhead of reflection. They can be useful when you don't know the
// exact type of TOML data until run time.
type Primitive interface{}
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including `PrimitiveDecode`. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
return unify(primValue, rvalue(v))
}
// Decode will decode the contents of `data` in TOML format into a pointer
// `v`.
//
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
// used interchangeably.)
//
// TOML datetimes correspond to Go `time.Time` values.
//
// All other TOML types (float, string, int, bool and array) correspond
// to the obvious Go types.
//
// TOML keys can map to either keys in a Go map or field names in a Go
// struct. The special `toml` struct tag may be used to map TOML keys to
// struct fields that don't match the key name exactly. (See the example.)
// A case insensitive match to struct names will be tried if an exact match
// can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there
// may exist TOML values that cannot be placed into your representation, and
// there may be parts of your representation that do not correspond to
// TOML values.
//
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
p, err := parse(data)
if err != nil {
return MetaData{}, err
}
return MetaData{p.mapping, p.types, p.ordered}, unify(p.mapping, rvalue(v))
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at `fpath` and decode it for you.
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadFile(fpath)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// DecodeReader is just like Decode, except it will consume all bytes
// from the reader and decode it for you.
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
return unifyAnything(data, rv)
}
// Special case. Go's `time.Time` is a struct, which we don't want
// to confuse with a user struct.
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
return unifyDatetime(data, rv)
}
k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 {
return unifyInt(data, rv)
}
switch k {
case reflect.Ptr:
elem := reflect.New(rv.Type().Elem())
err := unify(data, reflect.Indirect(elem))
if err != nil {
return err
}
rv.Set(elem)
return nil
case reflect.Struct:
return unifyStruct(data, rv)
case reflect.Map:
return unifyMap(data, rv)
case reflect.Slice:
return unifySlice(data, rv)
case reflect.String:
return unifyString(data, rv)
case reflect.Bool:
return unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("Unsupported type '%s'.", rv.Kind())
}
return unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
return unifyFloat64(data, rv)
}
return e("Unsupported type '%s'.", rv.Kind())
}
func unifyStruct(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return mismatch(rv, "map", mapping)
}
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
// A little tricky. We want to use the special `toml` name in the
// struct tag if it exists. In particular, we need to make sure that
// this struct field is in the current map before trying to unify it.
sft := rt.Field(i)
kname := sft.Tag.Get("toml")
if len(kname) == 0 {
kname = sft.Name
}
if datum, ok := insensitiveGet(tmap, kname); ok {
sf := indirect(rv.Field(i))
// Don't try to mess with unexported types and other such things.
if sf.CanSet() {
if err := unify(datum, sf); err != nil {
return e("Type mismatch for '%s.%s': %s",
rt.String(), sft.Name, err)
}
} else if len(sft.Tag.Get("toml")) > 0 {
// Bad user! No soup for you!
return e("Field '%s.%s' is unexported, and therefore cannot "+
"be loaded with reflection.", rt.String(), sft.Name)
}
}
}
return nil
}
func unifyMap(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
for k, v := range tmap {
rvkey := indirect(reflect.New(rv.Type().Key()))
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := unify(v, rvval); err != nil {
return err
}
rvkey.SetString(k)
rv.SetMapIndex(rvkey, rvval)
}
return nil
}
func unifySlice(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
return badtype("slice", data)
}
sliceLen := datav.Len()
if rv.IsNil() {
rv.Set(reflect.MakeSlice(rv.Type(), sliceLen, sliceLen))
}
for i := 0; i < sliceLen; i++ {
v := datav.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := unify(v, sliceval); err != nil {
return err
}
}
return nil
}
func unifyDatetime(data interface{}, rv reflect.Value) error {
if _, ok := data.(time.Time); ok {
rv.Set(reflect.ValueOf(data))
return nil
}
return badtype("time.Time", data)
}
func unifyString(data interface{}, rv reflect.Value) error {
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
}
return badtype("string", data)
}
func unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
fallthrough
case reflect.Float64:
rv.SetFloat(num)
default:
panic("bug")
}
return nil
}
return badtype("float", data)
}
func unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
switch rv.Kind() {
case reflect.Int:
fallthrough
case reflect.Int8:
fallthrough
case reflect.Int16:
fallthrough
case reflect.Int32:
fallthrough
case reflect.Int64:
rv.SetInt(int64(num))
case reflect.Uint:
fallthrough
case reflect.Uint8:
fallthrough
case reflect.Uint16:
fallthrough
case reflect.Uint32:
fallthrough
case reflect.Uint64:
rv.SetUint(uint64(num))
default:
panic("bug")
}
return nil
}
return badtype("integer", data)
}
func unifyBool(data interface{}, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
}
return badtype("integer", data)
}
func unifyAnything(data interface{}, rv reflect.Value) error {
// too awesome to fail
rv.Set(reflect.ValueOf(data))
return nil
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}
// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
return v
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return indirect(reflect.Indirect(v))
}
func tstring(rv reflect.Value) string {
return rv.Type().String()
}
func badtype(expected string, data interface{}) error {
return e("Expected %s but found '%T'.", expected, data)
}
func mismatch(user reflect.Value, expected string, data interface{}) error {
return e("Type mismatch for %s. Expected %s but found '%T'.",
tstring(user), expected, data)
}
func insensitiveGet(
tmap map[string]interface{}, kname string) (interface{}, bool) {
if datum, ok := tmap[kname]; ok {
return datum, true
}
for k, v := range tmap {
if strings.EqualFold(kname, k) {
return v, true
}
}
return nil, false
}
// MetaData allows access to meta information about TOML data that may not
// be inferrable via reflection. In particular, whether a key has been defined
// and the TOML type of a key.
//
// (XXX: If TOML gets NULL values, that information will be added here too.)
type MetaData struct {
mapping map[string]interface{}
types map[string]tomlType
keys []Key
}
// IsDefined returns true if the key given exists in the TOML data. The key
// should be specified hierarchially. e.g.,
//
// // access the TOML key 'a.b.c'
// IsDefined("a", "b", "c")
//
// IsDefined will return false if an empty key given. Keys are case sensitive.
func (md MetaData) IsDefined(key ...string) bool {
var hashOrVal interface{}
var hash map[string]interface{}
var ok bool
if len(key) == 0 {
return false
}
hashOrVal = md.mapping
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
return false
}
if hashOrVal, ok = hash[k]; !ok {
return false
}
}
return true
}
// Type returns a string representation of the type of the key specified.
//
// Type will return the empty string if given an empty key or a key that
// does not exist. Keys are case sensitive.
func (md MetaData) Type(key ...string) string {
fullkey := strings.Join(key, ".")
if typ, ok := md.types[fullkey]; ok {
return typ.typeString()
}
return ""
}
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
// to get values of this type.
type Key []string
func (k Key) String() string {
return strings.Join(k, ".")
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k))
copy(newKey, k)
return append(newKey, piece)
}
// Keys returns a slice of every key in the TOML data, including key groups.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific.
//
// The list will have the same order as the keys appeared in the TOML data.
//
// All keys returned are non-empty.
func (md MetaData) Keys() []Key {
return md.keys
}
func allKeys(m map[string]interface{}, context Key) []Key {
keys := make([]Key, 0, len(m))
for k, v := range m {
keys = append(keys, context.add(k))
if t, ok := v.(map[string]interface{}); ok {
keys = append(keys, allKeys(t, context.add(k))...)
}
}
return keys
}

View File

@ -0,0 +1,343 @@
package toml
import (
"fmt"
"log"
"reflect"
"testing"
"time"
)
func init() {
log.SetFlags(0)
}
var testSimple = `
age = 250
andrew = "gallant"
kait = "brady"
now = 1987-07-05T05:45:00Z
yesOrNo = true
pi = 3.14
colors = [
["red", "green", "blue"],
["cyan", "magenta", "yellow", "black"],
]
[Annoying.Cats]
plato = "smelly"
cauchy = "stupido"
`
type kitties struct {
Plato string
Cauchy string
}
type simple struct {
Age int
Colors [][]string
Pi float64
YesOrNo bool
Now time.Time
Andrew string
Kait string
Annoying map[string]kitties
}
func TestDecode(t *testing.T) {
var val simple
md, err := Decode(testSimple, &val)
if err != nil {
t.Fatal(err)
}
testf("Is 'Annoying.Cats.plato' defined? %v\n",
md.IsDefined("Annoying", "Cats", "plato"))
testf("Is 'Cats.Stinky' defined? %v\n", md.IsDefined("Cats", "Stinky"))
testf("Type of 'colors'? %s\n\n", md.Type("colors"))
testf("%v\n", val)
}
var tomlTableArrays = `
[[albums]]
name = "Born to Run"
[[albums.songs]]
name = "Jungleland"
[[albums.songs]]
name = "Meeting Across the River"
[[albums]]
name = "Born in the USA"
[[albums.songs]]
name = "Glory Days"
[[albums.songs]]
name = "Dancing in the Dark"
`
type Music struct {
Albums []Album
}
type Album struct {
Name string
Songs []Song
}
type Song struct {
Name string
}
func TestTableArrays(t *testing.T) {
expected := Music{[]Album{
{"Born to Run", []Song{{"Jungleland"}, {"Meeting Across the River"}}},
{"Born in the USA", []Song{{"Glory Days"}, {"Dancing in the Dark"}}},
}}
var got Music
if _, err := Decode(tomlTableArrays, &got); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, got) {
t.Fatalf("\n%#v\n!=\n%#v\n", expected, got)
}
}
// Case insensitive matching tests.
// A bit more comprehensive than needed given the current implementation,
// but implementations change.
// Probably still missing demonstrations of some ugly corner cases regarding
// case insensitive matching and multiple fields.
var caseToml = `
tOpString = "string"
tOpInt = 1
tOpFloat = 1.1
tOpBool = true
tOpdate = 2006-01-02T15:04:05Z
tOparray = [ "array" ]
Match = "i should be in Match only"
MatcH = "i should be in MatcH only"
Field = "neat"
FielD = "messy"
once = "just once"
[nEst.eD]
nEstedString = "another string"
`
type Insensitive struct {
TopString string
TopInt int
TopFloat float64
TopBool bool
TopDate time.Time
TopArray []string
Match string
MatcH string
Field string
Once string
OncE string
Nest InsensitiveNest
}
type InsensitiveNest struct {
Ed InsensitiveEd
}
type InsensitiveEd struct {
NestedString string
}
func TestCase(t *testing.T) {
tme, err := time.Parse(time.RFC3339, time.RFC3339[:len(time.RFC3339)-5])
if err != nil {
panic(err)
}
expected := Insensitive{
TopString: "string",
TopInt: 1,
TopFloat: 1.1,
TopBool: true,
TopDate: tme,
TopArray: []string{"array"},
MatcH: "i should be in MatcH only",
Match: "i should be in Match only",
Field: "neat", // encoding/json would store "messy" here
Once: "just once",
OncE: "just once", // wait, what?
Nest: InsensitiveNest{
Ed: InsensitiveEd{NestedString: "another string"},
},
}
var got Insensitive
_, err = Decode(caseToml, &got)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, got) {
t.Fatalf("\n%#v\n!=\n%#v\n", expected, got)
}
}
func TestPointers(t *testing.T) {
type Object struct {
Type string
Description string
}
type Dict struct {
NamedObject map[string]*Object
BaseObject *Object
Strptr *string
Strptrs []*string
}
s1, s2, s3 := "blah", "abc", "def"
expected := &Dict{
Strptr: &s1,
Strptrs: []*string{&s2, &s3},
NamedObject: map[string]*Object{
"foo": {"FOO", "fooooo!!!"},
"bar": {"BAR", "ba-ba-ba-ba-barrrr!!!"},
},
BaseObject: &Object{"BASE", "da base"},
}
ex1 := `
Strptr = "blah"
Strptrs = ["abc", "def"]
[NamedObject.foo]
Type = "FOO"
Description = "fooooo!!!"
[NamedObject.bar]
Type = "BAR"
Description = "ba-ba-ba-ba-barrrr!!!"
[BaseObject]
Type = "BASE"
Description = "da base"
`
dict := new(Dict)
_, err := Decode(ex1, dict)
if err != nil {
t.Errorf("Decode error: %v", err)
}
if !reflect.DeepEqual(expected, dict) {
t.Fatalf("\n%#v\n!=\n%#v\n", expected, dict)
}
}
func ExamplePrimitiveDecode() {
var md MetaData
var err error
var tomlBlob = `
ranking = ["Springsteen", "J Geils"]
[bands.Springsteen]
started = 1973
albums = ["Greetings", "WIESS", "Born to Run", "Darkness"]
[bands.J Geils]
started = 1970
albums = ["The J. Geils Band", "Full House", "Blow Your Face Out"]
`
type band struct {
Started int
Albums []string
}
type classics struct {
Ranking []string
Bands map[string]Primitive
}
// Do the initial decode. Reflection is delayed on Primitive values.
var music classics
if md, err = Decode(tomlBlob, &music); err != nil {
log.Fatal(err)
}
// MetaData still includes information on Primitive values.
fmt.Printf("Is `bands.Springsteen` defined? %v\n",
md.IsDefined("bands", "Springsteen"))
// Decode primitive data into Go values.
for _, artist := range music.Ranking {
// A band is a primitive value, so we need to decode it to get a
// real `band` value.
primValue := music.Bands[artist]
var aBand band
if err = PrimitiveDecode(primValue, &aBand); err != nil {
log.Fatal(err)
}
fmt.Printf("%s started in %d.\n", artist, aBand.Started)
}
// Output:
// Is `bands.Springsteen` defined? true
// Springsteen started in 1973.
// J Geils started in 1970.
}
func ExampleDecode() {
var tomlBlob = `
# Some comments.
[alpha]
ip = "10.0.0.1"
[alpha.config]
Ports = [ 8001, 8002 ]
Location = "Toronto"
Created = 1987-07-05T05:45:00Z
[beta]
ip = "10.0.0.2"
[beta.config]
Ports = [ 9001, 9002 ]
Location = "New Jersey"
Created = 1887-01-05T05:55:00Z
`
type serverConfig struct {
Ports []int
Location string
Created time.Time
}
type server struct {
IP string `toml:"ip"`
Config serverConfig `toml:"config"`
}
type servers map[string]server
var config servers
if _, err := Decode(tomlBlob, &config); err != nil {
log.Fatal(err)
}
for _, name := range []string{"alpha", "beta"} {
s := config[name]
fmt.Printf("Server: %s (ip: %s) in %s created on %s\n",
name, s.IP, s.Config.Location,
s.Config.Created.Format("2006-01-02"))
fmt.Printf("Ports: %v\n", s.Config.Ports)
}
// // Output:
// Server: alpha (ip: 10.0.0.1) in Toronto created on 1987-07-05
// Ports: [8001 8002]
// Server: beta (ip: 10.0.0.2) in New Jersey created on 1887-01-05
// Ports: [9001 9002]
}

View File

@ -0,0 +1,10 @@
/*
Package toml provides facilities for decoding TOML configuration files
via reflection.
Specification: https://github.com/mojombo/toml
Use github.com/BurntSushi/toml/tomlv to check whether a file is valid
TOML or not, with helpful error messages.
*/
package toml

View File

@ -0,0 +1,99 @@
package toml
// TODO: Build a decent encoder.
// Interestingly, this isn't as trivial as recursing down the type of the
// value given and outputting the corresponding TOML. In particular, multiple
// TOML types (especially if tuples are added) can map to a single Go type, so
// that the reverse correspondence isn't clear.
//
// One possible avenue is to choose a reasonable default (like structs map
// to hashes), but allow the user to override with struct tags. But this seems
// like a mess.
//
// The other possibility is to scrap an encoder altogether. After all, TOML
// is a configuration file format, and not a data exchange format.
import (
"bufio"
"fmt"
"io"
"reflect"
"strings"
)
type encoder struct {
// A single indentation level. By default it is two spaces.
Indent string
w *bufio.Writer
}
func newEncoder(w io.Writer) *encoder {
return &encoder{
w: bufio.NewWriter(w),
Indent: " ",
}
}
func (enc *encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.encode(Key([]string{}), rv); err != nil {
return err
}
return enc.w.Flush()
}
func (enc *encoder) encode(key Key, rv reflect.Value) error {
k := rv.Kind()
switch k {
case reflect.Struct:
return enc.eStruct(key, rv)
case reflect.String:
return enc.eString(key, rv)
}
return e("Unsupported type for key '%s': %s", key, k)
}
func (enc *encoder) eStruct(key Key, rv reflect.Value) error {
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
sft := rt.Field(i)
sf := rv.Field(i)
if err := enc.encode(key.add(sft.Name), sf); err != nil {
return err
}
}
return nil
}
func (enc *encoder) eString(key Key, rv reflect.Value) error {
s := rv.String()
s = strings.NewReplacer(
"\t", "\\t",
"\n", "\\n",
"\r", "\\r",
"\"", "\\\"",
"\\", "\\\\",
).Replace(s)
s = "\"" + s + "\""
if err := enc.eKeyVal(key, s); err != nil {
return err
}
return nil
}
func (enc *encoder) eKeyVal(key Key, value string) error {
out := fmt.Sprintf("%s%s = %s",
strings.Repeat(enc.Indent, len(key)-1), key[len(key)-1], value)
if _, err := fmt.Fprintln(enc.w, out); err != nil {
return err
}
return nil
}
func eindirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
return v
}
return eindirect(reflect.Indirect(v))
}

View File

@ -0,0 +1,25 @@
package toml
import (
"bytes"
"testing"
)
type encodeSimple struct {
Location string
// Ages []int
// DOB time.Time
}
func TestEncode(t *testing.T) {
v := encodeSimple{
Location: "Westborough, MA",
}
buf := new(bytes.Buffer)
e := newEncoder(buf)
if err := e.Encode(v); err != nil {
t.Fatal(err)
}
testf(buf.String())
}

View File

@ -0,0 +1,741 @@
package toml
import (
"fmt"
"unicode/utf8"
)
type itemType int
const (
itemError itemType = iota
itemNIL // used in the parser to indicate no type
itemEOF
itemText
itemString
itemBool
itemInteger
itemFloat
itemDatetime
itemArray // the start of an array
itemArrayEnd
itemTableStart
itemTableEnd
itemArrayTableStart
itemArrayTableEnd
itemKeyStart
itemCommentStart
)
const (
eof = 0
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
arrayValTerm = ','
commentStart = '#'
stringStart = '"'
stringEnd = '"'
)
type stateFn func(lx *lexer) stateFn
type lexer struct {
input string
start int
pos int
width int
line int
state stateFn
items chan item
// A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily
// nested arrays. The last state on the stack is used after a value has
// been lexed. Similarly for comments.
stack []stateFn
}
type item struct {
typ itemType
val string
line int
}
func (lx *lexer) nextItem() item {
for {
select {
case item := <-lx.items:
return item
default:
lx.state = lx.state(lx)
}
}
panic("not reached")
}
func lex(input string) *lexer {
lx := &lexer{
input: input,
state: lexTop,
line: 1,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
}
return lx
}
func (lx *lexer) push(state stateFn) {
lx.stack = append(lx.stack, state)
}
func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop.")
}
last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1]
return last
}
func (lx *lexer) current() string {
return lx.input[lx.start:lx.pos]
}
func (lx *lexer) emit(typ itemType) {
lx.items <- item{typ, lx.current(), lx.line}
lx.start = lx.pos
}
func (lx *lexer) next() (r rune) {
if lx.pos >= len(lx.input) {
lx.width = 0
return eof
}
if lx.input[lx.pos] == '\n' {
lx.line++
}
r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.pos += lx.width
return r
}
// ignore skips over the pending input before this point.
func (lx *lexer) ignore() {
lx.start = lx.pos
}
// backup steps back one rune. Can be called only once per call of next.
func (lx *lexer) backup() {
lx.pos -= lx.width
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line--
}
}
// accept consumes the next rune if it's equal to `valid`.
func (lx *lexer) accept(valid rune) bool {
if lx.next() == valid {
return true
}
lx.backup()
return false
}
// peek returns but does not consume the next rune in the input.
func (lx *lexer) peek() rune {
r := lx.next()
lx.backup()
return r
}
// errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special
// character (new lines, tabs, etc.).
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
for i, value := range values {
if v, ok := value.(rune); ok {
values[i] = escapeSpecial(v)
}
}
lx.items <- item{
itemError,
fmt.Sprintf(format, values...),
lx.line,
}
return nil
}
// lexTop consumes elements at the top level of TOML data.
func lexTop(lx *lexer) stateFn {
r := lx.next()
if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop)
}
switch r {
case commentStart:
lx.push(lexTop)
return lexCommentStart
case tableStart:
return lexTableStart
case eof:
if lx.pos > lx.start {
return lx.errorf("Unexpected EOF.")
}
lx.emit(itemEOF)
return nil
}
// At this point, the only valid item can be a key, so we back up
// and let the key lexer do the rest.
lx.backup()
lx.push(lexTopEnd)
return lexKeyStart
}
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
// or a table.) It must see only whitespace, and will turn back to lexTop
// upon a new line. If it sees EOF, it will quit the lexer successfully.
func lexTopEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case r == commentStart:
// a comment will read to a new line for us.
lx.push(lexTop)
return lexCommentStart
case isWhitespace(r):
return lexTopEnd
case isNL(r):
lx.ignore()
return lexTop
case r == eof:
lx.ignore()
return lexTop
}
return lx.errorf("Expected a top-level item to end with a new line, "+
"comment or EOF, but got '%s' instead.", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
// it starts with a character other than '.' and ']'.
// It assumes that '[' has already been consumed.
// It also handles the case that this is an item in an array of tables.
// e.g., '[[name]]'.
func lexTableStart(lx *lexer) stateFn {
if lx.peek() == arrayTableStart {
lx.next()
lx.emit(itemArrayTableStart)
lx.push(lexArrayTableEnd)
} else {
lx.emit(itemTableStart)
lx.push(lexTableEnd)
}
return lexTableNameStart
}
func lexTableEnd(lx *lexer) stateFn {
lx.emit(itemTableEnd)
return lexTopEnd
}
func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd {
return lx.errorf("Expected end of table array name delimiter '%s', "+
"but got '%s' instead.", arrayTableEnd, r)
}
lx.emit(itemArrayTableEnd)
return lexTopEnd
}
func lexTableNameStart(lx *lexer) stateFn {
switch lx.next() {
case tableEnd:
return lx.errorf("Unexpected end of table. (Tables cannot " +
"be empty.)")
case tableSep:
return lx.errorf("Unexpected table separator. (Tables cannot " +
"be empty.)")
}
return lexTableName
}
// lexTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read.
func lexTableName(lx *lexer) stateFn {
switch lx.peek() {
case tableStart:
return lx.errorf("Table names cannot contain '%s' or '%s'.",
tableStart, tableEnd)
case tableEnd:
lx.emit(itemText)
lx.next()
return lx.pop()
case tableSep:
lx.emit(itemText)
lx.next()
lx.ignore()
return lexTableNameStart
}
lx.next()
return lexTableName
}
// lexKeyStart consumes a key name up until the first non-whitespace character.
// lexKeyStart will ignore whitespace.
func lexKeyStart(lx *lexer) stateFn {
r := lx.peek()
switch {
case r == keySep:
return lx.errorf("Unexpected key separator '%s'.", keySep)
case isWhitespace(r) || isNL(r):
lx.next()
return lexSkip(lx, lexKeyStart)
}
lx.ignore()
lx.emit(itemKeyStart)
lx.next()
return lexKey
}
// lexKey consumes the text of a key. Assumes that the first character (which
// is not whitespace) has already been consumed.
func lexKey(lx *lexer) stateFn {
r := lx.peek()
// XXX: Possible divergence from spec?
// "Keys start with the first non-whitespace character and end with the
// last non-whitespace character before the equals sign."
// Note here that whitespace is either a tab or a space.
// But we'll call it quits if we see a new line too.
if isWhitespace(r) || isNL(r) {
lx.emit(itemText)
return lexKeyEnd
}
// Let's also call it quits if we see an equals sign.
if r == keySep {
lx.emit(itemText)
return lexKeyEnd
}
lx.next()
return lexKey
}
// lexKeyEnd consumes the end of a key (up to the key separator).
// Assumes that the first whitespace character after a key (or the '='
// separator) has NOT been consumed.
func lexKeyEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexKeyEnd)
case r == keySep:
return lexSkip(lx, lexValue)
}
return lx.errorf("Expected key separator '%s', but got '%s' instead.",
keySep, r)
}
// lexValue starts the consumption of a value anywhere a value is expected.
// lexValue will ignore whitespace.
// After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT new lines.
// In array syntax, the array states are responsible for ignoring new lines.
r := lx.next()
if isWhitespace(r) {
return lexSkip(lx, lexValue)
}
switch {
case r == arrayStart:
lx.ignore()
lx.emit(itemArray)
return lexArrayValue
case r == stringStart:
lx.ignore() // ignore the '"'
return lexString
case r == 't':
return lexTrue
case r == 'f':
return lexFalse
case r == '-':
return lexNumberStart
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
case r == '.': // special error case, be kind to users
return lx.errorf("Floats must start with a digit, not '.'.")
}
return lx.errorf("Expected value but found '%s' instead.", r)
}
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
// have already been consumed. All whitespace and new lines are ignored.
func lexArrayValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValue)
case r == commentStart:
lx.push(lexArrayValue)
return lexCommentStart
case r == arrayValTerm:
return lx.errorf("Unexpected array value terminator '%s'.",
arrayValTerm)
case r == arrayEnd:
return lexArrayEnd
}
lx.backup()
lx.push(lexArrayValueEnd)
return lexValue
}
// lexArrayValueEnd consumes the cruft between values of an array. Namely,
// it ignores whitespace and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValueEnd)
case r == commentStart:
lx.push(lexArrayValueEnd)
return lexCommentStart
case r == arrayValTerm:
return lexArrayValue // move on to the next value
case r == arrayEnd:
return lexArrayEnd
}
return lx.errorf("Expected an array value terminator '%s' or an array "+
"terminator '%s', but got '%s' instead.", arrayValTerm, arrayEnd, r)
}
// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has
// just been consumed.
func lexArrayEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemArrayEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn {
r := lx.next()
switch {
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
case r == '\\':
return lexStringEscape
case r == stringEnd:
lx.backup()
lx.emit(itemString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexString
}
// lexStringEscape consumes an escaped character. It assumes that the preceding
// '\\' has already been consumed.
func lexStringEscape(lx *lexer) stateFn {
r := lx.next()
switch r {
case 'b':
fallthrough
case 't':
fallthrough
case 'n':
fallthrough
case 'f':
fallthrough
case 'r':
fallthrough
case '"':
fallthrough
case '/':
fallthrough
case '\\':
return lexString
case 'u':
return lexStringUnicode
}
return lx.errorf("Invalid escape character '%s'. Only the following "+
"escape characters are allowed: "+
"\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, and \\uXXXX.", r)
}
// lexStringBinary consumes two hexadecimal digits following '\x'. It assumes
// that the '\x' has already been consumed.
func lexStringUnicode(lx *lexer) stateFn {
var r rune
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected four hexadecimal digits after '\\x', "+
"but got '%s' instead.", lx.current())
}
}
return lexString
}
// lexNumberOrDateStart consumes either a (positive) integer, float or datetime.
// It assumes that NO negative sign has been consumed.
func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got '%s'.", r)
}
}
return lexNumberOrDate
}
// lexNumberOrDate consumes either a (positive) integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '-':
if lx.pos-lx.start != 5 {
return lx.errorf("All ISO8601 dates must be in full Zulu form.")
}
return lexDateAfterYear
case isDigit(r):
return lexNumberOrDate
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format.
// It assumes that "YYYY-" has already been consumed.
func lexDateAfterYear(lx *lexer) stateFn {
formats := []rune{
// digits are '0'.
// everything else is direct equality.
'0', '0', '-', '0', '0',
'T',
'0', '0', ':', '0', '0', ':', '0', '0',
'Z',
}
for _, f := range formats {
r := lx.next()
if f == '0' {
if !isDigit(r) {
return lx.errorf("Expected digit in ISO8601 datetime, "+
"but found '%s' instead.", r)
}
} else if f != r {
return lx.errorf("Expected '%s' in ISO8601 datetime, "+
"but found '%s' instead.", f, r)
}
}
lx.emit(itemDatetime)
return lx.pop()
}
// lexNumberStart consumes either an integer or a float. It assumes that a
// negative sign has already been read, but that *no* digits have been consumed.
// lexNumberStart will move to the appropriate integer or float states.
func lexNumberStart(lx *lexer) stateFn {
// we MUST see a digit. Even floats have to start with a digit.
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got '%s'.", r)
}
}
return lexNumber
}
// lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn {
r := lx.next()
switch {
case isDigit(r):
return lexNumber
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexFloatStart starts the consumption of digits of a float after a '.'.
// Namely, at least one digit is required.
func lexFloatStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
return lx.errorf("Floats must have a digit after the '.', but got "+
"'%s' instead.", r)
}
return lexFloat
}
// lexFloat consumes the digits of a float after a '.'.
// Assumes that one digit has been consumed after a '.' already.
func lexFloat(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexFloat
}
lx.backup()
lx.emit(itemFloat)
return lx.pop()
}
// lexTrue consumes the "rue" in "true". It assumes that 't' has already
// been consumed.
func lexTrue(lx *lexer) stateFn {
if r := lx.next(); r != 'r' {
return lx.errorf("Expected 'tr', but found 't%s' instead.", r)
}
if r := lx.next(); r != 'u' {
return lx.errorf("Expected 'tru', but found 'tr%s' instead.", r)
}
if r := lx.next(); r != 'e' {
return lx.errorf("Expected 'true', but found 'tru%s' instead.", r)
}
lx.emit(itemBool)
return lx.pop()
}
// lexFalse consumes the "alse" in "false". It assumes that 'f' has already
// been consumed.
func lexFalse(lx *lexer) stateFn {
if r := lx.next(); r != 'a' {
return lx.errorf("Expected 'fa', but found 'f%s' instead.", r)
}
if r := lx.next(); r != 'l' {
return lx.errorf("Expected 'fal', but found 'fa%s' instead.", r)
}
if r := lx.next(); r != 's' {
return lx.errorf("Expected 'fals', but found 'fal%s' instead.", r)
}
if r := lx.next(); r != 'e' {
return lx.errorf("Expected 'false', but found 'fals%s' instead.", r)
}
lx.emit(itemBool)
return lx.pop()
}
// lexCommentStart begins the lexing of a comment. It will emit
// itemCommentStart and consume no characters, passing control to lexComment.
func lexCommentStart(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemCommentStart)
return lexComment
}
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
// It will consume *up to* the first new line character, and pass control
// back to the last state on the stack.
func lexComment(lx *lexer) stateFn {
r := lx.peek()
if isNL(r) || r == eof {
lx.emit(itemText)
return lx.pop()
}
lx.next()
return lexComment
}
// lexSkip ignores all slurped input and moves on to the next state.
func lexSkip(lx *lexer, nextState stateFn) stateFn {
return func(lx *lexer) stateFn {
lx.ignore()
return nextState
}
}
// isWhitespace returns true if `r` is a whitespace character according
// to the spec.
func isWhitespace(r rune) bool {
return r == '\t' || r == ' '
}
func isNL(r rune) bool {
return r == '\n' || r == '\r'
}
func isDigit(r rune) bool {
return r >= '0' && r <= '9'
}
func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func (itype itemType) String() string {
switch itype {
case itemError:
return "Error"
case itemNIL:
return "NIL"
case itemEOF:
return "EOF"
case itemText:
return "Text"
case itemString:
return "String"
case itemBool:
return "Bool"
case itemInteger:
return "Integer"
case itemFloat:
return "Float"
case itemDatetime:
return "DateTime"
case itemTableStart:
return "TableStart"
case itemTableEnd:
return "TableEnd"
case itemKeyStart:
return "KeyStart"
case itemArray:
return "Array"
case itemArrayEnd:
return "ArrayEnd"
case itemCommentStart:
return "CommentStart"
}
panic(fmt.Sprintf("BUG: Unknown type '%s'.", itype))
}
func (item item) String() string {
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
}
func escapeSpecial(c rune) string {
switch c {
case '\n':
return "\\n"
}
return string(c)
}

View File

@ -0,0 +1,59 @@
package toml
import (
"log"
"testing"
)
func init() {
log.SetFlags(0)
}
var testSmall = `
# This is a TOML document. Boom.
[owner]
[owner] # Whoa there.
andrew = "gallant # poopy" # weeeee
predicate = false
num = -5192
f = -0.5192
zulu = 1979-05-27T07:32:00Z
whoop = "poop"
arrs = [
1987-07-05T05:45:00Z,
5,
"wat?",
"hehe \n\r kewl",
[6], [],
5.0,
# sweetness
] # more comments
# hehe
`
var testSmaller = `
[a.b] # Do you ignore me?
andrew = "ga# ll\"ant" # what about me?
kait = "brady"
awesomeness = true
pi = 3.14
dob = 1987-07-05T17:45:00Z
perfection = [
[6, 28],
[496, 8128]
]
`
func TestLexer(t *testing.T) {
lx := lex(testSmaller)
for {
item := lx.nextItem()
if item.typ == itemEOF {
break
} else if item.typ == itemError {
t.Fatal(item.val)
}
testf("%s\n", item)
}
}

View File

@ -0,0 +1,19 @@
package toml
import (
"flag"
"fmt"
)
var flagOut = false
func init() {
flag.BoolVar(&flagOut, "out", flagOut, "Print debug output.")
flag.Parse()
}
func testf(format string, v ...interface{}) {
if flagOut {
fmt.Printf(format, v...)
}
}

View File

@ -0,0 +1,388 @@
package toml
import (
"fmt"
"log"
"strconv"
"strings"
"time"
"unicode/utf8"
)
type parser struct {
mapping map[string]interface{}
types map[string]tomlType
lx *lexer
// A list of keys in the order that they appear in the TOML data.
ordered []Key
// the full key for the current hash in scope
context Key
// the base key name for everything except hashes
currentKey string
// rough approximation of line number
approxLine int
// A map of 'key.group.names' to whether they were created implicitly.
implicits map[string]bool
}
type parseError string
func (pe parseError) Error() string {
return string(pe)
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(parseError); ok {
return
}
panic(r)
}
}()
p = &parser{
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]bool),
}
for {
item := p.next()
if item.typ == itemEOF {
break
}
p.topLevel(item)
}
return p, nil
}
func (p *parser) panic(format string, v ...interface{}) {
msg := fmt.Sprintf("Near line %d, key '%s': %s",
p.approxLine, p.current(), fmt.Sprintf(format, v...))
panic(parseError(msg))
}
func (p *parser) next() item {
it := p.lx.nextItem()
if it.typ == itemError {
p.panic("Near line %d: %s", it.line, it.val)
}
return it
}
func (p *parser) bug(format string, v ...interface{}) {
log.Fatalf("BUG: %s\n\n", fmt.Sprintf(format, v...))
}
func (p *parser) expect(typ itemType) item {
it := p.next()
p.assertEqual(typ, it.typ)
return it
}
func (p *parser) assertEqual(expected, got itemType) {
if expected != got {
p.bug("Expected '%s' but got '%s'.", expected, got)
}
}
func (p *parser) topLevel(item item) {
switch item.typ {
case itemCommentStart:
p.approxLine = item.line
p.expect(itemText)
case itemTableStart:
kg := p.expect(itemText)
p.approxLine = kg.line
key := make(Key, 0)
for ; kg.typ == itemText; kg = p.next() {
key = append(key, kg.val)
}
p.assertEqual(itemTableEnd, kg.typ)
p.establishContext(key, false)
p.setType("", tomlHash)
p.ordered = append(p.ordered, key)
case itemArrayTableStart:
kg := p.expect(itemText)
p.approxLine = kg.line
key := make(Key, 0)
for ; kg.typ == itemText; kg = p.next() {
key = append(key, kg.val)
}
p.assertEqual(itemArrayTableEnd, kg.typ)
p.establishContext(key, true)
p.setType("", tomlArrayHash)
p.ordered = append(p.ordered, key)
case itemKeyStart:
kname := p.expect(itemText)
p.currentKey = kname.val
p.approxLine = kname.line
val, typ := p.value(p.next())
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.currentKey = ""
default:
p.bug("Unexpected type at top level: %s", item.typ)
}
}
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item) (interface{}, tomlType) {
switch it.typ {
case itemString:
return p.replaceUnicode(replaceEscapes(it.val)), p.typeOfPrimitive(it)
case itemBool:
switch it.val {
case "true":
return true, p.typeOfPrimitive(it)
case "false":
return false, p.typeOfPrimitive(it)
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
num, err := strconv.ParseInt(it.val, 10, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panic("Integer '%s' is out of the range of 64-bit "+
"signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemFloat:
num, err := strconv.ParseFloat(it.val, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panic("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.bug("Expected float value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
t, err := time.Parse("2006-01-02T15:04:05Z", it.val)
if err != nil {
p.bug("Expected Zulu formatted DateTime, but got '%s'.", it.val)
}
return t, p.typeOfPrimitive(it)
case itemArray:
array := make([]interface{}, 0)
types := make([]tomlType, 0)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it)
array = append(array, val)
types = append(types, typ)
}
return array, p.typeOfArray(types)
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
// establishContext sets the current context of the parser,
// where the context is the hash currently in scope.
//
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) establishContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
hashContext := p.mapping
keyContext := make(Key, 0)
// We only need implicit hashes for key[0:-1]
for _, k := range key[0 : len(key)-1] {
_, ok = hashContext[k]
keyContext = append(keyContext, k)
// No key? Make an implicit hash and move on.
if !ok {
p.addImplicit(keyContext)
hashContext[k] = make(map[string]interface{})
}
// If the hash context is actually an array of tables, then set
// the hash context to the last element in that array.
//
// Otherwise, it better be a table, since this MUST be a key group (by
// virtue of it not being the last element in a key).
switch t := hashContext[k].(type) {
case []map[string]interface{}:
hashContext = t[len(t)-1]
case map[string]interface{}:
hashContext = t
default:
p.panic("Key '%s' was already created as a hash.", keyContext)
}
}
p.context = keyContext
if array {
k := key[len(key)-1]
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 5)
}
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
} else {
p.panic("Key '%s' was already created and cannot be used as "+
"an array.", keyContext)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
}
p.context = append(p.context, key[len(key)-1])
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
var tmpHash interface{}
var ok bool
hash := p.mapping
keyContext := make(Key, 0)
for _, k := range p.context {
keyContext = append(keyContext, k)
if tmpHash, ok = hash[k]; !ok {
p.bug("Context for key '%s' has not been established.", keyContext)
}
switch t := tmpHash.(type) {
case []map[string]interface{}:
hash = t[len(t)-1]
case map[string]interface{}:
hash = t
default:
p.bug("Expected hash to have type 'map[string]interface{}', but "+
"it has '%T' instead.", tmpHash)
}
}
keyContext = append(keyContext, key)
if _, ok := hash[key]; ok {
// We need to do some fancy footwork here. If `hash[key]` was implcitly
// created AND `value` is a hash, then let this go through and stop
// tagging this table as implicit.
if p.isImplicit(keyContext) {
p.removeImplicit(keyContext)
return
}
// Otherwise, we have a concrete key trying to override a previous
// key, which is *always* wrong.
p.panic("Key '%s' has already been defined.", keyContext)
}
hash[key] = value
}
// setType sets the type of a particular value at a given key.
// It should be called immediately AFTER setValue.
func (p *parser) setType(key string, typ tomlType) {
keyContext := make(Key, 0, len(p.context)+1)
for _, k := range p.context {
keyContext = append(keyContext, k)
}
if len(key) > 0 { // allow type setting for hashes
keyContext = append(keyContext, key)
}
p.types[keyContext.String()] = typ
}
// addImplicit sets the given Key as having been created implicitly.
func (p *parser) addImplicit(key Key) {
p.implicits[key.String()] = true
}
// removeImplicit stops tagging the given key as having been implicitly created.
func (p *parser) removeImplicit(key Key) {
p.implicits[key.String()] = false
}
// isImplicit returns true if the key group pointed to by the key was created
// implicitly.
func (p *parser) isImplicit(key Key) bool {
return p.implicits[key.String()]
}
// current returns the full key name of the current context.
func (p *parser) current() string {
if len(p.currentKey) == 0 {
return p.context.String()
}
if len(p.context) == 0 {
return p.currentKey
}
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
}
func replaceEscapes(s string) string {
return strings.NewReplacer(
"\\b", "\u0008",
"\\t", "\u0009",
"\\n", "\u000A",
"\\f", "\u000C",
"\\r", "\u000D",
"\\\"", "\u0022",
"\\/", "\u002F",
"\\\\", "\u005C",
).Replace(s)
}
func (p *parser) replaceUnicode(s string) string {
indexEsc := func() int {
return strings.Index(s, "\\u")
}
for i := indexEsc(); i != -1; i = indexEsc() {
asciiBytes := s[i+2 : i+6]
s = strings.Replace(s, s[i:i+6], p.asciiEscapeToUnicode(asciiBytes), -1)
}
return s
}
func (p *parser) asciiEscapeToUnicode(s string) string {
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
"lexer claims it's OK: %s", s, err)
}
// I honestly don't understand how this works. I can't seem to find
// a way to make this fail. I figured this would fail on invalid UTF-8
// characters like U+DCFF, but it doesn't.
r := string(rune(hex))
if !utf8.ValidString(r) {
p.panic("Escaped character '\\u%s' is not valid UTF-8.", s)
}
return string(r)
}

View File

@ -0,0 +1,61 @@
package toml
import (
"strings"
"testing"
)
var testParseSmall = `
# This is a TOML document. Boom.
wat = "chipper"
[owner.andrew.gallant]
hmm = "hi"
[owner] # Whoa there.
andreW = "gallant # poopy" # weeeee
predicate = false
num = -5192
f = -0.5192
zulu = 1979-05-27T07:32:00Z
whoop = "poop"
tests = [ [1, 2, 3], ["abc", "xyz"] ]
arrs = [ # hmm
# more comments are awesome.
1987-07-05T05:45:00Z,
# say wat?
1987-07-05T05:45:00Z,
1987-07-05T05:45:00Z,
# sweetness
] # more comments
# hehe
`
var testParseSmall2 = `
[a]
better = 43
[a.b.c]
answer = 42
`
func TestParse(t *testing.T) {
m, err := parse(testParseSmall)
if err != nil {
t.Fatal(err)
}
printMap(m.mapping, 0)
}
func printMap(m map[string]interface{}, depth int) {
for k, v := range m {
testf("%s%s\n", strings.Repeat(" ", depth), k)
switch subm := v.(type) {
case map[string]interface{}:
printMap(subm, depth+1)
default:
testf("%s%v\n", strings.Repeat(" ", depth+1), v)
}
}
}

View File

@ -0,0 +1 @@
au BufWritePost *.go silent!make tags > /dev/null 2>&1

View File

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@ -0,0 +1,14 @@
# Implements the TOML test suite interface
This is an implementation of the interface expected by
[toml-test](https://github.com/BurntSushi/toml-test) for my
[toml parser written in Go](https://github.com/BurntSushi/toml).
In particular, it maps TOML data on `stdin` to a JSON format on `stdout`.
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)
Compatible with `toml-test` version
[v0.2.0](https://github.com/BurntSushi/toml-test/tree/v0.2.0)

View File

@ -0,0 +1,89 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path"
"time"
"github.com/BurntSushi/toml"
)
func init() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s < toml-file\n", path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() != 0 {
flag.Usage()
}
var tmp interface{}
if _, err := toml.DecodeReader(os.Stdin, &tmp); err != nil {
log.Fatalf("Error decoding TOML: %s", err)
}
typedTmp := translate(tmp)
if err := json.NewEncoder(os.Stdout).Encode(typedTmp); err != nil {
log.Fatalf("Error encoding JSON: %s", err)
}
}
func translate(tomlData interface{}) interface{} {
switch orig := tomlData.(type) {
case map[string]interface{}:
typed := make(map[string]interface{}, len(orig))
for k, v := range orig {
typed[k] = translate(v)
}
return typed
case []map[string]interface{}:
typed := make([]map[string]interface{}, len(orig))
for i, v := range orig {
typed[i] = translate(v).(map[string]interface{})
}
return typed
case []interface{}:
typed := make([]interface{}, len(orig))
for i, v := range orig {
typed[i] = translate(v)
}
// We don't really need to tag arrays, but let's be future proof.
// (If TOML ever supports tuples, we'll need this.)
return tag("array", typed)
case time.Time:
return tag("datetime", orig.Format("2006-01-02T15:04:05Z"))
case bool:
return tag("bool", fmt.Sprintf("%v", orig))
case int64:
return tag("integer", fmt.Sprintf("%d", orig))
case float64:
return tag("float", fmt.Sprintf("%v", orig))
case string:
return tag("string", orig)
}
panic(fmt.Sprintf("Unknown type: %T", tomlData))
}
func tag(typeName string, data interface{}) map[string]interface{} {
return map[string]interface{}{
"type": typeName,
"value": data,
}
}

View File

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View File

@ -0,0 +1,22 @@
# TOML Validator
If Go is installed, it's simple to try it out:
```bash
go get github.com/BurntSushi/toml/tomlv
tomlv some-toml-file.toml
```
You can see the types of every key in a TOML file with:
```bash
tomlv -types some-toml-file.toml
```
At the moment, only one error message is reported at a time. Error messages
include line numbers. No output means that the files given are valid TOML, or
there is a bug in `tomlv`.
Compatible with TOML version
[v0.1.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.1.0.md)

View File

@ -0,0 +1,60 @@
package main
import (
"flag"
"fmt"
"log"
"os"
"path"
"strings"
"text/tabwriter"
"github.com/BurntSushi/toml"
)
var (
flagTypes = false
)
func init() {
log.SetFlags(0)
flag.BoolVar(&flagTypes, "types", flagTypes,
"When set, the types of every defined key will be shown.")
flag.Usage = usage
flag.Parse()
}
func usage() {
log.Printf("Usage: %s toml-file [ toml-file ... ]\n",
path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
func main() {
if flag.NArg() < 1 {
flag.Usage()
}
for _, f := range flag.Args() {
var tmp interface{}
md, err := toml.DecodeFile(f, &tmp)
if err != nil {
log.Fatalf("Error in '%s': %s", f, err)
}
if flagTypes {
printTypes(md)
}
}
}
func printTypes(md toml.MetaData) {
tabw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
for _, key := range md.Keys() {
fmt.Fprintf(tabw, "%s%s\t%s\n",
strings.Repeat(" ", len(key)-1), key, md.Type(key...))
}
tabw.Flush()
}

View File

@ -0,0 +1,78 @@
package toml
// tomlType represents any Go type that corresponds to a TOML type.
// While the first draft of the TOML spec has a simplistic type system that
// probably doesn't need this level of sophistication, we seem to be militating
// toward adding real composite types.
type tomlType interface {
typeString() string
}
// typeEqual accepts any two types and returns true if they are equal.
func typeEqual(t1, t2 tomlType) bool {
return t1.typeString() == t2.typeString()
}
type tomlBaseType string
func (btype tomlBaseType) typeString() string {
return string(btype)
}
func (btype tomlBaseType) String() string {
return btype.typeString()
}
var (
tomlInteger tomlBaseType = "Integer"
tomlFloat tomlBaseType = "Float"
tomlDatetime tomlBaseType = "Datetime"
tomlString tomlBaseType = "String"
tomlBool tomlBaseType = "Bool"
tomlArray tomlBaseType = "Array"
tomlHash tomlBaseType = "Hash"
tomlArrayHash tomlBaseType = "ArrayHash"
)
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
// Primitive values are: Integer, Float, Datetime, String and Bool.
//
// Passing a lexer item other than the following will cause a BUG message
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
switch lexItem.typ {
case itemInteger:
return tomlInteger
case itemFloat:
return tomlFloat
case itemDatetime:
return tomlDatetime
case itemString:
return tomlString
case itemBool:
return tomlBool
}
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
panic("unreachable")
}
// typeOfArray returns a tomlType for an array given a list of types of its
// values.
//
// In the current spec, if an array is homogeneous, then its type is always
// "Array". If the array is not homogeneous, an error is generated.
func (p *parser) typeOfArray(types []tomlType) tomlType {
// Empty arrays are cool.
if len(types) == 0 {
return tomlArray
}
theType := types[0]
for _, t := range types[1:] {
if !typeEqual(theType, t) {
p.panic("Array contains values of type '%s' and '%s', but arrays "+
"must be homogeneous.", theType, t)
}
}
return tomlArray
}

88
util.go
View File

@ -1,88 +0,0 @@
package main
import (
"net"
"net/url"
"os"
"os/signal"
"runtime/pprof"
"github.com/coreos/etcd/log"
)
//--------------------------------------
// HTTP Utilities
//--------------------------------------
// sanitizeURL will cleanup a host string in the format hostname:port and
// attach a schema.
func sanitizeURL(host string, defaultScheme string) string {
// Blank URLs are fine input, just return it
if len(host) == 0 {
return host
}
p, err := url.Parse(host)
if err != nil {
log.Fatal(err)
}
// Make sure the host is in Host:Port format
_, _, err = net.SplitHostPort(host)
if err != nil {
log.Fatal(err)
}
p = &url.URL{Host: host, Scheme: defaultScheme}
return p.String()
}
// sanitizeListenHost cleans up the ListenHost parameter and appends a port
// if necessary based on the advertised port.
func sanitizeListenHost(listen string, advertised string) string {
aurl, err := url.Parse(advertised)
if err != nil {
log.Fatal(err)
}
ahost, aport, err := net.SplitHostPort(aurl.Host)
if err != nil {
log.Fatal(err)
}
// If the listen host isn't set use the advertised host
if listen == "" {
listen = ahost
}
return net.JoinHostPort(listen, aport)
}
func check(err error) {
if err != nil {
log.Fatal(err)
}
}
//--------------------------------------
// CPU profile
//--------------------------------------
func runCPUProfile() {
f, err := os.Create(cpuprofile)
if err != nil {
log.Fatal(err)
}
pprof.StartCPUProfile(f)
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
for sig := range c {
log.Infof("captured %v, stopping profiler and exiting..", sig)
pprof.StopCPUProfile()
os.Exit(1)
}
}()
}