mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
Don't allow spaces in request header keys
See: https://github.com/golang/go/commit/6e6f4aaf70c8b1cc81e65a26332aa9409de03ad8 Reject any non GET or HEAD requests with a 400. We can't reject GET or HEAD requests with bad headers as we delay parsing of these headers until the user asks for one. So in this case we just ignore the header and don't return a value for it.
This commit is contained in:
+3
-3
@@ -351,7 +351,7 @@ func TestClientReadTimeout(t *testing.T) {
|
||||
timeout = true
|
||||
}
|
||||
},
|
||||
Logger: &customLogger{}, // Don't print closed pipe errors.
|
||||
Logger: &testLogger{}, // Don't print closed pipe errors.
|
||||
}
|
||||
go s.Serve(ln)
|
||||
|
||||
@@ -631,7 +631,7 @@ func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay
|
||||
MaxConns: maxConns,
|
||||
MaxPendingRequests: concurrency,
|
||||
MaxBatchDelay: maxBatchDelay,
|
||||
Logger: &customLogger{},
|
||||
Logger: &testLogger{},
|
||||
}
|
||||
|
||||
clientStopCh := make(chan struct{}, concurrency)
|
||||
@@ -1807,7 +1807,7 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch
|
||||
ctx.PostArgs().WriteTo(ctx)
|
||||
}
|
||||
},
|
||||
Logger: &customLogger{}, // Ignore log output.
|
||||
Logger: &testLogger{}, // Ignore log output.
|
||||
}
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
@@ -1893,6 +1893,13 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
|
||||
var err error
|
||||
for s.next() {
|
||||
if len(s.key) > 0 {
|
||||
// Spaces between the header key and colon are not allowed.
|
||||
// See RFC 7230, Section 3.2.4.
|
||||
if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 {
|
||||
err = fmt.Errorf("invalid header key %q", s.key)
|
||||
continue
|
||||
}
|
||||
|
||||
switch s.key[0] | 0x20 {
|
||||
case 'h':
|
||||
if caseInsensitiveCompare(s.key, strHost) {
|
||||
@@ -1911,7 +1918,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
|
||||
}
|
||||
if caseInsensitiveCompare(s.key, strContentLength) {
|
||||
if h.contentLength != -1 {
|
||||
if h.contentLength, err = parseContentLength(s.value); err != nil {
|
||||
var nerr error
|
||||
if h.contentLength, nerr = parseContentLength(s.value); nerr != nil {
|
||||
if err == nil {
|
||||
err = nerr
|
||||
}
|
||||
h.contentLength = -2
|
||||
} else {
|
||||
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
|
||||
@@ -1940,9 +1951,12 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
|
||||
}
|
||||
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
|
||||
}
|
||||
if s.err != nil {
|
||||
if s.err != nil && err == nil {
|
||||
err = s.err
|
||||
}
|
||||
if err != nil {
|
||||
h.connectionClose = true
|
||||
return 0, s.err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if h.contentLength < 0 {
|
||||
|
||||
+3
-4
@@ -2175,10 +2175,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) {
|
||||
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar",
|
||||
0, "/foo/bar", "", "", "", "foobar")
|
||||
|
||||
// post with invalid content-length
|
||||
testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty",
|
||||
-2, "/a", "bb", "", "aa", "qwerty")
|
||||
|
||||
// post without content-length and content-type
|
||||
testRequestHeaderReadSuccess(t, h, "POST /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nzxc",
|
||||
-2, "/aaa", "aaa.com", "", "", "zxc")
|
||||
@@ -2234,6 +2230,9 @@ func TestRequestHeaderReadError(t *testing.T) {
|
||||
|
||||
// missing RequestURI
|
||||
testRequestHeaderReadError(t, h, "GET HTTP/1.1\r\nHost: google.com\r\n\r\n")
|
||||
|
||||
// post with invalid content-length
|
||||
testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty")
|
||||
}
|
||||
|
||||
func testResponseHeaderReadError(t *testing.T, h *ResponseHeader, headers string) {
|
||||
|
||||
+79
-16
@@ -23,6 +23,69 @@ import (
|
||||
// Make sure RequestCtx implements context.Context
|
||||
var _ context.Context = &RequestCtx{}
|
||||
|
||||
func TestServerInvalidHeader(t *testing.T) {
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
if ctx.Request.Header.Peek("Foo") != nil || ctx.Request.Header.Peek("Foo ") != nil {
|
||||
t.Fatal("expected Foo header")
|
||||
}
|
||||
},
|
||||
Logger: &testLogger{},
|
||||
}
|
||||
|
||||
ln := fasthttputil.NewInmemoryListener()
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(ln); err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := ln.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
br := bufio.NewReader(c)
|
||||
var resp Response
|
||||
if err := resp.Read(br); err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if resp.StatusCode() != StatusBadRequest {
|
||||
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest)
|
||||
}
|
||||
|
||||
c, err = ln.Dial()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
br = bufio.NewReader(c)
|
||||
if err := resp.Read(br); err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// Since we delay header parsing for GET and HEAD requests until the users asks for it
|
||||
// we can't return 400 in case of a bad header.
|
||||
// Inside the handler above we make sure to test that the invalid Foo header was ignored.
|
||||
if resp.StatusCode() != StatusOK {
|
||||
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
|
||||
}
|
||||
|
||||
if err := c.Close(); err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerConnState(t *testing.T) {
|
||||
states := make([]string, 0)
|
||||
s := &Server{
|
||||
@@ -595,7 +658,7 @@ func TestServerMaxConnsPerIPLimit(t *testing.T) {
|
||||
ctx.WriteString("OK")
|
||||
},
|
||||
MaxConnsPerIP: 1,
|
||||
Logger: &customLogger{},
|
||||
Logger: &testLogger{},
|
||||
}
|
||||
|
||||
ln := fasthttputil.NewInmemoryListener()
|
||||
@@ -697,7 +760,7 @@ func TestServerConcurrencyLimit(t *testing.T) {
|
||||
ctx.WriteString("OK")
|
||||
},
|
||||
Concurrency: 1,
|
||||
Logger: &customLogger{},
|
||||
Logger: &testLogger{},
|
||||
}
|
||||
|
||||
ln := fasthttputil.NewInmemoryListener()
|
||||
@@ -1970,7 +2033,7 @@ func TestRequestCtxHijack(t *testing.T) {
|
||||
|
||||
func TestRequestCtxInit(t *testing.T) {
|
||||
var ctx RequestCtx
|
||||
var logger customLogger
|
||||
var logger testLogger
|
||||
globalConnID = 0x123456
|
||||
ctx.Init(&ctx.Request, zeroTCPAddr, &logger)
|
||||
ip := ctx.RemoteIP()
|
||||
@@ -2443,19 +2506,8 @@ func TestServerEmptyResponse(t *testing.T) {
|
||||
verifyResponse(t, br, 200, string(defaultContentType), "")
|
||||
}
|
||||
|
||||
type customLogger struct {
|
||||
lock sync.Mutex
|
||||
out string
|
||||
}
|
||||
|
||||
func (cl *customLogger) Printf(format string, args ...interface{}) {
|
||||
cl.lock.Lock()
|
||||
cl.out += fmt.Sprintf(format, args...)[6:] + "\n"
|
||||
cl.lock.Unlock()
|
||||
}
|
||||
|
||||
func TestServerLogger(t *testing.T) {
|
||||
cl := &customLogger{}
|
||||
cl := &testLogger{}
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
logger := ctx.Logger()
|
||||
@@ -2740,7 +2792,7 @@ func TestShutdownReuse(t *testing.T) {
|
||||
ctx.Success("aaa/bbb", []byte("real response"))
|
||||
},
|
||||
ReadTimeout: time.Second,
|
||||
Logger: &customLogger{}, // Ignore log output.
|
||||
Logger: &testLogger{}, // Ignore log output.
|
||||
}
|
||||
go func() {
|
||||
if err := s.Serve(ln); err != nil {
|
||||
@@ -3030,3 +3082,14 @@ func (rw *readWriter) SetReadDeadline(t time.Time) error {
|
||||
func (rw *readWriter) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
lock sync.Mutex
|
||||
out string
|
||||
}
|
||||
|
||||
func (cl *testLogger) Printf(format string, args ...interface{}) {
|
||||
cl.lock.Lock()
|
||||
cl.out += fmt.Sprintf(format, args...)[6:] + "\n"
|
||||
cl.lock.Unlock()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user