client: allow caller to decide HTTP redirect policy

release-2.1
Brian Waldon 2015-01-28 15:09:00 -08:00 committed by Yicheng Qin
parent 1c03df62a5
commit 9b334e07a6
2 changed files with 68 additions and 22 deletions

View File

@ -39,7 +39,6 @@ var (
ErrKeyExists = errors.New("client: key already exists") ErrKeyExists = errors.New("client: key already exists")
DefaultRequestTimeout = 5 * time.Second DefaultRequestTimeout = 5 * time.Second
DefaultMaxRedirects = 10
) )
var DefaultTransport CancelableTransport = &http.Transport{ var DefaultTransport CancelableTransport = &http.Transport{
@ -72,6 +71,17 @@ type Config struct {
// Transport is used by the Client to drive HTTP requests. If not // Transport is used by the Client to drive HTTP requests. If not
// provided, DefaultTransport will be used. // provided, DefaultTransport will be used.
Transport CancelableTransport Transport CancelableTransport
// CheckRedirect specifies the policy for handling HTTP redirects.
// If CheckRedirect is not nil, the Client calls it before
// following an HTTP redirect. The sole argument is the number of
// requests that have alrady been made. If CheckRedirect returns
// an error, Client.Do will not make any further requests and return
// the error back it to the caller.
//
// If CheckRedirect is nil, the Client uses its default policy,
// which is to stop after 10 consecutive requests.
CheckRedirect CheckRedirectFunc
} }
func (cfg *Config) transport() CancelableTransport { func (cfg *Config) transport() CancelableTransport {
@ -81,6 +91,13 @@ func (cfg *Config) transport() CancelableTransport {
return cfg.Transport return cfg.Transport
} }
func (cfg *Config) checkRedirect() CheckRedirectFunc {
if cfg.CheckRedirect == nil {
return DefaultCheckRedirect
}
return cfg.CheckRedirect
}
// CancelableTransport mimics net/http.Transport, but requires that // CancelableTransport mimics net/http.Transport, but requires that
// the object also support request cancellation. // the object also support request cancellation.
type CancelableTransport interface { type CancelableTransport interface {
@ -88,6 +105,16 @@ type CancelableTransport interface {
CancelRequest(req *http.Request) CancelRequest(req *http.Request)
} }
type CheckRedirectFunc func(via int) error
// DefaultCheckRedirect follows up to 10 redirects, but no more.
var DefaultCheckRedirect CheckRedirectFunc = func(via int) error {
if via > 10 {
return ErrTooManyRedirects
}
return nil
}
type Client interface { type Client interface {
// Sync updates the internal cache of the etcd cluster's membership. // Sync updates the internal cache of the etcd cluster's membership.
Sync(context.Context) error Sync(context.Context) error
@ -101,7 +128,7 @@ type Client interface {
} }
func New(cfg Config) (Client, error) { func New(cfg Config) (Client, error) {
c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport())} c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect())}
if err := c.reset(cfg.Endpoints); err != nil { if err := c.reset(cfg.Endpoints); err != nil {
return nil, err return nil, err
} }
@ -112,10 +139,10 @@ type httpClient interface {
Do(context.Context, httpAction) (*http.Response, []byte, error) Do(context.Context, httpAction) (*http.Response, []byte, error)
} }
func newHTTPClientFactory(tr CancelableTransport) httpClientFactory { func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc) httpClientFactory {
return func(ep url.URL) httpClient { return func(ep url.URL) httpClient {
return &redirectFollowingHTTPClient{ return &redirectFollowingHTTPClient{
max: DefaultMaxRedirects, checkRedirect: cr,
client: &simpleHTTPClient{ client: &simpleHTTPClient{
transport: tr, transport: tr,
endpoint: ep, endpoint: ep,
@ -270,12 +297,17 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon
} }
type redirectFollowingHTTPClient struct { type redirectFollowingHTTPClient struct {
client httpClient client httpClient
max int checkRedirect CheckRedirectFunc
} }
func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) { func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) {
for i := 0; i <= r.max; i++ { for i := 0; ; i++ {
if i > 0 {
if err := r.checkRedirect(i); err != nil {
return nil, nil, err
}
}
resp, body, err := r.client.Do(ctx, act) resp, body, err := r.client.Do(ctx, act)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -297,7 +329,6 @@ func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*
} }
return resp, body, nil return resp, body, nil
} }
return nil, nil, ErrTooManyRedirects
} }
type redirectedHTTPAction struct { type redirectedHTTPAction struct {

View File

@ -258,7 +258,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
{ {
client: &httpClusterClient{ client: &httpClusterClient{
endpoints: []url.URL{}, endpoints: []url.URL{},
clientFactory: newHTTPClientFactory(nil), clientFactory: newHTTPClientFactory(nil, nil),
}, },
wantErr: ErrNoEndpoints, wantErr: ErrNoEndpoints,
}, },
@ -349,14 +349,14 @@ func TestRedirectedHTTPAction(t *testing.T) {
func TestRedirectFollowingHTTPClient(t *testing.T) { func TestRedirectFollowingHTTPClient(t *testing.T) {
tests := []struct { tests := []struct {
max int checkRedirect CheckRedirectFunc
client httpClient client httpClient
wantCode int wantCode int
wantErr error wantErr error
}{ }{
// errors bubbled up // errors bubbled up
{ {
max: 2, checkRedirect: func(int) error { return ErrTooManyRedirects },
client: &multiStaticHTTPClient{ client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{ responses: []staticHTTPResponse{
staticHTTPResponse{ staticHTTPResponse{
@ -369,7 +369,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// no need to follow redirect if none given // no need to follow redirect if none given
{ {
max: 2, checkRedirect: func(int) error { return ErrTooManyRedirects },
client: &multiStaticHTTPClient{ client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{ responses: []staticHTTPResponse{
staticHTTPResponse{ staticHTTPResponse{
@ -384,7 +384,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// redirects if less than max // redirects if less than max
{ {
max: 2, checkRedirect: func(via int) error {
if via >= 2 {
return ErrTooManyRedirects
}
return nil
},
client: &multiStaticHTTPClient{ client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{ responses: []staticHTTPResponse{
staticHTTPResponse{ staticHTTPResponse{
@ -405,7 +410,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// succeed after reaching max redirects // succeed after reaching max redirects
{ {
max: 2, checkRedirect: func(via int) error {
if via >= 3 {
return ErrTooManyRedirects
}
return nil
},
client: &multiStaticHTTPClient{ client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{ responses: []staticHTTPResponse{
staticHTTPResponse{ staticHTTPResponse{
@ -430,9 +440,14 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
wantCode: http.StatusTeapot, wantCode: http.StatusTeapot,
}, },
// fail at max+1 redirects // fail if too many redirects
{ {
max: 1, checkRedirect: func(via int) error {
if via >= 2 {
return ErrTooManyRedirects
}
return nil
},
client: &multiStaticHTTPClient{ client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{ responses: []staticHTTPResponse{
staticHTTPResponse{ staticHTTPResponse{
@ -459,7 +474,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// fail if Location header not set // fail if Location header not set
{ {
max: 1, checkRedirect: func(int) error { return ErrTooManyRedirects },
client: &multiStaticHTTPClient{ client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{ responses: []staticHTTPResponse{
staticHTTPResponse{ staticHTTPResponse{
@ -474,7 +489,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// fail if Location header is invalid // fail if Location header is invalid
{ {
max: 1, checkRedirect: func(int) error { return ErrTooManyRedirects },
client: &multiStaticHTTPClient{ client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{ responses: []staticHTTPResponse{
staticHTTPResponse{ staticHTTPResponse{
@ -490,7 +505,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max} client := &redirectFollowingHTTPClient{client: tt.client, checkRedirect: tt.checkRedirect}
resp, _, err := client.Do(context.Background(), nil) resp, _, err := client.Do(context.Background(), nil)
if !reflect.DeepEqual(tt.wantErr, err) { if !reflect.DeepEqual(tt.wantErr, err) {
t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr) t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)