Validate header values (#1796)

* Validate header values

Fixes https://github.com/valyala/fasthttp/issues/1794

* Don't allow empty header keys

And improve error handling for bad headers.
This commit is contained in:
Erik Dubbelboer
2024-07-03 10:04:04 +02:00
committed by GitHub
parent 21b235d033
commit b4c0b2b47d
4 changed files with 222 additions and 168 deletions
+1
View File
@@ -9,3 +9,4 @@ const toUpperTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11
const quotedArgShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01"
const quotedPathShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x00\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01"
const validHeaderFieldByteTable = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01\x01\x01\x01\x01\x00\x00\x01\x01\x00\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x00\x01\x00"
const validHeaderValueByteTable = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01"
+43 -15
View File
@@ -100,23 +100,50 @@ func main() {
validHeaderFieldByteTable := func() [128]byte {
// Should match net/textproto's validHeaderFieldByte(c byte) bool
// Defined by RFC 9110 5.6.2
// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
var a [128]byte
for _, v := range "!#$%&'*+-.^_`|~" {
a[v] = 1
// Defined by RFC 7230 and 9110:
//
// header-field = field-name ":" OWS field-value OWS
// field-name = token
// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
// token = 1*tchar
var table [128]byte
for c := 0; c < 128; c++ {
if (c >= '0' && c <= '9') ||
(c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
c == '!' || c == '#' || c == '$' || c == '%' || c == '&' ||
c == '\'' || c == '*' || c == '+' || c == '-' || c == '.' ||
c == '^' || c == '_' || c == '`' || c == '|' || c == '~' {
table[c] = 1
}
}
for i := 'a'; i <= 'z'; i++ {
a[i] = 1
return table
}()
validHeaderValueByteTable := func() [256]byte {
// Should match net/textproto's validHeaderValueByte(c byte) bool
// Defined by RFC 7230 and 9110:
//
// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ]
// field-vchar = VCHAR / obs-text
// obs-text = %x80-FF
//
// RFC 5234:
//
// HTAB = %x09
// SP = %x20
// VCHAR = %x21-7E
var table [256]byte
for c := 0; c < 256; c++ {
if (c >= 0x21 && c <= 0x7E) || // VCHAR
c == 0x20 || // SP
c == 0x09 || // HTAB
c >= 0x80 { // obs-text
table[c] = 1
}
}
for i := 'A'; i <= 'Z'; i++ {
a[i] = 1
}
for i := '0'; i <= '9'; i++ {
a[i] = 1
}
return a
return table
}()
w := bytes.NewBufferString(pre)
@@ -126,6 +153,7 @@ func main() {
fmt.Fprintf(w, "const quotedArgShouldEscapeTable = %q\n", quotedArgShouldEscapeTable)
fmt.Fprintf(w, "const quotedPathShouldEscapeTable = %q\n", quotedPathShouldEscapeTable)
fmt.Fprintf(w, "const validHeaderFieldByteTable = %q\n", validHeaderFieldByteTable)
fmt.Fprintf(w, "const validHeaderValueByteTable = %q\n", validHeaderValueByteTable)
if err := os.WriteFile("bytesconv_table.go", w.Bytes(), 0o660); err != nil {
log.Fatal(err)
+172 -145
View File
@@ -545,12 +545,18 @@ func (h *ResponseHeader) AddTrailerBytes(trailer []byte) error {
return err
}
// validHeaderFieldByte returns true if c is a valid tchar as defined
// by section 5.6.2 of [RFC9110].
// validHeaderFieldByte returns true if c valid header field byte
// as defined by RFC 7230.
func validHeaderFieldByte(c byte) bool {
return c < 128 && validHeaderFieldByteTable[c] == 1
}
// validHeaderValueByte returns true if c valid header value byte
// as defined by RFC 7230.
func validHeaderValueByte(c byte) bool {
return validHeaderValueByteTable[c] == 1
}
// VisitHeaderParams calls f for each parameter in the given header bytes.
// It stops processing when f returns false or an invalid parameter is found.
// Parameter values may be quoted, in which case \ is treated as an escape
@@ -2945,75 +2951,90 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
var err error
var kv *argsKV
outer:
for s.next() {
if len(s.key) > 0 {
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
err = fmt.Errorf("invalid header key %q", s.key)
continue outer
}
}
switch s.key[0] | 0x20 {
case 'c':
if caseInsensitiveCompare(s.key, strContentType) {
h.contentType = append(h.contentType[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentEncoding) {
h.contentEncoding = append(h.contentEncoding[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentLength) {
if h.contentLength != -1 {
if h.contentLength, err = parseContentLength(s.value); err != nil {
h.contentLength = -2
} else {
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
}
}
continue
}
if caseInsensitiveCompare(s.key, strConnection) {
if bytes.Equal(s.value, strClose) {
h.connectionClose = true
} else {
h.connectionClose = false
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
continue
}
case 's':
if caseInsensitiveCompare(s.key, strServer) {
h.server = append(h.server[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strSetCookie) {
h.cookies, kv = allocArg(h.cookies)
kv.key = getCookieKey(kv.key, s.value)
kv.value = append(kv.value[:0], s.value...)
continue
}
case 't':
if caseInsensitiveCompare(s.key, strTransferEncoding) {
if len(s.value) > 0 && !bytes.Equal(s.value, strIdentity) {
h.contentLength = -1
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
continue
}
if caseInsensitiveCompare(s.key, strTrailer) {
err = h.SetTrailerBytes(s.value)
continue
}
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
if len(s.key) == 0 {
h.connectionClose = true
return 0, fmt.Errorf("invalid header key %q", s.key)
}
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
h.connectionClose = true
return 0, fmt.Errorf("invalid header key %q", s.key)
}
}
for _, ch := range s.value {
if !validHeaderValueByte(ch) {
h.connectionClose = true
return 0, fmt.Errorf("invalid header value %q", s.value)
}
}
switch s.key[0] | 0x20 {
case 'c':
if caseInsensitiveCompare(s.key, strContentType) {
h.contentType = append(h.contentType[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentEncoding) {
h.contentEncoding = append(h.contentEncoding[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentLength) {
if h.contentLength != -1 {
var err error
h.contentLength, err = parseContentLength(s.value)
if err != nil {
h.contentLength = -2
h.connectionClose = true
return 0, err
}
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
}
continue
}
if caseInsensitiveCompare(s.key, strConnection) {
if bytes.Equal(s.value, strClose) {
h.connectionClose = true
} else {
h.connectionClose = false
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
continue
}
case 's':
if caseInsensitiveCompare(s.key, strServer) {
h.server = append(h.server[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strSetCookie) {
h.cookies, kv = allocArg(h.cookies)
kv.key = getCookieKey(kv.key, s.value)
kv.value = append(kv.value[:0], s.value...)
continue
}
case 't':
if caseInsensitiveCompare(s.key, strTransferEncoding) {
if len(s.value) > 0 && !bytes.Equal(s.value, strIdentity) {
h.contentLength = -1
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
continue
}
if caseInsensitiveCompare(s.key, strTrailer) {
err := h.SetTrailerBytes(s.value)
if err != nil {
h.connectionClose = true
return 0, err
}
continue
}
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
if s.err != nil {
h.connectionClose = true
return 0, s.err
@@ -3032,7 +3053,7 @@ outer:
h.connectionClose = !hasHeaderValue(v, strKeepAlive)
}
return len(buf) - len(s.b), err
return len(buf) - len(s.b), nil
}
func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
@@ -3043,103 +3064,109 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
var err error
outer:
for s.next() {
if len(s.key) > 0 {
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
err = fmt.Errorf("invalid header key %q", s.key)
continue outer
}
}
if len(s.key) == 0 {
h.connectionClose = true
return 0, fmt.Errorf("invalid header key %q", s.key)
}
if h.disableSpecialHeader {
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
h.connectionClose = true
return 0, fmt.Errorf("invalid header key %q", s.key)
}
}
for _, ch := range s.value {
if !validHeaderValueByte(ch) {
h.connectionClose = true
return 0, fmt.Errorf("invalid header value %q", s.value)
}
}
if h.disableSpecialHeader {
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
continue
}
switch s.key[0] | 0x20 {
case 'h':
if caseInsensitiveCompare(s.key, strHost) {
h.host = append(h.host[:0], s.value...)
continue
}
case 'u':
if caseInsensitiveCompare(s.key, strUserAgent) {
h.userAgent = append(h.userAgent[:0], s.value...)
continue
}
case 'c':
if caseInsensitiveCompare(s.key, strContentType) {
h.contentType = append(h.contentType[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentLength) {
if contentLengthSeen {
h.connectionClose = true
return 0, errors.New("duplicate Content-Length header")
}
contentLengthSeen = true
switch s.key[0] | 0x20 {
case 'h':
if caseInsensitiveCompare(s.key, strHost) {
h.host = append(h.host[:0], s.value...)
continue
}
case 'u':
if caseInsensitiveCompare(s.key, strUserAgent) {
h.userAgent = append(h.userAgent[:0], s.value...)
continue
}
case 'c':
if caseInsensitiveCompare(s.key, strContentType) {
h.contentType = append(h.contentType[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentLength) {
if contentLengthSeen {
return 0, errors.New("duplicate Content-Length header")
}
contentLengthSeen = true
if h.contentLength != -1 {
var nerr error
if h.contentLength, nerr = parseContentLength(s.value); nerr != nil {
if err == nil {
err = nerr
}
h.contentLength = -2
} else {
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
}
}
continue
}
if caseInsensitiveCompare(s.key, strConnection) {
if bytes.Equal(s.value, strClose) {
if h.contentLength != -1 {
var err error
h.contentLength, err = parseContentLength(s.value)
if err != nil {
h.contentLength = -2
h.connectionClose = true
} else {
h.connectionClose = false
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
return 0, err
}
continue
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
}
case 't':
if caseInsensitiveCompare(s.key, strTransferEncoding) {
isIdentity := caseInsensitiveCompare(s.value, strIdentity)
isChunked := caseInsensitiveCompare(s.value, strChunked)
continue
}
if caseInsensitiveCompare(s.key, strConnection) {
if bytes.Equal(s.value, strClose) {
h.connectionClose = true
} else {
h.connectionClose = false
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
continue
}
case 't':
if caseInsensitiveCompare(s.key, strTransferEncoding) {
isIdentity := caseInsensitiveCompare(s.value, strIdentity)
isChunked := caseInsensitiveCompare(s.value, strChunked)
if !isIdentity && !isChunked {
if h.secureErrorLogMessage {
return 0, errors.New("unsupported Transfer-Encoding")
}
return 0, fmt.Errorf("unsupported Transfer-Encoding: %q", s.value)
if !isIdentity && !isChunked {
h.connectionClose = true
if h.secureErrorLogMessage {
return 0, errors.New("unsupported Transfer-Encoding")
}
return 0, fmt.Errorf("unsupported Transfer-Encoding: %q", s.value)
}
if isChunked {
h.contentLength = -1
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
continue
if isChunked {
h.contentLength = -1
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
if caseInsensitiveCompare(s.key, strTrailer) {
if nerr := h.SetTrailerBytes(s.value); nerr != nil {
if err == nil {
err = nerr
}
}
continue
continue
}
if caseInsensitiveCompare(s.key, strTrailer) {
err := h.SetTrailerBytes(s.value)
if err != nil {
h.connectionClose = true
return 0, err
}
continue
}
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
if s.err != nil && err == nil {
err = s.err
}
if err != nil {
if s.err != nil {
h.connectionClose = true
return 0, err
return 0, s.err
}
if h.contentLength < 0 {
+6 -8
View File
@@ -2439,10 +2439,6 @@ func TestResponseHeaderReadSuccess(t *testing.T) {
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length: 123\nContent-Type: text/html\n\n",
200, 123, "text/html")
// Zero-length headers with mixed crlf and lf
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\nContent-Length: 345\nZero-Value: \r\nContent-Type: aaa\n: zero-key\r\n\r\nooa",
400, 345, "aaa")
// No space after colon
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length:34\nContent-Type: sss\n\naaaa",
200, 34, "sss")
@@ -2600,10 +2596,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) {
testRequestHeaderReadSuccess(t, h, "POST /aaa?bbb HTTP/1.1\r\nHost: foobar.com\r\nContent-Length: 1235\r\nContent-Type: aaa\r\n\r\nabcdef",
1235, "/aaa?bbb", "foobar.com", "", "aaa")
// zero-length headers with mixed crlf and lf
testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost: aaa\r\nZero: \n: Zero-Value\n\r\nxccv",
-2, "/a", "aaa", "", "")
// no space after colon
testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost:aaaxd\n\nsdfds",
-2, "/a", "aaaxd", "", "")
@@ -2719,6 +2711,9 @@ func TestResponseHeaderReadError(t *testing.T) {
// no protocol in the first line
testResponseHeaderReadError(t, h, "GET /foo/bar\r\nHost: google.com\r\n\r\nisdD")
// zero-length headers
testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\n: zero-key\r\n\r\n")
}
func TestResponseHeaderReadErrorSecureLog(t *testing.T) {
@@ -2769,6 +2764,9 @@ func TestRequestHeaderReadError(t *testing.T) {
// post with duplicate content-length
testRequestHeaderReadError(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n")
// Zero-length header
testRequestHeaderReadError(t, h, "GET /foo/bar HTTP/1.1\r\n: zero-key\r\n\r\n")
}
func TestRequestHeaderReadSecuredError(t *testing.T) {