diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index 3e5ac2881..349ee7825 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -21,6 +21,7 @@ import ( "net" "net/http" "net/url" + "strconv" "strings" "sync" "time" @@ -149,8 +150,11 @@ type ServerConfig struct { type server struct { lg *zap.Logger - from url.URL - to url.URL + from url.URL + fromPort int + to url.URL + toPort int + tlsInfo transport.TLSInfo dialTimeout time.Duration @@ -198,8 +202,9 @@ func NewServer(cfg ServerConfig) Server { s := &server{ lg: cfg.Logger, - from: cfg.From, - to: cfg.To, + from: cfg.From, + to: cfg.To, + tlsInfo: cfg.TLSInfo, dialTimeout: cfg.DialTimeout, @@ -215,6 +220,16 @@ func NewServer(cfg ServerConfig) Server { pauseRxc: make(chan struct{}), } + _, fromPort, err := net.SplitHostPort(cfg.From.Host) + if err == nil { + s.fromPort, err = strconv.Atoi(fromPort) + } + var toPort string + _, toPort, err = net.SplitHostPort(cfg.To.Host) + if err == nil { + s.toPort, _ = strconv.Atoi(toPort) + } + if s.dialTimeout == 0 { s.dialTimeout = defaultDialTimeout } @@ -239,12 +254,16 @@ func NewServer(cfg ServerConfig) Server { s.to.Scheme = "tcp" } + addr := fmt.Sprintf(":%d", s.fromPort) + if s.fromPort == 0 { // unix + addr = s.from.Host + } + var ln net.Listener - var err error if !s.tlsInfo.Empty() { - ln, err = transport.NewListener(s.from.Host, s.from.Scheme, &s.tlsInfo) + ln, err = transport.NewListener(addr, s.from.Scheme, &s.tlsInfo) } else { - ln, err = net.Listen(s.from.Scheme, s.from.Host) + ln, err = net.Listen(s.from.Scheme, addr) } if err != nil { s.errc <- err