diff --git a/header.go b/header.go index 45ce267..0407284 100644 --- a/header.go +++ b/header.go @@ -2320,10 +2320,22 @@ func (h *RequestHeader) tryRead(r *bufio.Reader, n int) error { if errParse != nil { return headerError("request", err, errParse, b, h.secureErrorLogMessage) } + if errValidate := h.validate(); errValidate != nil { + return headerError("request", err, errValidate, b, h.secureErrorLogMessage) + } mustDiscard(r, headersLen) return nil } +func (h *RequestHeader) validate() error { + // Host header is mandatory in HTTP/1.1 requests. + if h.IsHTTP11() && len(h.Host()) == 0 { + h.connectionClose = true + return errRequestHostRequired + } + return nil +} + func bufferSnippet(b []byte) string { n := len(b) start := 200 @@ -3090,6 +3102,7 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { h.contentLength = -2 contentLengthSeen := false + hostSeen := false var s headerScanner s.b = buf @@ -3128,6 +3141,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { switch s.key[0] | 0x20 { case 'h': if caseInsensitiveCompare(s.key, strHost) { + if hostSeen { + h.connectionClose = true + return 0, errors.New("too many Host headers") + } + hostSeen = true h.host = append(h.host[:0], s.value...) continue } diff --git a/header_regression_test.go b/header_regression_test.go index 42f0344..8ece8c6 100644 --- a/header_regression_test.go +++ b/header_regression_test.go @@ -58,6 +58,7 @@ func testIssue6RequestHeaderSetContentType(t *testing.T, method string) { var h RequestHeader h.SetMethod(method) h.SetRequestURI("http://localhost/test") + h.SetHost("localhost") h.SetContentType(contentType) h.SetContentLength(contentLength) diff --git a/header_test.go b/header_test.go index 6cc27a3..cd149ad 100644 --- a/header_test.go +++ b/header_test.go @@ -371,7 +371,7 @@ func TestRequestRawHeaders(t *testing.T) { } }) t.Run("no-kvs", func(t *testing.T) { - s := "GET / HTTP/1.1\r\n\r\n" + s := "GET / HTTP/1.0\r\n\r\n" exp := "" var h RequestHeader h.DisableNormalizing() @@ -512,6 +512,7 @@ func TestRequestHeaderSetCookieWithSpecialChars(t *testing.T) { var h RequestHeader h.Set("Cookie", "ID&14") s := h.String() + s = strings.Replace(s, "\r\n", "\r\nHost: example.com\r\n", 1) if !strings.Contains(s, "Cookie: ID&14") { t.Fatalf("Missing cookie in request header: %q", s) @@ -676,7 +677,8 @@ func TestRequestHeaderAdd(t *testing.T) { if strings.Contains(s, "\r\nX-Injected: yes\r\n") { t.Fatalf("serialized request header contains injected header line: %q", s) } - br := bufio.NewReader(bytes.NewBufferString(s)) + sWithHost := strings.Replace(s, "User-Agent: xxx\r\n", "User-Agent: xxx\r\nHost: example.com\r\n", 1) + br := bufio.NewReader(bytes.NewBufferString(sWithHost)) var h1 RequestHeader if err := h1.Read(br); err != nil { t.Fatalf("unexpected error: %v", err) @@ -694,8 +696,8 @@ func TestRequestHeaderAdd(t *testing.T) { t.Fatalf("unexpected number of headers: %d. Expecting 13", len(m)) } s1 := h1.String() - if s != s1 { - t.Fatalf("unexpected headers %q. Expecting %q", s1, s) + if sWithHost != s1 { + t.Fatalf("unexpected headers %q. Expecting %q", s1, sWithHost) } } @@ -1284,6 +1286,10 @@ func TestRequestMultipartFormBoundary(t *testing.T) { } func testRequestMultipartFormBoundary(t *testing.T, s, boundary string) { + if strings.HasPrefix(s, "POST / HTTP/1.1\r\n") && !strings.Contains(s, "\r\nHost: ") { + s = strings.Replace(s, "\r\n", "\r\nHost: example.com\r\n", 1) + } + var h RequestHeader r := bytes.NewBufferString(s) br := bufio.NewReader(r) @@ -1624,6 +1630,7 @@ func TestRequestContentTypeDefaultNotEmpty(t *testing.T) { var h RequestHeader h.SetMethod(MethodPost) + h.SetHost("example.com") h.SetContentLength(5) w := &bytes.Buffer{} @@ -1651,6 +1658,7 @@ func TestRequestContentTypeNoDefault(t *testing.T) { var h RequestHeader h.SetMethod(MethodDelete) + h.SetHost("example.com") h.SetNoDefaultContentType(true) w := &bytes.Buffer{} @@ -2431,6 +2439,7 @@ func TestRequestHeaderMethod(t *testing.T) { func testRequestHeaderMethod(t *testing.T, expectedMethod string) { var h RequestHeader h.SetMethod(expectedMethod) + h.SetHost("example.com") m := h.Method() if string(m) != expectedMethod { t.Fatalf("unexpected method: %q. Expecting %q", m, expectedMethod) @@ -2881,6 +2890,16 @@ func TestRequestHeaderReadSuccess(t *testing.T) { t.Fatalf("expecting connectionClose for ancient http protocol") } + // ancient http protocol without Host + testRequestHeaderReadSuccess(t, h, "GET /bar HTTP/1.0\r\n\r\npppp", + -2, "/bar", "", "", "") + if h.IsHTTP11() { + t.Fatalf("ancient http protocol cannot be http/1.1") + } + if !h.ConnectionClose() { + t.Fatalf("expecting connectionClose for ancient http protocol") + } + // ancient http protocol with 'Connection: keep-alive' header testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.0\r\nHost: bb\r\nConnection: keep-alive\r\n\r\nxxx", -2, "/aa", "bb", "", "") @@ -2925,10 +2944,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) { testRequestHeaderReadSuccess(t, h, "GET /asdf HTTP/1.1\r\nHost: aaa.com\r\nReferer: bb.com\r\n\r\naaa", -2, "/asdf", "aaa.com", "bb.com", "") - // duplicate host - testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.1\r\nHost: aaaaaa.com\r\nHost: bb.com\r\n\r\n", - -2, "/aa", "bb.com", "", "") - // post with duplicate content-type testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: aa\r\nContent-Type: ab\r\nContent-Length: 123\r\nContent-Type: xx\r\n\r\n", 123, "/a", "aa", "", "xx") @@ -2957,12 +2972,10 @@ func TestRequestHeaderReadSuccess(t *testing.T) { testRequestHeaderReadError(t, h, "GET /foo/ bar baz HTTP/1.1\r\nHost: aa.com\r\n\r\nxxx") // no host - testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nFOObar: assdfd\r\n\r\naaa", - -2, "/foo/bar", "", "", "") + testRequestHeaderReadError(t, h, "GET /foo/bar HTTP/1.1\r\nFOObar: assdfd\r\n\r\naaa") // no host, no headers - testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar", - -2, "/foo/bar", "", "", "") + testRequestHeaderReadError(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar") // post without content-length and content-type testRequestHeaderReadSuccess(t, h, "POST /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nzxc", @@ -3079,6 +3092,12 @@ func TestRequestHeaderReadError(t *testing.T) { // Space before header name testRequestHeaderReadError(t, h, "G(ET /foo/bar HTTP/1.1\r\n foo: bar\r\n\r\n") + + // Duplicate host header + testRequestHeaderReadError(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaa.com\r\nhost: bbb.com\r\n\r\n") + + // Missing host header + testRequestHeaderReadError(t, h, "GET /foo/bar HTTP/1.1\r\n\r\n") } func TestRequestHeaderReadSecuredError(t *testing.T) { diff --git a/http_test.go b/http_test.go index 276b1a7..5525c2d 100644 --- a/http_test.go +++ b/http_test.go @@ -711,7 +711,7 @@ func TestRequestContentTypeWithCharsetIssue100(t *testing.T) { expectedContentType := "application/x-www-form-urlencoded; charset=UTF-8" expectedBody := "0123=56789" - s := fmt.Sprintf("POST / HTTP/1.1\r\nContent-Type: %s\r\nContent-Length: %d\r\n\r\n%s", + s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: example.com\r\nContent-Type: %s\r\nContent-Length: %d\r\n\r\n%s", expectedContentType, len(expectedBody), expectedBody) br := bufio.NewReader(bytes.NewBufferString(s)) @@ -1157,7 +1157,7 @@ func TestRequestReadNoBody(t *testing.T) { var r Request - br := bufio.NewReader(bytes.NewBufferString("GET / HTTP/1.1\r\n\r\n")) + br := bufio.NewReader(bytes.NewBufferString("GET / HTTP/1.1\r\nHost: foobar\r\n\r\n")) err := r.Read(br) r.SetHost("foobar") if err != nil { @@ -1325,7 +1325,7 @@ func TestRequestReadGzippedBody(t *testing.T) { bodyOriginal := "foo bar baz compress me better!" body := AppendGzipBytes(nil, []byte(bodyOriginal)) - s := fmt.Sprintf("POST /foobar HTTP/1.1\r\nContent-Type: foo/bar\r\nContent-Encoding: gzip\r\nContent-Length: %d\r\n\r\n%s", + s := fmt.Sprintf("POST /foobar HTTP/1.1\r\nHost: example.com\r\nContent-Type: foo/bar\r\nContent-Encoding: gzip\r\nContent-Length: %d\r\n\r\n%s", len(body), body) br := bufio.NewReader(bytes.NewBufferString(s)) if err := r.Read(br); err != nil { @@ -1356,7 +1356,7 @@ func TestRequestReadPostNoBody(t *testing.T) { var r Request - s := "POST /foo/bar HTTP/1.1\r\nContent-Type: aaa/bbb\r\n\r\naaaa" + s := "POST /foo/bar HTTP/1.1\r\nHost: example.com\r\nContent-Type: aaa/bbb\r\n\r\naaaa" br := bufio.NewReader(bytes.NewBufferString(s)) if err := r.Read(br); err != nil { t.Fatalf("unexpected error: %v", err) @@ -1387,7 +1387,7 @@ func TestRequestReadPostNoBody(t *testing.T) { func TestRequestContinueReadBody(t *testing.T) { t.Parallel() - s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + s := "PUT /foo/bar HTTP/1.1\r\nHost: example.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" br := bufio.NewReader(bytes.NewBufferString(s)) var r Request @@ -1771,6 +1771,10 @@ func TestRequestReadLimitBody(t *testing.T) { testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\nHost: a.com\nTransfer-Encoding: chunked\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\nFoo: bar\r\n\r\n", 9) testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 999) testRequestReadLimitBodyError(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 8, ErrBodyTooLarge) + + // missing Host header is invalid in HTTP/1.1, but still allowed in HTTP/1.0 + testRequestReadLimitBodyError(t, "GET /foo HTTP/1.1\r\n\r\n", 0, errRequestHostRequired) + testRequestReadLimitBodySuccess(t, "GET /foo HTTP/1.0\r\n\r\n", 0) } func testResponseReadLimitBodyError(t *testing.T, s string, maxBodySize int, expectedErr error) { @@ -1798,6 +1802,8 @@ func testResponseReadLimitBodySuccess(t *testing.T, s string, maxBodySize int) { } func testRequestReadLimitBodyError(t *testing.T, s string, maxBodySize int, expectedErr error) { + t.Helper() + var req Request r := bytes.NewBufferString(s) br := bufio.NewReader(r) @@ -1805,7 +1811,7 @@ func testRequestReadLimitBodyError(t *testing.T, s string, maxBodySize int, expe if err == nil { t.Fatalf("expecting error. s=%q, maxBodySize=%d", s, maxBodySize) } - if err != expectedErr { + if !errors.Is(err, expectedErr) { t.Fatalf("unexpected error: %v. Expecting %v. s=%q, maxBodySize=%d", err, expectedErr, s, maxBodySize) } } diff --git a/server_test.go b/server_test.go index deb37cc..0aa2945 100644 --- a/server_test.go +++ b/server_test.go @@ -1970,7 +1970,7 @@ func TestServerExpect103EarlyHints(t *testing.T) { } rw := &readWriter{} - rw.r.WriteString("GET /foo HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: example.com\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") if err := s.ServeConn(rw); err != nil { t.Fatalf("Unexpected error from serveConn: %v", err) @@ -1999,6 +1999,30 @@ func TestServerExpect103EarlyHints(t *testing.T) { } } +func TestServerRejectsMissingHostHTTP11(t *testing.T) { + t.Parallel() + + var handlerCalled atomic.Bool + s := &Server{ + Handler: func(ctx *RequestCtx) { + handlerCalled.Store(true) + ctx.Success("text/plain", []byte("ok")) + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\n\r\n") + + _ = s.ServeConn(rw) + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "Error when parsing request") + + if handlerCalled.Load() { + t.Fatal("handler should not run for HTTP/1.1 request without Host") + } +} + func TestServerContinueHandler(t *testing.T) { t.Parallel()