From 2f5015552e8db84d04cb75b0adb3f7b9d3be5c56 Mon Sep 17 00:00:00 2001 From: Brandon Philips Date: Mon, 2 Sep 2013 22:17:39 -0700 Subject: [PATCH] feat(etcd_handlers): enable CORS When developing or using web frontends for etcd it will be necessary to enable Cross-Origin Resource Sharing. Add a flag that lets the user enable this feature via a whitelist. --- etcd.go | 27 +++++++++++++++++++++++++++ etcd_handlers.go | 19 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/etcd.go b/etcd.go index 46546e8cc..5a0b4dee6 100644 --- a/etcd.go +++ b/etcd.go @@ -3,9 +3,11 @@ package main import ( "crypto/tls" "flag" + "fmt" "github.com/coreos/etcd/store" "github.com/coreos/go-raft" "io/ioutil" + "net/url" "os" "strings" "time" @@ -40,6 +42,9 @@ var ( maxClusterSize int cpuprofile string + + cors string + corsList map[string]bool ) func init() { @@ -77,6 +82,8 @@ func init() { flag.IntVar(&maxClusterSize, "maxsize", 9, "the max size of the cluster") flag.StringVar(&cpuprofile, "cpuprofile", "", "write cpu profile to file") + + flag.StringVar(&cors, "cors", "", "whitelist origins for cross-origin resource sharing (e.g. '*' or 'http://localhost:8001,etc')") } const ( @@ -152,6 +159,8 @@ func main() { raft.SetLogLevel(raft.Debug) } + parseCorsFlag() + if machines != "" { cluster = strings.Split(machines, ",") } else if machinesFile != "" { @@ -206,3 +215,21 @@ func main() { e.ListenAndServe() } + +// parseCorsFlag gathers up the cors whitelist and puts it into the corsList. +func parseCorsFlag() { + if cors != "" { + corsList = make(map[string]bool) + list := strings.Split(cors, ",") + for _, v := range list { + fmt.Println(v) + if v != "*" { + _, err := url.Parse(v) + if err != nil { + panic(fmt.Sprintf("bad cors url: %s", err)) + } + } + corsList[v] = true + } + } +} diff --git a/etcd_handlers.go b/etcd_handlers.go index 60e7b35b5..02cb2316a 100644 --- a/etcd_handlers.go +++ b/etcd_handlers.go @@ -29,7 +29,26 @@ func NewEtcdMuxer() *http.ServeMux { type errorHandler func(http.ResponseWriter, *http.Request) error +// addCorsHeader parses the request Origin header and loops through the user +// provided allowed origins and sets the Access-Control-Allow-Origin header if +// there is a match. +func addCorsHeader(w http.ResponseWriter, r *http.Request) { + val, ok := corsList["*"] + if val && ok { + w.Header().Add("Access-Control-Allow-Origin", "*") + return + } + + requestOrigin := r.Header.Get("Origin") + val, ok = corsList[requestOrigin] + if val && ok { + w.Header().Add("Access-Control-Allow-Origin", requestOrigin) + return + } +} + func (fn errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + addCorsHeader(w, r) if e := fn(w, r); e != nil { if etcdErr, ok := e.(etcdErr.Error); ok { debug("Return error: ", etcdErr.Error())