diff --git a/embed/etcd.go b/embed/etcd.go index 3dd374637..58449a2f8 100644 --- a/embed/etcd.go +++ b/embed/etcd.go @@ -23,6 +23,7 @@ import ( "net" "net/http" "net/url" + "sort" "strconv" "sync" "time" @@ -33,7 +34,6 @@ import ( "github.com/coreos/etcd/etcdserver/api/v2v3" "github.com/coreos/etcd/etcdserver/api/v3client" "github.com/coreos/etcd/etcdserver/api/v3rpc" - "github.com/coreos/etcd/pkg/cors" "github.com/coreos/etcd/pkg/debugutil" runtimeutil "github.com/coreos/etcd/pkg/runtime" "github.com/coreos/etcd/pkg/transport" @@ -168,6 +168,7 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) { StrictReconfigCheck: cfg.StrictReconfigCheck, ClientCertAuthEnabled: cfg.ClientTLSInfo.ClientCertAuth, AuthToken: cfg.AuthToken, + CORS: cfg.CORS, HostWhitelist: cfg.HostWhitelist, InitialCorruptCheck: cfg.ExperimentalInitialCorruptCheck, CorruptCheckTime: cfg.ExperimentalCorruptCheckTime, @@ -473,8 +474,13 @@ func (e *Etcd) serveClients() (err error) { plog.Infof("ClientTLS: %s", e.cfg.ClientTLSInfo) } - if e.cfg.CorsInfo.String() != "" { - plog.Infof("cors = %s", e.cfg.CorsInfo) + if len(e.cfg.CORS) > 0 { + ss := make([]string, 0, len(e.cfg.CORS)) + for v := range e.cfg.CORS { + ss = append(ss, v) + } + sort.Strings(ss) + plog.Infof("cors = %q", ss) } // Start a client server goroutine for each listen address @@ -491,7 +497,6 @@ func (e *Etcd) serveClients() (err error) { etcdhttp.HandleBasic(mux, e.Server) h = mux } - h = http.Handler(&cors.CORSHandler{Handler: h, Info: e.cfg.CorsInfo}) gopts := []grpc.ServerOption{} if e.cfg.GRPCKeepAliveMinTime > time.Duration(0) { diff --git a/embed/serve.go b/embed/serve.go index 72f162dc1..5f78719a0 100644 --- a/embed/serve.go +++ b/embed/serve.go @@ -116,7 +116,7 @@ func (sctx *serveCtx) serve( httpmux := sctx.createMux(gwmux, handler) srvhttp := &http.Server{ - Handler: wrapMux(s, httpmux), + Handler: createAccessController(s, httpmux), ErrorLog: logger, // do not log user error } httpl := m.Match(cmux.HTTP1()) @@ -159,7 +159,7 @@ func (sctx *serveCtx) serve( httpmux := sctx.createMux(gwmux, handler) srv := &http.Server{ - Handler: wrapMux(s, httpmux), + Handler: createAccessController(s, httpmux), TLSConfig: tlscfg, ErrorLog: logger, // do not log user error } @@ -250,20 +250,20 @@ func (sctx *serveCtx) createMux(gwmux *gw.ServeMux, handler http.Handler) *http. return httpmux } -// wrapMux wraps HTTP multiplexer: +// createAccessController wraps HTTP multiplexer: // - mutate gRPC gateway request paths // - check hostname whitelist // client HTTP requests goes here first -func wrapMux(s *etcdserver.EtcdServer, mux *http.ServeMux) http.Handler { - return &httpWrapper{s: s, mux: mux} +func createAccessController(s *etcdserver.EtcdServer, mux *http.ServeMux) http.Handler { + return &accessController{s: s, mux: mux} } -type httpWrapper struct { +type accessController struct { s *etcdserver.EtcdServer mux *http.ServeMux } -func (m *httpWrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) { +func (ac *accessController) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // redirect for backward compatibilities if req != nil && req.URL != nil && strings.HasPrefix(req.URL.Path, "/v3beta/") { req.URL.Path = strings.Replace(req.URL.Path, "/v3beta/", "/v3/", 1) @@ -271,7 +271,7 @@ func (m *httpWrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if req.TLS == nil { // check origin if client connection is not secure host := httputil.GetHostname(req) - if !m.s.IsHostWhitelisted(host) { + if !ac.s.AccessController.IsHostWhitelisted(host) { plog.Warningf("rejecting HTTP request from %q to prevent DNS rebinding attacks", host) // TODO: use Go's "http.StatusMisdirectedRequest" (421) // https://github.com/golang/go/commit/4b8a7eafef039af1834ef9bfa879257c4a72b7b5 @@ -280,7 +280,26 @@ func (m *httpWrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } - m.mux.ServeHTTP(rw, req) + // Write CORS header. + if ac.s.AccessController.OriginAllowed("*") { + addCORSHeader(rw, "*") + } else if origin := req.Header.Get("Origin"); ac.s.OriginAllowed(origin) { + addCORSHeader(rw, origin) + } + + if req.Method == "OPTIONS" { + rw.WriteHeader(http.StatusOK) + return + } + + ac.mux.ServeHTTP(rw, req) +} + +// addCORSHeader adds the correct cors headers given an origin +func addCORSHeader(w http.ResponseWriter, origin string) { + w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") + w.Header().Add("Access-Control-Allow-Origin", origin) + w.Header().Add("Access-Control-Allow-Headers", "accept, content-type, authorization") } // https://github.com/transmission/transmission/pull/468 @@ -297,6 +316,35 @@ This requirement has been added to help prevent "DNS Rebinding" attacks (CVE-201 `, host) } +// WrapCORS wraps existing handler with CORS. +// TODO: deprecate this after v2 proxy deprecate +func WrapCORS(cors map[string]struct{}, h http.Handler) http.Handler { + return &corsHandler{ + ac: &etcdserver.AccessController{CORS: cors}, + h: h, + } +} + +type corsHandler struct { + ac *etcdserver.AccessController + h http.Handler +} + +func (ch *corsHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if ch.ac.OriginAllowed("*") { + addCORSHeader(rw, "*") + } else if origin := req.Header.Get("Origin"); ch.ac.OriginAllowed(origin) { + addCORSHeader(rw, origin) + } + + if req.Method == "OPTIONS" { + rw.WriteHeader(http.StatusOK) + return + } + + ch.h.ServeHTTP(rw, req) +} + func (sctx *serveCtx) registerUserHandler(s string, h http.Handler) { if sctx.userHandlers[s] != nil { plog.Warningf("path %s already registered by user handler", s)