diff --git a/client_test.go b/client_test.go index 2ffdd39..0dcfd39 100644 --- a/client_test.go +++ b/client_test.go @@ -351,7 +351,7 @@ func TestClientReadTimeout(t *testing.T) { timeout = true } }, - Logger: &customLogger{}, // Don't print closed pipe errors. + Logger: &testLogger{}, // Don't print closed pipe errors. } go s.Serve(ln) @@ -631,7 +631,7 @@ func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay MaxConns: maxConns, MaxPendingRequests: concurrency, MaxBatchDelay: maxBatchDelay, - Logger: &customLogger{}, + Logger: &testLogger{}, } clientStopCh := make(chan struct{}, concurrency) @@ -1807,7 +1807,7 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch ctx.PostArgs().WriteTo(ctx) } }, - Logger: &customLogger{}, // Ignore log output. + Logger: &testLogger{}, // Ignore log output. } ch := make(chan struct{}) go func() { diff --git a/header.go b/header.go index 4561764..399b486 100644 --- a/header.go +++ b/header.go @@ -1893,6 +1893,13 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { var err error for s.next() { if len(s.key) > 0 { + // Spaces between the header key and colon are not allowed. + // See RFC 7230, Section 3.2.4. + if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 { + err = fmt.Errorf("invalid header key %q", s.key) + continue + } + switch s.key[0] | 0x20 { case 'h': if caseInsensitiveCompare(s.key, strHost) { @@ -1911,7 +1918,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { } if caseInsensitiveCompare(s.key, strContentLength) { if h.contentLength != -1 { - if h.contentLength, err = parseContentLength(s.value); err != nil { + var nerr error + if h.contentLength, nerr = parseContentLength(s.value); nerr != nil { + if err == nil { + err = nerr + } h.contentLength = -2 } else { h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) @@ -1940,9 +1951,12 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { } h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) } - if s.err != nil { + if s.err != nil && err == nil { + err = s.err + } + if err != nil { h.connectionClose = true - return 0, s.err + return 0, err } if h.contentLength < 0 { diff --git a/header_test.go b/header_test.go index b6ca26d..3020b92 100644 --- a/header_test.go +++ b/header_test.go @@ -2175,10 +2175,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) { testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar", 0, "/foo/bar", "", "", "", "foobar") - // post with invalid content-length - testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty", - -2, "/a", "bb", "", "aa", "qwerty") - // post without content-length and content-type testRequestHeaderReadSuccess(t, h, "POST /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nzxc", -2, "/aaa", "aaa.com", "", "", "zxc") @@ -2234,6 +2230,9 @@ func TestRequestHeaderReadError(t *testing.T) { // missing RequestURI testRequestHeaderReadError(t, h, "GET HTTP/1.1\r\nHost: google.com\r\n\r\n") + + // post with invalid content-length + testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty") } func testResponseHeaderReadError(t *testing.T, h *ResponseHeader, headers string) { diff --git a/server_test.go b/server_test.go index 1db50c2..cd199c3 100644 --- a/server_test.go +++ b/server_test.go @@ -23,6 +23,69 @@ import ( // Make sure RequestCtx implements context.Context var _ context.Context = &RequestCtx{} +func TestServerInvalidHeader(t *testing.T) { + s := &Server{ + Handler: func(ctx *RequestCtx) { + if ctx.Request.Header.Peek("Foo") != nil || ctx.Request.Header.Peek("Foo ") != nil { + t.Fatal("expected Foo header") + } + }, + Logger: &testLogger{}, + } + + ln := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln); err != nil { + t.Fatalf("unexpected error: %s", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil { + t.Fatal(err) + } + + br := bufio.NewReader(c) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusBadRequest { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest) + } + + c, err = ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil { + t.Fatal(err) + } + + br = bufio.NewReader(c) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + // Since we delay header parsing for GET and HEAD requests until the users asks for it + // we can't return 400 in case of a bad header. + // Inside the handler above we make sure to test that the invalid Foo header was ignored. + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + if err := c.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + func TestServerConnState(t *testing.T) { states := make([]string, 0) s := &Server{ @@ -595,7 +658,7 @@ func TestServerMaxConnsPerIPLimit(t *testing.T) { ctx.WriteString("OK") }, MaxConnsPerIP: 1, - Logger: &customLogger{}, + Logger: &testLogger{}, } ln := fasthttputil.NewInmemoryListener() @@ -697,7 +760,7 @@ func TestServerConcurrencyLimit(t *testing.T) { ctx.WriteString("OK") }, Concurrency: 1, - Logger: &customLogger{}, + Logger: &testLogger{}, } ln := fasthttputil.NewInmemoryListener() @@ -1970,7 +2033,7 @@ func TestRequestCtxHijack(t *testing.T) { func TestRequestCtxInit(t *testing.T) { var ctx RequestCtx - var logger customLogger + var logger testLogger globalConnID = 0x123456 ctx.Init(&ctx.Request, zeroTCPAddr, &logger) ip := ctx.RemoteIP() @@ -2443,19 +2506,8 @@ func TestServerEmptyResponse(t *testing.T) { verifyResponse(t, br, 200, string(defaultContentType), "") } -type customLogger struct { - lock sync.Mutex - out string -} - -func (cl *customLogger) Printf(format string, args ...interface{}) { - cl.lock.Lock() - cl.out += fmt.Sprintf(format, args...)[6:] + "\n" - cl.lock.Unlock() -} - func TestServerLogger(t *testing.T) { - cl := &customLogger{} + cl := &testLogger{} s := &Server{ Handler: func(ctx *RequestCtx) { logger := ctx.Logger() @@ -2740,7 +2792,7 @@ func TestShutdownReuse(t *testing.T) { ctx.Success("aaa/bbb", []byte("real response")) }, ReadTimeout: time.Second, - Logger: &customLogger{}, // Ignore log output. + Logger: &testLogger{}, // Ignore log output. } go func() { if err := s.Serve(ln); err != nil { @@ -3030,3 +3082,14 @@ func (rw *readWriter) SetReadDeadline(t time.Time) error { func (rw *readWriter) SetWriteDeadline(t time.Time) error { return nil } + +type testLogger struct { + lock sync.Mutex + out string +} + +func (cl *testLogger) Printf(format string, args ...interface{}) { + cl.lock.Lock() + cl.out += fmt.Sprintf(format, args...)[6:] + "\n" + cl.lock.Unlock() +}