diff --git a/client.go b/client.go index 1b79591..b1eea7d 100644 --- a/client.go +++ b/client.go @@ -321,20 +321,10 @@ func (c *Client) GetTLSClientConfig() *tls.Config { return c.TLSClientConfig } -func (c *Client) defaultCheckRedirect(req *http.Request, via []*http.Request) error { - if len(via) >= 10 { - return errors.New("stopped after 10 redirects") - } - if c.DebugLog { - c.log.Debugf(" %s %s", req.Method, req.URL.String()) - } - return nil -} - // SetRedirectPolicy set the RedirectPolicy which controls the behavior of receiving redirect // responses (usually responses with 301 and 302 status code), see the predefined -// AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, MaxRedirectPolicy, NoRedirectPolicy, -// SameDomainRedirectPolicy and SameHostRedirectPolicy. +// AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, DefaultRedirectPolicy, MaxRedirectPolicy, +// NoRedirectPolicy, SameDomainRedirectPolicy and SameHostRedirectPolicy. func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { if len(policies) == 0 { return c @@ -1565,7 +1555,7 @@ func C() *Client { xmlUnmarshal: xml.Unmarshal, cookiejarFactory: memoryCookieJarFactory, } - httpClient.CheckRedirect = c.defaultCheckRedirect + c.SetRedirectPolicy(DefaultRedirectPolicy()) c.initCookieJar() c.initTransport() diff --git a/client_test.go b/client_test.go index e9e9f75..7a6aeeb 100644 --- a/client_test.go +++ b/client_test.go @@ -369,6 +369,10 @@ func TestRedirect(t *testing.T) { tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "stopped after 3 redirects", true) + _, err = tc().SetRedirectPolicy(MaxRedirectPolicy(20)).SetRedirectPolicy(DefaultRedirectPolicy()).R().Get("/unlimited-redirect") + tests.AssertNotNil(t, err) + tests.AssertContains(t, err.Error(), "stopped after 10 redirects", true) + _, err = tc().SetRedirectPolicy(SameDomainRedirectPolicy()).R().Get("/redirect-to-other") tests.AssertNotNil(t, err) tests.AssertContains(t, err.Error(), "different domain name is not allowed", true) diff --git a/redirect.go b/redirect.go index f1cc433..fcc13e4 100644 --- a/redirect.go +++ b/redirect.go @@ -21,6 +21,11 @@ func MaxRedirectPolicy(noOfRedirect int) RedirectPolicy { } } +// DefaultRedirectPolicy allows up to 10 redirects +func DefaultRedirectPolicy() RedirectPolicy { + return MaxRedirectPolicy(10) +} + // NoRedirectPolicy disable redirect behaviour func NoRedirectPolicy() RedirectPolicy { return func(req *http.Request, via []*http.Request) error {