HostClient can't switch between protocols (#800)

This commit is contained in:
Erik Dubbelboer
2020-05-18 18:30:29 +02:00
committed by GitHub
3 changed files with 51 additions and 42 deletions
+8 -27
View File
@@ -881,6 +881,9 @@ var (
// ErrTooManyRedirects is returned by clients when the number of redirects followed
// exceed the max count.
ErrTooManyRedirects = errors.New("too many redirects detected when doing the request")
// HostClients are only able to follow redirects to the same protocol.
ErrHostClientRedirectToDifferentScheme = errors.New("HostClient can't follow redirects to a different protocol, please use Client instead")
)
const defaultMaxRedirectsCount = 16
@@ -903,27 +906,11 @@ func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clie
}
func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer) (statusCode int, body []byte, err error) {
scheme := req.uri.Scheme()
req.schemaUpdate = false
redirectsCount := 0
for {
// In case redirect to different scheme
if redirectsCount > 0 && !bytes.Equal(scheme, req.uri.Scheme()) {
if strings.HasPrefix(url, string(strHTTPS)) {
req.isTLS = true
req.uri.SetSchemeBytes(strHTTPS)
} else {
req.isTLS = false
req.uri.SetSchemeBytes(strHTTP)
}
scheme = req.uri.Scheme()
req.schemaUpdate = true
}
req.parsedURI = false
req.Header.host = req.Header.host[:0]
req.SetRequestURI(url)
req.parseURI()
if err = c.Do(req, resp); err != nil {
break
@@ -1271,6 +1258,10 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
panic("BUG: resp cannot be nil")
}
if c.IsTLS != bytes.Equal(req.uri.Scheme(), strHTTPS) {
return false, ErrHostClientRedirectToDifferentScheme
}
atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix))
// Free up resources occupied by response before sending the request,
@@ -1285,16 +1276,6 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
req.URI().DisablePathNormalizing = true
}
// If we detected a redirect to another schema
if req.schemaUpdate {
c.IsTLS = bytes.Equal(req.URI().Scheme(), strHTTPS)
c.Addr = addMissingPort(string(req.Host()), c.IsTLS)
c.addrIdx = 0
c.addrs = nil
req.schemaUpdate = false
req.SetConnectionClose()
}
cc, err := c.acquireConn(req.timeout)
if err != nil {
return false, err
+41 -12
View File
@@ -245,7 +245,7 @@ func TestClientRedirectSameSchema(t *testing.T) {
urlParsed, err := url.Parse(destURL)
if err != nil {
fmt.Println(err)
t.Fatal(err)
return
}
@@ -270,7 +270,42 @@ func TestClientRedirectSameSchema(t *testing.T) {
}
func TestClientRedirectChangingSchemaHttp2Https(t *testing.T) {
func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) {
t.Parallel()
listenHTTPS := testClientRedirectListener(t, true)
defer listenHTTPS.Close()
listenHTTP := testClientRedirectListener(t, false)
defer listenHTTP.Close()
sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
defer sHTTPS.Stop()
sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
defer sHTTP.Stop()
destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())
reqClient := &Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != nil {
t.Fatalf("HostClient error: %s", err)
return
}
if statusCode != 200 {
t.Fatalf("HostClient error code response %d", statusCode)
return
}
}
func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) {
t.Parallel()
listenHTTPS := testClientRedirectListener(t, true)
@@ -289,7 +324,7 @@ func TestClientRedirectChangingSchemaHttp2Https(t *testing.T) {
urlParsed, err := url.Parse(destURL)
if err != nil {
fmt.Println(err)
t.Fatal(err)
return
}
@@ -300,15 +335,9 @@ func TestClientRedirectChangingSchemaHttp2Https(t *testing.T) {
},
}
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != nil {
t.Fatalf("HostClient error: %s", err)
return
}
if statusCode != 200 {
t.Fatalf("HostClient error code response %d", statusCode)
return
_, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != ErrHostClientRedirectToDifferentScheme {
t.Fatal("expected HostClient error")
}
}
+2 -3
View File
@@ -46,11 +46,10 @@ type Request struct {
keepBodyBuffer bool
// Used by Server to indicate the request was received on a HTTPS endpoint.
// Client/HostClient shouldn't use this field but should depend on the uri.scheme instead.
isTLS bool
// To detect scheme changes in redirects
schemaUpdate bool
// Request timeout. Usually set by DoDealine or DoTimeout
// if <= 0, means not set
timeout time.Duration