diff --git a/client_test.go b/client_test.go index 7f90f7c..efd73f8 100644 --- a/client_test.go +++ b/client_test.go @@ -273,14 +273,89 @@ func TestClientInvalidURI(t *testing.T) { req.Header.SetMethod(MethodGet) req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n") err := c.Do(req, res) - if err == nil { - t.Fatal("expected error (missing required Host header in request)") + if err == nil && res.StatusCode() != StatusBadRequest { + t.Fatalf("expected invalid URI to be rejected, got status code %d", res.StatusCode()) } if n := requests.Load(); n != 0 { t.Fatalf("0 requests expected, got %d", n) } } +func TestClientRequestProtocolSetterSanitizesNewlines(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + var requests atomic.Int64 + s := &Server{ + Handler: func(_ *RequestCtx) { + requests.Add(1) + }, + } + go s.Serve(ln) //nolint:errcheck + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + + req, res := AcquireRequest(), AcquireResponse() + defer func() { + ReleaseRequest(req) + ReleaseResponse(res) + }() + + req.SetRequestURI("http://example.com/") + req.Header.SetProtocol("HTTP/1.1\r\nX-Injected-Protocol: true") + + if err := c.Do(req, res); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := res.StatusCode(); got != StatusBadRequest { + t.Fatalf("unexpected status code: %d. Expected %d", got, StatusBadRequest) + } + if n := requests.Load(); n != 0 { + t.Fatalf("expected malformed request to be rejected before reaching handler, got %d handled requests", n) + } +} + +func TestClientResponseStatusMessageSetterSanitizesNewlines(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Response.Header.SetStatusCode(StatusOK) + ctx.Response.Header.SetStatusMessage([]byte("OK\r\nX-Injected-Status: true")) + }, + } + go s.Serve(ln) //nolint:errcheck + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + + req, res := AcquireRequest(), AcquireResponse() + defer func() { + ReleaseRequest(req) + ReleaseResponse(res) + }() + + req.SetRequestURI("http://example.com/") + + if err := c.Do(req, res); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := string(res.Header.StatusMessage()); got != "OK X-Injected-Status: true" { + t.Fatalf("unexpected status message: %q. Expected %q", got, "OK X-Injected-Status: true") + } + if got := string(res.Header.Peek("X-Injected-Status")); got != "" { + t.Fatalf("unexpected injected response header value: %q", got) + } +} + func TestClientGetWithBody(t *testing.T) { t.Parallel() diff --git a/header.go b/header.go index f3fc0e5..2cacf0f 100644 --- a/header.go +++ b/header.go @@ -143,12 +143,12 @@ func (h *ResponseHeader) StatusMessage() []byte { // SetStatusMessage sets response status message bytes. func (h *ResponseHeader) SetStatusMessage(statusMessage []byte) { - h.statusMessage = append(h.statusMessage[:0], statusMessage...) + h.statusMessage = initHeaderValueBytes(h.statusMessage, statusMessage) } // SetProtocol sets response protocol bytes. func (h *ResponseHeader) SetProtocol(protocol []byte) { - h.protocol = append(h.protocol[:0], protocol...) + h.protocol = initHeaderValueBytes(h.protocol, protocol) } // SetLastModified sets 'Last-Modified' header to the given value. @@ -750,12 +750,12 @@ func (h *RequestHeader) Method() []byte { // SetMethod sets HTTP request method. func (h *RequestHeader) SetMethod(method string) { - h.method = append(h.method[:0], method...) + h.method = initHeaderValueString(h.method, method) } // SetMethodBytes sets HTTP request method. func (h *RequestHeader) SetMethodBytes(method []byte) { - h.method = append(h.method[:0], method...) + h.method = initHeaderValueBytes(h.method, method) } // Protocol returns HTTP protocol. @@ -768,13 +768,13 @@ func (h *header) Protocol() []byte { // SetProtocol sets HTTP request protocol. func (h *RequestHeader) SetProtocol(protocol string) { - h.protocol = append(h.protocol[:0], protocol...) + h.protocol = initHeaderValueString(h.protocol, protocol) h.noHTTP11 = !bytes.Equal(h.protocol, strHTTP11) } // SetProtocolBytes sets HTTP request protocol. func (h *RequestHeader) SetProtocolBytes(protocol []byte) { - h.protocol = append(h.protocol[:0], protocol...) + h.protocol = initHeaderValueBytes(h.protocol, protocol) h.noHTTP11 = !bytes.Equal(h.protocol, strHTTP11) } @@ -791,14 +791,14 @@ func (h *RequestHeader) RequestURI() []byte { // RequestURI must be properly encoded. // Use URI.RequestURI for constructing proper RequestURI if unsure. func (h *RequestHeader) SetRequestURI(requestURI string) { - h.requestURI = append(h.requestURI[:0], requestURI...) + h.requestURI = initHeaderValueString(h.requestURI, requestURI) } // SetRequestURIBytes sets RequestURI for the first HTTP request line. // RequestURI must be properly encoded. // Use URI.RequestURI for constructing proper RequestURI if unsure. func (h *RequestHeader) SetRequestURIBytes(requestURI []byte) { - h.requestURI = append(h.requestURI[:0], requestURI...) + h.requestURI = initHeaderValueBytes(h.requestURI, requestURI) } // IsGet returns true if request method is GET. diff --git a/header_test.go b/header_test.go index bc5e706..c17abb9 100644 --- a/header_test.go +++ b/header_test.go @@ -58,6 +58,64 @@ func TestResponseHeaderAddContentEncoding(t *testing.T) { } } +func TestResponseHeaderFirstLineSettersSanitizeNewlines(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + set func(*ResponseHeader) + value func(*ResponseHeader) []byte + wantValue string + wantFirstLine string + }{ + { + name: "SetStatusMessage", + set: func(h *ResponseHeader) { + h.SetStatusMessage([]byte("OK\r\nInjected-Status: true")) + }, + value: func(h *ResponseHeader) []byte { return h.StatusMessage() }, + wantValue: "OK Injected-Status: true", + wantFirstLine: "HTTP/1.1 200 OK Injected-Status: true", + }, + { + name: "SetProtocol", + set: func(h *ResponseHeader) { + h.SetProtocol([]byte("HTTP/1.1\r\nInjected-Protocol: true")) + }, + value: func(h *ResponseHeader) []byte { return h.Protocol() }, + wantValue: "HTTP/1.1 Injected-Protocol: true", + wantFirstLine: "HTTP/1.1 Injected-Protocol: true 200 OK", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var h ResponseHeader + h.SetStatusCode(StatusOK) + h.SetStatusMessage([]byte("OK")) + + tc.set(&h) + + if got := string(tc.value(&h)); got != tc.wantValue { + t.Fatalf("unexpected sanitized value: %q. Expected %q", got, tc.wantValue) + } + + firstLine, _, ok := bytes.Cut(h.Header(), strCRLF) + if !ok { + t.Fatalf("missing response first line terminator in header %q", h.Header()) + } + if got := string(firstLine); got != tc.wantFirstLine { + t.Fatalf("unexpected response first line: %q. Expected %q", got, tc.wantFirstLine) + } + if bytes.Contains(h.Header(), []byte("\r\nInjected-")) { + t.Fatalf("unexpected injected header line in %q", h.Header()) + } + }) + } +} + func TestResponseHeaderMultiLineValue(t *testing.T) { t.Parallel() @@ -2442,6 +2500,122 @@ func TestRequestHeaderMethod(t *testing.T) { testRequestHeaderMethod(t, "ABC") } +func TestRequestHeaderFirstLineSettersSanitizeNewlines(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + set func(*RequestHeader) + value func(*RequestHeader) []byte + wantValue string + wantFirstLine string + wantNoHTTP11 bool + }{ + { + name: "SetMethod", + set: func(h *RequestHeader) { + h.SetMethod("GET\r\nInjected-Method: true") + }, + value: func(h *RequestHeader) []byte { return h.Method() }, + wantValue: "GET Injected-Method: true", + wantFirstLine: "GET Injected-Method: true / HTTP/1.1", + }, + { + name: "SetMethodBytes", + set: func(h *RequestHeader) { + h.SetMethodBytes([]byte("GET\r\nInjected-Method-Bytes: true")) + }, + value: func(h *RequestHeader) []byte { return h.Method() }, + wantValue: "GET Injected-Method-Bytes: true", + wantFirstLine: "GET Injected-Method-Bytes: true / HTTP/1.1", + }, + { + name: "SetRequestURI", + set: func(h *RequestHeader) { + h.SetRequestURI("/\r\nInjected-URI: true") + }, + value: func(h *RequestHeader) []byte { return h.RequestURI() }, + wantValue: "/ Injected-URI: true", + wantFirstLine: "GET / Injected-URI: true HTTP/1.1", + }, + { + name: "SetRequestURIBytes", + set: func(h *RequestHeader) { + h.SetRequestURIBytes([]byte("/\r\nInjected-URI-Bytes: true")) + }, + value: func(h *RequestHeader) []byte { return h.RequestURI() }, + wantValue: "/ Injected-URI-Bytes: true", + wantFirstLine: "GET / Injected-URI-Bytes: true HTTP/1.1", + }, + { + name: "SetProtocol", + set: func(h *RequestHeader) { + h.SetProtocol("HTTP/1.1\r\nInjected-Protocol: true") + }, + value: func(h *RequestHeader) []byte { return h.Protocol() }, + wantValue: "HTTP/1.1 Injected-Protocol: true", + wantFirstLine: "GET / HTTP/1.1 Injected-Protocol: true", + wantNoHTTP11: true, + }, + { + name: "SetProtocolBytes", + set: func(h *RequestHeader) { + h.SetProtocolBytes([]byte("HTTP/1.1\r\nInjected-Protocol-Bytes: true")) + }, + value: func(h *RequestHeader) []byte { return h.Protocol() }, + wantValue: "HTTP/1.1 Injected-Protocol-Bytes: true", + wantFirstLine: "GET / HTTP/1.1 Injected-Protocol-Bytes: true", + wantNoHTTP11: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var h RequestHeader + h.SetMethod(MethodGet) + h.SetRequestURI("/") + h.SetProtocol("HTTP/1.1") + h.SetHost("example.com") + + tc.set(&h) + + if got := string(tc.value(&h)); got != tc.wantValue { + t.Fatalf("unexpected sanitized value: %q. Expected %q", got, tc.wantValue) + } + + firstLine, _, ok := bytes.Cut(h.Header(), strCRLF) + if !ok { + t.Fatalf("missing request first line terminator in header %q", h.Header()) + } + if got := string(firstLine); got != tc.wantFirstLine { + t.Fatalf("unexpected request first line: %q. Expected %q", got, tc.wantFirstLine) + } + if bytes.Contains(h.Header(), []byte("\r\nInjected-")) { + t.Fatalf("unexpected injected header line in %q", h.Header()) + } + if h.noHTTP11 != tc.wantNoHTTP11 { + t.Fatalf("unexpected noHTTP11 flag: %v. Expected %v", h.noHTTP11, tc.wantNoHTTP11) + } + }) + } +} + +func TestRequestHeaderSetProtocolKeepsHTTP11FlagForSanitizedHTTP11(t *testing.T) { + t.Parallel() + + var h RequestHeader + h.SetProtocolBytes([]byte("HTTP/1.1")) + + if h.noHTTP11 { + t.Fatalf("expected noHTTP11 to remain false for HTTP/1.1") + } + if got := string(h.Protocol()); got != "HTTP/1.1" { + t.Fatalf("unexpected protocol: %q. Expected %q", got, "HTTP/1.1") + } +} + func testRequestHeaderMethod(t *testing.T, expectedMethod string) { var h RequestHeader h.SetMethod(expectedMethod)