Add support for user:pass in URLs (#614)

Fixes #609
This commit is contained in:
Erik Dubbelboer
2019-08-18 11:23:33 +02:00
committed by GitHub
parent 85217e0d5e
commit 2edabf3b76
5 changed files with 126 additions and 1 deletions
+15 -1
View File
@@ -49,7 +49,7 @@ func TestAllocationClient(t *testing.T) {
go s.Serve(ln)
c := &Client{}
url := "http://" + ln.Addr().String()
url := "http://test:test@" + ln.Addr().String() + "/foo?bar=baz"
n := testing.AllocsPerRun(100, func() {
req := AcquireRequest()
@@ -68,3 +68,17 @@ func TestAllocationClient(t *testing.T) {
t.Fatalf("expected 0 allocations, got %f", n)
}
}
func TestAllocationURI(t *testing.T) {
uri := []byte("http://username:password@example.com/some/path?foo=bar#test")
n := testing.AllocsPerRun(100, func() {
u := AcquireURI()
u.Parse(nil, uri)
ReleaseURI(u)
})
if n != 0 {
t.Fatalf("expected 0 allocations, got %f", n)
}
}
+38
View File
@@ -19,6 +19,44 @@ import (
"github.com/valyala/fasthttp/fasthttputil"
)
func TestClientURLAuth(t *testing.T) {
cases := map[string]string{
"user:pass@": "dXNlcjpwYXNz",
"foo:@": "Zm9vOg==",
":@": "",
"@": "",
"": "",
}
ch := make(chan string, 1)
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ch <- string(ctx.Request.Header.Peek(HeaderAuthorization))
},
}
go s.Serve(ln)
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
for up, expected := range cases {
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://" + up + "example.com")
if err := c.Do(req, nil); err != nil {
t.Fatal(err)
}
val := <-ch
if val != expected {
t.Fatalf("wrong %s header: %s expected %s", HeaderAuthorization, val, expected)
}
}
}
func TestClientNilResp(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
s := &Server{
+19
View File
@@ -3,6 +3,7 @@ package fasthttp
import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -1148,6 +1149,24 @@ func (req *Request) Write(w *bufio.Writer) error {
}
req.Header.SetHostBytes(host)
req.Header.SetRequestURIBytes(uri.RequestURI())
if len(uri.username) > 0 {
// RequestHeader.SetBytesKV only uses RequestHeader.bufKV.key
// So we are free to use RequestHeader.bufKV.value as a scratch pad for
// the base64 encoding.
nl := len(uri.username) + len(uri.password) + 1
tl := nl + base64.StdEncoding.EncodedLen(nl)
if tl > cap(req.Header.bufKV.value) {
req.Header.bufKV.value = make([]byte, 0, tl)
}
buf := req.Header.bufKV.value[:0]
buf = append(buf, uri.username...)
buf = append(buf, strColon...)
buf = append(buf, uri.password...)
buf = buf[:tl]
base64.StdEncoding.Encode(buf[nl:], buf[:nl])
req.Header.SetBytesKV(strAuthorization, buf[nl:])
}
}
if req.bodyStream != nil {
+3
View File
@@ -16,9 +16,11 @@ var (
strHTTP = []byte("http")
strHTTPS = []byte("https")
strHTTP11 = []byte("HTTP/1.1")
strColon = []byte(":")
strColonSlashSlash = []byte("://")
strColonSpace = []byte(": ")
strGMT = []byte("GMT")
strAt = []byte("@")
strResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n")
@@ -52,6 +54,7 @@ var (
strAcceptRanges = []byte(HeaderAcceptRanges)
strRange = []byte(HeaderRange)
strContentRange = []byte(HeaderContentRange)
strAuthorization = []byte(HeaderAuthorization)
strCookieExpires = []byte("expires")
strCookieDomain = []byte("domain")
+51
View File
@@ -51,6 +51,9 @@ type URI struct {
fullURI []byte
requestURI []byte
username []byte
password []byte
h *RequestHeader
}
@@ -63,6 +66,8 @@ func (u *URI) CopyTo(dst *URI) {
dst.queryString = append(dst.queryString[:0], u.queryString...)
dst.hash = append(dst.hash[:0], u.hash...)
dst.host = append(dst.host[:0], u.host...)
dst.username = append(dst.username[:0], u.username...)
dst.password = append(dst.password[:0], u.password...)
u.queryArgs.CopyTo(&dst.queryArgs)
dst.parsedQueryArgs = u.parsedQueryArgs
@@ -89,6 +94,36 @@ func (u *URI) SetHashBytes(hash []byte) {
u.hash = append(u.hash[:0], hash...)
}
// Username returns URI username
func (u *URI) Username() []byte {
return u.username
}
// SetUsername sets URI username.
func (u *URI) SetUsername(username string) {
u.username = append(u.username[:0], username...)
}
// SetUsernameBytes sets URI username.
func (u *URI) SetUsernameBytes(username []byte) {
u.username = append(u.username[:0], username...)
}
// Password returns URI password
func (u *URI) Password() []byte {
return u.password
}
// SetPassword sets URI password.
func (u *URI) SetPassword(password string) {
u.password = append(u.password[:0], password...)
}
// SetPasswordBytes sets URI password.
func (u *URI) SetPasswordBytes(password []byte) {
u.password = append(u.password[:0], password...)
}
// QueryString returns URI query string,
// i.e. baz=123 of http://aaa.com/foo/bar?baz=123#qwe .
//
@@ -174,6 +209,8 @@ func (u *URI) Reset() {
u.path = u.path[:0]
u.queryString = u.queryString[:0]
u.hash = u.hash[:0]
u.username = u.username[:0]
u.password = u.password[:0]
u.host = u.host[:0]
u.queryArgs.Reset()
@@ -236,6 +273,20 @@ func (u *URI) parse(host, uri []byte, h *RequestHeader) {
scheme, host, uri := splitHostURI(host, uri)
u.scheme = append(u.scheme, scheme...)
lowercaseBytes(u.scheme)
if n := bytes.Index(host, strAt); n >= 0 {
auth := host[:n]
host = host[n+1:]
if n := bytes.Index(auth, strColon); n >= 0 {
u.username = auth[:n]
u.password = auth[n+1:]
} else {
u.username = auth
u.password = auth[:0] // Make sure it's not nil
}
}
u.host = append(u.host, host...)
lowercaseBytes(u.host)