diff --git a/http.go b/http.go index ce30f8c..737673a 100644 --- a/http.go +++ b/http.go @@ -926,6 +926,10 @@ var ErrGetOnly = errors.New("non-GET request received") // io.EOF is returned if r is closed before reading the first header byte. func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { req.resetSkipHeader() + if err := req.Header.Read(r); err != nil { + return err + } + return req.readLimitBody(r, maxBodySize, false) } @@ -933,10 +937,6 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool // Do not reset the request here - the caller must reset it before // calling this method. - err := req.Header.Read(r) - if err != nil { - return err - } if getOnly && !req.Header.IsGet() { return ErrGetOnly } diff --git a/server.go b/server.go index def5eb1..5fa4e12 100644 --- a/server.go +++ b/server.go @@ -167,6 +167,11 @@ type Server struct { // * ErrBrokenChunks ErrorHandler func(ctx *RequestCtx, err error) + // HeaderReceived is called after receiving the header + // + // non zero RequestConfig field values will overwrite the default configs + HeaderReceived func(header *RequestHeader) RequestConfig + // Server name for sending in response headers. // // Default server name is used if left blank. @@ -415,6 +420,21 @@ func TimeoutWithCodeHandler(h RequestHandler, timeout time.Duration, msg string, } } +//RequestConfig configure the per request deadline and body limits +type RequestConfig struct { + // ReadTimeout is the maximum duration for reading the entire + // request body. + // a zero value means that default values will be honored + ReadTimeout time.Duration + // WriteTimeout is the maximum duration before timing out + // writes of the response. + // a zero value means that default values will be honored + WriteTimeout time.Duration + // Maximum request body size. + // a zero value means that default values will be honored + MaxRequestBodySize int +} + // CompressHandler returns RequestHandler that transparently compresses // response body generated by h if the request contains 'gzip' or 'deflate' // 'Accept-Encoding' header. @@ -1834,6 +1854,7 @@ func (s *Server) serveConn(c net.Conn) error { if maxRequestBodySize <= 0 { maxRequestBodySize = DefaultMaxRequestBodySize } + writeTimeout := s.WriteTimeout ctx := s.acquireCtx(c) ctx.connTime = connTime @@ -1896,17 +1917,35 @@ func (s *Server) serveConn(c net.Conn) error { panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err)) } } - if s.DisableHeaderNamesNormalizing { ctx.Request.Header.DisableNormalizing() ctx.Response.Header.DisableNormalizing() } - // reading Headers and Body - err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly) + // reading Headers + if err = ctx.Request.Header.Read(br); err == nil { + if onHdrRecv := s.HeaderReceived; onHdrRecv != nil { + reqConf := onHdrRecv(&ctx.Request.Header) + if reqConf.ReadTimeout > 0 { + deadline := time.Now().Add(reqConf.ReadTimeout) + if err := c.SetReadDeadline(deadline); err != nil { + panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", deadline, err)) + } + } + if reqConf.MaxRequestBodySize > 0 { + maxRequestBodySize = reqConf.MaxRequestBodySize + } + if reqConf.WriteTimeout > 0 { + writeTimeout = reqConf.WriteTimeout + } + } + //read body + err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly) + } if err == nil { // If we read any bytes off the wire, we're active. s.setState(c, StateActive) } + if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil { releaseReader(s, br) br = nil @@ -2003,8 +2042,8 @@ func (s *Server) serveConn(c net.Conn) error { ctx.SetConnectionClose() } - if s.WriteTimeout > 0 { - if err := c.SetWriteDeadline(time.Now().Add(s.WriteTimeout)); err != nil { + if writeTimeout > 0 { + if err := c.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", s.WriteTimeout, err)) } } diff --git a/server_test.go b/server_test.go index a68d432..1db50c2 100644 --- a/server_test.go +++ b/server_test.go @@ -2849,6 +2849,142 @@ func TestShutdownErr(t *testing.T) { verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") } +func TestMaxBodySizePerRequest(t *testing.T) { + s := &Server{ + Handler: func(ctx *RequestCtx) { + // do nothing :) + }, + HeaderReceived: func(header *RequestHeader) RequestConfig { + return RequestConfig{ + MaxRequestBodySize: 5 << 10, + } + }, + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, + MaxRequestBodySize: 1 << 20, + } + + rw := &readWriter{} + rw.r.WriteString(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", (5<<10)+1, strings.Repeat("a", (5<<10)+1))) + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err != ErrBodyTooLarge { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } +} + +func TestMaxReadTimeoutPerRequest(t *testing.T) { + headers := []byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", 5*1024)) + s := &Server{ + Handler: func(ctx *RequestCtx) { + t.Fatal("shouldn't reach handler") + }, + HeaderReceived: func(header *RequestHeader) RequestConfig { + return RequestConfig{ + ReadTimeout: time.Millisecond, + } + }, + ReadBufferSize: len(headers), + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, + } + + pipe := fasthttputil.NewPipeConns() + cc, sc := pipe.Conn1(), pipe.Conn2() + go func() { + //write headers + _, err := cc.Write(headers) + if err != nil { + t.Fatal(err) + } + //write body + for i := 0; i < 5*1024; i++ { + time.Sleep(time.Millisecond) + cc.Write([]byte{'a'}) + } + }() + ch := make(chan error) + go func() { + ch <- s.ServeConn(sc) + }() + + select { + case err := <-ch: + if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(time.Second): + t.Fatalf("test timeout") + } +} + +func TestMaxWriteTimeoutPerRequest(t *testing.T) { + headers := []byte("GET /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: aa\r\n\r\n") + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.SetBodyStreamWriter(func(w *bufio.Writer) { + var buf [192]byte + for { + w.Write(buf[:]) + } + }) + }, + HeaderReceived: func(header *RequestHeader) RequestConfig { + return RequestConfig{ + WriteTimeout: time.Millisecond, + } + }, + ReadBufferSize: 192, + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, + } + + pipe := fasthttputil.NewPipeConns() + cc, sc := pipe.Conn1(), pipe.Conn2() + + var resp Response + go func() { + //write headers + _, err := cc.Write(headers) + if err != nil { + t.Fatal(err) + } + br := bufio.NewReaderSize(cc, 192) + err = resp.Header.Read(br) + if err != nil { + t.Fatal(err) + } + + var chunk [192]byte + for { + time.Sleep(time.Millisecond) + br.Read(chunk[:]) + } + }() + ch := make(chan error) + go func() { + ch <- s.ServeConn(sc) + }() + + select { + case err := <-ch: + if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(time.Second): + t.Fatalf("test timeout") + } +} + func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) { var resp Response if err := resp.Read(r); err != nil {