From dd6954f4b2bb258ab87aa54ea54d3604a0c052e1 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Thu, 25 Feb 2016 14:00:04 +0200 Subject: [PATCH] Issue #57: Server: added ability to disable header names' normalizing --- header.go | 135 ++++++++++++++++++++++++++++++------------ header_timing_test.go | 2 +- server.go | 22 +++++++ server_test.go | 56 ++++++++++++++++++ 4 files changed, 177 insertions(+), 38 deletions(-) diff --git a/header.go b/header.go index 9ce62e1..93bc8c9 100644 --- a/header.go +++ b/header.go @@ -18,11 +18,11 @@ import ( // ResponseHeader instance MUST NOT be used from concurrently running // goroutines. type ResponseHeader struct { - statusCode int - - noHTTP11 bool - connectionClose bool + disableNormalizing bool + noHTTP11 bool + connectionClose bool + statusCode int contentLength int contentLengthBytes []byte @@ -43,9 +43,15 @@ type ResponseHeader struct { // RequestHeader instance MUST NOT be used from concurrently running // goroutines. type RequestHeader struct { - noHTTP11 bool - connectionClose bool - isGet bool + disableNormalizing bool + noHTTP11 bool + connectionClose bool + isGet bool + + // These two fields have been moved close to other bool fields + // for reducing RequestHeader object size. + cookiesCollected bool + rawHeadersParsed bool contentLength int contentLengthBytes []byte @@ -59,11 +65,9 @@ type RequestHeader struct { h []argsKV bufKV argsKV - cookies []argsKV - cookiesCollected bool + cookies []argsKV - rawHeaders []byte - rawHeadersParsed bool + rawHeaders []byte } // SetContentRange sets 'Content-Range: bytes startPos-endPos/contentLength' @@ -561,12 +565,49 @@ func (h *RequestHeader) Len() int { return n } +// DisableNormalizing disables header names' normalization. +// +// By default all the header names are normalized by uppercasing +// the first letter and all the first letters following dashes, +// while lowercasing all the other letters. +// Examples: +// +// * CONNECTION -> Connection +// * conteNT-tYPE -> Content-Type +// * foo-bar-baz -> Foo-Bar-Baz +// +// Disable header names' normalization only if know what are you doing. +func (h *RequestHeader) DisableNormalizing() { + h.disableNormalizing = true +} + +// DisableNormalizing disables header names' normalization. +// +// By default all the header names are normalized by uppercasing +// the first letter and all the first letters following dashes, +// while lowercasing all the other letters. +// Examples: +// +// * CONNECTION -> Connection +// * conteNT-tYPE -> Content-Type +// * foo-bar-baz -> Foo-Bar-Baz +// +// Disable header names' normalization only if know what are you doing. +func (h *ResponseHeader) DisableNormalizing() { + h.disableNormalizing = true +} + // Reset clears response header. func (h *ResponseHeader) Reset() { + h.disableNormalizing = false + h.resetSkipNormalize() +} + +func (h *ResponseHeader) resetSkipNormalize() { h.noHTTP11 = false - h.statusCode = 0 h.connectionClose = false + h.statusCode = 0 h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] @@ -579,6 +620,11 @@ func (h *ResponseHeader) Reset() { // Reset clears request header. func (h *RequestHeader) Reset() { + h.disableNormalizing = false + h.resetSkipNormalize() +} + +func (h *RequestHeader) resetSkipNormalize() { h.noHTTP11 = false h.connectionClose = false h.isGet = false @@ -603,9 +649,12 @@ func (h *RequestHeader) Reset() { // CopyTo copies all the headers to dst. func (h *ResponseHeader) CopyTo(dst *ResponseHeader) { dst.Reset() + + dst.disableNormalizing = h.disableNormalizing dst.noHTTP11 = h.noHTTP11 - dst.statusCode = h.statusCode dst.connectionClose = h.connectionClose + + dst.statusCode = h.statusCode dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) dst.contentType = append(dst.contentType[:0], h.contentType...) @@ -617,8 +666,12 @@ func (h *ResponseHeader) CopyTo(dst *ResponseHeader) { // CopyTo copies all the headers to dst. func (h *RequestHeader) CopyTo(dst *RequestHeader) { dst.Reset() + + dst.disableNormalizing = h.disableNormalizing dst.noHTTP11 = h.noHTTP11 dst.connectionClose = h.connectionClose + dst.isGet = h.isGet + dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) dst.method = append(dst.method[:0], h.method...) @@ -715,21 +768,21 @@ func (h *RequestHeader) VisitAll(f func(key, value []byte)) { // Del deletes header with the given key. func (h *ResponseHeader) Del(key string) { - k := getHeaderKeyBytes(&h.bufKV, key) + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.h = delArg(h.h, k) } // DelBytes deletes header with the given key. func (h *ResponseHeader) DelBytes(key []byte) { h.bufKV.key = append(h.bufKV.key[:0], key...) - normalizeHeaderKey(h.bufKV.key) + normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) h.h = delArg(h.h, h.bufKV.key) } // Del deletes header with the given key. func (h *RequestHeader) Del(key string) { h.parseRawHeaders() - k := getHeaderKeyBytes(&h.bufKV, key) + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.h = delArg(h.h, k) } @@ -737,13 +790,13 @@ func (h *RequestHeader) Del(key string) { func (h *RequestHeader) DelBytes(key []byte) { h.parseRawHeaders() h.bufKV.key = append(h.bufKV.key[:0], key...) - normalizeHeaderKey(h.bufKV.key) + normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) h.h = delArg(h.h, h.bufKV.key) } // Set sets the given 'key: value' header. func (h *ResponseHeader) Set(key, value string) { - initHeaderKV(&h.bufKV, key, value) + initHeaderKV(&h.bufKV, key, value, h.disableNormalizing) h.SetCanonical(h.bufKV.key, h.bufKV.value) } @@ -755,14 +808,14 @@ func (h *ResponseHeader) SetBytesK(key []byte, value string) { // SetBytesV sets the given 'key: value' header. func (h *ResponseHeader) SetBytesV(key string, value []byte) { - k := getHeaderKeyBytes(&h.bufKV, key) + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.SetCanonical(k, value) } // SetBytesKV sets the given 'key: value' header. func (h *ResponseHeader) SetBytesKV(key, value []byte) { h.bufKV.key = append(h.bufKV.key[:0], key...) - normalizeHeaderKey(h.bufKV.key) + normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) h.SetCanonical(h.bufKV.key, value) } @@ -826,7 +879,7 @@ func (h *RequestHeader) SetCookieBytesKV(key, value []byte) { // Set sets the given 'key: value' header. func (h *RequestHeader) Set(key, value string) { - initHeaderKV(&h.bufKV, key, value) + initHeaderKV(&h.bufKV, key, value, h.disableNormalizing) h.SetCanonical(h.bufKV.key, h.bufKV.value) } @@ -838,14 +891,14 @@ func (h *RequestHeader) SetBytesK(key []byte, value string) { // SetBytesV sets the given 'key: value' header. func (h *RequestHeader) SetBytesV(key string, value []byte) { - k := getHeaderKeyBytes(&h.bufKV, key) + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.SetCanonical(k, value) } // SetBytesKV sets the given 'key: value' header. func (h *RequestHeader) SetBytesKV(key, value []byte) { h.bufKV.key = append(h.bufKV.key[:0], key...) - normalizeHeaderKey(h.bufKV.key) + normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) h.SetCanonical(h.bufKV.key, value) } @@ -887,7 +940,7 @@ func (h *RequestHeader) SetCanonical(key, value []byte) { // Returned value is valid until the next call to ResponseHeader. // Do not store references to returned value. Make copies instead. func (h *ResponseHeader) Peek(key string) []byte { - k := getHeaderKeyBytes(&h.bufKV, key) + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) return h.peek(k) } @@ -897,7 +950,7 @@ func (h *ResponseHeader) Peek(key string) []byte { // Do not store references to returned value. Make copies instead. func (h *ResponseHeader) PeekBytes(key []byte) []byte { h.bufKV.key = append(h.bufKV.key[:0], key...) - normalizeHeaderKey(h.bufKV.key) + normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) return h.peek(h.bufKV.key) } @@ -906,7 +959,7 @@ func (h *ResponseHeader) PeekBytes(key []byte) []byte { // Returned value is valid until the next call to RequestHeader. // Do not store references to returned value. Make copies instead. func (h *RequestHeader) Peek(key string) []byte { - k := getHeaderKeyBytes(&h.bufKV, key) + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) return h.peek(k) } @@ -916,7 +969,7 @@ func (h *RequestHeader) Peek(key string) []byte { // Do not store references to returned value. Make copies instead. func (h *RequestHeader) PeekBytes(key []byte) []byte { h.bufKV.key = append(h.bufKV.key[:0], key...) - normalizeHeaderKey(h.bufKV.key) + normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) return h.peek(h.bufKV.key) } @@ -996,7 +1049,7 @@ func (h *ResponseHeader) Read(r *bufio.Reader) error { return nil } if err != errNeedMore { - h.Reset() + h.resetSkipNormalize() return err } n = r.Buffered() + 1 @@ -1004,7 +1057,7 @@ func (h *ResponseHeader) Read(r *bufio.Reader) error { } func (h *ResponseHeader) tryRead(r *bufio.Reader, n int) error { - h.Reset() + h.resetSkipNormalize() b, err := r.Peek(n) if len(b) == 0 { // treat all errors on the first byte read as EOF @@ -1049,7 +1102,7 @@ func (h *RequestHeader) Read(r *bufio.Reader) error { return nil } if err != errNeedMore { - h.Reset() + h.resetSkipNormalize() return err } n = r.Buffered() + 1 @@ -1057,7 +1110,7 @@ func (h *RequestHeader) Read(r *bufio.Reader) error { } func (h *RequestHeader) tryRead(r *bufio.Reader, n int) error { - h.Reset() + h.resetSkipNormalize() b, err := r.Peek(n) if len(b) == 0 { // treat all errors on the first byte read as EOF @@ -1479,6 +1532,7 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) { var s headerScanner s.b = buf + s.disableNormalizing = h.disableNormalizing var err error var kv *argsKV for s.next() { @@ -1541,6 +1595,7 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { var s headerScanner s.b = buf + s.disableNormalizing = h.disableNormalizing var err error for s.next() { switch { @@ -1642,6 +1697,8 @@ type headerScanner struct { key []byte value []byte err error + + disableNormalizing bool } func (s *headerScanner) next() bool { @@ -1660,7 +1717,7 @@ func (s *headerScanner) next() bool { return false } s.key = s.b[:n] - normalizeHeaderKey(s.key) + normalizeHeaderKey(s.key, s.disableNormalizing) n++ for len(s.b) > n && s.b[n] == ' ' { n++ @@ -1696,18 +1753,22 @@ func nextLine(b []byte) ([]byte, []byte, error) { return b[:n], b[nNext+1:], nil } -func initHeaderKV(kv *argsKV, key, value string) { - kv.key = getHeaderKeyBytes(kv, key) +func initHeaderKV(kv *argsKV, key, value string, disableNormalizing bool) { + kv.key = getHeaderKeyBytes(kv, key, disableNormalizing) kv.value = append(kv.value[:0], value...) } -func getHeaderKeyBytes(kv *argsKV, key string) []byte { +func getHeaderKeyBytes(kv *argsKV, key string, disableNormalizing bool) []byte { kv.key = append(kv.key[:0], key...) - normalizeHeaderKey(kv.key) + normalizeHeaderKey(kv.key, disableNormalizing) return kv.key } -func normalizeHeaderKey(b []byte) { +func normalizeHeaderKey(b []byte, disableNormalizing bool) { + if disableNormalizing { + return + } + n := len(b) up := true for i := 0; i < n; i++ { diff --git a/header_timing_test.go b/header_timing_test.go index 872a812..a6f7a11 100644 --- a/header_timing_test.go +++ b/header_timing_test.go @@ -140,7 +140,7 @@ func benchmarkNormalizeHeaderKey(b *testing.B, src []byte) { buf := make([]byte, len(src)) for pb.Next() { copy(buf, src) - normalizeHeaderKey(buf) + normalizeHeaderKey(buf, false) } }) } diff --git a/server.go b/server.go index 2495db3..1a89085 100644 --- a/server.go +++ b/server.go @@ -216,6 +216,24 @@ type Server struct { // are suppressed in order to limit output log traffic. LogAllErrors bool + // Header names are passed as-is without normalization + // if this option is set. + // + // Disabled header names' normalization may be useful only for proxying + // incoming requests to other servers expecting case-sensitive + // header names. See https://github.com/valyala/fasthttp/issues/57 + // for details. + // + // By default request and response header names are normalized, i.e. + // The first letter and the first letters following dashes + // are uppercased, while all the other letters are lowercased. + // Examples: + // + // * HOST -> Host + // * content-type -> Content-Type + // * cONTENT-lenGTH -> Content-Length + DisableHeaderNamesNormalizing bool + // Logger, which is used by RequestCtx.Logger(). // // By default standard logger from log package is used. @@ -1271,6 +1289,10 @@ func (s *Server) serveConn(c net.Conn) error { } if err == nil { + if s.DisableHeaderNamesNormalizing { + ctx.Request.Header.DisableNormalizing() + ctx.Response.Header.DisableNormalizing() + } err = ctx.Request.readLimitBody(br, s.MaxRequestBodySize, s.GetOnly) if br.Buffered() == 0 || err != nil { releaseReader(s, br) diff --git a/server_test.go b/server_test.go index 0bd6c3f..2412319 100644 --- a/server_test.go +++ b/server_test.go @@ -8,12 +8,68 @@ import ( "io/ioutil" "net" "os" + "strings" "testing" "time" "github.com/valyala/fasthttp/fasthttputil" ) +func TestServerDisableHeaderNamesNormalizing(t *testing.T) { + headerName := "CASE-senSITive-HEAder-NAME" + headerNameLower := strings.ToLower(headerName) + headerValue := "foobar baz" + s := &Server{ + Handler: func(ctx *RequestCtx) { + hv := ctx.Request.Header.Peek(headerName) + if string(hv) != headerValue { + t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue) + } + hv = ctx.Request.Header.Peek(headerNameLower) + if len(hv) > 0 { + t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv) + } + ctx.Response.Header.Set(headerName, headerValue) + ctx.WriteString("ok") + ctx.SetContentType("aaa") + }, + DisableHeaderNamesNormalizing: true, + } + + rw := &readWriter{} + rw.r.WriteString(fmt.Sprintf("GET / HTTP/1.1\r\n%s: %s\r\nHost: google.com\r\n\r\n", headerName, headerValue)) + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } + + br := bufio.NewReader(&rw.w) + var resp Response + resp.Header.DisableNormalizing() + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + hv := resp.Header.Peek(headerName) + if string(hv) != headerValue { + t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue) + } + hv = resp.Header.Peek(headerNameLower) + if len(hv) > 0 { + t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv) + } +} + func TestServerReduceMemoryUsageSerial(t *testing.T) { ln := fasthttputil.NewInmemoryListener()