diff --git a/header.go b/header.go index 3ce5923..97205b6 100644 --- a/header.go +++ b/header.go @@ -3119,9 +3119,12 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { s.b = buf for s.next() { - // Trim trailing whitespace before the colon to normalize headers - // like "Content-Length :" to "Content-Length:". + key := s.key s.key = trimTrailingSpace(s.key) + if len(s.key) != len(key) { + h.connectionClose = true + return 0, fmt.Errorf("invalid header key %q", key) + } if len(s.key) == 0 { h.connectionClose = true diff --git a/header_test.go b/header_test.go index 19040c4..e3bd5bb 100644 --- a/header_test.go +++ b/header_test.go @@ -3372,6 +3372,12 @@ func TestRequestHeaderReadError(t *testing.T) { // Space before header name testRequestHeaderReadError(t, h, "G(ET /foo/bar HTTP/1.1\r\n foo: bar\r\n\r\n") + // Whitespace before the colon in request header fields + testRequestHeaderReadError(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaa.com\r\nFoo : bar\r\n\r\n") + testRequestHeaderReadError(t, h, "GET /foo/bar HTTP/1.1\r\nHost : aaa.com\r\n\r\n") + testRequestHeaderReadError(t, h, "POST /foo/bar HTTP/1.1\r\nHost: aaa.com\r\nContent-Length : 4\r\n\r\ntest") + testRequestHeaderReadError(t, h, "POST /foo/bar HTTP/1.1\r\nHost: aaa.com\r\nTransfer-Encoding : chunked\r\n\r\n4\r\ntest\r\n0\r\n\r\n") + // Duplicate host header testRequestHeaderReadError(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaa.com\r\nhost: bbb.com\r\n\r\n") diff --git a/http_test.go b/http_test.go index baf40a3..63cd67e 100644 --- a/http_test.go +++ b/http_test.go @@ -1777,25 +1777,23 @@ func TestRequestReadLimitBody(t *testing.T) { testRequestReadLimitBodySuccess(t, "GET /foo HTTP/1.0\r\n\r\n", 0) } -func TestRequestReadLimitBodyWhitespaceBeforeColonFramingHeaders(t *testing.T) { +func TestRequestReadLimitBodyRejectWhitespaceBeforeColonFramingHeaders(t *testing.T) { t.Parallel() - var req Request - r := bytes.NewBufferString("POST /foo HTTP/1.1\r\nHost: a.com\r\nContent-Length : 4\r\n\r\ntestNEXT") - br := bufio.NewReader(r) - if err := req.ReadLimitBody(br, 10); err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got := string(req.Body()); got != "test" { - t.Fatalf("unexpected body %q", got) + tests := []string{ + "POST /foo HTTP/1.1\r\nHost: a.com\r\nContent-Length : 4\r\n\r\ntestNEXT", + "POST /foo HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding : chunked\r\n\r\n4\r\ntest\r\n0\r\n\r\n", } - rest, err := io.ReadAll(br) - if err != nil { - t.Fatalf("unexpected read error: %v", err) - } - if got := string(rest); got != "NEXT" { - t.Fatalf("unexpected buffered bytes %q", got) + for _, s := range tests { + var req Request + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := req.ReadLimitBody(br, 10); err == nil { + t.Fatalf("expecting error for %q", s) + } + if body := req.Body(); len(body) != 0 { + t.Fatalf("unexpected body %q for %q", body, s) + } } }