diff --git a/http.go b/http.go index 4cb1fa3..4f08f21 100644 --- a/http.go +++ b/http.go @@ -3,8 +3,10 @@ package fasthttp import ( "bufio" "bytes" + "errors" "fmt" "io" + "time" ) type Request struct { @@ -18,6 +20,8 @@ type Request struct { // PostArgs becomes available only after Request.ParsePostArgs() call. PostArgs Args parsedPostArgs bool + + timeoutCh chan error } type Response struct { @@ -27,6 +31,8 @@ type Response struct { // if set to true, Response.Read() skips reading body. // Use it for HEAD requests. SkipBody bool + + timeoutCh chan error } func (req *Request) ParseURI() { @@ -68,6 +74,64 @@ func (resp *Response) Clear() { resp.Body = resp.Body[:0] } +var ErrReadTimeout = errors.New("read timeout") + +func (req *Request) ReadTimeout(r *bufio.Reader, timeout time.Duration) error { + if timeout <= 0 { + return req.Read(r) + } + + ch := req.timeoutCh + if ch == nil { + ch = make(chan error, 1) + req.timeoutCh = ch + } else if len(ch) > 0 { + panic("BUG: Request.timeoutCh must be empty!") + } + + go func() { + ch <- req.Read(r) + }() + + tc := acquireTimer(timeout) + select { + case err := <-ch: + releaseTimer(tc) + return err + case <-tc.C: + req.timeoutCh = nil + return ErrReadTimeout + } +} + +func (resp *Response) ReadTimeout(r *bufio.Reader, timeout time.Duration) error { + if timeout <= 0 { + return resp.Read(r) + } + + ch := resp.timeoutCh + if ch == nil { + ch = make(chan error, 1) + resp.timeoutCh = ch + } else if len(ch) > 0 { + panic("BUG: Response.timeoutCh must be empty!") + } + + go func() { + ch <- resp.Read(r) + }() + + tc := acquireTimer(timeout) + select { + case err := <-ch: + releaseTimer(tc) + return err + case <-tc.C: + resp.timeoutCh = nil + return ErrReadTimeout + } +} + func (req *Request) Read(r *bufio.Reader) error { req.Body = req.Body[:0] req.URI.Clear() diff --git a/http_test.go b/http_test.go index 5bbb5c5..d52e599 100644 --- a/http_test.go +++ b/http_test.go @@ -4,10 +4,76 @@ import ( "bufio" "bytes" "fmt" + "io" "strings" "testing" + "time" ) +func TestResponseReadTimeout(t *testing.T) { + var resp Response + + for i := 0; i < 5; i++ { + testResponseReadTimeoutError(t, &resp) + } + + s := "HTTP/1.1 200 OK\r\nContent-Type: text/aaa\r\nContent-Length: 5\r\n\r\n12345" + r := bytes.NewBufferString(s) + rb := bufio.NewReader(r) + if err := resp.ReadTimeout(rb, 100*time.Millisecond); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + verifyResponseHeader(t, &resp.Header, 200, 5, "text/aaa") + + for i := 0; i < 5; i++ { + testResponseReadTimeoutError(t, &resp) + } +} + +func TestRequestReadTimeout(t *testing.T) { + var req Request + + for i := 0; i < 5; i++ { + testRequestReadTimeoutError(t, &req) + } + + s := "GET /abc HTTP/1.1\r\nHost: google.com\r\n\r\n" + r := bytes.NewBufferString(s) + rb := bufio.NewReader(r) + if err := req.ReadTimeout(rb, 100*time.Millisecond); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + verifyRequestHeader(t, &req.Header, 0, "/abc", "google.com", "", "") + + for i := 0; i < 5; i++ { + testRequestReadTimeoutError(t, &req) + } +} + +func testResponseReadTimeoutError(t *testing.T, resp *Response) { + r, _ := io.Pipe() + rb := bufio.NewReader(r) + err := resp.ReadTimeout(rb, 5*time.Millisecond) + if err == nil { + t.Fatalf("Expecting error") + } + if err != ErrReadTimeout { + t.Fatalf("Unexpected error: %s. Expecting %s", err, ErrReadTimeout) + } +} + +func testRequestReadTimeoutError(t *testing.T, req *Request) { + r, _ := io.Pipe() + rb := bufio.NewReader(r) + err := req.ReadTimeout(rb, 5*time.Millisecond) + if err == nil { + t.Fatalf("Expecting error") + } + if err != ErrReadTimeout { + t.Fatalf("Unexpected error: %s. Expecting %s", err, ErrReadTimeout) + } +} + func TestRequestReadChunked(t *testing.T) { var req Request diff --git a/server.go b/server.go index f81b82d..9f8fdf2 100644 --- a/server.go +++ b/server.go @@ -30,6 +30,11 @@ type Server struct { // Per-connection buffer size for responses' writing. WriteBufferSize int + // Maximum duration for full request reading (including body). + // + // By default request read timeout is unlimited. + RequestReadTimeout time.Duration + // Logger. Logger Logger @@ -138,22 +143,6 @@ func (ctx *RequestCtx) TimeoutError(msg string) { } } -func (ctx *RequestCtx) writeResponse() error { - if atomic.LoadPointer(&ctx.shadow) != nil { - panic("BUG: cannot write response with shadow") - } - h := &ctx.Response.Header - serverOld := h.server - if len(serverOld) == 0 { - h.server = ctx.s.getServerName() - } - err := ctx.Response.Write(ctx.w) - if len(serverOld) == 0 { - h.server = serverOld - } - return err -} - const defaultConcurrency = 64 * 1024 func (s *Server) Serve(ln net.Listener) error { @@ -288,7 +277,7 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error { initRequestCtx(ctx, c) var err error for { - if err = ctx.Request.Read(ctx.r); err != nil { + if err = ctx.Request.ReadTimeout(ctx.r, s.RequestReadTimeout); err != nil { if err == io.EOF { err = nil } @@ -302,7 +291,7 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error { ctx = (*RequestCtx)(shadow) *ctxP = ctx } - if err = ctx.writeResponse(); err != nil { + if err = writeResponse(ctx); err != nil { break } connectionClose := ctx.Response.Header.ConnectionClose @@ -322,6 +311,22 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error { return err } +func writeResponse(ctx *RequestCtx) error { + if atomic.LoadPointer(&ctx.shadow) != nil { + panic("BUG: cannot write response with shadow") + } + h := &ctx.Response.Header + serverOld := h.server + if len(serverOld) == 0 { + h.server = ctx.s.getServerName() + } + err := ctx.Response.Write(ctx.w) + if len(serverOld) == 0 { + h.server = serverOld + } + return err +} + const bigBufferLimit = 16 * 1024 func trimBigBuffers(ctx *RequestCtx) { diff --git a/server_timing_test.go b/server_timing_test.go index 3e86d82..af8fb09 100644 --- a/server_timing_test.go +++ b/server_timing_test.go @@ -78,6 +78,22 @@ func BenchmarkNetHTTPServerPost10000ReqPerConn(b *testing.B) { benchmarkNetHTTPServerPost(b, 10000) } +func BenchmarkServerGetRequestReadTimeout1ReqPerConn(b *testing.B) { + benchmarkServerGetRequestReadTimeout(b, 1) +} + +func BenchmarkServerGetRequestReadTimeout2ReqPerConn(b *testing.B) { + benchmarkServerGetRequestReadTimeout(b, 2) +} + +func BenchmarkServerGetRequestReadTimeout10ReqPerConn(b *testing.B) { + benchmarkServerGetRequestReadTimeout(b, 10) +} + +func BenchmarkServerGetRequestReadTimeout10000ReqPerConn(b *testing.B) { + benchmarkServerGetRequestReadTimeout(b, 10000) +} + func BenchmarkServerTimeoutError(b *testing.B) { requestsPerConn := 10 ch := make(chan struct{}, b.N) @@ -264,6 +280,22 @@ func benchmarkNetHTTPServerPost(b *testing.B, requestsPerConn int) { verifyRequestsServed(b, requestsSent, ch) } +func benchmarkServerGetRequestReadTimeout(b *testing.B, requestsPerConn int) { + ch := make(chan struct{}, b.N) + s := &Server{ + Handler: func(ctx *RequestCtx) { + if !ctx.Request.Header.IsMethodGet() { + b.Fatalf("Unexpected request method: %s", ctx.Request.Header.Method) + } + ctx.Success("text/plain", fakeResponse) + registerServedRequest(b, ch) + }, + RequestReadTimeout: 5 * time.Second, + } + requestsSent := benchmarkServer(b, &testServer{s}, requestsPerConn, getRequest) + verifyRequestsServed(b, requestsSent, ch) +} + func registerServedRequest(b *testing.B, ch chan<- struct{}) { select { case ch <- struct{}{}: