mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-15 16:07:51 +03:00
Added timeout covering full request read
This commit is contained in:
@@ -3,8 +3,10 @@ package fasthttp
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
@@ -18,6 +20,8 @@ type Request struct {
|
||||
// PostArgs becomes available only after Request.ParsePostArgs() call.
|
||||
PostArgs Args
|
||||
parsedPostArgs bool
|
||||
|
||||
timeoutCh chan error
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
@@ -27,6 +31,8 @@ type Response struct {
|
||||
// if set to true, Response.Read() skips reading body.
|
||||
// Use it for HEAD requests.
|
||||
SkipBody bool
|
||||
|
||||
timeoutCh chan error
|
||||
}
|
||||
|
||||
func (req *Request) ParseURI() {
|
||||
@@ -68,6 +74,64 @@ func (resp *Response) Clear() {
|
||||
resp.Body = resp.Body[:0]
|
||||
}
|
||||
|
||||
var ErrReadTimeout = errors.New("read timeout")
|
||||
|
||||
func (req *Request) ReadTimeout(r *bufio.Reader, timeout time.Duration) error {
|
||||
if timeout <= 0 {
|
||||
return req.Read(r)
|
||||
}
|
||||
|
||||
ch := req.timeoutCh
|
||||
if ch == nil {
|
||||
ch = make(chan error, 1)
|
||||
req.timeoutCh = ch
|
||||
} else if len(ch) > 0 {
|
||||
panic("BUG: Request.timeoutCh must be empty!")
|
||||
}
|
||||
|
||||
go func() {
|
||||
ch <- req.Read(r)
|
||||
}()
|
||||
|
||||
tc := acquireTimer(timeout)
|
||||
select {
|
||||
case err := <-ch:
|
||||
releaseTimer(tc)
|
||||
return err
|
||||
case <-tc.C:
|
||||
req.timeoutCh = nil
|
||||
return ErrReadTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func (resp *Response) ReadTimeout(r *bufio.Reader, timeout time.Duration) error {
|
||||
if timeout <= 0 {
|
||||
return resp.Read(r)
|
||||
}
|
||||
|
||||
ch := resp.timeoutCh
|
||||
if ch == nil {
|
||||
ch = make(chan error, 1)
|
||||
resp.timeoutCh = ch
|
||||
} else if len(ch) > 0 {
|
||||
panic("BUG: Response.timeoutCh must be empty!")
|
||||
}
|
||||
|
||||
go func() {
|
||||
ch <- resp.Read(r)
|
||||
}()
|
||||
|
||||
tc := acquireTimer(timeout)
|
||||
select {
|
||||
case err := <-ch:
|
||||
releaseTimer(tc)
|
||||
return err
|
||||
case <-tc.C:
|
||||
resp.timeoutCh = nil
|
||||
return ErrReadTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func (req *Request) Read(r *bufio.Reader) error {
|
||||
req.Body = req.Body[:0]
|
||||
req.URI.Clear()
|
||||
|
||||
@@ -4,10 +4,76 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestResponseReadTimeout(t *testing.T) {
|
||||
var resp Response
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
testResponseReadTimeoutError(t, &resp)
|
||||
}
|
||||
|
||||
s := "HTTP/1.1 200 OK\r\nContent-Type: text/aaa\r\nContent-Length: 5\r\n\r\n12345"
|
||||
r := bytes.NewBufferString(s)
|
||||
rb := bufio.NewReader(r)
|
||||
if err := resp.ReadTimeout(rb, 100*time.Millisecond); err != nil {
|
||||
t.Fatalf("Unexpected error: %s", err)
|
||||
}
|
||||
verifyResponseHeader(t, &resp.Header, 200, 5, "text/aaa")
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
testResponseReadTimeoutError(t, &resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestReadTimeout(t *testing.T) {
|
||||
var req Request
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
testRequestReadTimeoutError(t, &req)
|
||||
}
|
||||
|
||||
s := "GET /abc HTTP/1.1\r\nHost: google.com\r\n\r\n"
|
||||
r := bytes.NewBufferString(s)
|
||||
rb := bufio.NewReader(r)
|
||||
if err := req.ReadTimeout(rb, 100*time.Millisecond); err != nil {
|
||||
t.Fatalf("Unexpected error: %s", err)
|
||||
}
|
||||
verifyRequestHeader(t, &req.Header, 0, "/abc", "google.com", "", "")
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
testRequestReadTimeoutError(t, &req)
|
||||
}
|
||||
}
|
||||
|
||||
func testResponseReadTimeoutError(t *testing.T, resp *Response) {
|
||||
r, _ := io.Pipe()
|
||||
rb := bufio.NewReader(r)
|
||||
err := resp.ReadTimeout(rb, 5*time.Millisecond)
|
||||
if err == nil {
|
||||
t.Fatalf("Expecting error")
|
||||
}
|
||||
if err != ErrReadTimeout {
|
||||
t.Fatalf("Unexpected error: %s. Expecting %s", err, ErrReadTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func testRequestReadTimeoutError(t *testing.T, req *Request) {
|
||||
r, _ := io.Pipe()
|
||||
rb := bufio.NewReader(r)
|
||||
err := req.ReadTimeout(rb, 5*time.Millisecond)
|
||||
if err == nil {
|
||||
t.Fatalf("Expecting error")
|
||||
}
|
||||
if err != ErrReadTimeout {
|
||||
t.Fatalf("Unexpected error: %s. Expecting %s", err, ErrReadTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestReadChunked(t *testing.T) {
|
||||
var req Request
|
||||
|
||||
|
||||
@@ -30,6 +30,11 @@ type Server struct {
|
||||
// Per-connection buffer size for responses' writing.
|
||||
WriteBufferSize int
|
||||
|
||||
// Maximum duration for full request reading (including body).
|
||||
//
|
||||
// By default request read timeout is unlimited.
|
||||
RequestReadTimeout time.Duration
|
||||
|
||||
// Logger.
|
||||
Logger Logger
|
||||
|
||||
@@ -138,22 +143,6 @@ func (ctx *RequestCtx) TimeoutError(msg string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *RequestCtx) writeResponse() error {
|
||||
if atomic.LoadPointer(&ctx.shadow) != nil {
|
||||
panic("BUG: cannot write response with shadow")
|
||||
}
|
||||
h := &ctx.Response.Header
|
||||
serverOld := h.server
|
||||
if len(serverOld) == 0 {
|
||||
h.server = ctx.s.getServerName()
|
||||
}
|
||||
err := ctx.Response.Write(ctx.w)
|
||||
if len(serverOld) == 0 {
|
||||
h.server = serverOld
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
const defaultConcurrency = 64 * 1024
|
||||
|
||||
func (s *Server) Serve(ln net.Listener) error {
|
||||
@@ -288,7 +277,7 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error {
|
||||
initRequestCtx(ctx, c)
|
||||
var err error
|
||||
for {
|
||||
if err = ctx.Request.Read(ctx.r); err != nil {
|
||||
if err = ctx.Request.ReadTimeout(ctx.r, s.RequestReadTimeout); err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
@@ -302,7 +291,7 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error {
|
||||
ctx = (*RequestCtx)(shadow)
|
||||
*ctxP = ctx
|
||||
}
|
||||
if err = ctx.writeResponse(); err != nil {
|
||||
if err = writeResponse(ctx); err != nil {
|
||||
break
|
||||
}
|
||||
connectionClose := ctx.Response.Header.ConnectionClose
|
||||
@@ -322,6 +311,22 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func writeResponse(ctx *RequestCtx) error {
|
||||
if atomic.LoadPointer(&ctx.shadow) != nil {
|
||||
panic("BUG: cannot write response with shadow")
|
||||
}
|
||||
h := &ctx.Response.Header
|
||||
serverOld := h.server
|
||||
if len(serverOld) == 0 {
|
||||
h.server = ctx.s.getServerName()
|
||||
}
|
||||
err := ctx.Response.Write(ctx.w)
|
||||
if len(serverOld) == 0 {
|
||||
h.server = serverOld
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
const bigBufferLimit = 16 * 1024
|
||||
|
||||
func trimBigBuffers(ctx *RequestCtx) {
|
||||
|
||||
@@ -78,6 +78,22 @@ func BenchmarkNetHTTPServerPost10000ReqPerConn(b *testing.B) {
|
||||
benchmarkNetHTTPServerPost(b, 10000)
|
||||
}
|
||||
|
||||
func BenchmarkServerGetRequestReadTimeout1ReqPerConn(b *testing.B) {
|
||||
benchmarkServerGetRequestReadTimeout(b, 1)
|
||||
}
|
||||
|
||||
func BenchmarkServerGetRequestReadTimeout2ReqPerConn(b *testing.B) {
|
||||
benchmarkServerGetRequestReadTimeout(b, 2)
|
||||
}
|
||||
|
||||
func BenchmarkServerGetRequestReadTimeout10ReqPerConn(b *testing.B) {
|
||||
benchmarkServerGetRequestReadTimeout(b, 10)
|
||||
}
|
||||
|
||||
func BenchmarkServerGetRequestReadTimeout10000ReqPerConn(b *testing.B) {
|
||||
benchmarkServerGetRequestReadTimeout(b, 10000)
|
||||
}
|
||||
|
||||
func BenchmarkServerTimeoutError(b *testing.B) {
|
||||
requestsPerConn := 10
|
||||
ch := make(chan struct{}, b.N)
|
||||
@@ -264,6 +280,22 @@ func benchmarkNetHTTPServerPost(b *testing.B, requestsPerConn int) {
|
||||
verifyRequestsServed(b, requestsSent, ch)
|
||||
}
|
||||
|
||||
func benchmarkServerGetRequestReadTimeout(b *testing.B, requestsPerConn int) {
|
||||
ch := make(chan struct{}, b.N)
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
if !ctx.Request.Header.IsMethodGet() {
|
||||
b.Fatalf("Unexpected request method: %s", ctx.Request.Header.Method)
|
||||
}
|
||||
ctx.Success("text/plain", fakeResponse)
|
||||
registerServedRequest(b, ch)
|
||||
},
|
||||
RequestReadTimeout: 5 * time.Second,
|
||||
}
|
||||
requestsSent := benchmarkServer(b, &testServer{s}, requestsPerConn, getRequest)
|
||||
verifyRequestsServed(b, requestsSent, ch)
|
||||
}
|
||||
|
||||
func registerServedRequest(b *testing.B, ch chan<- struct{}) {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
|
||||
Reference in New Issue
Block a user