Support {readTimeout,maxBodySize,writeTimeout} per request based on the headers. (#598)

This commit is contained in:
Marcelo Pires
2019-07-12 14:42:07 +02:00
committed by Erik Dubbelboer
parent a0248ed3a1
commit ccaae97f5b
3 changed files with 184 additions and 9 deletions
+4 -4
View File
@@ -926,6 +926,10 @@ var ErrGetOnly = errors.New("non-GET request received")
// io.EOF is returned if r is closed before reading the first header byte.
func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
req.resetSkipHeader()
if err := req.Header.Read(r); err != nil {
return err
}
return req.readLimitBody(r, maxBodySize, false)
}
@@ -933,10 +937,6 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool
// Do not reset the request here - the caller must reset it before
// calling this method.
err := req.Header.Read(r)
if err != nil {
return err
}
if getOnly && !req.Header.IsGet() {
return ErrGetOnly
}
+44 -5
View File
@@ -167,6 +167,11 @@ type Server struct {
// * ErrBrokenChunks
ErrorHandler func(ctx *RequestCtx, err error)
// HeaderReceived is called after receiving the header
//
// non zero RequestConfig field values will overwrite the default configs
HeaderReceived func(header *RequestHeader) RequestConfig
// Server name for sending in response headers.
//
// Default server name is used if left blank.
@@ -415,6 +420,21 @@ func TimeoutWithCodeHandler(h RequestHandler, timeout time.Duration, msg string,
}
}
//RequestConfig configure the per request deadline and body limits
type RequestConfig struct {
// ReadTimeout is the maximum duration for reading the entire
// request body.
// a zero value means that default values will be honored
ReadTimeout time.Duration
// WriteTimeout is the maximum duration before timing out
// writes of the response.
// a zero value means that default values will be honored
WriteTimeout time.Duration
// Maximum request body size.
// a zero value means that default values will be honored
MaxRequestBodySize int
}
// CompressHandler returns RequestHandler that transparently compresses
// response body generated by h if the request contains 'gzip' or 'deflate'
// 'Accept-Encoding' header.
@@ -1834,6 +1854,7 @@ func (s *Server) serveConn(c net.Conn) error {
if maxRequestBodySize <= 0 {
maxRequestBodySize = DefaultMaxRequestBodySize
}
writeTimeout := s.WriteTimeout
ctx := s.acquireCtx(c)
ctx.connTime = connTime
@@ -1896,17 +1917,35 @@ func (s *Server) serveConn(c net.Conn) error {
panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err))
}
}
if s.DisableHeaderNamesNormalizing {
ctx.Request.Header.DisableNormalizing()
ctx.Response.Header.DisableNormalizing()
}
// reading Headers and Body
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly)
// reading Headers
if err = ctx.Request.Header.Read(br); err == nil {
if onHdrRecv := s.HeaderReceived; onHdrRecv != nil {
reqConf := onHdrRecv(&ctx.Request.Header)
if reqConf.ReadTimeout > 0 {
deadline := time.Now().Add(reqConf.ReadTimeout)
if err := c.SetReadDeadline(deadline); err != nil {
panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", deadline, err))
}
}
if reqConf.MaxRequestBodySize > 0 {
maxRequestBodySize = reqConf.MaxRequestBodySize
}
if reqConf.WriteTimeout > 0 {
writeTimeout = reqConf.WriteTimeout
}
}
//read body
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly)
}
if err == nil {
// If we read any bytes off the wire, we're active.
s.setState(c, StateActive)
}
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
releaseReader(s, br)
br = nil
@@ -2003,8 +2042,8 @@ func (s *Server) serveConn(c net.Conn) error {
ctx.SetConnectionClose()
}
if s.WriteTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(s.WriteTimeout)); err != nil {
if writeTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", s.WriteTimeout, err))
}
}
+136
View File
@@ -2849,6 +2849,142 @@ func TestShutdownErr(t *testing.T) {
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}
func TestMaxBodySizePerRequest(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
// do nothing :)
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
return RequestConfig{
MaxRequestBodySize: 5 << 10,
}
},
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
MaxRequestBodySize: 1 << 20,
}
rw := &readWriter{}
rw.r.WriteString(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", (5<<10)+1, strings.Repeat("a", (5<<10)+1)))
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != ErrBodyTooLarge {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
}
func TestMaxReadTimeoutPerRequest(t *testing.T) {
headers := []byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", 5*1024))
s := &Server{
Handler: func(ctx *RequestCtx) {
t.Fatal("shouldn't reach handler")
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
return RequestConfig{
ReadTimeout: time.Millisecond,
}
},
ReadBufferSize: len(headers),
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
go func() {
//write headers
_, err := cc.Write(headers)
if err != nil {
t.Fatal(err)
}
//write body
for i := 0; i < 5*1024; i++ {
time.Sleep(time.Millisecond)
cc.Write([]byte{'a'})
}
}()
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case err := <-ch:
if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("test timeout")
}
}
func TestMaxWriteTimeoutPerRequest(t *testing.T) {
headers := []byte("GET /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: aa\r\n\r\n")
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
var buf [192]byte
for {
w.Write(buf[:])
}
})
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
return RequestConfig{
WriteTimeout: time.Millisecond,
}
},
ReadBufferSize: 192,
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
var resp Response
go func() {
//write headers
_, err := cc.Write(headers)
if err != nil {
t.Fatal(err)
}
br := bufio.NewReaderSize(cc, 192)
err = resp.Header.Read(br)
if err != nil {
t.Fatal(err)
}
var chunk [192]byte
for {
time.Sleep(time.Millisecond)
br.Read(chunk[:])
}
}()
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case err := <-ch:
if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("test timeout")
}
}
func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) {
var resp Response
if err := resp.Read(r); err != nil {