diff --git a/header.go b/header.go index ce2b832..57e5cb8 100644 --- a/header.go +++ b/header.go @@ -520,38 +520,8 @@ var ErrBadTrailer = errors.New("contain forbidden trailer") // 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer) // // Return ErrBadTrailer if contain any forbidden trailers. -func (h *ResponseHeader) AddTrailerBytes(trailer []byte) error { - var err error - for i := -1; i+1 < len(trailer); { - trailer = trailer[i+1:] - i = bytes.IndexByte(trailer, ',') - if i < 0 { - i = len(trailer) - } - key := trailer[:i] - for len(key) > 0 && key[0] == ' ' { - key = key[1:] - } - for len(key) > 0 && key[len(key)-1] == ' ' { - key = key[:len(key)-1] - } - // Forbidden by RFC 7230, section 4.1.2 - if isBadTrailer(key) { - err = ErrBadTrailer - continue - } - h.bufK = append(h.bufK[:0], key...) - normalizeHeaderKey(h.bufK, h.disableNormalizing) - if cap(h.trailer) > len(h.trailer) { - h.trailer = h.trailer[:len(h.trailer)+1] - h.trailer[len(h.trailer)-1] = append(h.trailer[len(h.trailer)-1][:0], h.bufK...) - } else { - key = make([]byte, len(h.bufK)) - copy(key, h.bufK) - h.trailer = append(h.trailer, key) - } - } - +func (h *ResponseHeader) AddTrailerBytes(trailer []byte) (err error) { + h.bufK, h.trailer, err = addTrailerBytes(trailer, h.bufK, h.trailer, h.disableNormalizing) return err } @@ -875,38 +845,8 @@ func (h *RequestHeader) AddTrailer(trailer string) error { // 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer) // // Return ErrBadTrailer if contain any forbidden trailers. -func (h *RequestHeader) AddTrailerBytes(trailer []byte) error { - var err error - for i := -1; i+1 < len(trailer); { - trailer = trailer[i+1:] - i = bytes.IndexByte(trailer, ',') - if i < 0 { - i = len(trailer) - } - key := trailer[:i] - for len(key) > 0 && key[0] == ' ' { - key = key[1:] - } - for len(key) > 0 && key[len(key)-1] == ' ' { - key = key[:len(key)-1] - } - // Forbidden by RFC 7230, section 4.1.2 - if isBadTrailer(key) { - err = ErrBadTrailer - continue - } - h.bufK = append(h.bufK[:0], key...) - normalizeHeaderKey(h.bufK, h.disableNormalizing) - if cap(h.trailer) > len(h.trailer) { - h.trailer = h.trailer[:len(h.trailer)+1] - h.trailer[len(h.trailer)-1] = append(h.trailer[len(h.trailer)-1][:0], h.bufK...) - } else { - key = make([]byte, len(h.bufK)) - copy(key, h.bufK) - h.trailer = append(h.trailer, key) - } - } - +func (h *RequestHeader) AddTrailerBytes(trailer []byte) (err error) { + h.bufK, h.trailer, err = addTrailerBytes(trailer, h.bufK, h.trailer, h.disableNormalizing) return err } @@ -1321,8 +1261,8 @@ func (h *RequestHeader) VisitAll(f func(key, value []byte)) { func (h *RequestHeader) VisitAllInOrder(f func(key, value []byte)) { var s headerScanner s.b = h.rawHeaders - s.disableNormalizing = h.disableNormalizing for s.next() { + normalizeHeaderKey(s.key, h.disableNormalizing || bytes.IndexByte(s.key, ' ') != -1) if len(s.key) > 0 { f(s.key, s.value) } @@ -1338,7 +1278,7 @@ func (h *ResponseHeader) Del(key string) { // DelBytes deletes header with the given key. func (h *ResponseHeader) DelBytes(key []byte) { h.bufK = append(h.bufK[:0], key...) - normalizeHeaderKey(h.bufK, h.disableNormalizing) + normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1) h.del(h.bufK) } @@ -1372,7 +1312,7 @@ func (h *RequestHeader) Del(key string) { // DelBytes deletes header with the given key. func (h *RequestHeader) DelBytes(key []byte) { h.bufK = append(h.bufK[:0], key...) - normalizeHeaderKey(h.bufK, h.disableNormalizing) + normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1) h.del(h.bufK) } @@ -1638,7 +1578,7 @@ func (h *ResponseHeader) SetBytesV(key string, value []byte) { // Use AddBytesKV for setting multiple header values under the same key. func (h *ResponseHeader) SetBytesKV(key, value []byte) { h.bufK = append(h.bufK[:0], key...) - normalizeHeaderKey(h.bufK, h.disableNormalizing) + normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1) h.SetCanonical(h.bufK, value) } @@ -1864,7 +1804,7 @@ func (h *RequestHeader) SetBytesV(key string, value []byte) { // Use AddBytesKV for setting multiple header values under the same key. func (h *RequestHeader) SetBytesKV(key, value []byte) { h.bufK = append(h.bufK[:0], key...) - normalizeHeaderKey(h.bufK, h.disableNormalizing) + normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1) h.SetCanonical(h.bufK, value) } @@ -1900,7 +1840,7 @@ func (h *ResponseHeader) Peek(key string) []byte { // Do not store references to returned value. Make copies instead. func (h *ResponseHeader) PeekBytes(key []byte) []byte { h.bufK = append(h.bufK[:0], key...) - normalizeHeaderKey(h.bufK, h.disableNormalizing) + normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1) return h.peek(h.bufK) } @@ -1921,7 +1861,7 @@ func (h *RequestHeader) Peek(key string) []byte { // Do not store references to returned value. Make copies instead. func (h *RequestHeader) PeekBytes(key []byte) []byte { h.bufK = append(h.bufK[:0], key...) - normalizeHeaderKey(h.bufK, h.disableNormalizing) + normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1) return h.peek(h.bufK) } @@ -2239,7 +2179,8 @@ func (h *ResponseHeader) tryReadTrailer(r *bufio.Reader, n int) error { return fmt.Errorf("error when reading response trailer: %w", err) } b = mustPeekBuffered(r) - headersLen, errParse := h.parseTrailer(b) + trailers, headersLen, errParse := parseTrailer(b, h.h, h.disableNormalizing) + h.h = trailers if errParse != nil { if err == io.EOF { return err @@ -2348,7 +2289,8 @@ func (h *RequestHeader) tryReadTrailer(r *bufio.Reader, n int) error { return fmt.Errorf("error when reading request trailer: %w", err) } b = mustPeekBuffered(r) - headersLen, errParse := h.parseTrailer(b) + trailers, headersLen, errParse := parseTrailer(b, h.h, h.disableNormalizing) + h.h = trailers if errParse != nil { if err == io.EOF { return err @@ -2725,43 +2667,6 @@ func (h *ResponseHeader) parse(buf []byte) (int, error) { return m + n, nil } -func (h *ResponseHeader) parseTrailer(buf []byte) (int, error) { - // Skip any 0 length chunk. - if buf[0] == '0' { - skip := len(strCRLF) + 1 - if len(buf) < skip { - return 0, io.EOF - } - buf = buf[skip:] - } - - var s headerScanner - s.b = buf - s.disableNormalizing = h.disableNormalizing - var err error - for s.next() { - if len(s.key) > 0 { - if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 { - err = fmt.Errorf("invalid trailer key %q", s.key) - continue - } - // Forbidden by RFC 7230, section 4.1.2 - if isBadTrailer(s.key) { - err = fmt.Errorf("forbidden trailer key %q", s.key) - continue - } - h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) - } - } - if s.err != nil { - return 0, s.err - } - if err != nil { - return 0, err - } - return s.hLen, nil -} - func (h *RequestHeader) ignoreBody() bool { return h.IsGet() || h.IsHead() } @@ -2784,41 +2689,82 @@ func (h *RequestHeader) parse(buf []byte) (int, error) { return m + n, nil } -func (h *RequestHeader) parseTrailer(buf []byte) (int, error) { - // Skip any 0 length chunk. - if buf[0] == '0' { - skip := len(strCRLF) + 1 - if len(buf) < skip { - return 0, io.EOF +func addTrailerBytes(src, buf []byte, trailers [][]byte, disableNormalizing bool) ([]byte, [][]byte, error) { + var err error + for i := -1; i+1 < len(src); { + src = src[i+1:] + i = bytes.IndexByte(src, ',') + if i < 0 { + i = len(src) } - buf = buf[skip:] + key := src[:i] + for len(key) > 0 && key[0] == ' ' { + key = key[1:] + } + for len(key) > 0 && key[len(key)-1] == ' ' { + key = key[:len(key)-1] + } + // Forbidden by RFC 7230, section 4.1.2 + if isBadTrailer(key) { + err = ErrBadTrailer + continue + } + buf = append(buf[:0], key...) + normalizeHeaderKey(buf, disableNormalizing || bytes.IndexByte(buf, ' ') != -1) + if cap(trailers) > len(trailers) { + trailers = trailers[:len(trailers)+1] + trailers[len(trailers)-1] = append(trailers[len(trailers)-1][:0], buf...) + } else { + key = make([]byte, len(buf)) + copy(key, buf) + trailers = append(trailers, key) + } + } + + return buf, trailers, err +} + +func parseTrailer(src []byte, dest []argsKV, disableNormalizing bool) ([]argsKV, int, error) { + // Skip any 0 length chunk. + if src[0] == '0' { + skip := len(strCRLF) + 1 + if len(src) < skip { + return dest, 0, io.EOF + } + src = src[skip:] } var s headerScanner - s.b = buf - s.disableNormalizing = h.disableNormalizing - var err error + s.b = src + for s.next() { - if len(s.key) > 0 { - if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 { - err = fmt.Errorf("invalid trailer key %q", s.key) - continue - } - // Forbidden by RFC 7230, section 4.1.2 - if isBadTrailer(s.key) { - err = fmt.Errorf("forbidden trailer key %q", s.key) - continue - } - h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + if len(s.key) == 0 { + continue } + disable := disableNormalizing + for _, ch := range s.key { + if !validHeaderFieldByte(ch) { + // We accept invalid headers with a space before the + // colon, but must not canonicalize them. + // See: https://github.com/valyala/fasthttp/issues/1917 + if ch == ' ' { + disable = true + continue + } + return dest, 0, fmt.Errorf("invalid trailer key %q", s.key) + } + } + // Forbidden by RFC 7230, section 4.1.2 + if isBadTrailer(s.key) { + return dest, 0, fmt.Errorf("forbidden trailer key %q", s.key) + } + normalizeHeaderKey(s.key, disable) + dest = appendArgBytes(dest, s.key, s.value, argsHasValue) } if s.err != nil { - return 0, s.err + return dest, 0, s.err } - if err != nil { - return 0, err - } - return s.hLen, nil + return dest, s.hLen, nil } func isBadTrailer(key []byte) bool { @@ -3019,7 +2965,6 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) { var s headerScanner s.b = buf - s.disableNormalizing = h.disableNormalizing var kv *argsKV for s.next() { @@ -3028,12 +2973,22 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) { return 0, fmt.Errorf("invalid header key %q", s.key) } + disableNormalizing := h.disableNormalizing for _, ch := range s.key { if !validHeaderFieldByte(ch) { h.connectionClose = true + // We accept invalid headers with a space before the + // colon, but must not canonicalize them. + // See: https://github.com/valyala/fasthttp/issues/1917 + if ch == ' ' { + disableNormalizing = true + continue + } return 0, fmt.Errorf("invalid header key %q", s.key) } } + normalizeHeaderKey(s.key, disableNormalizing) + for _, ch := range s.value { if !validHeaderValueByte(ch) { h.connectionClose = true @@ -3136,7 +3091,6 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { var s headerScanner s.b = buf - s.disableNormalizing = h.disableNormalizing for s.next() { if len(s.key) == 0 { @@ -3144,12 +3098,19 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { return 0, fmt.Errorf("invalid header key %q", s.key) } + disableNormalizing := h.disableNormalizing for _, ch := range s.key { if !validHeaderFieldByte(ch) { + if ch == ' ' { + disableNormalizing = true + continue + } h.connectionClose = true return 0, fmt.Errorf("invalid header key %q", s.key) } } + normalizeHeaderKey(s.key, disableNormalizing) + for _, ch := range s.value { if !validHeaderValueByte(ch) { h.connectionClose = true @@ -3304,8 +3265,7 @@ type headerScanner struct { nextColon int nextNewLine int - disableNormalizing bool - initialized bool + initialized bool } func (s *headerScanner) next() bool { @@ -3351,7 +3311,6 @@ func (s *headerScanner) next() bool { return false } s.key = s.b[:n] - normalizeHeaderKey(s.key, s.disableNormalizing) n++ for len(s.b) > n && (s.b[n] == ' ' || s.b[n] == '\t') { n++ @@ -3482,7 +3441,7 @@ func initHeaderKV(bufK, bufV []byte, key, value string, disableNormalizing bool) func getHeaderKeyBytes(bufK []byte, key string, disableNormalizing bool) []byte { bufK = append(bufK[:0], key...) - normalizeHeaderKey(bufK, disableNormalizing) + normalizeHeaderKey(bufK, disableNormalizing || bytes.IndexByte(bufK, ' ') != -1) return bufK } diff --git a/header_test.go b/header_test.go index 1af2631..3c3affa 100644 --- a/header_test.go +++ b/header_test.go @@ -327,6 +327,7 @@ func TestRequestRawHeaders(t *testing.T) { kvs := "hOsT: foobar\r\n" + "value: b\r\n" + + "uSeR agent: agent\r\n" + "\r\n" t.Run("normalized", func(t *testing.T) { s := "GET / HTTP/1.1\r\n" + kvs @@ -343,6 +344,12 @@ func TestRequestRawHeaders(t *testing.T) { if !bytes.Equal(v2, []byte{'b'}) { t.Fatalf("expecting non empty value. Got %q", v2) } + // We accept invalid headers with a space. + // See: https://github.com/valyala/fasthttp/issues/1917 + v3 := h.Peek("uSeR agent") + if !bytes.Equal(v3, []byte("agent")) { + t.Fatalf("expecting non empty value. Got %q", v3) + } if raw := h.RawHeaders(); string(raw) != exp { t.Fatalf("expected header %q, got %q", exp, raw) } @@ -1860,8 +1867,8 @@ func TestResponseHeaderAddTrailerError(t *testing.T) { t.Parallel() var h ResponseHeader - err := h.AddTrailer("Foo, Content-Length , Bar,Transfer-Encoding,") - expectedTrailer := "Foo, Bar" + err := h.AddTrailer("Foo, Content-Length , bAr,Transfer-Encoding, uSer aGent") + expectedTrailer := "Foo, Bar, uSer aGent" if !errors.Is(err, ErrBadTrailer) { t.Fatalf("unexpected err %q. Expected %q", err, ErrBadTrailer) diff --git a/server_test.go b/server_test.go index 59f27d1..c0ab2ce 100644 --- a/server_test.go +++ b/server_test.go @@ -189,8 +189,8 @@ func TestServerInvalidHeader(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { - if ctx.Request.Header.Peek("Foo") != nil || ctx.Request.Header.Peek("Foo ") != nil { - t.Error("expected Foo header") + if ctx.Request.Header.Peek("Foő") != nil || ctx.Request.Header.Peek("Foő ") != nil { + t.Error("expected Foő header") } }, Logger: &testLogger{}, @@ -208,7 +208,7 @@ func TestServerInvalidHeader(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil { + if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoő : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil { t.Fatal(err) } @@ -225,7 +225,7 @@ func TestServerInvalidHeader(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil { + if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoő : bar\r\n\r\n")); err != nil { t.Fatal(err) }