Don't allow ASCII control character in URLs (#790)

* Don't allow ASCII control character in URLs

* Add tests
This commit is contained in:
Erik Dubbelboer
2020-04-25 20:54:59 +02:00
committed by GitHub
parent 3e27d8ebad
commit 079f39bddc
3 changed files with 50 additions and 0 deletions
+32
View File
@@ -20,6 +20,38 @@ import (
"github.com/valyala/fasthttp/fasthttputil"
)
func TestClientInvalidURI(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
requests := int64(0)
s := &Server{
Handler: func(ctx *RequestCtx) {
atomic.AddInt64(&requests, 1)
},
}
go s.Serve(ln)
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req, res := AcquireRequest(), AcquireResponse()
defer func() {
ReleaseRequest(req)
ReleaseResponse(res)
}()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n")
err := c.Do(req, res)
if err == nil {
t.Fatal("expected error (missing required Host header in request)")
}
if n := atomic.LoadInt64(&requests); n != 0 {
t.Fatalf("0 requests expected, got %d", n)
}
}
func TestClientGetWithBody(t *testing.T) {
t.Parallel()
+15
View File
@@ -263,6 +263,10 @@ func (u *URI) Parse(host, uri []byte) {
func (u *URI) parse(host, uri []byte, isTLS bool) {
u.Reset()
if stringContainsCTLByte(uri) {
return
}
if len(host) == 0 || bytes.Contains(uri, strColonSlashSlash) {
scheme, newHost, newURI := splitHostURI(host, uri)
u.scheme = append(u.scheme, scheme...)
@@ -581,3 +585,14 @@ func (u *URI) parseQueryArgs() {
u.queryArgs.ParseBytes(u.queryString)
u.parsedQueryArgs = true
}
// stringContainsCTLByte reports whether s contains any ASCII control character.
func stringContainsCTLByte(s []byte) bool {
for i := 0; i < len(s); i++ {
b := s[i]
if b < ' ' || b == 0x7f {
return true
}
}
return false
}
+3
View File
@@ -354,6 +354,9 @@ func TestURIParse(t *testing.T) {
testURIParse(t, &u, "", "//aaa.com//absolute",
"http://aaa.com/absolute", "aaa.com", "/absolute", "//absolute", "", "")
testURIParse(t, &u, "", "//aaa.com\r\n\r\nGET x",
"http:///", "", "/", "", "", "")
}
func testURIParse(t *testing.T, u *URI, host, uri,