diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 56dc890..3f0e6ec 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -8,7 +8,7 @@ jobs: test: strategy: matrix: - go-version: [1.20.x] + go-version: [1.25.x] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} env: diff --git a/bytesconv.go b/bytesconv.go index 7350e8b..5632486 100644 --- a/bytesconv.go +++ b/bytesconv.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net" + "net/http" "strconv" "sync" "time" @@ -68,6 +69,8 @@ func AppendIPv4(dst []byte, ip net.IP) []byte { var errEmptyIPStr = errors.New("empty ip address string") +var httpDateGMT = time.FixedZone("GMT", 0) + // ParseIPv4 parses ip address from ipStr into dst and returns the extended dst. func ParseIPv4(dst net.IP, ipStr []byte) (net.IP, error) { if len(ipStr) == 0 { @@ -117,7 +120,131 @@ func AppendHTTPDate(dst []byte, date time.Time) []byte { // ParseHTTPDate parses HTTP-compliant (RFC1123) date. func ParseHTTPDate(date []byte) (time.Time, error) { - return time.Parse(time.RFC1123, b2s(date)) + if t, ok := parseRFC1123DateGMT(date); ok { + return t, nil + } + return time.Parse(http.TimeFormat, b2s(date)) +} + +func parseRFC1123DateGMT(b []byte) (time.Time, bool) { + // Expects "Mon, 02 Jan 2006 15:04:05 GMT". + if len(b) != 29 { + return time.Time{}, false + } + if !isWeekday3(b[0], b[1], b[2]) { + return time.Time{}, false + } + if b[3] != ',' || b[4] != ' ' || b[7] != ' ' || b[11] != ' ' || + b[16] != ' ' || b[19] != ':' || b[22] != ':' || b[25] != ' ' { + return time.Time{}, false + } + if b[26] != 'G' || b[27] != 'M' || b[28] != 'T' { + return time.Time{}, false + } + + day, ok := parse2Digits(b[5], b[6]) + if !ok || day < 1 || day > 31 { + return time.Time{}, false + } + month, ok := parseMonth3(b[8], b[9], b[10]) + if !ok { + return time.Time{}, false + } + year, ok := parse4Digits(b[12], b[13], b[14], b[15]) + if !ok { + return time.Time{}, false + } + hour, ok := parse2Digits(b[17], b[18]) + if !ok || hour > 23 { + return time.Time{}, false + } + minute, ok := parse2Digits(b[20], b[21]) + if !ok || minute > 59 { + return time.Time{}, false + } + second, ok := parse2Digits(b[23], b[24]) + if !ok || second > 59 { + return time.Time{}, false + } + + t := time.Date(year, month, day, hour, minute, second, 0, httpDateGMT) + // Reject calendar-invalid dates like "31 Feb", which time.Date normalizes. + if t.Year() != year || t.Month() != month || t.Day() != day { + return time.Time{}, false + } + return t, true +} + +func isWeekday3(a, b, c byte) bool { + a |= 0x20 + b |= 0x20 + c |= 0x20 + k := uint32(a)<<16 | uint32(b)<<8 | uint32(c) + switch k { + case uint32('m')<<16 | uint32('o')<<8 | uint32('n'), + uint32('t')<<16 | uint32('u')<<8 | uint32('e'), + uint32('w')<<16 | uint32('e')<<8 | uint32('d'), + uint32('t')<<16 | uint32('h')<<8 | uint32('u'), + uint32('f')<<16 | uint32('r')<<8 | uint32('i'), + uint32('s')<<16 | uint32('a')<<8 | uint32('t'), + uint32('s')<<16 | uint32('u')<<8 | uint32('n'): + return true + default: + return false + } +} + +func parse2Digits(a, b byte) (int, bool) { + if a < '0' || a > '9' || b < '0' || b > '9' { + return 0, false + } + return int(a-'0')*10 + int(b-'0'), true +} + +func parse4Digits(a, b, c, d byte) (int, bool) { + v1, ok := parse2Digits(a, b) + if !ok { + return 0, false + } + v2, ok := parse2Digits(c, d) + if !ok { + return 0, false + } + return v1*100 + v2, true +} + +func parseMonth3(a, b, c byte) (time.Month, bool) { + a |= 0x20 + b |= 0x20 + c |= 0x20 + k := uint32(a)<<16 | uint32(b)<<8 | uint32(c) + switch k { + case uint32('j')<<16 | uint32('a')<<8 | uint32('n'): + return time.January, true + case uint32('f')<<16 | uint32('e')<<8 | uint32('b'): + return time.February, true + case uint32('m')<<16 | uint32('a')<<8 | uint32('r'): + return time.March, true + case uint32('a')<<16 | uint32('p')<<8 | uint32('r'): + return time.April, true + case uint32('m')<<16 | uint32('a')<<8 | uint32('y'): + return time.May, true + case uint32('j')<<16 | uint32('u')<<8 | uint32('n'): + return time.June, true + case uint32('j')<<16 | uint32('u')<<8 | uint32('l'): + return time.July, true + case uint32('a')<<16 | uint32('u')<<8 | uint32('g'): + return time.August, true + case uint32('s')<<16 | uint32('e')<<8 | uint32('p'): + return time.September, true + case uint32('o')<<16 | uint32('c')<<8 | uint32('t'): + return time.October, true + case uint32('n')<<16 | uint32('o')<<8 | uint32('v'): + return time.November, true + case uint32('d')<<16 | uint32('e')<<8 | uint32('c'): + return time.December, true + } + return 0, false } // AppendUint appends n to dst and returns the extended dst. diff --git a/bytesconv_test.go b/bytesconv_test.go index 13becc1..e9671ad 100644 --- a/bytesconv_test.go +++ b/bytesconv_test.go @@ -5,6 +5,7 @@ import ( "bytes" "html" "net" + "net/http" "net/url" "strconv" "testing" @@ -200,6 +201,116 @@ func TestAppendHTTPDate(t *testing.T) { } } +func TestParseHTTPDateCompatibility(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + value string + hasError bool + roundTrip bool + }{ + {name: "gmt-fast-path", value: "Tue, 10 Nov 2009 23:00:00 GMT", roundTrip: true}, + {name: "epoch", value: "Thu, 01 Jan 1970 00:00:00 GMT", roundTrip: true}, + {name: "year-boundary", value: "Fri, 31 Dec 1999 23:59:59 GMT", roundTrip: true}, + {name: "leap-year", value: "Mon, 29 Feb 2016 12:34:56 GMT", roundTrip: true}, + {name: "utc-fallback", value: "Tue, 10 Nov 2009 23:00:00 UTC", hasError: true}, + {name: "mixedcase-weekday-month", value: "tUe, 10 nOv 2009 23:00:00 GMT"}, + {name: "day-zero", value: "Tue, 00 Nov 2009 23:00:00 GMT", hasError: true}, + {name: "invalid-day", value: "Tue, 31 Feb 2009 23:00:00 GMT", hasError: true}, + {name: "invalid-weekday", value: "Xxx, 10 Nov 2009 23:00:00 GMT", hasError: true}, + {name: "invalid-month", value: "Tue, 10 Foo 2009 23:00:00 GMT", hasError: true}, + {name: "invalid-hour", value: "Tue, 10 Nov 2009 24:00:00 GMT", hasError: true}, + {name: "invalid-minute", value: "Tue, 10 Nov 2009 23:60:00 GMT", hasError: true}, + {name: "invalid-second", value: "Tue, 10 Nov 2009 23:00:60 GMT", hasError: true}, + {name: "invalid-separator", value: "Tue 10 Nov 2009 23:00:00 GMT", hasError: true}, + {name: "invalid-time-separator", value: "Tue, 10 Nov 2009 23-00-00 GMT", hasError: true}, + {name: "non-leap-year", value: "Tue, 29 Feb 2019 23:00:00 GMT", hasError: true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got, gotErr := ParseHTTPDate([]byte(tc.value)) + want, wantErr := time.Parse(http.TimeFormat, tc.value) + + if (gotErr != nil) != (wantErr != nil) { + t.Fatalf("error mismatch for %q: ParseHTTPDate err=%v, ParseInLocation err=%v", tc.value, gotErr, wantErr) + } + if tc.hasError != (gotErr != nil) { + t.Fatalf("unexpected error state for %q: gotErr=%v, expectedError=%v", tc.value, gotErr, tc.hasError) + } + if gotErr != nil { + return + } + if !got.Equal(want) { + t.Fatalf("parsed time mismatch for %q: got=%v want=%v", tc.value, got, want) + } + if tc.roundTrip && got.Format(time.RFC1123) != tc.value { + t.Fatalf("unexpected formatted date %q. Expecting %q", got.Format(time.RFC1123), tc.value) + } + }) + } +} + +func BenchmarkParseHTTPDate(b *testing.B) { + date := []byte("Tue, 10 Nov 2009 23:00:00 GMT") + + b.Run("fast-path", func(b *testing.B) { + b.ReportAllocs() + for range b.N { + _, _ = ParseHTTPDate(date) + } + }) + + b.Run("stdlib-only", func(b *testing.B) { + b.ReportAllocs() + s := string(date) + for range b.N { + _, _ = time.Parse(http.TimeFormat, s) + } + }) +} + +func FuzzParseHTTPDate(f *testing.F) { + // Seed corpus: valid RFC1123 dates. + seeds := []string{ + "Tue, 10 Nov 2009 23:00:00 GMT", + "Thu, 01 Jan 1970 00:00:00 GMT", + "Fri, 31 Dec 1999 23:59:59 GMT", + "Mon, 29 Feb 2016 12:34:56 GMT", + "Sun, 06 Nov 1994 08:49:37 GMT", + // Invalid inputs to exercise rejection paths. + "Tue, 10 Nov 2009 23:00:00 UTC", + "Tue, 31 Feb 2009 23:00:00 GMT", + "Xxx, 10 Nov 2009 23:00:00 GMT", + "Tue, 00 Nov 2009 23:00:00 GMT", + "Tue, 10 Nov 2009 24:00:00 GMT", + "not a date at all", + "", + } + for _, s := range seeds { + f.Add(s) + } + + f.Fuzz(func(t *testing.T, s string) { + b := []byte(s) + + // Reference: time.Parse with http.TimeFormat is what ParseHTTPDate falls back to. + stdTime, stdErr := time.Parse(http.TimeFormat, s) + + // The public API must always agree with time.Parse. + got, gotErr := ParseHTTPDate(b) + if (gotErr != nil) != (stdErr != nil) { + t.Fatalf("ParseHTTPDate error mismatch for %q: got err=%v, std err=%v", s, gotErr, stdErr) + } + if gotErr == nil && !got.Equal(stdTime) { + t.Fatalf("ParseHTTPDate time mismatch for %q: got=%v std=%v", s, got, stdTime) + } + }) +} + func TestParseUintError(t *testing.T) { t.Parallel() diff --git a/cookie.go b/cookie.go index e7d0592..134cb48 100644 --- a/cookie.go +++ b/cookie.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "io" + "net/http" "sync" "time" ) @@ -379,20 +380,21 @@ func (c *Cookie) ParseBytes(src []byte) error { var s cookieScanner s.b = src - if !s.next(&c.bufK, &c.bufV) { + var k, v []byte + if !s.nextRaw(&k, &v) { return errNoCookies } - c.key = append(c.key, c.bufK...) - c.value = append(c.value, c.bufV...) + c.key = append(c.key, k...) + c.value = append(c.value, v...) - for s.next(&c.bufK, &c.bufV) { - if len(c.bufK) != 0 { + for s.nextRaw(&k, &v) { + if len(k) != 0 { // Case insensitive switch on first char - switch c.bufK[0] | 0x20 { + switch k[0] | 0x20 { case 'm': - if caseInsensitiveCompare(strCookieMaxAge, c.bufK) { - maxAge, err := ParseUint(c.bufV) + if caseInsensitiveCompare(strCookieMaxAge, k) { + maxAge, err := ParseUint(v) if err != nil { return err } @@ -400,67 +402,61 @@ func (c *Cookie) ParseBytes(src []byte) error { } case 'e': // "expires" - if caseInsensitiveCompare(strCookieExpires, c.bufK) { - v := b2s(c.bufV) - // Try the same two formats as net/http - // See: https://github.com/golang/go/blob/00379be17e63a5b75b3237819392d2dc3b313a27/src/net/http/cookie.go#L133-L135 - exptime, err := time.ParseInLocation(time.RFC1123, v, time.UTC) + if caseInsensitiveCompare(strCookieExpires, k) { + exptime, err := parseCookieExpires(v) if err != nil { - exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", v) - if err != nil { - return err - } + return err } c.expire = exptime } case 'd': // "domain" - if caseInsensitiveCompare(strCookieDomain, c.bufK) { - c.domain = append(c.domain, c.bufV...) + if caseInsensitiveCompare(strCookieDomain, k) { + c.domain = append(c.domain, v...) } case 'p': // "path" - if caseInsensitiveCompare(strCookiePath, c.bufK) { - c.path = append(c.path, c.bufV...) + if caseInsensitiveCompare(strCookiePath, k) { + c.path = append(c.path, v...) } case 's': // "samesite" - if caseInsensitiveCompare(strCookieSameSite, c.bufK) { - if len(c.bufV) > 0 { + if caseInsensitiveCompare(strCookieSameSite, k) { + if len(v) > 0 { // Case insensitive switch on first char - switch c.bufV[0] | 0x20 { + switch v[0] | 0x20 { case 'l': // "lax" - if caseInsensitiveCompare(strCookieSameSiteLax, c.bufV) { + if caseInsensitiveCompare(strCookieSameSiteLax, v) { c.sameSite = CookieSameSiteLaxMode } case 's': // "strict" - if caseInsensitiveCompare(strCookieSameSiteStrict, c.bufV) { + if caseInsensitiveCompare(strCookieSameSiteStrict, v) { c.sameSite = CookieSameSiteStrictMode } case 'n': // "none" - if caseInsensitiveCompare(strCookieSameSiteNone, c.bufV) { + if caseInsensitiveCompare(strCookieSameSiteNone, v) { c.sameSite = CookieSameSiteNoneMode } } } } } - } else if len(c.bufV) != 0 { + } else if len(v) != 0 { // Case insensitive switch on first char - switch c.bufV[0] | 0x20 { + switch v[0] | 0x20 { case 'h': // "httponly" - if caseInsensitiveCompare(strCookieHTTPOnly, c.bufV) { + if caseInsensitiveCompare(strCookieHTTPOnly, v) { c.httpOnly = true } case 's': // "secure" - if caseInsensitiveCompare(strCookieSecure, c.bufV) { + if caseInsensitiveCompare(strCookieSecure, v) { c.secure = true - } else if caseInsensitiveCompare(strCookieSameSite, c.bufV) { + } else if caseInsensitiveCompare(strCookieSameSite, v) { c.sameSite = CookieSameSiteDefaultMode } case 'p': // "partitioned" - if caseInsensitiveCompare(strCookiePartitioned, c.bufV) { + if caseInsensitiveCompare(strCookiePartitioned, v) { c.partitioned = true } } @@ -529,6 +525,44 @@ type cookieScanner struct { b []byte } +func (s *cookieScanner) nextRaw(key, val *[]byte) bool { + b := s.b + if len(b) == 0 { + return false + } + + isKey := true + k := 0 + for i, c := range b { + switch c { + case '=': + if isKey { + isKey = false + *key = trimCookieArgNoCopy(b[:i], false) + k = i + 1 + } + case ';': + if isKey { + *key = (*key)[:0] + } + *val = trimCookieArgNoCopy(b[k:i], true) + j := i + 1 + if j < len(b) && b[j] == ' ' { + j++ + } + s.b = b[j:] + return true + } + } + + if isKey { + *key = (*key)[:0] + } + *val = trimCookieArgNoCopy(b[k:], true) + s.b = b[len(b):] + return true +} + func (s *cookieScanner) next(key, val *[]byte) bool { b := s.b if len(b) == 0 { @@ -550,7 +584,11 @@ func (s *cookieScanner) next(key, val *[]byte) bool { *key = (*key)[:0] } *val = decodeCookieArg(*val, b[k:i], true) - s.b = b[i+1:] + j := i + 1 + if j < len(b) && b[j] == ' ' { + j++ + } + s.b = b[j:] return true } } @@ -564,6 +602,12 @@ func (s *cookieScanner) next(key, val *[]byte) bool { } func decodeCookieArg(dst, src []byte, skipQuotes bool) []byte { + // Fast path: already trimmed and not quoted. + if n := len(src); n > 0 && src[0] != ' ' && src[n-1] != ' ' && + (!skipQuotes || n < 2 || src[0] != '"' || src[n-1] != '"') { + return append(dst[:0], src...) + } + for len(src) > 0 && src[0] == ' ' { src = src[1:] } @@ -578,8 +622,38 @@ func decodeCookieArg(dst, src []byte, skipQuotes bool) []byte { return append(dst[:0], src...) } +func trimCookieArgNoCopy(src []byte, skipQuotes bool) []byte { + for len(src) > 0 && src[0] == ' ' { + src = src[1:] + } + for len(src) > 0 && src[len(src)-1] == ' ' { + src = src[:len(src)-1] + } + if skipQuotes && len(src) > 1 && src[0] == '"' && src[len(src)-1] == '"' { + src = src[1 : len(src)-1] + } + return src +} + // caseInsensitiveCompare does a case insensitive equality comparison of // two []byte. Assumes only letters need to be matched. +func parseCookieExpires(src []byte) (time.Time, error) { + if t, ok := parseRFC1123DateGMT(src); ok { + return t, nil + } + + s := b2s(src) + + // UTC-anchored RFC1123 parsing behavior for non-GMT. + t, err := time.ParseInLocation(http.TimeFormat, s, time.UTC) + if err == nil { + return t, nil + } + + // Legacy cookie date compatibility used by net/http. + return time.Parse("Mon, 02-Jan-2006 15:04:05 MST", s) +} + func caseInsensitiveCompare(a, b []byte) bool { if len(a) != len(b) { return false diff --git a/header.go b/header.go index b7e8957..caff0bc 100644 --- a/header.go +++ b/header.go @@ -2713,7 +2713,7 @@ func parseTrailer(src []byte, dest []argsKV, disableNormalizing bool) ([]argsKV, if isBadTrailer(s.key) { return dest, 0, fmt.Errorf("forbidden trailer key %q", s.key) } - normalizeHeaderKey(s.key, disable) + normalizeHeaderKeyValidated(s.key, disable) dest = appendArgBytes(dest, s.key, s.value, argsHasValue) } if s.err != nil { @@ -2997,7 +2997,7 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) { return 0, fmt.Errorf("invalid header key %q", s.key) } } - normalizeHeaderKey(s.key, disableNormalizing) + normalizeHeaderKeyValidated(s.key, disableNormalizing) for _, ch := range s.value { if !validHeaderValueByte(ch) { @@ -3120,7 +3120,7 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { return 0, fmt.Errorf("invalid header key %q", s.key) } } - normalizeHeaderKey(s.key, disableNormalizing) + normalizeHeaderKeyValidated(s.key, disableNormalizing) for _, ch := range s.value { if !validHeaderValueByte(ch) { @@ -3355,6 +3355,19 @@ func normalizeHeaderKey(b []byte, disableNormalizing bool) { } } + normalizeHeaderKeyValidated(b, false) +} + +func normalizeHeaderKeyValidated(b []byte, disableNormalizing bool) { + if disableNormalizing { + return + } + + n := len(b) + if n == 0 { + return + } + upper := true for i, c := range b { if upper { diff --git a/status.go b/status.go index f92727c..e3d8f09 100644 --- a/status.go +++ b/status.go @@ -164,14 +164,56 @@ func StatusMessage(statusCode int) string { } func formatStatusLine(dst, protocol []byte, statusCode int, statusText []byte) []byte { + if len(statusText) == 0 { + statusText = s2b(StatusMessage(statusCode)) + } + need := len(protocol) + 1 + statusCodeLen(statusCode) + 1 + len(statusText) + len(strCRLF) + if cap(dst)-len(dst) < need { + ndst := make([]byte, len(dst), len(dst)+need) + copy(ndst, dst) + dst = ndst + } + dst = append(dst, protocol...) dst = append(dst, ' ') - dst = strconv.AppendInt(dst, int64(statusCode), 10) + dst = appendStatusCode(dst, statusCode) dst = append(dst, ' ') - if len(statusText) == 0 { - dst = append(dst, s2b(StatusMessage(statusCode))...) - } else { - dst = append(dst, statusText...) - } + dst = append(dst, statusText...) return append(dst, strCRLF...) } + +func statusCodeLen(statusCode int) int { + switch { + case statusCode < 0: + return digits10Int(statusCode) + case statusCode < 10: + return 1 + case statusCode < 100: + return 2 + case statusCode < 1000: + return 3 + default: + return digits10Int(statusCode) + } +} + +func digits10Int(v int) int { + n := 1 + for v <= -10 || v >= 10 { + v /= 10 + n++ + } + return n +} + +func appendStatusCode(dst []byte, statusCode int) []byte { + if statusCode >= 100 && statusCode <= 999 { + dst = append(dst, + byte('0'+statusCode/100), + byte('0'+(statusCode/10)%10), + byte('0'+statusCode%10), + ) + return dst + } + return strconv.AppendInt(dst, int64(statusCode), 10) +}