From d055141f646f3e9ce941fb9a4dbbd602a378ce98 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Wed, 17 Aug 2016 14:01:35 +0300 Subject: [PATCH] Propagate 'https' scheme to request URI for TLS connections --- http.go | 6 +++++- http_test.go | 22 ++++++++++++++++++++++ server.go | 2 ++ uri.go | 5 ++++- 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/http.go b/http.go index 59997b8..a698986 100644 --- a/http.go +++ b/http.go @@ -42,6 +42,8 @@ type Request struct { parsedPostArgs bool keepBodyBuffer bool + + isTLS bool } // Response represents HTTP response. @@ -552,6 +554,7 @@ func (req *Request) copyToSkipBody(dst *Request) { req.postArgs.CopyTo(&dst.postArgs) dst.parsedPostArgs = req.parsedPostArgs + dst.isTLS = req.isTLS // do not copy multipartForm - it will be automatically // re-created on the first call to MultipartForm. @@ -595,7 +598,7 @@ func (req *Request) parseURI() { } req.parsedURI = true - req.uri.parseQuick(req.Header.RequestURI(), &req.Header) + req.uri.parseQuick(req.Header.RequestURI(), &req.Header, req.isTLS) } // PostArgs returns POST arguments. @@ -744,6 +747,7 @@ func (req *Request) resetSkipHeader() { req.parsedURI = false req.postArgs.Reset() req.parsedPostArgs = false + req.isTLS = false } // RemoveMultipartFormFiles removes multipart/form-data temporary files diff --git a/http_test.go b/http_test.go index 715d2e6..1ce9a14 100644 --- a/http_test.go +++ b/http_test.go @@ -1377,6 +1377,28 @@ func TestReadBodyChunked(t *testing.T) { testReadBodyChunked(t, b, 12343) } +func TestRequestURITLS(t *testing.T) { + uriNoScheme := "//foobar.com/baz/aa?bb=dd&dd#sdf" + requestURI := "http:" + uriNoScheme + requestURITLS := "https:" + uriNoScheme + + var req Request + + req.isTLS = true + req.SetRequestURI(requestURI) + uri := req.URI().String() + if uri != requestURITLS { + t.Fatalf("unexpected request uri: %q. Expecting %q", uri, requestURITLS) + } + + req.Reset() + req.SetRequestURI(requestURI) + uri = req.URI().String() + if uri != requestURI { + t.Fatalf("unexpected request uri: %q. Expecting %q", uri, requestURI) + } +} + func TestRequestURI(t *testing.T) { host := "foobar.com" requestURI := "/aaa/bb+b%20d?ccc=ddd&qqq#1334dfds&=d" diff --git a/server.go b/server.go index 2d96816..87f97c2 100644 --- a/server.go +++ b/server.go @@ -1418,6 +1418,7 @@ func (s *Server) serveConn(c net.Conn) error { ctx := s.acquireCtx(c) ctx.connTime = connTime + isTLS := ctx.IsTLS() var ( br *bufio.Reader bw *bufio.Writer @@ -1450,6 +1451,7 @@ func (s *Server) serveConn(c net.Conn) error { } } else { br, err = acquireByteReader(&ctx) + ctx.Request.isTLS = isTLS } if err == nil { diff --git a/uri.go b/uri.go index 684b8f9..48a445e 100644 --- a/uri.go +++ b/uri.go @@ -217,8 +217,11 @@ func (u *URI) Parse(host, uri []byte) { u.parse(host, uri, nil) } -func (u *URI) parseQuick(uri []byte, h *RequestHeader) { +func (u *URI) parseQuick(uri []byte, h *RequestHeader, isTLS bool) { u.parse(nil, uri, h) + if isTLS { + u.scheme = append(u.scheme[:0], strHTTPS...) + } } func (u *URI) parse(host, uri []byte, h *RequestHeader) {