mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-26 17:46:34 +03:00
Support {readTimeout,maxBodySize,writeTimeout} per request based on the headers. (#598)
This commit is contained in:
committed by
Erik Dubbelboer
parent
a0248ed3a1
commit
ccaae97f5b
@@ -926,6 +926,10 @@ var ErrGetOnly = errors.New("non-GET request received")
|
||||
// io.EOF is returned if r is closed before reading the first header byte.
|
||||
func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
|
||||
req.resetSkipHeader()
|
||||
if err := req.Header.Read(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return req.readLimitBody(r, maxBodySize, false)
|
||||
}
|
||||
|
||||
@@ -933,10 +937,6 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool
|
||||
// Do not reset the request here - the caller must reset it before
|
||||
// calling this method.
|
||||
|
||||
err := req.Header.Read(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if getOnly && !req.Header.IsGet() {
|
||||
return ErrGetOnly
|
||||
}
|
||||
|
||||
@@ -167,6 +167,11 @@ type Server struct {
|
||||
// * ErrBrokenChunks
|
||||
ErrorHandler func(ctx *RequestCtx, err error)
|
||||
|
||||
// HeaderReceived is called after receiving the header
|
||||
//
|
||||
// non zero RequestConfig field values will overwrite the default configs
|
||||
HeaderReceived func(header *RequestHeader) RequestConfig
|
||||
|
||||
// Server name for sending in response headers.
|
||||
//
|
||||
// Default server name is used if left blank.
|
||||
@@ -415,6 +420,21 @@ func TimeoutWithCodeHandler(h RequestHandler, timeout time.Duration, msg string,
|
||||
}
|
||||
}
|
||||
|
||||
//RequestConfig configure the per request deadline and body limits
|
||||
type RequestConfig struct {
|
||||
// ReadTimeout is the maximum duration for reading the entire
|
||||
// request body.
|
||||
// a zero value means that default values will be honored
|
||||
ReadTimeout time.Duration
|
||||
// WriteTimeout is the maximum duration before timing out
|
||||
// writes of the response.
|
||||
// a zero value means that default values will be honored
|
||||
WriteTimeout time.Duration
|
||||
// Maximum request body size.
|
||||
// a zero value means that default values will be honored
|
||||
MaxRequestBodySize int
|
||||
}
|
||||
|
||||
// CompressHandler returns RequestHandler that transparently compresses
|
||||
// response body generated by h if the request contains 'gzip' or 'deflate'
|
||||
// 'Accept-Encoding' header.
|
||||
@@ -1834,6 +1854,7 @@ func (s *Server) serveConn(c net.Conn) error {
|
||||
if maxRequestBodySize <= 0 {
|
||||
maxRequestBodySize = DefaultMaxRequestBodySize
|
||||
}
|
||||
writeTimeout := s.WriteTimeout
|
||||
|
||||
ctx := s.acquireCtx(c)
|
||||
ctx.connTime = connTime
|
||||
@@ -1896,17 +1917,35 @@ func (s *Server) serveConn(c net.Conn) error {
|
||||
panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err))
|
||||
}
|
||||
}
|
||||
|
||||
if s.DisableHeaderNamesNormalizing {
|
||||
ctx.Request.Header.DisableNormalizing()
|
||||
ctx.Response.Header.DisableNormalizing()
|
||||
}
|
||||
// reading Headers and Body
|
||||
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly)
|
||||
// reading Headers
|
||||
if err = ctx.Request.Header.Read(br); err == nil {
|
||||
if onHdrRecv := s.HeaderReceived; onHdrRecv != nil {
|
||||
reqConf := onHdrRecv(&ctx.Request.Header)
|
||||
if reqConf.ReadTimeout > 0 {
|
||||
deadline := time.Now().Add(reqConf.ReadTimeout)
|
||||
if err := c.SetReadDeadline(deadline); err != nil {
|
||||
panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", deadline, err))
|
||||
}
|
||||
}
|
||||
if reqConf.MaxRequestBodySize > 0 {
|
||||
maxRequestBodySize = reqConf.MaxRequestBodySize
|
||||
}
|
||||
if reqConf.WriteTimeout > 0 {
|
||||
writeTimeout = reqConf.WriteTimeout
|
||||
}
|
||||
}
|
||||
//read body
|
||||
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly)
|
||||
}
|
||||
if err == nil {
|
||||
// If we read any bytes off the wire, we're active.
|
||||
s.setState(c, StateActive)
|
||||
}
|
||||
|
||||
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
|
||||
releaseReader(s, br)
|
||||
br = nil
|
||||
@@ -2003,8 +2042,8 @@ func (s *Server) serveConn(c net.Conn) error {
|
||||
ctx.SetConnectionClose()
|
||||
}
|
||||
|
||||
if s.WriteTimeout > 0 {
|
||||
if err := c.SetWriteDeadline(time.Now().Add(s.WriteTimeout)); err != nil {
|
||||
if writeTimeout > 0 {
|
||||
if err := c.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
|
||||
panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", s.WriteTimeout, err))
|
||||
}
|
||||
}
|
||||
|
||||
+136
@@ -2849,6 +2849,142 @@ func TestShutdownErr(t *testing.T) {
|
||||
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
|
||||
}
|
||||
|
||||
func TestMaxBodySizePerRequest(t *testing.T) {
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
// do nothing :)
|
||||
},
|
||||
HeaderReceived: func(header *RequestHeader) RequestConfig {
|
||||
return RequestConfig{
|
||||
MaxRequestBodySize: 5 << 10,
|
||||
}
|
||||
},
|
||||
ReadTimeout: time.Second * 5,
|
||||
WriteTimeout: time.Second * 5,
|
||||
MaxRequestBodySize: 1 << 20,
|
||||
}
|
||||
|
||||
rw := &readWriter{}
|
||||
rw.r.WriteString(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", (5<<10)+1, strings.Repeat("a", (5<<10)+1)))
|
||||
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
ch <- s.ServeConn(rw)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
if err != ErrBodyTooLarge {
|
||||
t.Fatalf("Unexpected error from serveConn: %s", err)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxReadTimeoutPerRequest(t *testing.T) {
|
||||
headers := []byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", 5*1024))
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
t.Fatal("shouldn't reach handler")
|
||||
},
|
||||
HeaderReceived: func(header *RequestHeader) RequestConfig {
|
||||
return RequestConfig{
|
||||
ReadTimeout: time.Millisecond,
|
||||
}
|
||||
},
|
||||
ReadBufferSize: len(headers),
|
||||
ReadTimeout: time.Second * 5,
|
||||
WriteTimeout: time.Second * 5,
|
||||
}
|
||||
|
||||
pipe := fasthttputil.NewPipeConns()
|
||||
cc, sc := pipe.Conn1(), pipe.Conn2()
|
||||
go func() {
|
||||
//write headers
|
||||
_, err := cc.Write(headers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//write body
|
||||
for i := 0; i < 5*1024; i++ {
|
||||
time.Sleep(time.Millisecond)
|
||||
cc.Write([]byte{'a'})
|
||||
}
|
||||
}()
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
ch <- s.ServeConn(sc)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") {
|
||||
t.Fatalf("Unexpected error from serveConn: %s", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("test timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxWriteTimeoutPerRequest(t *testing.T) {
|
||||
headers := []byte("GET /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: aa\r\n\r\n")
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
var buf [192]byte
|
||||
for {
|
||||
w.Write(buf[:])
|
||||
}
|
||||
})
|
||||
},
|
||||
HeaderReceived: func(header *RequestHeader) RequestConfig {
|
||||
return RequestConfig{
|
||||
WriteTimeout: time.Millisecond,
|
||||
}
|
||||
},
|
||||
ReadBufferSize: 192,
|
||||
ReadTimeout: time.Second * 5,
|
||||
WriteTimeout: time.Second * 5,
|
||||
}
|
||||
|
||||
pipe := fasthttputil.NewPipeConns()
|
||||
cc, sc := pipe.Conn1(), pipe.Conn2()
|
||||
|
||||
var resp Response
|
||||
go func() {
|
||||
//write headers
|
||||
_, err := cc.Write(headers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
br := bufio.NewReaderSize(cc, 192)
|
||||
err = resp.Header.Read(br)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var chunk [192]byte
|
||||
for {
|
||||
time.Sleep(time.Millisecond)
|
||||
br.Read(chunk[:])
|
||||
}
|
||||
}()
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
ch <- s.ServeConn(sc)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") {
|
||||
t.Fatalf("Unexpected error from serveConn: %s", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("test timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) {
|
||||
var resp Response
|
||||
if err := resp.Read(r); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user