Follow redirects in client Get* and Post* methods. Added Redirect method to RequestCtx.

This commit is contained in:
Aliaksandr Valialkin
2015-11-28 14:47:05 +02:00
parent 9f27e4c2b0
commit 8563a2e762
4 changed files with 157 additions and 14 deletions
+44 -8
View File
@@ -441,7 +441,7 @@ type clientDoer interface {
func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
req := acquireRequest()
statusCode, body, err = doRequest(req, dst, url, c)
statusCode, body, err = doRequestFollowRedirects(req, dst, url, c)
releaseRequest(req)
return statusCode, body, err
@@ -498,7 +498,7 @@ func clientGetURLTimeoutFreeConn(dst []byte, url string, timeout time.Duration,
// concurrent requests, since timed out requests on client side
// usually continue execution on the host.
go func() {
statusCodeCopy, bodyCopy, errCopy := doRequest(req, dst, url, c)
statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirects(req, dst, url, c)
ch <- clientURLResponse{
statusCode: statusCodeCopy,
body: bodyCopy,
@@ -542,22 +542,51 @@ func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (status
postArgs.WriteTo(req.BodyWriter())
}
statusCode, body, err = doRequest(req, dst, url, c)
statusCode, body, err = doRequestFollowRedirects(req, dst, url, c)
releaseRequest(req)
return statusCode, body, err
}
func doRequest(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
req.SetRequestURI(url)
var (
errMissingLocation = errors.New("missing Location header for http redirect")
errTooManyRedirects = errors.New("too many redirects detected when doing the request")
)
const maxRedirectsCount = 16
func doRequestFollowRedirects(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
resp := acquireResponse()
oldBody := resp.body
resp.body = dst
if err = c.Do(req, resp); err != nil {
return 0, dst, err
redirectsCount := 0
for {
req.parsedURI = false
req.Header.host = req.Header.host[:0]
req.SetRequestURI(url)
if err = c.Do(req, resp); err != nil {
break
}
statusCode = resp.Header.StatusCode()
if statusCode != StatusMovedPermanently && statusCode != StatusFound && statusCode != StatusSeeOther {
break
}
redirectsCount++
if redirectsCount > maxRedirectsCount {
err = errTooManyRedirects
break
}
location := resp.Header.peek(strLocation)
if len(location) == 0 {
err = errMissingLocation
break
}
url = getRedirectURL(url, location)
}
statusCode = resp.Header.StatusCode()
body = resp.body
resp.body = oldBody
releaseResponse(resp)
@@ -565,6 +594,13 @@ func doRequest(req *Request, dst []byte, url string, c clientDoer) (statusCode i
return statusCode, body, err
}
func getRedirectURL(baseURL string, location []byte) string {
var u URI
u.Parse(nil, []byte(baseURL))
u.UpdateBytes(location)
return u.String()
}
var (
requestPool sync.Pool
responsePool sync.Pool
+61 -6
View File
@@ -12,6 +12,61 @@ import (
"time"
)
func TestClientFollowRedirects(t *testing.T) {
addr := "127.0.0.1:55234"
s := &Server{
Handler: func(ctx *RequestCtx) {
if EqualBytesStr(ctx.Path(), "/foo") {
u := ctx.URI()
u.Update("/bar")
ctx.Redirect(u.String(), StatusFound)
} else {
ctx.Success("text/plain", ctx.Path())
}
},
}
ln, err := net.Listen("tcp4", addr)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Fatalf("unexpected error: %s", err)
}
close(serverStopCh)
}()
uri := fmt.Sprintf("http://%s/foo", addr)
for i := 0; i < 10; i++ {
statusCode, body, err := Get(nil, uri)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if string(body) != "/bar" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
}
}
uri = fmt.Sprintf("http://%s/aaab/sss", addr)
for i := 0; i < 10; i++ {
statusCode, body, err := Get(nil, uri)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if string(body) != "/aaab/sss" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss")
}
}
}
func TestClientGetTimeoutSuccess(t *testing.T) {
addr := "127.0.0.1:56889"
s := startEchoServer(t, "tcp", addr)
@@ -251,8 +306,8 @@ func TestClientHTTPSConcurrent(t *testing.T) {
}
go func() {
defer wg.Done()
testClientGet(t, &defaultClient, addr, 3000)
testClientPost(t, &defaultClient, addr, 1000)
testClientGet(t, &defaultClient, addr, 300)
testClientPost(t, &defaultClient, addr, 100)
}()
}
wg.Wait()
@@ -309,8 +364,8 @@ func TestClientConcurrent(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
testClientGet(t, &defaultClient, addr, 3000)
testClientPost(t, &defaultClient, addr, 1000)
testClientGet(t, &defaultClient, addr, 300)
testClientPost(t, &defaultClient, addr, 100)
}()
}
wg.Wait()
@@ -345,8 +400,8 @@ func TestHostClientConcurrent(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
testHostClientGet(t, c, 3000)
testHostClientPost(t, c, 1000)
testHostClientGet(t, c, 300)
testHostClientPost(t, c, 100)
}()
}
wg.Wait()
+51
View File
@@ -517,6 +517,57 @@ func (ctx *RequestCtx) Success(contentType string, body []byte) {
ctx.SetBody(body)
}
// Redirect sets 'Location: uri' response header and sets the given statusCode.
//
// statusCode must have one of the following values:
//
// * StatusMovedPermanently (301)
// * StatusFound (302)
// * StatusSeeOther (303)
//
// All other statusCode values are replaced by StatusFound (302).
//
// The redirect uri may be either absolute or relative to the current
// request uri.
func (ctx *RequestCtx) Redirect(uri string, statusCode int) {
var u URI
ctx.URI().CopyTo(&u)
u.Update(uri)
ctx.redirect(u.FullURI(), statusCode)
}
// Redirect sets 'Location: uri' response header and sets the given statusCode.
//
// statusCode must have one of the following values:
//
// * StatusMovedPermanently (301)
// * StatusFound (302)
// * StatusSeeOther (303)
//
// All other statusCode values are replaced by StatusFound (302).
//
// The redirect uri may be either absolute or relative to the current
// request uri.
func (ctx *RequestCtx) RedirectBytes(uri []byte, statusCode int) {
var u URI
ctx.URI().CopyTo(&u)
u.UpdateBytes(uri)
ctx.redirect(u.FullURI(), statusCode)
}
func (ctx *RequestCtx) redirect(uri []byte, statusCode int) {
ctx.Response.Header.SetCanonical(strLocation, uri)
statusCode = getRedirectStatusCode(statusCode)
ctx.Response.SetStatusCode(statusCode)
}
func getRedirectStatusCode(statusCode int) int {
if statusCode == StatusMovedPermanently || statusCode == StatusFound || statusCode == StatusSeeOther {
return statusCode
}
return StatusFound
}
// SetBody sets response body to the given value.
func (ctx *RequestCtx) SetBody(body []byte) {
ctx.Response.SetBody(body)
+1
View File
@@ -33,6 +33,7 @@ var (
strUserAgent = []byte("User-Agent")
strCookie = []byte("Cookie")
strSetCookie = []byte("Set-Cookie")
strLocation = []byte("Location")
strCookieExpires = []byte("expires")
strCookieDomain = []byte("domain")