diff --git a/server_test.go b/server_test.go index 727306b..01c832f 100644 --- a/server_test.go +++ b/server_test.go @@ -59,6 +59,7 @@ func TestRequestCtxRedirect(t *testing.T) { testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "./.././../x.html", "http://qqq/x.html") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "http://foo.bar/baz", "http://foo.bar/baz") testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "https://foo.bar/baz", "https://foo.bar/baz") + testRequestCtxRedirect(t, "https://foo.com/bar?aaa", "//google.com/aaa?bb", "https://google.com/aaa?bb") } func testRequestCtxRedirect(t *testing.T, origURL, redirectURL, expectedURL string) { diff --git a/uri.go b/uri.go index 48a445e..43060de 100644 --- a/uri.go +++ b/uri.go @@ -395,6 +395,22 @@ func (u *URI) updateBytes(newURI, buf []byte) []byte { if len(newURI) == 0 { return buf } + + n := bytes.Index(newURI, strSlashSlash) + if n >= 0 { + // absolute uri + var b [32]byte + schemeOriginal := b[:0] + if len(u.scheme) > 0 { + schemeOriginal = append([]byte(nil), u.scheme...) + } + u.Parse(nil, newURI) + if len(schemeOriginal) > 0 && len(u.scheme) == 0 { + u.scheme = append(u.scheme[:0], schemeOriginal...) + } + return buf + } + if newURI[0] == '/' { // uri without host buf = u.appendSchemeHost(buf[:0]) @@ -403,13 +419,6 @@ func (u *URI) updateBytes(newURI, buf []byte) []byte { return buf } - n := bytes.Index(newURI, strColonSlashSlash) - if n >= 0 { - // absolute uri - u.Parse(nil, newURI) - return buf - } - // relative path switch newURI[0] { case '?': @@ -467,7 +476,7 @@ func (u *URI) String() string { } func splitHostURI(host, uri []byte) ([]byte, []byte, []byte) { - n := bytes.Index(uri, strColonSlashSlash) + n := bytes.Index(uri, strSlashSlash) if n < 0 { return strHTTP, host, uri } @@ -475,7 +484,10 @@ func splitHostURI(host, uri []byte) ([]byte, []byte, []byte) { if bytes.IndexByte(scheme, '/') >= 0 { return strHTTP, host, uri } - n += len(strColonSlashSlash) + if len(scheme) > 0 && scheme[len(scheme)-1] == ':' { + scheme = scheme[:len(scheme)-1] + } + n += len(strSlashSlash) uri = uri[n:] n = bytes.IndexByte(uri, '/') if n < 0 { diff --git a/uri_test.go b/uri_test.go index 508a9ba..b86bb48 100644 --- a/uri_test.go +++ b/uri_test.go @@ -112,6 +112,9 @@ func TestURIUpdate(t *testing.T) { // hash testURIUpdate(t, "http://foo.bar/baz#aaa", "#fragment", "http://foo.bar/baz#fragment") + + // uri without scheme + testURIUpdate(t, "https://foo.bar/baz", "//aaa.bbb/cc?dd", "https://aaa.bbb/cc?dd") } func testURIUpdate(t *testing.T, base, update, result string) {