From 7796335d5f797638af64bd27956ec6fa0239ef3f Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sun, 9 Sep 2018 16:33:25 +0800 Subject: [PATCH] Do case insensitive comparisons for headers and cookies --- client_test.go | 38 ++++++++++++++ cookie.go | 79 +++++++++++++++++++--------- cookie_test.go | 2 +- header.go | 138 +++++++++++++++++++++++++++++-------------------- 4 files changed, 176 insertions(+), 81 deletions(-) diff --git a/client_test.go b/client_test.go index d92e831..1be6edf 100644 --- a/client_test.go +++ b/client_test.go @@ -17,6 +17,44 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) +func TestClientHeaderCase(t *testing.T) { + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + c.Write([]byte("HTTP/1.1 200 OK\r\n" + + "content-type: text/plain\r\n" + + "transfer-encoding: chunked\r\n\r\n" + + "24\r\nThis is the data in the first chunk \r\n" + + "1B\r\nand this is the second one \r\n" + + "0\r\n\r\n", + )) + }() + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + ReadTimeout: time.Millisecond * 10, + + // Even without name normalizing we should parse headers correctly. + DisableHeaderNamesNormalizing: true, + } + + code, body, err := c.Get(nil, "http://example.com") + if err != nil { + t.Error(err) + } else if code != 200 { + t.Errorf("expected status code 200 got %d", code) + } else if string(body) != "This is the data in the first chunk and this is the second one " { + t.Errorf("wrong body: %q", body) + } +} + func TestClientReadTimeout(t *testing.T) { // This test is rather slow and increase the total test time // from 2.5 seconds to 6.5 seconds. diff --git a/cookie.go b/cookie.go index bce5d2c..77b6369 100644 --- a/cookie.go +++ b/cookie.go @@ -272,34 +272,49 @@ func (c *Cookie) ParseBytes(src []byte) error { c.value = append(c.value[:0], kv.value...) for s.next(kv) { - if len(kv.key) == 0 && len(kv.value) == 0 { - continue - } - switch string(kv.key) { - case "expires": - v := b2s(kv.value) - // Try the same two formats as net/http - // See: https://github.com/golang/go/blob/00379be17e63a5b75b3237819392d2dc3b313a27/src/net/http/cookie.go#L133-L135 - exptime, err := time.ParseInLocation(time.RFC1123, v, time.UTC) - if err != nil { - exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", v) - if err != nil { - return err + if len(kv.key) != 0 { + // Case insensitive switch on first char + switch kv.key[0] | 0x20 { + case 'e': // "expires" + if caseInsensitiveCompare(strCookieExpires, kv.key) { + v := b2s(kv.value) + // Try the same two formats as net/http + // See: https://github.com/golang/go/blob/00379be17e63a5b75b3237819392d2dc3b313a27/src/net/http/cookie.go#L133-L135 + exptime, err := time.ParseInLocation(time.RFC1123, v, time.UTC) + if err != nil { + exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", v) + if err != nil { + return err + } + } + c.expire = exptime + } + + case 'd': // "domain" + if caseInsensitiveCompare(strCookieDomain, kv.key) { + c.domain = append(c.domain[:0], kv.value...) + } + + case 'p': // "path" + if caseInsensitiveCompare(strCookiePath, kv.key) { + c.path = append(c.path[:0], kv.value...) } } - c.expire = exptime - case "domain": - c.domain = append(c.domain[:0], kv.value...) - case "path": - c.path = append(c.path[:0], kv.value...) - case "": - switch string(kv.value) { - case "HttpOnly": - c.httpOnly = true - case "secure": - c.secure = true + + } else if len(kv.value) != 0 { + // Case insensitive switch on first char + switch kv.value[0] | 0x20 { + case 'h': // "httponly" + if caseInsensitiveCompare(strCookieHTTPOnly, kv.value) { + c.httpOnly = true + } + + case 's': // "secure" + if caseInsensitiveCompare(strCookieSecure, kv.value) { + c.secure = true + } } - } + } // else empty or no match } return nil } @@ -412,3 +427,17 @@ func decodeCookieArg(dst, src []byte, skipQuotes bool) []byte { } return append(dst[:0], src...) } + +// caseInsensitiveCompare does a case insensitive equality comparison of +// two []byte. Assumes only letters need to be matched. +func caseInsensitiveCompare(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i]|0x20 != b[i]|0x20 { + return false + } + } + return true +} diff --git a/cookie_test.go b/cookie_test.go index 85f1569..332c0aa 100644 --- a/cookie_test.go +++ b/cookie_test.go @@ -175,7 +175,7 @@ func TestCookieParse(t *testing.T) { testCookieParse(t, "foo=", "foo=") testCookieParse(t, `foo="bar"`, "foo=bar") testCookieParse(t, `"foo"=bar`, `"foo"=bar`) - testCookieParse(t, "foo=bar; domain=aaa.com; path=/foo/bar", "foo=bar; domain=aaa.com; path=/foo/bar") + testCookieParse(t, "foo=bar; Domain=aaa.com; PATH=/foo/bar", "foo=bar; domain=aaa.com; path=/foo/bar") testCookieParse(t, " xxx = yyy ; path=/a/b;;;domain=foobar.com ; expires= Tue, 10 Nov 2009 23:00:00 GMT ; ;;", "xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b") } diff --git a/header.go b/header.go index da7889b..2097602 100644 --- a/header.go +++ b/header.go @@ -1773,36 +1773,49 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) { var err error var kv *argsKV for s.next() { - switch string(s.key) { - case "Content-Type": - h.contentType = append(h.contentType[:0], s.value...) - case "Server": - h.server = append(h.server[:0], s.value...) - case "Content-Length": - if h.contentLength != -1 { - if h.contentLength, err = parseContentLength(s.value); err != nil { - h.contentLength = -2 - } else { - h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) + if len(s.key) > 0 { + switch s.key[0] | 0x20 { + case 'c': + if caseInsensitiveCompare(s.key, strContentType) { + h.contentType = append(h.contentType[:0], s.value...) + continue + } else if caseInsensitiveCompare(s.key, strContentLength) { + if h.contentLength != -1 { + if h.contentLength, err = parseContentLength(s.value); err != nil { + h.contentLength = -2 + } else { + h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) + } + } + continue + } else if caseInsensitiveCompare(s.key, strConnection) { + if bytes.Equal(s.value, strClose) { + h.connectionClose = true + } else { + h.connectionClose = false + h.h = appendArgBytes(h.h, s.key, s.value) + } + continue + } + case 's': + if caseInsensitiveCompare(s.key, strServer) { + h.server = append(h.server[:0], s.value...) + continue + } else if caseInsensitiveCompare(s.key, strSetCookie) { + h.cookies, kv = allocArg(h.cookies) + kv.key = getCookieKey(kv.key, s.value) + kv.value = append(kv.value[:0], s.value...) + continue + } + case 't': + if caseInsensitiveCompare(s.key, strTransferEncoding) { + if !bytes.Equal(s.value, strIdentity) { + h.contentLength = -1 + h.h = setArgBytes(h.h, strTransferEncoding, strChunked) + } + continue } } - case "Transfer-Encoding": - if !bytes.Equal(s.value, strIdentity) { - h.contentLength = -1 - h.h = setArgBytes(h.h, strTransferEncoding, strChunked) - } - case "Set-Cookie": - h.cookies, kv = allocArg(h.cookies) - kv.key = getCookieKey(kv.key, s.value) - kv.value = append(kv.value[:0], s.value...) - case "Connection": - if bytes.Equal(s.value, strClose) { - h.connectionClose = true - } else { - h.connectionClose = false - h.h = appendArgBytes(h.h, s.key, s.value) - } - default: h.h = appendArgBytes(h.h, s.key, s.value) } } @@ -1835,36 +1848,51 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { s.disableNormalizing = h.disableNormalizing var err error for s.next() { - switch string(s.key) { - case "Host": - h.host = append(h.host[:0], s.value...) - case "User-Agent": - h.userAgent = append(h.userAgent[:0], s.value...) - case "Content-Type": - h.contentType = append(h.contentType[:0], s.value...) - case "Content-Length": - if h.contentLength != -1 { - if h.contentLength, err = parseContentLength(s.value); err != nil { - h.contentLength = -2 - } else { - h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) + if len(s.key) > 0 { + switch s.key[0] | 0x20 { + case 'h': + if caseInsensitiveCompare(s.key, strHost) { + h.host = append(h.host[:0], s.value...) + continue + } + case 'u': + if caseInsensitiveCompare(s.key, strUserAgent) { + h.userAgent = append(h.userAgent[:0], s.value...) + continue + } + case 'c': + if caseInsensitiveCompare(s.key, strContentType) { + h.contentType = append(h.contentType[:0], s.value...) + continue + } else if caseInsensitiveCompare(s.key, strContentLength) { + if h.contentLength != -1 { + if h.contentLength, err = parseContentLength(s.value); err != nil { + h.contentLength = -2 + } else { + h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) + } + } + continue + } else if caseInsensitiveCompare(s.key, strConnection) { + if bytes.Equal(s.value, strClose) { + h.connectionClose = true + } else { + h.connectionClose = false + h.h = appendArgBytes(h.h, s.key, s.value) + } + continue + } + case 't': + if caseInsensitiveCompare(s.key, strTransferEncoding) { + if !bytes.Equal(s.value, strIdentity) { + h.contentLength = -1 + h.h = setArgBytes(h.h, strTransferEncoding, strChunked) + } + continue } } - case "Transfer-Encoding": - if !bytes.Equal(s.value, strIdentity) { - h.contentLength = -1 - h.h = setArgBytes(h.h, strTransferEncoding, strChunked) - } - case "Connection": - if bytes.Equal(s.value, strClose) { - h.connectionClose = true - } else { - h.connectionClose = false - h.h = appendArgBytes(h.h, s.key, s.value) - } - default: - h.h = appendArgBytes(h.h, s.key, s.value) } + h.h = appendArgBytes(h.h, s.key, s.value) } if s.err != nil { h.connectionClose = true