diff --git a/bytesconv_table.go b/bytesconv_table.go index 5b230f1..b00ec6a 100644 --- a/bytesconv_table.go +++ b/bytesconv_table.go @@ -9,3 +9,4 @@ const toUpperTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11 const quotedArgShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" const quotedPathShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x00\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" const validHeaderFieldByteTable = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01\x01\x01\x01\x01\x00\x00\x01\x01\x00\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x00\x01\x00" +const validHeaderValueByteTable = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" diff --git a/bytesconv_table_gen.go b/bytesconv_table_gen.go index 4fa8c16..56c33c5 100644 --- a/bytesconv_table_gen.go +++ b/bytesconv_table_gen.go @@ -100,23 +100,50 @@ func main() { validHeaderFieldByteTable := func() [128]byte { // Should match net/textproto's validHeaderFieldByte(c byte) bool - // Defined by RFC 9110 5.6.2 - // tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / - // "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA - var a [128]byte - for _, v := range "!#$%&'*+-.^_`|~" { - a[v] = 1 + // Defined by RFC 7230 and 9110: + // + // header-field = field-name ":" OWS field-value OWS + // field-name = token + // tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / + // "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA + // token = 1*tchar + var table [128]byte + for c := 0; c < 128; c++ { + if (c >= '0' && c <= '9') || + (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + c == '!' || c == '#' || c == '$' || c == '%' || c == '&' || + c == '\'' || c == '*' || c == '+' || c == '-' || c == '.' || + c == '^' || c == '_' || c == '`' || c == '|' || c == '~' { + table[c] = 1 + } } - for i := 'a'; i <= 'z'; i++ { - a[i] = 1 + return table + }() + + validHeaderValueByteTable := func() [256]byte { + // Should match net/textproto's validHeaderValueByte(c byte) bool + // Defined by RFC 7230 and 9110: + // + // field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] + // field-vchar = VCHAR / obs-text + // obs-text = %x80-FF + // + // RFC 5234: + // + // HTAB = %x09 + // SP = %x20 + // VCHAR = %x21-7E + var table [256]byte + for c := 0; c < 256; c++ { + if (c >= 0x21 && c <= 0x7E) || // VCHAR + c == 0x20 || // SP + c == 0x09 || // HTAB + c >= 0x80 { // obs-text + table[c] = 1 + } } - for i := 'A'; i <= 'Z'; i++ { - a[i] = 1 - } - for i := '0'; i <= '9'; i++ { - a[i] = 1 - } - return a + return table }() w := bytes.NewBufferString(pre) @@ -126,6 +153,7 @@ func main() { fmt.Fprintf(w, "const quotedArgShouldEscapeTable = %q\n", quotedArgShouldEscapeTable) fmt.Fprintf(w, "const quotedPathShouldEscapeTable = %q\n", quotedPathShouldEscapeTable) fmt.Fprintf(w, "const validHeaderFieldByteTable = %q\n", validHeaderFieldByteTable) + fmt.Fprintf(w, "const validHeaderValueByteTable = %q\n", validHeaderValueByteTable) if err := os.WriteFile("bytesconv_table.go", w.Bytes(), 0o660); err != nil { log.Fatal(err) diff --git a/header.go b/header.go index b026294..26e5266 100644 --- a/header.go +++ b/header.go @@ -545,12 +545,18 @@ func (h *ResponseHeader) AddTrailerBytes(trailer []byte) error { return err } -// validHeaderFieldByte returns true if c is a valid tchar as defined -// by section 5.6.2 of [RFC9110]. +// validHeaderFieldByte returns true if c valid header field byte +// as defined by RFC 7230. func validHeaderFieldByte(c byte) bool { return c < 128 && validHeaderFieldByteTable[c] == 1 } +// validHeaderValueByte returns true if c valid header value byte +// as defined by RFC 7230. +func validHeaderValueByte(c byte) bool { + return validHeaderValueByteTable[c] == 1 +} + // VisitHeaderParams calls f for each parameter in the given header bytes. // It stops processing when f returns false or an invalid parameter is found. // Parameter values may be quoted, in which case \ is treated as an escape @@ -2945,75 +2951,90 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) { var s headerScanner s.b = buf s.disableNormalizing = h.disableNormalizing - var err error var kv *argsKV -outer: for s.next() { - if len(s.key) > 0 { - for _, ch := range s.key { - if !validHeaderFieldByte(ch) { - err = fmt.Errorf("invalid header key %q", s.key) - continue outer - } - } - - switch s.key[0] | 0x20 { - case 'c': - if caseInsensitiveCompare(s.key, strContentType) { - h.contentType = append(h.contentType[:0], s.value...) - continue - } - if caseInsensitiveCompare(s.key, strContentEncoding) { - h.contentEncoding = append(h.contentEncoding[:0], s.value...) - continue - } - if caseInsensitiveCompare(s.key, strContentLength) { - if h.contentLength != -1 { - if h.contentLength, err = parseContentLength(s.value); err != nil { - h.contentLength = -2 - } else { - h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) - } - } - continue - } - if caseInsensitiveCompare(s.key, strConnection) { - if bytes.Equal(s.value, strClose) { - h.connectionClose = true - } else { - h.connectionClose = false - h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) - } - continue - } - case 's': - if caseInsensitiveCompare(s.key, strServer) { - h.server = append(h.server[:0], s.value...) - continue - } - if caseInsensitiveCompare(s.key, strSetCookie) { - h.cookies, kv = allocArg(h.cookies) - kv.key = getCookieKey(kv.key, s.value) - kv.value = append(kv.value[:0], s.value...) - continue - } - case 't': - if caseInsensitiveCompare(s.key, strTransferEncoding) { - if len(s.value) > 0 && !bytes.Equal(s.value, strIdentity) { - h.contentLength = -1 - h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue) - } - continue - } - if caseInsensitiveCompare(s.key, strTrailer) { - err = h.SetTrailerBytes(s.value) - continue - } - } - h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + if len(s.key) == 0 { + h.connectionClose = true + return 0, fmt.Errorf("invalid header key %q", s.key) } + + for _, ch := range s.key { + if !validHeaderFieldByte(ch) { + h.connectionClose = true + return 0, fmt.Errorf("invalid header key %q", s.key) + } + } + for _, ch := range s.value { + if !validHeaderValueByte(ch) { + h.connectionClose = true + return 0, fmt.Errorf("invalid header value %q", s.value) + } + } + + switch s.key[0] | 0x20 { + case 'c': + if caseInsensitiveCompare(s.key, strContentType) { + h.contentType = append(h.contentType[:0], s.value...) + continue + } + if caseInsensitiveCompare(s.key, strContentEncoding) { + h.contentEncoding = append(h.contentEncoding[:0], s.value...) + continue + } + if caseInsensitiveCompare(s.key, strContentLength) { + if h.contentLength != -1 { + var err error + h.contentLength, err = parseContentLength(s.value) + if err != nil { + h.contentLength = -2 + h.connectionClose = true + return 0, err + } + h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) + } + continue + } + if caseInsensitiveCompare(s.key, strConnection) { + if bytes.Equal(s.value, strClose) { + h.connectionClose = true + } else { + h.connectionClose = false + h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + } + continue + } + case 's': + if caseInsensitiveCompare(s.key, strServer) { + h.server = append(h.server[:0], s.value...) + continue + } + if caseInsensitiveCompare(s.key, strSetCookie) { + h.cookies, kv = allocArg(h.cookies) + kv.key = getCookieKey(kv.key, s.value) + kv.value = append(kv.value[:0], s.value...) + continue + } + case 't': + if caseInsensitiveCompare(s.key, strTransferEncoding) { + if len(s.value) > 0 && !bytes.Equal(s.value, strIdentity) { + h.contentLength = -1 + h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue) + } + continue + } + if caseInsensitiveCompare(s.key, strTrailer) { + err := h.SetTrailerBytes(s.value) + if err != nil { + h.connectionClose = true + return 0, err + } + continue + } + } + h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) } + if s.err != nil { h.connectionClose = true return 0, s.err @@ -3032,7 +3053,7 @@ outer: h.connectionClose = !hasHeaderValue(v, strKeepAlive) } - return len(buf) - len(s.b), err + return len(buf) - len(s.b), nil } func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { @@ -3043,103 +3064,109 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { var s headerScanner s.b = buf s.disableNormalizing = h.disableNormalizing - var err error -outer: for s.next() { - if len(s.key) > 0 { - for _, ch := range s.key { - if !validHeaderFieldByte(ch) { - err = fmt.Errorf("invalid header key %q", s.key) - continue outer - } - } + if len(s.key) == 0 { + h.connectionClose = true + return 0, fmt.Errorf("invalid header key %q", s.key) + } - if h.disableSpecialHeader { - h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + for _, ch := range s.key { + if !validHeaderFieldByte(ch) { + h.connectionClose = true + return 0, fmt.Errorf("invalid header key %q", s.key) + } + } + for _, ch := range s.value { + if !validHeaderValueByte(ch) { + h.connectionClose = true + return 0, fmt.Errorf("invalid header value %q", s.value) + } + } + + if h.disableSpecialHeader { + h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + continue + } + + switch s.key[0] | 0x20 { + case 'h': + if caseInsensitiveCompare(s.key, strHost) { + h.host = append(h.host[:0], s.value...) continue } + case 'u': + if caseInsensitiveCompare(s.key, strUserAgent) { + h.userAgent = append(h.userAgent[:0], s.value...) + continue + } + case 'c': + if caseInsensitiveCompare(s.key, strContentType) { + h.contentType = append(h.contentType[:0], s.value...) + continue + } + if caseInsensitiveCompare(s.key, strContentLength) { + if contentLengthSeen { + h.connectionClose = true + return 0, errors.New("duplicate Content-Length header") + } + contentLengthSeen = true - switch s.key[0] | 0x20 { - case 'h': - if caseInsensitiveCompare(s.key, strHost) { - h.host = append(h.host[:0], s.value...) - continue - } - case 'u': - if caseInsensitiveCompare(s.key, strUserAgent) { - h.userAgent = append(h.userAgent[:0], s.value...) - continue - } - case 'c': - if caseInsensitiveCompare(s.key, strContentType) { - h.contentType = append(h.contentType[:0], s.value...) - continue - } - if caseInsensitiveCompare(s.key, strContentLength) { - if contentLengthSeen { - return 0, errors.New("duplicate Content-Length header") - } - contentLengthSeen = true - - if h.contentLength != -1 { - var nerr error - if h.contentLength, nerr = parseContentLength(s.value); nerr != nil { - if err == nil { - err = nerr - } - h.contentLength = -2 - } else { - h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) - } - } - continue - } - if caseInsensitiveCompare(s.key, strConnection) { - if bytes.Equal(s.value, strClose) { + if h.contentLength != -1 { + var err error + h.contentLength, err = parseContentLength(s.value) + if err != nil { + h.contentLength = -2 h.connectionClose = true - } else { - h.connectionClose = false - h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + return 0, err } - continue + h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) } - case 't': - if caseInsensitiveCompare(s.key, strTransferEncoding) { - isIdentity := caseInsensitiveCompare(s.value, strIdentity) - isChunked := caseInsensitiveCompare(s.value, strChunked) + continue + } + if caseInsensitiveCompare(s.key, strConnection) { + if bytes.Equal(s.value, strClose) { + h.connectionClose = true + } else { + h.connectionClose = false + h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + } + continue + } + case 't': + if caseInsensitiveCompare(s.key, strTransferEncoding) { + isIdentity := caseInsensitiveCompare(s.value, strIdentity) + isChunked := caseInsensitiveCompare(s.value, strChunked) - if !isIdentity && !isChunked { - if h.secureErrorLogMessage { - return 0, errors.New("unsupported Transfer-Encoding") - } - return 0, fmt.Errorf("unsupported Transfer-Encoding: %q", s.value) + if !isIdentity && !isChunked { + h.connectionClose = true + if h.secureErrorLogMessage { + return 0, errors.New("unsupported Transfer-Encoding") } + return 0, fmt.Errorf("unsupported Transfer-Encoding: %q", s.value) + } - if isChunked { - h.contentLength = -1 - h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue) - } - continue + if isChunked { + h.contentLength = -1 + h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue) } - if caseInsensitiveCompare(s.key, strTrailer) { - if nerr := h.SetTrailerBytes(s.value); nerr != nil { - if err == nil { - err = nerr - } - } - continue + continue + } + if caseInsensitiveCompare(s.key, strTrailer) { + err := h.SetTrailerBytes(s.value) + if err != nil { + h.connectionClose = true + return 0, err } + continue } } h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) } - if s.err != nil && err == nil { - err = s.err - } - if err != nil { + + if s.err != nil { h.connectionClose = true - return 0, err + return 0, s.err } if h.contentLength < 0 { diff --git a/header_test.go b/header_test.go index 39b06e8..527d499 100644 --- a/header_test.go +++ b/header_test.go @@ -2439,10 +2439,6 @@ func TestResponseHeaderReadSuccess(t *testing.T) { testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length: 123\nContent-Type: text/html\n\n", 200, 123, "text/html") - // Zero-length headers with mixed crlf and lf - testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\nContent-Length: 345\nZero-Value: \r\nContent-Type: aaa\n: zero-key\r\n\r\nooa", - 400, 345, "aaa") - // No space after colon testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length:34\nContent-Type: sss\n\naaaa", 200, 34, "sss") @@ -2600,10 +2596,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) { testRequestHeaderReadSuccess(t, h, "POST /aaa?bbb HTTP/1.1\r\nHost: foobar.com\r\nContent-Length: 1235\r\nContent-Type: aaa\r\n\r\nabcdef", 1235, "/aaa?bbb", "foobar.com", "", "aaa") - // zero-length headers with mixed crlf and lf - testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost: aaa\r\nZero: \n: Zero-Value\n\r\nxccv", - -2, "/a", "aaa", "", "") - // no space after colon testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost:aaaxd\n\nsdfds", -2, "/a", "aaaxd", "", "") @@ -2719,6 +2711,9 @@ func TestResponseHeaderReadError(t *testing.T) { // no protocol in the first line testResponseHeaderReadError(t, h, "GET /foo/bar\r\nHost: google.com\r\n\r\nisdD") + + // zero-length headers + testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\n: zero-key\r\n\r\n") } func TestResponseHeaderReadErrorSecureLog(t *testing.T) { @@ -2769,6 +2764,9 @@ func TestRequestHeaderReadError(t *testing.T) { // post with duplicate content-length testRequestHeaderReadError(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n") + + // Zero-length header + testRequestHeaderReadError(t, h, "GET /foo/bar HTTP/1.1\r\n: zero-key\r\n\r\n") } func TestRequestHeaderReadSecuredError(t *testing.T) {