Don't allow spaces in request header keys

See: https://github.com/golang/go/commit/6e6f4aaf70c8b1cc81e65a26332aa9409de03ad8

Reject any non GET or HEAD requests with a 400.

We can't reject GET or HEAD requests with bad headers as we delay
parsing of these headers until the user asks for one. So in this case we
just ignore the header and don't return a value for it.
This commit is contained in:
Erik Dubbelboer
2019-10-16 00:57:13 +02:00
parent 4ebe993965
commit 9dbe5fc77c
4 changed files with 102 additions and 26 deletions
+3 -3
View File
@@ -351,7 +351,7 @@ func TestClientReadTimeout(t *testing.T) {
timeout = true
}
},
Logger: &customLogger{}, // Don't print closed pipe errors.
Logger: &testLogger{}, // Don't print closed pipe errors.
}
go s.Serve(ln)
@@ -631,7 +631,7 @@ func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay
MaxConns: maxConns,
MaxPendingRequests: concurrency,
MaxBatchDelay: maxBatchDelay,
Logger: &customLogger{},
Logger: &testLogger{},
}
clientStopCh := make(chan struct{}, concurrency)
@@ -1807,7 +1807,7 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch
ctx.PostArgs().WriteTo(ctx)
}
},
Logger: &customLogger{}, // Ignore log output.
Logger: &testLogger{}, // Ignore log output.
}
ch := make(chan struct{})
go func() {
+17 -3
View File
@@ -1893,6 +1893,13 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
var err error
for s.next() {
if len(s.key) > 0 {
// Spaces between the header key and colon are not allowed.
// See RFC 7230, Section 3.2.4.
if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 {
err = fmt.Errorf("invalid header key %q", s.key)
continue
}
switch s.key[0] | 0x20 {
case 'h':
if caseInsensitiveCompare(s.key, strHost) {
@@ -1911,7 +1918,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
}
if caseInsensitiveCompare(s.key, strContentLength) {
if h.contentLength != -1 {
if h.contentLength, err = parseContentLength(s.value); err != nil {
var nerr error
if h.contentLength, nerr = parseContentLength(s.value); nerr != nil {
if err == nil {
err = nerr
}
h.contentLength = -2
} else {
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
@@ -1940,9 +1951,12 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
if s.err != nil {
if s.err != nil && err == nil {
err = s.err
}
if err != nil {
h.connectionClose = true
return 0, s.err
return 0, err
}
if h.contentLength < 0 {
+3 -4
View File
@@ -2175,10 +2175,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) {
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar",
0, "/foo/bar", "", "", "", "foobar")
// post with invalid content-length
testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty",
-2, "/a", "bb", "", "aa", "qwerty")
// post without content-length and content-type
testRequestHeaderReadSuccess(t, h, "POST /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nzxc",
-2, "/aaa", "aaa.com", "", "", "zxc")
@@ -2234,6 +2230,9 @@ func TestRequestHeaderReadError(t *testing.T) {
// missing RequestURI
testRequestHeaderReadError(t, h, "GET HTTP/1.1\r\nHost: google.com\r\n\r\n")
// post with invalid content-length
testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty")
}
func testResponseHeaderReadError(t *testing.T, h *ResponseHeader, headers string) {
+79 -16
View File
@@ -23,6 +23,69 @@ import (
// Make sure RequestCtx implements context.Context
var _ context.Context = &RequestCtx{}
func TestServerInvalidHeader(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
if ctx.Request.Header.Peek("Foo") != nil || ctx.Request.Header.Peek("Foo ") != nil {
t.Fatal("expected Foo header")
}
},
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil {
t.Fatal(err)
}
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if resp.StatusCode() != StatusBadRequest {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest)
}
c, err = ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil {
t.Fatal(err)
}
br = bufio.NewReader(c)
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
// Since we delay header parsing for GET and HEAD requests until the users asks for it
// we can't return 400 in case of a bad header.
// Inside the handler above we make sure to test that the invalid Foo header was ignored.
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := c.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
func TestServerConnState(t *testing.T) {
states := make([]string, 0)
s := &Server{
@@ -595,7 +658,7 @@ func TestServerMaxConnsPerIPLimit(t *testing.T) {
ctx.WriteString("OK")
},
MaxConnsPerIP: 1,
Logger: &customLogger{},
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
@@ -697,7 +760,7 @@ func TestServerConcurrencyLimit(t *testing.T) {
ctx.WriteString("OK")
},
Concurrency: 1,
Logger: &customLogger{},
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
@@ -1970,7 +2033,7 @@ func TestRequestCtxHijack(t *testing.T) {
func TestRequestCtxInit(t *testing.T) {
var ctx RequestCtx
var logger customLogger
var logger testLogger
globalConnID = 0x123456
ctx.Init(&ctx.Request, zeroTCPAddr, &logger)
ip := ctx.RemoteIP()
@@ -2443,19 +2506,8 @@ func TestServerEmptyResponse(t *testing.T) {
verifyResponse(t, br, 200, string(defaultContentType), "")
}
type customLogger struct {
lock sync.Mutex
out string
}
func (cl *customLogger) Printf(format string, args ...interface{}) {
cl.lock.Lock()
cl.out += fmt.Sprintf(format, args...)[6:] + "\n"
cl.lock.Unlock()
}
func TestServerLogger(t *testing.T) {
cl := &customLogger{}
cl := &testLogger{}
s := &Server{
Handler: func(ctx *RequestCtx) {
logger := ctx.Logger()
@@ -2740,7 +2792,7 @@ func TestShutdownReuse(t *testing.T) {
ctx.Success("aaa/bbb", []byte("real response"))
},
ReadTimeout: time.Second,
Logger: &customLogger{}, // Ignore log output.
Logger: &testLogger{}, // Ignore log output.
}
go func() {
if err := s.Serve(ln); err != nil {
@@ -3030,3 +3082,14 @@ func (rw *readWriter) SetReadDeadline(t time.Time) error {
func (rw *readWriter) SetWriteDeadline(t time.Time) error {
return nil
}
type testLogger struct {
lock sync.Mutex
out string
}
func (cl *testLogger) Printf(format string, args ...interface{}) {
cl.lock.Lock()
cl.out += fmt.Sprintf(format, args...)[6:] + "\n"
cl.lock.Unlock()
}