client: allow caller to decide HTTP redirect policy
parent
1c03df62a5
commit
9b334e07a6
|
@ -39,7 +39,6 @@ var (
|
|||
ErrKeyExists = errors.New("client: key already exists")
|
||||
|
||||
DefaultRequestTimeout = 5 * time.Second
|
||||
DefaultMaxRedirects = 10
|
||||
)
|
||||
|
||||
var DefaultTransport CancelableTransport = &http.Transport{
|
||||
|
@ -72,6 +71,17 @@ type Config struct {
|
|||
// Transport is used by the Client to drive HTTP requests. If not
|
||||
// provided, DefaultTransport will be used.
|
||||
Transport CancelableTransport
|
||||
|
||||
// CheckRedirect specifies the policy for handling HTTP redirects.
|
||||
// If CheckRedirect is not nil, the Client calls it before
|
||||
// following an HTTP redirect. The sole argument is the number of
|
||||
// requests that have alrady been made. If CheckRedirect returns
|
||||
// an error, Client.Do will not make any further requests and return
|
||||
// the error back it to the caller.
|
||||
//
|
||||
// If CheckRedirect is nil, the Client uses its default policy,
|
||||
// which is to stop after 10 consecutive requests.
|
||||
CheckRedirect CheckRedirectFunc
|
||||
}
|
||||
|
||||
func (cfg *Config) transport() CancelableTransport {
|
||||
|
@ -81,6 +91,13 @@ func (cfg *Config) transport() CancelableTransport {
|
|||
return cfg.Transport
|
||||
}
|
||||
|
||||
func (cfg *Config) checkRedirect() CheckRedirectFunc {
|
||||
if cfg.CheckRedirect == nil {
|
||||
return DefaultCheckRedirect
|
||||
}
|
||||
return cfg.CheckRedirect
|
||||
}
|
||||
|
||||
// CancelableTransport mimics net/http.Transport, but requires that
|
||||
// the object also support request cancellation.
|
||||
type CancelableTransport interface {
|
||||
|
@ -88,6 +105,16 @@ type CancelableTransport interface {
|
|||
CancelRequest(req *http.Request)
|
||||
}
|
||||
|
||||
type CheckRedirectFunc func(via int) error
|
||||
|
||||
// DefaultCheckRedirect follows up to 10 redirects, but no more.
|
||||
var DefaultCheckRedirect CheckRedirectFunc = func(via int) error {
|
||||
if via > 10 {
|
||||
return ErrTooManyRedirects
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
// Sync updates the internal cache of the etcd cluster's membership.
|
||||
Sync(context.Context) error
|
||||
|
@ -101,7 +128,7 @@ type Client interface {
|
|||
}
|
||||
|
||||
func New(cfg Config) (Client, error) {
|
||||
c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport())}
|
||||
c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect())}
|
||||
if err := c.reset(cfg.Endpoints); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -112,10 +139,10 @@ type httpClient interface {
|
|||
Do(context.Context, httpAction) (*http.Response, []byte, error)
|
||||
}
|
||||
|
||||
func newHTTPClientFactory(tr CancelableTransport) httpClientFactory {
|
||||
func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc) httpClientFactory {
|
||||
return func(ep url.URL) httpClient {
|
||||
return &redirectFollowingHTTPClient{
|
||||
max: DefaultMaxRedirects,
|
||||
checkRedirect: cr,
|
||||
client: &simpleHTTPClient{
|
||||
transport: tr,
|
||||
endpoint: ep,
|
||||
|
@ -270,12 +297,17 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon
|
|||
}
|
||||
|
||||
type redirectFollowingHTTPClient struct {
|
||||
client httpClient
|
||||
max int
|
||||
client httpClient
|
||||
checkRedirect CheckRedirectFunc
|
||||
}
|
||||
|
||||
func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) {
|
||||
for i := 0; i <= r.max; i++ {
|
||||
for i := 0; ; i++ {
|
||||
if i > 0 {
|
||||
if err := r.checkRedirect(i); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
resp, body, err := r.client.Do(ctx, act)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -297,7 +329,6 @@ func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*
|
|||
}
|
||||
return resp, body, nil
|
||||
}
|
||||
return nil, nil, ErrTooManyRedirects
|
||||
}
|
||||
|
||||
type redirectedHTTPAction struct {
|
||||
|
|
|
@ -258,7 +258,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
|
|||
{
|
||||
client: &httpClusterClient{
|
||||
endpoints: []url.URL{},
|
||||
clientFactory: newHTTPClientFactory(nil),
|
||||
clientFactory: newHTTPClientFactory(nil, nil),
|
||||
},
|
||||
wantErr: ErrNoEndpoints,
|
||||
},
|
||||
|
@ -349,14 +349,14 @@ func TestRedirectedHTTPAction(t *testing.T) {
|
|||
|
||||
func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
max int
|
||||
client httpClient
|
||||
wantCode int
|
||||
wantErr error
|
||||
checkRedirect CheckRedirectFunc
|
||||
client httpClient
|
||||
wantCode int
|
||||
wantErr error
|
||||
}{
|
||||
// errors bubbled up
|
||||
{
|
||||
max: 2,
|
||||
checkRedirect: func(int) error { return ErrTooManyRedirects },
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
|
@ -369,7 +369,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
|||
|
||||
// no need to follow redirect if none given
|
||||
{
|
||||
max: 2,
|
||||
checkRedirect: func(int) error { return ErrTooManyRedirects },
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
|
@ -384,7 +384,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
|||
|
||||
// redirects if less than max
|
||||
{
|
||||
max: 2,
|
||||
checkRedirect: func(via int) error {
|
||||
if via >= 2 {
|
||||
return ErrTooManyRedirects
|
||||
}
|
||||
return nil
|
||||
},
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
|
@ -405,7 +410,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
|||
|
||||
// succeed after reaching max redirects
|
||||
{
|
||||
max: 2,
|
||||
checkRedirect: func(via int) error {
|
||||
if via >= 3 {
|
||||
return ErrTooManyRedirects
|
||||
}
|
||||
return nil
|
||||
},
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
|
@ -430,9 +440,14 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
|||
wantCode: http.StatusTeapot,
|
||||
},
|
||||
|
||||
// fail at max+1 redirects
|
||||
// fail if too many redirects
|
||||
{
|
||||
max: 1,
|
||||
checkRedirect: func(via int) error {
|
||||
if via >= 2 {
|
||||
return ErrTooManyRedirects
|
||||
}
|
||||
return nil
|
||||
},
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
|
@ -459,7 +474,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
|||
|
||||
// fail if Location header not set
|
||||
{
|
||||
max: 1,
|
||||
checkRedirect: func(int) error { return ErrTooManyRedirects },
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
|
@ -474,7 +489,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
|||
|
||||
// fail if Location header is invalid
|
||||
{
|
||||
max: 1,
|
||||
checkRedirect: func(int) error { return ErrTooManyRedirects },
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
|
@ -490,7 +505,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max}
|
||||
client := &redirectFollowingHTTPClient{client: tt.client, checkRedirect: tt.checkRedirect}
|
||||
resp, _, err := client.Do(context.Background(), nil)
|
||||
if !reflect.DeepEqual(tt.wantErr, err) {
|
||||
t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
|
||||
|
|
Loading…
Reference in New Issue