Added timeout covering full request read

This commit is contained in:
Aliaksandr Valialkin
2015-10-22 14:32:20 +03:00
parent 444dfb7213
commit 9fc3f767e6
4 changed files with 185 additions and 18 deletions
+64
View File
@@ -3,8 +3,10 @@ package fasthttp
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"time"
)
type Request struct {
@@ -18,6 +20,8 @@ type Request struct {
// PostArgs becomes available only after Request.ParsePostArgs() call.
PostArgs Args
parsedPostArgs bool
timeoutCh chan error
}
type Response struct {
@@ -27,6 +31,8 @@ type Response struct {
// if set to true, Response.Read() skips reading body.
// Use it for HEAD requests.
SkipBody bool
timeoutCh chan error
}
func (req *Request) ParseURI() {
@@ -68,6 +74,64 @@ func (resp *Response) Clear() {
resp.Body = resp.Body[:0]
}
var ErrReadTimeout = errors.New("read timeout")
func (req *Request) ReadTimeout(r *bufio.Reader, timeout time.Duration) error {
if timeout <= 0 {
return req.Read(r)
}
ch := req.timeoutCh
if ch == nil {
ch = make(chan error, 1)
req.timeoutCh = ch
} else if len(ch) > 0 {
panic("BUG: Request.timeoutCh must be empty!")
}
go func() {
ch <- req.Read(r)
}()
tc := acquireTimer(timeout)
select {
case err := <-ch:
releaseTimer(tc)
return err
case <-tc.C:
req.timeoutCh = nil
return ErrReadTimeout
}
}
func (resp *Response) ReadTimeout(r *bufio.Reader, timeout time.Duration) error {
if timeout <= 0 {
return resp.Read(r)
}
ch := resp.timeoutCh
if ch == nil {
ch = make(chan error, 1)
resp.timeoutCh = ch
} else if len(ch) > 0 {
panic("BUG: Response.timeoutCh must be empty!")
}
go func() {
ch <- resp.Read(r)
}()
tc := acquireTimer(timeout)
select {
case err := <-ch:
releaseTimer(tc)
return err
case <-tc.C:
resp.timeoutCh = nil
return ErrReadTimeout
}
}
func (req *Request) Read(r *bufio.Reader) error {
req.Body = req.Body[:0]
req.URI.Clear()
+66
View File
@@ -4,10 +4,76 @@ import (
"bufio"
"bytes"
"fmt"
"io"
"strings"
"testing"
"time"
)
func TestResponseReadTimeout(t *testing.T) {
var resp Response
for i := 0; i < 5; i++ {
testResponseReadTimeoutError(t, &resp)
}
s := "HTTP/1.1 200 OK\r\nContent-Type: text/aaa\r\nContent-Length: 5\r\n\r\n12345"
r := bytes.NewBufferString(s)
rb := bufio.NewReader(r)
if err := resp.ReadTimeout(rb, 100*time.Millisecond); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
verifyResponseHeader(t, &resp.Header, 200, 5, "text/aaa")
for i := 0; i < 5; i++ {
testResponseReadTimeoutError(t, &resp)
}
}
func TestRequestReadTimeout(t *testing.T) {
var req Request
for i := 0; i < 5; i++ {
testRequestReadTimeoutError(t, &req)
}
s := "GET /abc HTTP/1.1\r\nHost: google.com\r\n\r\n"
r := bytes.NewBufferString(s)
rb := bufio.NewReader(r)
if err := req.ReadTimeout(rb, 100*time.Millisecond); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
verifyRequestHeader(t, &req.Header, 0, "/abc", "google.com", "", "")
for i := 0; i < 5; i++ {
testRequestReadTimeoutError(t, &req)
}
}
func testResponseReadTimeoutError(t *testing.T, resp *Response) {
r, _ := io.Pipe()
rb := bufio.NewReader(r)
err := resp.ReadTimeout(rb, 5*time.Millisecond)
if err == nil {
t.Fatalf("Expecting error")
}
if err != ErrReadTimeout {
t.Fatalf("Unexpected error: %s. Expecting %s", err, ErrReadTimeout)
}
}
func testRequestReadTimeoutError(t *testing.T, req *Request) {
r, _ := io.Pipe()
rb := bufio.NewReader(r)
err := req.ReadTimeout(rb, 5*time.Millisecond)
if err == nil {
t.Fatalf("Expecting error")
}
if err != ErrReadTimeout {
t.Fatalf("Unexpected error: %s. Expecting %s", err, ErrReadTimeout)
}
}
func TestRequestReadChunked(t *testing.T) {
var req Request
+23 -18
View File
@@ -30,6 +30,11 @@ type Server struct {
// Per-connection buffer size for responses' writing.
WriteBufferSize int
// Maximum duration for full request reading (including body).
//
// By default request read timeout is unlimited.
RequestReadTimeout time.Duration
// Logger.
Logger Logger
@@ -138,22 +143,6 @@ func (ctx *RequestCtx) TimeoutError(msg string) {
}
}
func (ctx *RequestCtx) writeResponse() error {
if atomic.LoadPointer(&ctx.shadow) != nil {
panic("BUG: cannot write response with shadow")
}
h := &ctx.Response.Header
serverOld := h.server
if len(serverOld) == 0 {
h.server = ctx.s.getServerName()
}
err := ctx.Response.Write(ctx.w)
if len(serverOld) == 0 {
h.server = serverOld
}
return err
}
const defaultConcurrency = 64 * 1024
func (s *Server) Serve(ln net.Listener) error {
@@ -288,7 +277,7 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error {
initRequestCtx(ctx, c)
var err error
for {
if err = ctx.Request.Read(ctx.r); err != nil {
if err = ctx.Request.ReadTimeout(ctx.r, s.RequestReadTimeout); err != nil {
if err == io.EOF {
err = nil
}
@@ -302,7 +291,7 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error {
ctx = (*RequestCtx)(shadow)
*ctxP = ctx
}
if err = ctx.writeResponse(); err != nil {
if err = writeResponse(ctx); err != nil {
break
}
connectionClose := ctx.Response.Header.ConnectionClose
@@ -322,6 +311,22 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error {
return err
}
func writeResponse(ctx *RequestCtx) error {
if atomic.LoadPointer(&ctx.shadow) != nil {
panic("BUG: cannot write response with shadow")
}
h := &ctx.Response.Header
serverOld := h.server
if len(serverOld) == 0 {
h.server = ctx.s.getServerName()
}
err := ctx.Response.Write(ctx.w)
if len(serverOld) == 0 {
h.server = serverOld
}
return err
}
const bigBufferLimit = 16 * 1024
func trimBigBuffers(ctx *RequestCtx) {
+32
View File
@@ -78,6 +78,22 @@ func BenchmarkNetHTTPServerPost10000ReqPerConn(b *testing.B) {
benchmarkNetHTTPServerPost(b, 10000)
}
func BenchmarkServerGetRequestReadTimeout1ReqPerConn(b *testing.B) {
benchmarkServerGetRequestReadTimeout(b, 1)
}
func BenchmarkServerGetRequestReadTimeout2ReqPerConn(b *testing.B) {
benchmarkServerGetRequestReadTimeout(b, 2)
}
func BenchmarkServerGetRequestReadTimeout10ReqPerConn(b *testing.B) {
benchmarkServerGetRequestReadTimeout(b, 10)
}
func BenchmarkServerGetRequestReadTimeout10000ReqPerConn(b *testing.B) {
benchmarkServerGetRequestReadTimeout(b, 10000)
}
func BenchmarkServerTimeoutError(b *testing.B) {
requestsPerConn := 10
ch := make(chan struct{}, b.N)
@@ -264,6 +280,22 @@ func benchmarkNetHTTPServerPost(b *testing.B, requestsPerConn int) {
verifyRequestsServed(b, requestsSent, ch)
}
func benchmarkServerGetRequestReadTimeout(b *testing.B, requestsPerConn int) {
ch := make(chan struct{}, b.N)
s := &Server{
Handler: func(ctx *RequestCtx) {
if !ctx.Request.Header.IsMethodGet() {
b.Fatalf("Unexpected request method: %s", ctx.Request.Header.Method)
}
ctx.Success("text/plain", fakeResponse)
registerServedRequest(b, ch)
},
RequestReadTimeout: 5 * time.Second,
}
requestsSent := benchmarkServer(b, &testServer{s}, requestsPerConn, getRequest)
verifyRequestsServed(b, requestsSent, ch)
}
func registerServedRequest(b *testing.B, ch chan<- struct{}) {
select {
case ch <- struct{}{}: