From 9dbe5fc77c0a914d7f437c50fea046aaebdd0d14 Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Wed, 16 Oct 2019 00:57:13 +0200 Subject: [PATCH] Don't allow spaces in request header keys See: https://github.com/golang/go/commit/6e6f4aaf70c8b1cc81e65a26332aa9409de03ad8 Reject any non GET or HEAD requests with a 400. We can't reject GET or HEAD requests with bad headers as we delay parsing of these headers until the user asks for one. So in this case we just ignore the header and don't return a value for it. --- client_test.go | 6 ++-- header.go | 20 +++++++++-- header_test.go | 7 ++-- server_test.go | 95 +++++++++++++++++++++++++++++++++++++++++--------- 4 files changed, 102 insertions(+), 26 deletions(-) diff --git a/client_test.go b/client_test.go index 2ffdd39..0dcfd39 100644 --- a/client_test.go +++ b/client_test.go @@ -351,7 +351,7 @@ func TestClientReadTimeout(t *testing.T) { timeout = true } }, - Logger: &customLogger{}, // Don't print closed pipe errors. + Logger: &testLogger{}, // Don't print closed pipe errors. } go s.Serve(ln) @@ -631,7 +631,7 @@ func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay MaxConns: maxConns, MaxPendingRequests: concurrency, MaxBatchDelay: maxBatchDelay, - Logger: &customLogger{}, + Logger: &testLogger{}, } clientStopCh := make(chan struct{}, concurrency) @@ -1807,7 +1807,7 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch ctx.PostArgs().WriteTo(ctx) } }, - Logger: &customLogger{}, // Ignore log output. + Logger: &testLogger{}, // Ignore log output. } ch := make(chan struct{}) go func() { diff --git a/header.go b/header.go index 4561764..399b486 100644 --- a/header.go +++ b/header.go @@ -1893,6 +1893,13 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { var err error for s.next() { if len(s.key) > 0 { + // Spaces between the header key and colon are not allowed. + // See RFC 7230, Section 3.2.4. + if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 { + err = fmt.Errorf("invalid header key %q", s.key) + continue + } + switch s.key[0] | 0x20 { case 'h': if caseInsensitiveCompare(s.key, strHost) { @@ -1911,7 +1918,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { } if caseInsensitiveCompare(s.key, strContentLength) { if h.contentLength != -1 { - if h.contentLength, err = parseContentLength(s.value); err != nil { + var nerr error + if h.contentLength, nerr = parseContentLength(s.value); nerr != nil { + if err == nil { + err = nerr + } h.contentLength = -2 } else { h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) @@ -1940,9 +1951,12 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { } h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) } - if s.err != nil { + if s.err != nil && err == nil { + err = s.err + } + if err != nil { h.connectionClose = true - return 0, s.err + return 0, err } if h.contentLength < 0 { diff --git a/header_test.go b/header_test.go index b6ca26d..3020b92 100644 --- a/header_test.go +++ b/header_test.go @@ -2175,10 +2175,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) { testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar", 0, "/foo/bar", "", "", "", "foobar") - // post with invalid content-length - testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty", - -2, "/a", "bb", "", "aa", "qwerty") - // post without content-length and content-type testRequestHeaderReadSuccess(t, h, "POST /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nzxc", -2, "/aaa", "aaa.com", "", "", "zxc") @@ -2234,6 +2230,9 @@ func TestRequestHeaderReadError(t *testing.T) { // missing RequestURI testRequestHeaderReadError(t, h, "GET HTTP/1.1\r\nHost: google.com\r\n\r\n") + + // post with invalid content-length + testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty") } func testResponseHeaderReadError(t *testing.T, h *ResponseHeader, headers string) { diff --git a/server_test.go b/server_test.go index 1db50c2..cd199c3 100644 --- a/server_test.go +++ b/server_test.go @@ -23,6 +23,69 @@ import ( // Make sure RequestCtx implements context.Context var _ context.Context = &RequestCtx{} +func TestServerInvalidHeader(t *testing.T) { + s := &Server{ + Handler: func(ctx *RequestCtx) { + if ctx.Request.Header.Peek("Foo") != nil || ctx.Request.Header.Peek("Foo ") != nil { + t.Fatal("expected Foo header") + } + }, + Logger: &testLogger{}, + } + + ln := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln); err != nil { + t.Fatalf("unexpected error: %s", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil { + t.Fatal(err) + } + + br := bufio.NewReader(c) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusBadRequest { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest) + } + + c, err = ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil { + t.Fatal(err) + } + + br = bufio.NewReader(c) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + // Since we delay header parsing for GET and HEAD requests until the users asks for it + // we can't return 400 in case of a bad header. + // Inside the handler above we make sure to test that the invalid Foo header was ignored. + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + if err := c.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + func TestServerConnState(t *testing.T) { states := make([]string, 0) s := &Server{ @@ -595,7 +658,7 @@ func TestServerMaxConnsPerIPLimit(t *testing.T) { ctx.WriteString("OK") }, MaxConnsPerIP: 1, - Logger: &customLogger{}, + Logger: &testLogger{}, } ln := fasthttputil.NewInmemoryListener() @@ -697,7 +760,7 @@ func TestServerConcurrencyLimit(t *testing.T) { ctx.WriteString("OK") }, Concurrency: 1, - Logger: &customLogger{}, + Logger: &testLogger{}, } ln := fasthttputil.NewInmemoryListener() @@ -1970,7 +2033,7 @@ func TestRequestCtxHijack(t *testing.T) { func TestRequestCtxInit(t *testing.T) { var ctx RequestCtx - var logger customLogger + var logger testLogger globalConnID = 0x123456 ctx.Init(&ctx.Request, zeroTCPAddr, &logger) ip := ctx.RemoteIP() @@ -2443,19 +2506,8 @@ func TestServerEmptyResponse(t *testing.T) { verifyResponse(t, br, 200, string(defaultContentType), "") } -type customLogger struct { - lock sync.Mutex - out string -} - -func (cl *customLogger) Printf(format string, args ...interface{}) { - cl.lock.Lock() - cl.out += fmt.Sprintf(format, args...)[6:] + "\n" - cl.lock.Unlock() -} - func TestServerLogger(t *testing.T) { - cl := &customLogger{} + cl := &testLogger{} s := &Server{ Handler: func(ctx *RequestCtx) { logger := ctx.Logger() @@ -2740,7 +2792,7 @@ func TestShutdownReuse(t *testing.T) { ctx.Success("aaa/bbb", []byte("real response")) }, ReadTimeout: time.Second, - Logger: &customLogger{}, // Ignore log output. + Logger: &testLogger{}, // Ignore log output. } go func() { if err := s.Serve(ln); err != nil { @@ -3030,3 +3082,14 @@ func (rw *readWriter) SetReadDeadline(t time.Time) error { func (rw *readWriter) SetWriteDeadline(t time.Time) error { return nil } + +type testLogger struct { + lock sync.Mutex + out string +} + +func (cl *testLogger) Printf(format string, args ...interface{}) { + cl.lock.Lock() + cl.out += fmt.Sprintf(format, args...)[6:] + "\n" + cl.lock.Unlock() +}