client: allow caller to decide HTTP redirect policy

release-2.1
Brian Waldon 2015-01-28 15:09:00 -08:00 committed by Yicheng Qin
parent 1c03df62a5
commit 9b334e07a6
2 changed files with 68 additions and 22 deletions

View File

@ -39,7 +39,6 @@ var (
ErrKeyExists = errors.New("client: key already exists")
DefaultRequestTimeout = 5 * time.Second
DefaultMaxRedirects = 10
)
var DefaultTransport CancelableTransport = &http.Transport{
@ -72,6 +71,17 @@ type Config struct {
// Transport is used by the Client to drive HTTP requests. If not
// provided, DefaultTransport will be used.
Transport CancelableTransport
// CheckRedirect specifies the policy for handling HTTP redirects.
// If CheckRedirect is not nil, the Client calls it before
// following an HTTP redirect. The sole argument is the number of
// requests that have alrady been made. If CheckRedirect returns
// an error, Client.Do will not make any further requests and return
// the error back it to the caller.
//
// If CheckRedirect is nil, the Client uses its default policy,
// which is to stop after 10 consecutive requests.
CheckRedirect CheckRedirectFunc
}
func (cfg *Config) transport() CancelableTransport {
@ -81,6 +91,13 @@ func (cfg *Config) transport() CancelableTransport {
return cfg.Transport
}
func (cfg *Config) checkRedirect() CheckRedirectFunc {
if cfg.CheckRedirect == nil {
return DefaultCheckRedirect
}
return cfg.CheckRedirect
}
// CancelableTransport mimics net/http.Transport, but requires that
// the object also support request cancellation.
type CancelableTransport interface {
@ -88,6 +105,16 @@ type CancelableTransport interface {
CancelRequest(req *http.Request)
}
type CheckRedirectFunc func(via int) error
// DefaultCheckRedirect follows up to 10 redirects, but no more.
var DefaultCheckRedirect CheckRedirectFunc = func(via int) error {
if via > 10 {
return ErrTooManyRedirects
}
return nil
}
type Client interface {
// Sync updates the internal cache of the etcd cluster's membership.
Sync(context.Context) error
@ -101,7 +128,7 @@ type Client interface {
}
func New(cfg Config) (Client, error) {
c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport())}
c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect())}
if err := c.reset(cfg.Endpoints); err != nil {
return nil, err
}
@ -112,10 +139,10 @@ type httpClient interface {
Do(context.Context, httpAction) (*http.Response, []byte, error)
}
func newHTTPClientFactory(tr CancelableTransport) httpClientFactory {
func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc) httpClientFactory {
return func(ep url.URL) httpClient {
return &redirectFollowingHTTPClient{
max: DefaultMaxRedirects,
checkRedirect: cr,
client: &simpleHTTPClient{
transport: tr,
endpoint: ep,
@ -270,12 +297,17 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon
}
type redirectFollowingHTTPClient struct {
client httpClient
max int
client httpClient
checkRedirect CheckRedirectFunc
}
func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) {
for i := 0; i <= r.max; i++ {
for i := 0; ; i++ {
if i > 0 {
if err := r.checkRedirect(i); err != nil {
return nil, nil, err
}
}
resp, body, err := r.client.Do(ctx, act)
if err != nil {
return nil, nil, err
@ -297,7 +329,6 @@ func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*
}
return resp, body, nil
}
return nil, nil, ErrTooManyRedirects
}
type redirectedHTTPAction struct {

View File

@ -258,7 +258,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
{
client: &httpClusterClient{
endpoints: []url.URL{},
clientFactory: newHTTPClientFactory(nil),
clientFactory: newHTTPClientFactory(nil, nil),
},
wantErr: ErrNoEndpoints,
},
@ -349,14 +349,14 @@ func TestRedirectedHTTPAction(t *testing.T) {
func TestRedirectFollowingHTTPClient(t *testing.T) {
tests := []struct {
max int
client httpClient
wantCode int
wantErr error
checkRedirect CheckRedirectFunc
client httpClient
wantCode int
wantErr error
}{
// errors bubbled up
{
max: 2,
checkRedirect: func(int) error { return ErrTooManyRedirects },
client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{
staticHTTPResponse{
@ -369,7 +369,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// no need to follow redirect if none given
{
max: 2,
checkRedirect: func(int) error { return ErrTooManyRedirects },
client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{
staticHTTPResponse{
@ -384,7 +384,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// redirects if less than max
{
max: 2,
checkRedirect: func(via int) error {
if via >= 2 {
return ErrTooManyRedirects
}
return nil
},
client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{
staticHTTPResponse{
@ -405,7 +410,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// succeed after reaching max redirects
{
max: 2,
checkRedirect: func(via int) error {
if via >= 3 {
return ErrTooManyRedirects
}
return nil
},
client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{
staticHTTPResponse{
@ -430,9 +440,14 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
wantCode: http.StatusTeapot,
},
// fail at max+1 redirects
// fail if too many redirects
{
max: 1,
checkRedirect: func(via int) error {
if via >= 2 {
return ErrTooManyRedirects
}
return nil
},
client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{
staticHTTPResponse{
@ -459,7 +474,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// fail if Location header not set
{
max: 1,
checkRedirect: func(int) error { return ErrTooManyRedirects },
client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{
staticHTTPResponse{
@ -474,7 +489,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
// fail if Location header is invalid
{
max: 1,
checkRedirect: func(int) error { return ErrTooManyRedirects },
client: &multiStaticHTTPClient{
responses: []staticHTTPResponse{
staticHTTPResponse{
@ -490,7 +505,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
}
for i, tt := range tests {
client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max}
client := &redirectFollowingHTTPClient{client: tt.client, checkRedirect: tt.checkRedirect}
resp, _, err := client.Do(context.Background(), nil)
if !reflect.DeepEqual(tt.wantErr, err) {
t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)