diff --git a/main.go b/main.go index 08a1ddb1c..874b3a88e 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "log" + "net" "net/http" "os" "path" @@ -200,12 +201,15 @@ func startProxy() { type Addrs []string // Set parses a command line set of listen addresses, formatted like: -// 127.0.0.1:7001,unix:///var/run/etcd.sock,10.1.1.1:8080 +// 127.0.0.1:7001,10.1.1.2:80 func (as *Addrs) Set(s string) error { - // TODO(jonboulle): validate things. parsed := make([]string, 0) - for _, a := range strings.Split(s, ",") { - parsed = append(parsed, strings.TrimSpace(a)) + for _, in := range strings.Split(s, ",") { + a := strings.TrimSpace(in) + if err := validateAddr(a); err != nil { + return err + } + parsed = append(parsed, a) } if len(parsed) == 0 { return errors.New("no valid addresses given!") @@ -218,6 +222,23 @@ func (as *Addrs) String() string { return strings.Join(*as, ",") } +// validateAddr ensures that the provided string is a valid address. Valid +// addresses are of the form IP:port. +// Returns an error if the address is invalid, else nil. +func validateAddr(s string) error { + parts := strings.SplitN(s, ":", 2) + if len(parts) != 2 { + return errors.New("bad format in address specification") + } + if net.ParseIP(parts[0]) == nil { + return errors.New("bad IP in address specification") + } + if _, err := strconv.Atoi(parts[1]); err != nil { + return errors.New("bad port in address specification") + } + return nil +} + // ProxyFlag implements the flag.Value interface. type ProxyFlag string diff --git a/main_test.go b/main_test.go index dffcb0b08..97647cd95 100644 --- a/main_test.go +++ b/main_test.go @@ -64,3 +64,40 @@ func TestProxyFlagSet(t *testing.T) { } } } + +func TestBadValidateAddr(t *testing.T) { + tests := []string{ + // bad IP specification + ":4001", + "127.0:8080", + "123:456", + // bad port specification + "127.0.0.1:foo", + "127.0.0.1:", + // unix sockets not supported + "unix://", + "unix://tmp/etcd.sock", + // bad strings + "somewhere", + "234#$", + "file://foo/bar", + "http://hello", + } + for i, in := range tests { + if err := validateAddr(in); err == nil { + t.Errorf(`#%d: unexpected nil error for in=%q`, i, in) + } + } +} + +func TestValidateAddr(t *testing.T) { + tests := []string{ + "1.2.3.4:8080", + "10.1.1.1:80", + } + for i, in := range tests { + if err := validateAddr(in); err != nil { + t.Errorf("#%d: err=%v, want nil for in=%q", i, err, in) + } + } +}