Issue #57: Server: added ability to disable header names' normalizing

This commit is contained in:
Aliaksandr Valialkin
2016-02-25 14:00:04 +02:00
parent c1437a71e6
commit dd6954f4b2
4 changed files with 177 additions and 38 deletions
+98 -37
View File
@@ -18,11 +18,11 @@ import (
// ResponseHeader instance MUST NOT be used from concurrently running
// goroutines.
type ResponseHeader struct {
statusCode int
noHTTP11 bool
connectionClose bool
disableNormalizing bool
noHTTP11 bool
connectionClose bool
statusCode int
contentLength int
contentLengthBytes []byte
@@ -43,9 +43,15 @@ type ResponseHeader struct {
// RequestHeader instance MUST NOT be used from concurrently running
// goroutines.
type RequestHeader struct {
noHTTP11 bool
connectionClose bool
isGet bool
disableNormalizing bool
noHTTP11 bool
connectionClose bool
isGet bool
// These two fields have been moved close to other bool fields
// for reducing RequestHeader object size.
cookiesCollected bool
rawHeadersParsed bool
contentLength int
contentLengthBytes []byte
@@ -59,11 +65,9 @@ type RequestHeader struct {
h []argsKV
bufKV argsKV
cookies []argsKV
cookiesCollected bool
cookies []argsKV
rawHeaders []byte
rawHeadersParsed bool
rawHeaders []byte
}
// SetContentRange sets 'Content-Range: bytes startPos-endPos/contentLength'
@@ -561,12 +565,49 @@ func (h *RequestHeader) Len() int {
return n
}
// DisableNormalizing disables header names' normalization.
//
// By default all the header names are normalized by uppercasing
// the first letter and all the first letters following dashes,
// while lowercasing all the other letters.
// Examples:
//
// * CONNECTION -> Connection
// * conteNT-tYPE -> Content-Type
// * foo-bar-baz -> Foo-Bar-Baz
//
// Disable header names' normalization only if know what are you doing.
func (h *RequestHeader) DisableNormalizing() {
h.disableNormalizing = true
}
// DisableNormalizing disables header names' normalization.
//
// By default all the header names are normalized by uppercasing
// the first letter and all the first letters following dashes,
// while lowercasing all the other letters.
// Examples:
//
// * CONNECTION -> Connection
// * conteNT-tYPE -> Content-Type
// * foo-bar-baz -> Foo-Bar-Baz
//
// Disable header names' normalization only if know what are you doing.
func (h *ResponseHeader) DisableNormalizing() {
h.disableNormalizing = true
}
// Reset clears response header.
func (h *ResponseHeader) Reset() {
h.disableNormalizing = false
h.resetSkipNormalize()
}
func (h *ResponseHeader) resetSkipNormalize() {
h.noHTTP11 = false
h.statusCode = 0
h.connectionClose = false
h.statusCode = 0
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
@@ -579,6 +620,11 @@ func (h *ResponseHeader) Reset() {
// Reset clears request header.
func (h *RequestHeader) Reset() {
h.disableNormalizing = false
h.resetSkipNormalize()
}
func (h *RequestHeader) resetSkipNormalize() {
h.noHTTP11 = false
h.connectionClose = false
h.isGet = false
@@ -603,9 +649,12 @@ func (h *RequestHeader) Reset() {
// CopyTo copies all the headers to dst.
func (h *ResponseHeader) CopyTo(dst *ResponseHeader) {
dst.Reset()
dst.disableNormalizing = h.disableNormalizing
dst.noHTTP11 = h.noHTTP11
dst.statusCode = h.statusCode
dst.connectionClose = h.connectionClose
dst.statusCode = h.statusCode
dst.contentLength = h.contentLength
dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...)
dst.contentType = append(dst.contentType[:0], h.contentType...)
@@ -617,8 +666,12 @@ func (h *ResponseHeader) CopyTo(dst *ResponseHeader) {
// CopyTo copies all the headers to dst.
func (h *RequestHeader) CopyTo(dst *RequestHeader) {
dst.Reset()
dst.disableNormalizing = h.disableNormalizing
dst.noHTTP11 = h.noHTTP11
dst.connectionClose = h.connectionClose
dst.isGet = h.isGet
dst.contentLength = h.contentLength
dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...)
dst.method = append(dst.method[:0], h.method...)
@@ -715,21 +768,21 @@ func (h *RequestHeader) VisitAll(f func(key, value []byte)) {
// Del deletes header with the given key.
func (h *ResponseHeader) Del(key string) {
k := getHeaderKeyBytes(&h.bufKV, key)
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.h = delArg(h.h, k)
}
// DelBytes deletes header with the given key.
func (h *ResponseHeader) DelBytes(key []byte) {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
h.h = delArg(h.h, h.bufKV.key)
}
// Del deletes header with the given key.
func (h *RequestHeader) Del(key string) {
h.parseRawHeaders()
k := getHeaderKeyBytes(&h.bufKV, key)
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.h = delArg(h.h, k)
}
@@ -737,13 +790,13 @@ func (h *RequestHeader) Del(key string) {
func (h *RequestHeader) DelBytes(key []byte) {
h.parseRawHeaders()
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
h.h = delArg(h.h, h.bufKV.key)
}
// Set sets the given 'key: value' header.
func (h *ResponseHeader) Set(key, value string) {
initHeaderKV(&h.bufKV, key, value)
initHeaderKV(&h.bufKV, key, value, h.disableNormalizing)
h.SetCanonical(h.bufKV.key, h.bufKV.value)
}
@@ -755,14 +808,14 @@ func (h *ResponseHeader) SetBytesK(key []byte, value string) {
// SetBytesV sets the given 'key: value' header.
func (h *ResponseHeader) SetBytesV(key string, value []byte) {
k := getHeaderKeyBytes(&h.bufKV, key)
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.SetCanonical(k, value)
}
// SetBytesKV sets the given 'key: value' header.
func (h *ResponseHeader) SetBytesKV(key, value []byte) {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
h.SetCanonical(h.bufKV.key, value)
}
@@ -826,7 +879,7 @@ func (h *RequestHeader) SetCookieBytesKV(key, value []byte) {
// Set sets the given 'key: value' header.
func (h *RequestHeader) Set(key, value string) {
initHeaderKV(&h.bufKV, key, value)
initHeaderKV(&h.bufKV, key, value, h.disableNormalizing)
h.SetCanonical(h.bufKV.key, h.bufKV.value)
}
@@ -838,14 +891,14 @@ func (h *RequestHeader) SetBytesK(key []byte, value string) {
// SetBytesV sets the given 'key: value' header.
func (h *RequestHeader) SetBytesV(key string, value []byte) {
k := getHeaderKeyBytes(&h.bufKV, key)
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.SetCanonical(k, value)
}
// SetBytesKV sets the given 'key: value' header.
func (h *RequestHeader) SetBytesKV(key, value []byte) {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
h.SetCanonical(h.bufKV.key, value)
}
@@ -887,7 +940,7 @@ func (h *RequestHeader) SetCanonical(key, value []byte) {
// Returned value is valid until the next call to ResponseHeader.
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) Peek(key string) []byte {
k := getHeaderKeyBytes(&h.bufKV, key)
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peek(k)
}
@@ -897,7 +950,7 @@ func (h *ResponseHeader) Peek(key string) []byte {
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) PeekBytes(key []byte) []byte {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
return h.peek(h.bufKV.key)
}
@@ -906,7 +959,7 @@ func (h *ResponseHeader) PeekBytes(key []byte) []byte {
// Returned value is valid until the next call to RequestHeader.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) Peek(key string) []byte {
k := getHeaderKeyBytes(&h.bufKV, key)
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peek(k)
}
@@ -916,7 +969,7 @@ func (h *RequestHeader) Peek(key string) []byte {
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) PeekBytes(key []byte) []byte {
h.bufKV.key = append(h.bufKV.key[:0], key...)
normalizeHeaderKey(h.bufKV.key)
normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
return h.peek(h.bufKV.key)
}
@@ -996,7 +1049,7 @@ func (h *ResponseHeader) Read(r *bufio.Reader) error {
return nil
}
if err != errNeedMore {
h.Reset()
h.resetSkipNormalize()
return err
}
n = r.Buffered() + 1
@@ -1004,7 +1057,7 @@ func (h *ResponseHeader) Read(r *bufio.Reader) error {
}
func (h *ResponseHeader) tryRead(r *bufio.Reader, n int) error {
h.Reset()
h.resetSkipNormalize()
b, err := r.Peek(n)
if len(b) == 0 {
// treat all errors on the first byte read as EOF
@@ -1049,7 +1102,7 @@ func (h *RequestHeader) Read(r *bufio.Reader) error {
return nil
}
if err != errNeedMore {
h.Reset()
h.resetSkipNormalize()
return err
}
n = r.Buffered() + 1
@@ -1057,7 +1110,7 @@ func (h *RequestHeader) Read(r *bufio.Reader) error {
}
func (h *RequestHeader) tryRead(r *bufio.Reader, n int) error {
h.Reset()
h.resetSkipNormalize()
b, err := r.Peek(n)
if len(b) == 0 {
// treat all errors on the first byte read as EOF
@@ -1479,6 +1532,7 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
var err error
var kv *argsKV
for s.next() {
@@ -1541,6 +1595,7 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
var err error
for s.next() {
switch {
@@ -1642,6 +1697,8 @@ type headerScanner struct {
key []byte
value []byte
err error
disableNormalizing bool
}
func (s *headerScanner) next() bool {
@@ -1660,7 +1717,7 @@ func (s *headerScanner) next() bool {
return false
}
s.key = s.b[:n]
normalizeHeaderKey(s.key)
normalizeHeaderKey(s.key, s.disableNormalizing)
n++
for len(s.b) > n && s.b[n] == ' ' {
n++
@@ -1696,18 +1753,22 @@ func nextLine(b []byte) ([]byte, []byte, error) {
return b[:n], b[nNext+1:], nil
}
func initHeaderKV(kv *argsKV, key, value string) {
kv.key = getHeaderKeyBytes(kv, key)
func initHeaderKV(kv *argsKV, key, value string, disableNormalizing bool) {
kv.key = getHeaderKeyBytes(kv, key, disableNormalizing)
kv.value = append(kv.value[:0], value...)
}
func getHeaderKeyBytes(kv *argsKV, key string) []byte {
func getHeaderKeyBytes(kv *argsKV, key string, disableNormalizing bool) []byte {
kv.key = append(kv.key[:0], key...)
normalizeHeaderKey(kv.key)
normalizeHeaderKey(kv.key, disableNormalizing)
return kv.key
}
func normalizeHeaderKey(b []byte) {
func normalizeHeaderKey(b []byte, disableNormalizing bool) {
if disableNormalizing {
return
}
n := len(b)
up := true
for i := 0; i < n; i++ {
+1 -1
View File
@@ -140,7 +140,7 @@ func benchmarkNormalizeHeaderKey(b *testing.B, src []byte) {
buf := make([]byte, len(src))
for pb.Next() {
copy(buf, src)
normalizeHeaderKey(buf)
normalizeHeaderKey(buf, false)
}
})
}
+22
View File
@@ -216,6 +216,24 @@ type Server struct {
// are suppressed in order to limit output log traffic.
LogAllErrors bool
// Header names are passed as-is without normalization
// if this option is set.
//
// Disabled header names' normalization may be useful only for proxying
// incoming requests to other servers expecting case-sensitive
// header names. See https://github.com/valyala/fasthttp/issues/57
// for details.
//
// By default request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
// Logger, which is used by RequestCtx.Logger().
//
// By default standard logger from log package is used.
@@ -1271,6 +1289,10 @@ func (s *Server) serveConn(c net.Conn) error {
}
if err == nil {
if s.DisableHeaderNamesNormalizing {
ctx.Request.Header.DisableNormalizing()
ctx.Response.Header.DisableNormalizing()
}
err = ctx.Request.readLimitBody(br, s.MaxRequestBodySize, s.GetOnly)
if br.Buffered() == 0 || err != nil {
releaseReader(s, br)
+56
View File
@@ -8,12 +8,68 @@ import (
"io/ioutil"
"net"
"os"
"strings"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
func TestServerDisableHeaderNamesNormalizing(t *testing.T) {
headerName := "CASE-senSITive-HEAder-NAME"
headerNameLower := strings.ToLower(headerName)
headerValue := "foobar baz"
s := &Server{
Handler: func(ctx *RequestCtx) {
hv := ctx.Request.Header.Peek(headerName)
if string(hv) != headerValue {
t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue)
}
hv = ctx.Request.Header.Peek(headerNameLower)
if len(hv) > 0 {
t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv)
}
ctx.Response.Header.Set(headerName, headerValue)
ctx.WriteString("ok")
ctx.SetContentType("aaa")
},
DisableHeaderNamesNormalizing: true,
}
rw := &readWriter{}
rw.r.WriteString(fmt.Sprintf("GET / HTTP/1.1\r\n%s: %s\r\nHost: google.com\r\n\r\n", headerName, headerValue))
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
resp.Header.DisableNormalizing()
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
hv := resp.Header.Peek(headerName)
if string(hv) != headerValue {
t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue)
}
hv = resp.Header.Peek(headerNameLower)
if len(hv) > 0 {
t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv)
}
}
func TestServerReduceMemoryUsageSerial(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()