client: allow caller to decide HTTP redirect policy
parent
1c03df62a5
commit
9b334e07a6
|
@ -39,7 +39,6 @@ var (
|
||||||
ErrKeyExists = errors.New("client: key already exists")
|
ErrKeyExists = errors.New("client: key already exists")
|
||||||
|
|
||||||
DefaultRequestTimeout = 5 * time.Second
|
DefaultRequestTimeout = 5 * time.Second
|
||||||
DefaultMaxRedirects = 10
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var DefaultTransport CancelableTransport = &http.Transport{
|
var DefaultTransport CancelableTransport = &http.Transport{
|
||||||
|
@ -72,6 +71,17 @@ type Config struct {
|
||||||
// Transport is used by the Client to drive HTTP requests. If not
|
// Transport is used by the Client to drive HTTP requests. If not
|
||||||
// provided, DefaultTransport will be used.
|
// provided, DefaultTransport will be used.
|
||||||
Transport CancelableTransport
|
Transport CancelableTransport
|
||||||
|
|
||||||
|
// CheckRedirect specifies the policy for handling HTTP redirects.
|
||||||
|
// If CheckRedirect is not nil, the Client calls it before
|
||||||
|
// following an HTTP redirect. The sole argument is the number of
|
||||||
|
// requests that have alrady been made. If CheckRedirect returns
|
||||||
|
// an error, Client.Do will not make any further requests and return
|
||||||
|
// the error back it to the caller.
|
||||||
|
//
|
||||||
|
// If CheckRedirect is nil, the Client uses its default policy,
|
||||||
|
// which is to stop after 10 consecutive requests.
|
||||||
|
CheckRedirect CheckRedirectFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) transport() CancelableTransport {
|
func (cfg *Config) transport() CancelableTransport {
|
||||||
|
@ -81,6 +91,13 @@ func (cfg *Config) transport() CancelableTransport {
|
||||||
return cfg.Transport
|
return cfg.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cfg *Config) checkRedirect() CheckRedirectFunc {
|
||||||
|
if cfg.CheckRedirect == nil {
|
||||||
|
return DefaultCheckRedirect
|
||||||
|
}
|
||||||
|
return cfg.CheckRedirect
|
||||||
|
}
|
||||||
|
|
||||||
// CancelableTransport mimics net/http.Transport, but requires that
|
// CancelableTransport mimics net/http.Transport, but requires that
|
||||||
// the object also support request cancellation.
|
// the object also support request cancellation.
|
||||||
type CancelableTransport interface {
|
type CancelableTransport interface {
|
||||||
|
@ -88,6 +105,16 @@ type CancelableTransport interface {
|
||||||
CancelRequest(req *http.Request)
|
CancelRequest(req *http.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CheckRedirectFunc func(via int) error
|
||||||
|
|
||||||
|
// DefaultCheckRedirect follows up to 10 redirects, but no more.
|
||||||
|
var DefaultCheckRedirect CheckRedirectFunc = func(via int) error {
|
||||||
|
if via > 10 {
|
||||||
|
return ErrTooManyRedirects
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
// Sync updates the internal cache of the etcd cluster's membership.
|
// Sync updates the internal cache of the etcd cluster's membership.
|
||||||
Sync(context.Context) error
|
Sync(context.Context) error
|
||||||
|
@ -101,7 +128,7 @@ type Client interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg Config) (Client, error) {
|
func New(cfg Config) (Client, error) {
|
||||||
c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport())}
|
c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect())}
|
||||||
if err := c.reset(cfg.Endpoints); err != nil {
|
if err := c.reset(cfg.Endpoints); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -112,10 +139,10 @@ type httpClient interface {
|
||||||
Do(context.Context, httpAction) (*http.Response, []byte, error)
|
Do(context.Context, httpAction) (*http.Response, []byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPClientFactory(tr CancelableTransport) httpClientFactory {
|
func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc) httpClientFactory {
|
||||||
return func(ep url.URL) httpClient {
|
return func(ep url.URL) httpClient {
|
||||||
return &redirectFollowingHTTPClient{
|
return &redirectFollowingHTTPClient{
|
||||||
max: DefaultMaxRedirects,
|
checkRedirect: cr,
|
||||||
client: &simpleHTTPClient{
|
client: &simpleHTTPClient{
|
||||||
transport: tr,
|
transport: tr,
|
||||||
endpoint: ep,
|
endpoint: ep,
|
||||||
|
@ -270,12 +297,17 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon
|
||||||
}
|
}
|
||||||
|
|
||||||
type redirectFollowingHTTPClient struct {
|
type redirectFollowingHTTPClient struct {
|
||||||
client httpClient
|
client httpClient
|
||||||
max int
|
checkRedirect CheckRedirectFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) {
|
func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) {
|
||||||
for i := 0; i <= r.max; i++ {
|
for i := 0; ; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
if err := r.checkRedirect(i); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
resp, body, err := r.client.Do(ctx, act)
|
resp, body, err := r.client.Do(ctx, act)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -297,7 +329,6 @@ func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*
|
||||||
}
|
}
|
||||||
return resp, body, nil
|
return resp, body, nil
|
||||||
}
|
}
|
||||||
return nil, nil, ErrTooManyRedirects
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type redirectedHTTPAction struct {
|
type redirectedHTTPAction struct {
|
||||||
|
|
|
@ -258,7 +258,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
|
||||||
{
|
{
|
||||||
client: &httpClusterClient{
|
client: &httpClusterClient{
|
||||||
endpoints: []url.URL{},
|
endpoints: []url.URL{},
|
||||||
clientFactory: newHTTPClientFactory(nil),
|
clientFactory: newHTTPClientFactory(nil, nil),
|
||||||
},
|
},
|
||||||
wantErr: ErrNoEndpoints,
|
wantErr: ErrNoEndpoints,
|
||||||
},
|
},
|
||||||
|
@ -349,14 +349,14 @@ func TestRedirectedHTTPAction(t *testing.T) {
|
||||||
|
|
||||||
func TestRedirectFollowingHTTPClient(t *testing.T) {
|
func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
max int
|
checkRedirect CheckRedirectFunc
|
||||||
client httpClient
|
client httpClient
|
||||||
wantCode int
|
wantCode int
|
||||||
wantErr error
|
wantErr error
|
||||||
}{
|
}{
|
||||||
// errors bubbled up
|
// errors bubbled up
|
||||||
{
|
{
|
||||||
max: 2,
|
checkRedirect: func(int) error { return ErrTooManyRedirects },
|
||||||
client: &multiStaticHTTPClient{
|
client: &multiStaticHTTPClient{
|
||||||
responses: []staticHTTPResponse{
|
responses: []staticHTTPResponse{
|
||||||
staticHTTPResponse{
|
staticHTTPResponse{
|
||||||
|
@ -369,7 +369,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||||
|
|
||||||
// no need to follow redirect if none given
|
// no need to follow redirect if none given
|
||||||
{
|
{
|
||||||
max: 2,
|
checkRedirect: func(int) error { return ErrTooManyRedirects },
|
||||||
client: &multiStaticHTTPClient{
|
client: &multiStaticHTTPClient{
|
||||||
responses: []staticHTTPResponse{
|
responses: []staticHTTPResponse{
|
||||||
staticHTTPResponse{
|
staticHTTPResponse{
|
||||||
|
@ -384,7 +384,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||||
|
|
||||||
// redirects if less than max
|
// redirects if less than max
|
||||||
{
|
{
|
||||||
max: 2,
|
checkRedirect: func(via int) error {
|
||||||
|
if via >= 2 {
|
||||||
|
return ErrTooManyRedirects
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
client: &multiStaticHTTPClient{
|
client: &multiStaticHTTPClient{
|
||||||
responses: []staticHTTPResponse{
|
responses: []staticHTTPResponse{
|
||||||
staticHTTPResponse{
|
staticHTTPResponse{
|
||||||
|
@ -405,7 +410,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||||
|
|
||||||
// succeed after reaching max redirects
|
// succeed after reaching max redirects
|
||||||
{
|
{
|
||||||
max: 2,
|
checkRedirect: func(via int) error {
|
||||||
|
if via >= 3 {
|
||||||
|
return ErrTooManyRedirects
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
client: &multiStaticHTTPClient{
|
client: &multiStaticHTTPClient{
|
||||||
responses: []staticHTTPResponse{
|
responses: []staticHTTPResponse{
|
||||||
staticHTTPResponse{
|
staticHTTPResponse{
|
||||||
|
@ -430,9 +440,14 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||||
wantCode: http.StatusTeapot,
|
wantCode: http.StatusTeapot,
|
||||||
},
|
},
|
||||||
|
|
||||||
// fail at max+1 redirects
|
// fail if too many redirects
|
||||||
{
|
{
|
||||||
max: 1,
|
checkRedirect: func(via int) error {
|
||||||
|
if via >= 2 {
|
||||||
|
return ErrTooManyRedirects
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
client: &multiStaticHTTPClient{
|
client: &multiStaticHTTPClient{
|
||||||
responses: []staticHTTPResponse{
|
responses: []staticHTTPResponse{
|
||||||
staticHTTPResponse{
|
staticHTTPResponse{
|
||||||
|
@ -459,7 +474,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||||
|
|
||||||
// fail if Location header not set
|
// fail if Location header not set
|
||||||
{
|
{
|
||||||
max: 1,
|
checkRedirect: func(int) error { return ErrTooManyRedirects },
|
||||||
client: &multiStaticHTTPClient{
|
client: &multiStaticHTTPClient{
|
||||||
responses: []staticHTTPResponse{
|
responses: []staticHTTPResponse{
|
||||||
staticHTTPResponse{
|
staticHTTPResponse{
|
||||||
|
@ -474,7 +489,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||||
|
|
||||||
// fail if Location header is invalid
|
// fail if Location header is invalid
|
||||||
{
|
{
|
||||||
max: 1,
|
checkRedirect: func(int) error { return ErrTooManyRedirects },
|
||||||
client: &multiStaticHTTPClient{
|
client: &multiStaticHTTPClient{
|
||||||
responses: []staticHTTPResponse{
|
responses: []staticHTTPResponse{
|
||||||
staticHTTPResponse{
|
staticHTTPResponse{
|
||||||
|
@ -490,7 +505,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max}
|
client := &redirectFollowingHTTPClient{client: tt.client, checkRedirect: tt.checkRedirect}
|
||||||
resp, _, err := client.Do(context.Background(), nil)
|
resp, _, err := client.Do(context.Background(), nil)
|
||||||
if !reflect.DeepEqual(tt.wantErr, err) {
|
if !reflect.DeepEqual(tt.wantErr, err) {
|
||||||
t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
|
t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
|
||||||
|
|
Loading…
Reference in New Issue