From 079f39bddceb89b52f80952dd9beed0a8fc331d2 Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sat, 25 Apr 2020 20:54:59 +0200 Subject: [PATCH] Don't allow ASCII control character in URLs (#790) * Don't allow ASCII control character in URLs * Add tests --- client_test.go | 32 ++++++++++++++++++++++++++++++++ uri.go | 15 +++++++++++++++ uri_test.go | 3 +++ 3 files changed, 50 insertions(+) diff --git a/client_test.go b/client_test.go index 59bc413..3f9bc77 100644 --- a/client_test.go +++ b/client_test.go @@ -20,6 +20,38 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) +func TestClientInvalidURI(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + requests := int64(0) + s := &Server{ + Handler: func(ctx *RequestCtx) { + atomic.AddInt64(&requests, 1) + }, + } + go s.Serve(ln) + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req, res := AcquireRequest(), AcquireResponse() + defer func() { + ReleaseRequest(req) + ReleaseResponse(res) + }() + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n") + err := c.Do(req, res) + if err == nil { + t.Fatal("expected error (missing required Host header in request)") + } + if n := atomic.LoadInt64(&requests); n != 0 { + t.Fatalf("0 requests expected, got %d", n) + } +} + func TestClientGetWithBody(t *testing.T) { t.Parallel() diff --git a/uri.go b/uri.go index c9a81d4..9d64db8 100644 --- a/uri.go +++ b/uri.go @@ -263,6 +263,10 @@ func (u *URI) Parse(host, uri []byte) { func (u *URI) parse(host, uri []byte, isTLS bool) { u.Reset() + if stringContainsCTLByte(uri) { + return + } + if len(host) == 0 || bytes.Contains(uri, strColonSlashSlash) { scheme, newHost, newURI := splitHostURI(host, uri) u.scheme = append(u.scheme, scheme...) @@ -581,3 +585,14 @@ func (u *URI) parseQueryArgs() { u.queryArgs.ParseBytes(u.queryString) u.parsedQueryArgs = true } + +// stringContainsCTLByte reports whether s contains any ASCII control character. +func stringContainsCTLByte(s []byte) bool { + for i := 0; i < len(s); i++ { + b := s[i] + if b < ' ' || b == 0x7f { + return true + } + } + return false +} diff --git a/uri_test.go b/uri_test.go index 9c4a67b..d41c067 100644 --- a/uri_test.go +++ b/uri_test.go @@ -354,6 +354,9 @@ func TestURIParse(t *testing.T) { testURIParse(t, &u, "", "//aaa.com//absolute", "http://aaa.com/absolute", "aaa.com", "/absolute", "//absolute", "", "") + + testURIParse(t, &u, "", "//aaa.com\r\n\r\nGET x", + "http:///", "", "/", "", "", "") } func testURIParse(t *testing.T, u *URI, host, uri,