netutil: use "context" and ctx-ize TCP addr resolution

release-3.2
Anthony Romano 2017-04-21 09:59:34 -07:00
parent 8bad78cb98
commit 85e87e8f6b
2 changed files with 35 additions and 10 deletions

View File

@ -16,14 +16,13 @@
package netutil
import (
"context"
"net"
"net/url"
"reflect"
"sort"
"time"
"golang.org/x/net/context"
"github.com/coreos/etcd/pkg/types"
"github.com/coreos/pkg/capnslog"
)
@ -32,11 +31,38 @@ var (
plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "pkg/netutil")
// indirection for testing
resolveTCPAddr = net.ResolveTCPAddr
resolveTCPAddr = resolveTCPAddrDefault
)
const retryInterval = time.Second
// taken from go's ResolveTCP code but uses configurable ctx
func resolveTCPAddrDefault(ctx context.Context, addr string) (*net.TCPAddr, error) {
host, port, serr := net.SplitHostPort(addr)
if serr != nil {
return nil, serr
}
portnum, perr := net.DefaultResolver.LookupPort(ctx, "tcp", port)
if perr != nil {
return nil, perr
}
var ips []net.IPAddr
if ip := net.ParseIP(host); ip != nil {
ips = []net.IPAddr{{IP: ip}}
} else {
// Try as a DNS name.
ipss, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
ips = ipss
}
// randomize?
ip := ips[0]
return &net.TCPAddr{IP: ip.IP, Port: portnum, Zone: ip.Zone}, nil
}
// resolveTCPAddrs is a convenience wrapper for net.ResolveTCPAddr.
// resolveTCPAddrs return a new set of url.URLs, in which all DNS hostnames
// are resolved.
@ -75,7 +101,7 @@ func resolveURL(ctx context.Context, u url.URL) (string, error) {
if host == "localhost" || net.ParseIP(host) != nil {
return "", nil
}
tcpAddr, err := resolveTCPAddr("tcp", u.Host)
tcpAddr, err := resolveTCPAddr(ctx, u.Host)
if err == nil {
plog.Infof("resolving %s to %s", u.Host, tcpAddr.String())
return tcpAddr.String(), nil

View File

@ -15,6 +15,7 @@
package netutil
import (
"context"
"errors"
"net"
"net/url"
@ -22,12 +23,10 @@ import (
"strconv"
"testing"
"time"
"golang.org/x/net/context"
)
func TestResolveTCPAddrs(t *testing.T) {
defer func() { resolveTCPAddr = net.ResolveTCPAddr }()
defer func() { resolveTCPAddr = resolveTCPAddrDefault }()
tests := []struct {
urls [][]url.URL
expected [][]url.URL
@ -113,7 +112,7 @@ func TestResolveTCPAddrs(t *testing.T) {
},
}
for _, tt := range tests {
resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
resolveTCPAddr = func(ctx context.Context, addr string) (*net.TCPAddr, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
@ -143,13 +142,13 @@ func TestResolveTCPAddrs(t *testing.T) {
}
func TestURLsEqual(t *testing.T) {
defer func() { resolveTCPAddr = net.ResolveTCPAddr }()
defer func() { resolveTCPAddr = resolveTCPAddrDefault }()
hostm := map[string]string{
"example.com": "10.0.10.1",
"first.com": "10.0.11.1",
"second.com": "10.0.11.2",
}
resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
resolveTCPAddr = func(ctx context.Context, addr string) (*net.TCPAddr, error) {
host, port, herr := net.SplitHostPort(addr)
if herr != nil {
return nil, herr