mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-26 17:46:34 +03:00
Follow redirects in client Get* and Post* methods. Added Redirect method to RequestCtx.
This commit is contained in:
@@ -441,7 +441,7 @@ type clientDoer interface {
|
||||
func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
|
||||
req := acquireRequest()
|
||||
|
||||
statusCode, body, err = doRequest(req, dst, url, c)
|
||||
statusCode, body, err = doRequestFollowRedirects(req, dst, url, c)
|
||||
|
||||
releaseRequest(req)
|
||||
return statusCode, body, err
|
||||
@@ -498,7 +498,7 @@ func clientGetURLTimeoutFreeConn(dst []byte, url string, timeout time.Duration,
|
||||
// concurrent requests, since timed out requests on client side
|
||||
// usually continue execution on the host.
|
||||
go func() {
|
||||
statusCodeCopy, bodyCopy, errCopy := doRequest(req, dst, url, c)
|
||||
statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirects(req, dst, url, c)
|
||||
ch <- clientURLResponse{
|
||||
statusCode: statusCodeCopy,
|
||||
body: bodyCopy,
|
||||
@@ -542,22 +542,51 @@ func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (status
|
||||
postArgs.WriteTo(req.BodyWriter())
|
||||
}
|
||||
|
||||
statusCode, body, err = doRequest(req, dst, url, c)
|
||||
statusCode, body, err = doRequestFollowRedirects(req, dst, url, c)
|
||||
|
||||
releaseRequest(req)
|
||||
return statusCode, body, err
|
||||
}
|
||||
|
||||
func doRequest(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
|
||||
req.SetRequestURI(url)
|
||||
var (
|
||||
errMissingLocation = errors.New("missing Location header for http redirect")
|
||||
errTooManyRedirects = errors.New("too many redirects detected when doing the request")
|
||||
)
|
||||
|
||||
const maxRedirectsCount = 16
|
||||
|
||||
func doRequestFollowRedirects(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
|
||||
resp := acquireResponse()
|
||||
oldBody := resp.body
|
||||
resp.body = dst
|
||||
if err = c.Do(req, resp); err != nil {
|
||||
return 0, dst, err
|
||||
|
||||
redirectsCount := 0
|
||||
for {
|
||||
req.parsedURI = false
|
||||
req.Header.host = req.Header.host[:0]
|
||||
req.SetRequestURI(url)
|
||||
|
||||
if err = c.Do(req, resp); err != nil {
|
||||
break
|
||||
}
|
||||
statusCode = resp.Header.StatusCode()
|
||||
if statusCode != StatusMovedPermanently && statusCode != StatusFound && statusCode != StatusSeeOther {
|
||||
break
|
||||
}
|
||||
|
||||
redirectsCount++
|
||||
if redirectsCount > maxRedirectsCount {
|
||||
err = errTooManyRedirects
|
||||
break
|
||||
}
|
||||
location := resp.Header.peek(strLocation)
|
||||
if len(location) == 0 {
|
||||
err = errMissingLocation
|
||||
break
|
||||
}
|
||||
url = getRedirectURL(url, location)
|
||||
}
|
||||
statusCode = resp.Header.StatusCode()
|
||||
|
||||
body = resp.body
|
||||
resp.body = oldBody
|
||||
releaseResponse(resp)
|
||||
@@ -565,6 +594,13 @@ func doRequest(req *Request, dst []byte, url string, c clientDoer) (statusCode i
|
||||
return statusCode, body, err
|
||||
}
|
||||
|
||||
func getRedirectURL(baseURL string, location []byte) string {
|
||||
var u URI
|
||||
u.Parse(nil, []byte(baseURL))
|
||||
u.UpdateBytes(location)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
var (
|
||||
requestPool sync.Pool
|
||||
responsePool sync.Pool
|
||||
|
||||
+61
-6
@@ -12,6 +12,61 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestClientFollowRedirects(t *testing.T) {
|
||||
addr := "127.0.0.1:55234"
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
if EqualBytesStr(ctx.Path(), "/foo") {
|
||||
u := ctx.URI()
|
||||
u.Update("/bar")
|
||||
ctx.Redirect(u.String(), StatusFound)
|
||||
} else {
|
||||
ctx.Success("text/plain", ctx.Path())
|
||||
}
|
||||
},
|
||||
}
|
||||
ln, err := net.Listen("tcp4", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
serverStopCh := make(chan struct{})
|
||||
go func() {
|
||||
if err := s.Serve(ln); err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
close(serverStopCh)
|
||||
}()
|
||||
|
||||
uri := fmt.Sprintf("http://%s/foo", addr)
|
||||
for i := 0; i < 10; i++ {
|
||||
statusCode, body, err := Get(nil, uri)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if statusCode != StatusOK {
|
||||
t.Fatalf("unexpected status code: %d", statusCode)
|
||||
}
|
||||
if string(body) != "/bar" {
|
||||
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
|
||||
}
|
||||
}
|
||||
|
||||
uri = fmt.Sprintf("http://%s/aaab/sss", addr)
|
||||
for i := 0; i < 10; i++ {
|
||||
statusCode, body, err := Get(nil, uri)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if statusCode != StatusOK {
|
||||
t.Fatalf("unexpected status code: %d", statusCode)
|
||||
}
|
||||
if string(body) != "/aaab/sss" {
|
||||
t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientGetTimeoutSuccess(t *testing.T) {
|
||||
addr := "127.0.0.1:56889"
|
||||
s := startEchoServer(t, "tcp", addr)
|
||||
@@ -251,8 +306,8 @@ func TestClientHTTPSConcurrent(t *testing.T) {
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
testClientGet(t, &defaultClient, addr, 3000)
|
||||
testClientPost(t, &defaultClient, addr, 1000)
|
||||
testClientGet(t, &defaultClient, addr, 300)
|
||||
testClientPost(t, &defaultClient, addr, 100)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
@@ -309,8 +364,8 @@ func TestClientConcurrent(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
testClientGet(t, &defaultClient, addr, 3000)
|
||||
testClientPost(t, &defaultClient, addr, 1000)
|
||||
testClientGet(t, &defaultClient, addr, 300)
|
||||
testClientPost(t, &defaultClient, addr, 100)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
@@ -345,8 +400,8 @@ func TestHostClientConcurrent(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
testHostClientGet(t, c, 3000)
|
||||
testHostClientPost(t, c, 1000)
|
||||
testHostClientGet(t, c, 300)
|
||||
testHostClientPost(t, c, 100)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
@@ -517,6 +517,57 @@ func (ctx *RequestCtx) Success(contentType string, body []byte) {
|
||||
ctx.SetBody(body)
|
||||
}
|
||||
|
||||
// Redirect sets 'Location: uri' response header and sets the given statusCode.
|
||||
//
|
||||
// statusCode must have one of the following values:
|
||||
//
|
||||
// * StatusMovedPermanently (301)
|
||||
// * StatusFound (302)
|
||||
// * StatusSeeOther (303)
|
||||
//
|
||||
// All other statusCode values are replaced by StatusFound (302).
|
||||
//
|
||||
// The redirect uri may be either absolute or relative to the current
|
||||
// request uri.
|
||||
func (ctx *RequestCtx) Redirect(uri string, statusCode int) {
|
||||
var u URI
|
||||
ctx.URI().CopyTo(&u)
|
||||
u.Update(uri)
|
||||
ctx.redirect(u.FullURI(), statusCode)
|
||||
}
|
||||
|
||||
// Redirect sets 'Location: uri' response header and sets the given statusCode.
|
||||
//
|
||||
// statusCode must have one of the following values:
|
||||
//
|
||||
// * StatusMovedPermanently (301)
|
||||
// * StatusFound (302)
|
||||
// * StatusSeeOther (303)
|
||||
//
|
||||
// All other statusCode values are replaced by StatusFound (302).
|
||||
//
|
||||
// The redirect uri may be either absolute or relative to the current
|
||||
// request uri.
|
||||
func (ctx *RequestCtx) RedirectBytes(uri []byte, statusCode int) {
|
||||
var u URI
|
||||
ctx.URI().CopyTo(&u)
|
||||
u.UpdateBytes(uri)
|
||||
ctx.redirect(u.FullURI(), statusCode)
|
||||
}
|
||||
|
||||
func (ctx *RequestCtx) redirect(uri []byte, statusCode int) {
|
||||
ctx.Response.Header.SetCanonical(strLocation, uri)
|
||||
statusCode = getRedirectStatusCode(statusCode)
|
||||
ctx.Response.SetStatusCode(statusCode)
|
||||
}
|
||||
|
||||
func getRedirectStatusCode(statusCode int) int {
|
||||
if statusCode == StatusMovedPermanently || statusCode == StatusFound || statusCode == StatusSeeOther {
|
||||
return statusCode
|
||||
}
|
||||
return StatusFound
|
||||
}
|
||||
|
||||
// SetBody sets response body to the given value.
|
||||
func (ctx *RequestCtx) SetBody(body []byte) {
|
||||
ctx.Response.SetBody(body)
|
||||
|
||||
@@ -33,6 +33,7 @@ var (
|
||||
strUserAgent = []byte("User-Agent")
|
||||
strCookie = []byte("Cookie")
|
||||
strSetCookie = []byte("Set-Cookie")
|
||||
strLocation = []byte("Location")
|
||||
|
||||
strCookieExpires = []byte("expires")
|
||||
strCookieDomain = []byte("domain")
|
||||
|
||||
Reference in New Issue
Block a user