fix: accept invalid headers with a space (#1953)

* fix: accept invalid headers with a space #1917

Make behavior consistent with net/http by allowing header keys and trailers containing spaces without canonicalizing them

* fix: lint paramTypeCombine

* fix: https://github.com/valyala/fasthttp/pull/1953#issuecomment-2660691298

* fix: golangci-lint nestingReduce
This commit is contained in:
Kashiwa
2025-02-19 18:49:48 +08:00
committed by GitHub
parent b59f47e3ee
commit 086a114445
3 changed files with 114 additions and 148 deletions
+101 -142
View File
@@ -520,38 +520,8 @@ var ErrBadTrailer = errors.New("contain forbidden trailer")
// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
//
// Return ErrBadTrailer if contain any forbidden trailers.
func (h *ResponseHeader) AddTrailerBytes(trailer []byte) error {
var err error
for i := -1; i+1 < len(trailer); {
trailer = trailer[i+1:]
i = bytes.IndexByte(trailer, ',')
if i < 0 {
i = len(trailer)
}
key := trailer[:i]
for len(key) > 0 && key[0] == ' ' {
key = key[1:]
}
for len(key) > 0 && key[len(key)-1] == ' ' {
key = key[:len(key)-1]
}
// Forbidden by RFC 7230, section 4.1.2
if isBadTrailer(key) {
err = ErrBadTrailer
continue
}
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing)
if cap(h.trailer) > len(h.trailer) {
h.trailer = h.trailer[:len(h.trailer)+1]
h.trailer[len(h.trailer)-1] = append(h.trailer[len(h.trailer)-1][:0], h.bufK...)
} else {
key = make([]byte, len(h.bufK))
copy(key, h.bufK)
h.trailer = append(h.trailer, key)
}
}
func (h *ResponseHeader) AddTrailerBytes(trailer []byte) (err error) {
h.bufK, h.trailer, err = addTrailerBytes(trailer, h.bufK, h.trailer, h.disableNormalizing)
return err
}
@@ -875,38 +845,8 @@ func (h *RequestHeader) AddTrailer(trailer string) error {
// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
//
// Return ErrBadTrailer if contain any forbidden trailers.
func (h *RequestHeader) AddTrailerBytes(trailer []byte) error {
var err error
for i := -1; i+1 < len(trailer); {
trailer = trailer[i+1:]
i = bytes.IndexByte(trailer, ',')
if i < 0 {
i = len(trailer)
}
key := trailer[:i]
for len(key) > 0 && key[0] == ' ' {
key = key[1:]
}
for len(key) > 0 && key[len(key)-1] == ' ' {
key = key[:len(key)-1]
}
// Forbidden by RFC 7230, section 4.1.2
if isBadTrailer(key) {
err = ErrBadTrailer
continue
}
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing)
if cap(h.trailer) > len(h.trailer) {
h.trailer = h.trailer[:len(h.trailer)+1]
h.trailer[len(h.trailer)-1] = append(h.trailer[len(h.trailer)-1][:0], h.bufK...)
} else {
key = make([]byte, len(h.bufK))
copy(key, h.bufK)
h.trailer = append(h.trailer, key)
}
}
func (h *RequestHeader) AddTrailerBytes(trailer []byte) (err error) {
h.bufK, h.trailer, err = addTrailerBytes(trailer, h.bufK, h.trailer, h.disableNormalizing)
return err
}
@@ -1321,8 +1261,8 @@ func (h *RequestHeader) VisitAll(f func(key, value []byte)) {
func (h *RequestHeader) VisitAllInOrder(f func(key, value []byte)) {
var s headerScanner
s.b = h.rawHeaders
s.disableNormalizing = h.disableNormalizing
for s.next() {
normalizeHeaderKey(s.key, h.disableNormalizing || bytes.IndexByte(s.key, ' ') != -1)
if len(s.key) > 0 {
f(s.key, s.value)
}
@@ -1338,7 +1278,7 @@ func (h *ResponseHeader) Del(key string) {
// DelBytes deletes header with the given key.
func (h *ResponseHeader) DelBytes(key []byte) {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
h.del(h.bufK)
}
@@ -1372,7 +1312,7 @@ func (h *RequestHeader) Del(key string) {
// DelBytes deletes header with the given key.
func (h *RequestHeader) DelBytes(key []byte) {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
h.del(h.bufK)
}
@@ -1638,7 +1578,7 @@ func (h *ResponseHeader) SetBytesV(key string, value []byte) {
// Use AddBytesKV for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesKV(key, value []byte) {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
h.SetCanonical(h.bufK, value)
}
@@ -1864,7 +1804,7 @@ func (h *RequestHeader) SetBytesV(key string, value []byte) {
// Use AddBytesKV for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesKV(key, value []byte) {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
h.SetCanonical(h.bufK, value)
}
@@ -1900,7 +1840,7 @@ func (h *ResponseHeader) Peek(key string) []byte {
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) PeekBytes(key []byte) []byte {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
return h.peek(h.bufK)
}
@@ -1921,7 +1861,7 @@ func (h *RequestHeader) Peek(key string) []byte {
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) PeekBytes(key []byte) []byte {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
return h.peek(h.bufK)
}
@@ -2239,7 +2179,8 @@ func (h *ResponseHeader) tryReadTrailer(r *bufio.Reader, n int) error {
return fmt.Errorf("error when reading response trailer: %w", err)
}
b = mustPeekBuffered(r)
headersLen, errParse := h.parseTrailer(b)
trailers, headersLen, errParse := parseTrailer(b, h.h, h.disableNormalizing)
h.h = trailers
if errParse != nil {
if err == io.EOF {
return err
@@ -2348,7 +2289,8 @@ func (h *RequestHeader) tryReadTrailer(r *bufio.Reader, n int) error {
return fmt.Errorf("error when reading request trailer: %w", err)
}
b = mustPeekBuffered(r)
headersLen, errParse := h.parseTrailer(b)
trailers, headersLen, errParse := parseTrailer(b, h.h, h.disableNormalizing)
h.h = trailers
if errParse != nil {
if err == io.EOF {
return err
@@ -2725,43 +2667,6 @@ func (h *ResponseHeader) parse(buf []byte) (int, error) {
return m + n, nil
}
func (h *ResponseHeader) parseTrailer(buf []byte) (int, error) {
// Skip any 0 length chunk.
if buf[0] == '0' {
skip := len(strCRLF) + 1
if len(buf) < skip {
return 0, io.EOF
}
buf = buf[skip:]
}
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
var err error
for s.next() {
if len(s.key) > 0 {
if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 {
err = fmt.Errorf("invalid trailer key %q", s.key)
continue
}
// Forbidden by RFC 7230, section 4.1.2
if isBadTrailer(s.key) {
err = fmt.Errorf("forbidden trailer key %q", s.key)
continue
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
}
if s.err != nil {
return 0, s.err
}
if err != nil {
return 0, err
}
return s.hLen, nil
}
func (h *RequestHeader) ignoreBody() bool {
return h.IsGet() || h.IsHead()
}
@@ -2784,41 +2689,82 @@ func (h *RequestHeader) parse(buf []byte) (int, error) {
return m + n, nil
}
func (h *RequestHeader) parseTrailer(buf []byte) (int, error) {
// Skip any 0 length chunk.
if buf[0] == '0' {
skip := len(strCRLF) + 1
if len(buf) < skip {
return 0, io.EOF
func addTrailerBytes(src, buf []byte, trailers [][]byte, disableNormalizing bool) ([]byte, [][]byte, error) {
var err error
for i := -1; i+1 < len(src); {
src = src[i+1:]
i = bytes.IndexByte(src, ',')
if i < 0 {
i = len(src)
}
buf = buf[skip:]
key := src[:i]
for len(key) > 0 && key[0] == ' ' {
key = key[1:]
}
for len(key) > 0 && key[len(key)-1] == ' ' {
key = key[:len(key)-1]
}
// Forbidden by RFC 7230, section 4.1.2
if isBadTrailer(key) {
err = ErrBadTrailer
continue
}
buf = append(buf[:0], key...)
normalizeHeaderKey(buf, disableNormalizing || bytes.IndexByte(buf, ' ') != -1)
if cap(trailers) > len(trailers) {
trailers = trailers[:len(trailers)+1]
trailers[len(trailers)-1] = append(trailers[len(trailers)-1][:0], buf...)
} else {
key = make([]byte, len(buf))
copy(key, buf)
trailers = append(trailers, key)
}
}
return buf, trailers, err
}
func parseTrailer(src []byte, dest []argsKV, disableNormalizing bool) ([]argsKV, int, error) {
// Skip any 0 length chunk.
if src[0] == '0' {
skip := len(strCRLF) + 1
if len(src) < skip {
return dest, 0, io.EOF
}
src = src[skip:]
}
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
var err error
s.b = src
for s.next() {
if len(s.key) > 0 {
if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 {
err = fmt.Errorf("invalid trailer key %q", s.key)
continue
}
// Forbidden by RFC 7230, section 4.1.2
if isBadTrailer(s.key) {
err = fmt.Errorf("forbidden trailer key %q", s.key)
continue
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
if len(s.key) == 0 {
continue
}
disable := disableNormalizing
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
// We accept invalid headers with a space before the
// colon, but must not canonicalize them.
// See: https://github.com/valyala/fasthttp/issues/1917
if ch == ' ' {
disable = true
continue
}
return dest, 0, fmt.Errorf("invalid trailer key %q", s.key)
}
}
// Forbidden by RFC 7230, section 4.1.2
if isBadTrailer(s.key) {
return dest, 0, fmt.Errorf("forbidden trailer key %q", s.key)
}
normalizeHeaderKey(s.key, disable)
dest = appendArgBytes(dest, s.key, s.value, argsHasValue)
}
if s.err != nil {
return 0, s.err
return dest, 0, s.err
}
if err != nil {
return 0, err
}
return s.hLen, nil
return dest, s.hLen, nil
}
func isBadTrailer(key []byte) bool {
@@ -3019,7 +2965,6 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
var kv *argsKV
for s.next() {
@@ -3028,12 +2973,22 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
return 0, fmt.Errorf("invalid header key %q", s.key)
}
disableNormalizing := h.disableNormalizing
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
h.connectionClose = true
// We accept invalid headers with a space before the
// colon, but must not canonicalize them.
// See: https://github.com/valyala/fasthttp/issues/1917
if ch == ' ' {
disableNormalizing = true
continue
}
return 0, fmt.Errorf("invalid header key %q", s.key)
}
}
normalizeHeaderKey(s.key, disableNormalizing)
for _, ch := range s.value {
if !validHeaderValueByte(ch) {
h.connectionClose = true
@@ -3136,7 +3091,6 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
for s.next() {
if len(s.key) == 0 {
@@ -3144,12 +3098,19 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
return 0, fmt.Errorf("invalid header key %q", s.key)
}
disableNormalizing := h.disableNormalizing
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
if ch == ' ' {
disableNormalizing = true
continue
}
h.connectionClose = true
return 0, fmt.Errorf("invalid header key %q", s.key)
}
}
normalizeHeaderKey(s.key, disableNormalizing)
for _, ch := range s.value {
if !validHeaderValueByte(ch) {
h.connectionClose = true
@@ -3304,8 +3265,7 @@ type headerScanner struct {
nextColon int
nextNewLine int
disableNormalizing bool
initialized bool
initialized bool
}
func (s *headerScanner) next() bool {
@@ -3351,7 +3311,6 @@ func (s *headerScanner) next() bool {
return false
}
s.key = s.b[:n]
normalizeHeaderKey(s.key, s.disableNormalizing)
n++
for len(s.b) > n && (s.b[n] == ' ' || s.b[n] == '\t') {
n++
@@ -3482,7 +3441,7 @@ func initHeaderKV(bufK, bufV []byte, key, value string, disableNormalizing bool)
func getHeaderKeyBytes(bufK []byte, key string, disableNormalizing bool) []byte {
bufK = append(bufK[:0], key...)
normalizeHeaderKey(bufK, disableNormalizing)
normalizeHeaderKey(bufK, disableNormalizing || bytes.IndexByte(bufK, ' ') != -1)
return bufK
}
+9 -2
View File
@@ -327,6 +327,7 @@ func TestRequestRawHeaders(t *testing.T) {
kvs := "hOsT: foobar\r\n" +
"value: b\r\n" +
"uSeR agent: agent\r\n" +
"\r\n"
t.Run("normalized", func(t *testing.T) {
s := "GET / HTTP/1.1\r\n" + kvs
@@ -343,6 +344,12 @@ func TestRequestRawHeaders(t *testing.T) {
if !bytes.Equal(v2, []byte{'b'}) {
t.Fatalf("expecting non empty value. Got %q", v2)
}
// We accept invalid headers with a space.
// See: https://github.com/valyala/fasthttp/issues/1917
v3 := h.Peek("uSeR agent")
if !bytes.Equal(v3, []byte("agent")) {
t.Fatalf("expecting non empty value. Got %q", v3)
}
if raw := h.RawHeaders(); string(raw) != exp {
t.Fatalf("expected header %q, got %q", exp, raw)
}
@@ -1860,8 +1867,8 @@ func TestResponseHeaderAddTrailerError(t *testing.T) {
t.Parallel()
var h ResponseHeader
err := h.AddTrailer("Foo, Content-Length , Bar,Transfer-Encoding,")
expectedTrailer := "Foo, Bar"
err := h.AddTrailer("Foo, Content-Length , bAr,Transfer-Encoding, uSer aGent")
expectedTrailer := "Foo, Bar, uSer aGent"
if !errors.Is(err, ErrBadTrailer) {
t.Fatalf("unexpected err %q. Expected %q", err, ErrBadTrailer)
+4 -4
View File
@@ -189,8 +189,8 @@ func TestServerInvalidHeader(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
if ctx.Request.Header.Peek("Foo") != nil || ctx.Request.Header.Peek("Foo ") != nil {
t.Error("expected Foo header")
if ctx.Request.Header.Peek("Foő") != nil || ctx.Request.Header.Peek("Foő ") != nil {
t.Error("expected Foő header")
}
},
Logger: &testLogger{},
@@ -208,7 +208,7 @@ func TestServerInvalidHeader(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil {
if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoő : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil {
t.Fatal(err)
}
@@ -225,7 +225,7 @@ func TestServerInvalidHeader(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil {
if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoő : bar\r\n\r\n")); err != nil {
t.Fatal(err)
}