diff --git a/pkg/netutil/netutil.go b/pkg/netutil/netutil.go index b3ff499d6..6c1dc96c8 100644 --- a/pkg/netutil/netutil.go +++ b/pkg/netutil/netutil.go @@ -35,15 +35,24 @@ var ( ) // resolveTCPAddrs is a convenience wrapper for net.ResolveTCPAddr. -// resolveTCPAddrs resolves all DNS hostnames in-place for the given set of -// url.URLs. -func resolveTCPAddrs(urls ...[]url.URL) error { +// resolveTCPAddrs return a new set of url.URLs, in which all DNS hostnames +// are resolved. +func resolveTCPAddrs(urls [][]url.URL) ([][]url.URL, error) { + newurls := make([][]url.URL, 0) for _, us := range urls { + nus := make([]url.URL, len(us)) for i, u := range us { + nu, err := url.Parse(u.String()) + if err != nil { + return nil, err + } + nus[i] = *nu + } + for i, u := range nus { host, _, err := net.SplitHostPort(u.Host) if err != nil { plog.Errorf("could not parse url %s during tcp resolving", u.Host) - return err + return nil, err } if host == "localhost" { continue @@ -54,13 +63,14 @@ func resolveTCPAddrs(urls ...[]url.URL) error { tcpAddr, err := resolveTCPAddr("tcp", u.Host) if err != nil { plog.Errorf("could not resolve host %s", u.Host) - return err + return nil, err } plog.Infof("resolving %s to %s", u.Host, tcpAddr.String()) - us[i].Host = tcpAddr.String() + nus[i].Host = tcpAddr.String() } + newurls = append(newurls, nus) } - return nil + return newurls, nil } // urlsEqual checks equality of url.URLS between two arrays. @@ -69,7 +79,11 @@ func urlsEqual(a []url.URL, b []url.URL) bool { if len(a) != len(b) { return false } - resolveTCPAddrs(a, b) + urls, err := resolveTCPAddrs([][]url.URL{a, b}) + if err != nil { + return false + } + a, b = urls[0], urls[1] sort.Sort(types.URLs(a)) sort.Sort(types.URLs(b)) for i := range a { diff --git a/pkg/netutil/netutil_test.go b/pkg/netutil/netutil_test.go index 1bf000694..123829dd0 100644 --- a/pkg/netutil/netutil_test.go +++ b/pkg/netutil/netutil_test.go @@ -124,15 +124,15 @@ func TestResolveTCPAddrs(t *testing.T) { } return &net.TCPAddr{IP: net.ParseIP(tt.hostMap[host]), Port: i, Zone: ""}, nil } - err := resolveTCPAddrs(tt.urls...) + urls, err := resolveTCPAddrs(tt.urls) if tt.hasError { if err == nil { t.Errorf("expected error") } continue } - if !reflect.DeepEqual(tt.urls, tt.expected) { - t.Errorf("expected: %v, got %v", tt.expected, tt.urls) + if !reflect.DeepEqual(urls, tt.expected) { + t.Errorf("expected: %v, got %v", tt.expected, urls) } } }