diff --git a/http.go b/http.go index 444f94e..33fe3d8 100644 --- a/http.go +++ b/http.go @@ -21,7 +21,8 @@ type Request struct { PostArgs Args parsedPostArgs bool - timeoutCh chan error + timeoutCh chan error + timeoutTimer *time.Timer } type Response struct { @@ -32,7 +33,8 @@ type Response struct { // Use it for HEAD requests. SkipBody bool - timeoutCh chan error + timeoutCh chan error + timeoutTimer *time.Timer } func (req *Request) ParseURI() { @@ -94,14 +96,14 @@ func (req *Request) ReadTimeout(r *bufio.Reader, timeout time.Duration) error { }() var err error - tc := acquireTimer(timeout) + req.timeoutTimer = initTimer(req.timeoutTimer, timeout) select { case err = <-ch: - case <-tc.C: + case <-req.timeoutTimer.C: req.timeoutCh = nil err = ErrReadTimeout } - releaseTimer(tc) + stopTimer(req.timeoutTimer) return err } @@ -123,14 +125,14 @@ func (resp *Response) ReadTimeout(r *bufio.Reader, timeout time.Duration) error }() var err error - tc := acquireTimer(timeout) + resp.timeoutTimer = initTimer(resp.timeoutTimer, timeout) select { case err = <-ch: - case <-tc.C: + case <-resp.timeoutTimer.C: resp.timeoutCh = nil err = ErrReadTimeout } - releaseTimer(tc) + stopTimer(resp.timeoutTimer) return err } diff --git a/server.go b/server.go index deb3c26..23d00c2 100644 --- a/server.go +++ b/server.go @@ -16,18 +16,34 @@ import ( "unsafe" ) +// RequestHandler must process incoming requests. +// +// ResponseHandler must call ctx.TimeoutError() before return +// if it keeps references to ctx and/or its' members after the return. +type RequestHandler func(ctx *RequestCtx) + +// Server implements HTTP server. +// +// It is forbidden copying Server instances. Create new Server instances +// instead. type Server struct { - // Request handler + // Handler for processing incoming requests. Handler RequestHandler // Server name for sending in response headers. + // + // Default server name is used if left blank. Name string // Per-connection buffer size for requests' reading. // This also limits the maximum header size. + // + // Default buffer size is used if 0. ReadBufferSize int // Per-connection buffer size for responses' writing. + // + // Default buffer size is used if 0. WriteBufferSize int // Maximum duration for full request reading (including body). @@ -40,17 +56,57 @@ type Server struct { // By default response write timeout is unlimited. WriteTimeout time.Duration - // Logger. + // Logger, which is used by ServerCtx.Logger(). + // + // By default standard logger from log package is used. Logger Logger serverName atomic.Value ctxPool sync.Pool } -type RequestHandler func(ctx *RequestCtx) +// TimeoutHandler returns StatusRequestTimeout error with the given msg +// in the body to the client if h didn't return during the given duration. +func TimeoutHandler(h RequestHandler, timeout time.Duration, msg string) RequestHandler { + if timeout <= 0 { + return h + } + return func(ctx *RequestCtx) { + ch := ctx.timeoutCh + if ch == nil { + ch = make(chan struct{}, 1) + ctx.timeoutCh = ch + } + go func() { + h(ctx) + ch <- struct{}{} + }() + ctx.timeoutTimer = initTimer(ctx.timeoutTimer, timeout) + select { + case <-ch: + case <-ctx.timeoutTimer.C: + ctx.TimeoutError(msg) + } + stopTimer(ctx.timeoutTimer) + } +} + +// RequestCtx contains incoming request and manages outgoing response. +// +// It is forbidden copying RequestCtx instances. +// +// RequestHandler should avoid holding references to incoming RequestCtx and/or +// its' members after the return. +// If holding RequestCtx references after the return is unavoidable +// (for instance, ctx is passed to a separate goroutine and we cannot control +// ctx lifetime in this goroutine), then the RequestHandler MUST call +// ctx.TimeoutError() before return. type RequestCtx struct { - Request Request + // Incoming request. + Request Request + + // Outgoing response. Response Response // Unique id of the request. @@ -70,6 +126,9 @@ type RequestCtx struct { w *bufio.Writer shadow unsafe.Pointer + timeoutCh chan struct{} + timeoutTimer *time.Timer + v interface{} } @@ -85,7 +144,9 @@ type writeDeadliner interface { SetWriteDeadline(time.Time) error } +// Logger is used for logging formatted messages. type Logger interface { + // Printf must have the same semantics as log.Printf. Printf(format string, args ...interface{}) } @@ -105,6 +166,7 @@ func (cl *ctxLogger) Printf(format string, args ...interface{}) { ctxLoggerLock.Unlock() } +// RemoteAddr returns client address for the given request. func (ctx *RequestCtx) RemoteAddr() string { x, ok := ctx.c.(remoteAddrer) if !ok { @@ -113,6 +175,7 @@ func (ctx *RequestCtx) RemoteAddr() string { return x.RemoteAddr().String() } +// RemoteIP returns client ip for the given request. func (ctx *RequestCtx) RemoteIP() string { addr := ctx.RemoteAddr() n := strings.LastIndexByte(addr, ':') @@ -122,6 +185,10 @@ func (ctx *RequestCtx) RemoteIP() string { return addr[:n] } +// Error sets response status code to the given value and sets response body +// to the given message. +// +// Error calls are ignored after TimeoutError call. func (ctx *RequestCtx) Error(msg string, statusCode int) { resp := &ctx.Response resp.Clear() @@ -130,16 +197,35 @@ func (ctx *RequestCtx) Error(msg string, statusCode int) { resp.Body = append(resp.Body, []byte(msg)...) } +// Success sets response Content-Type and body to the given values. +// +// It is safe modifying body buffer after the Success() call. +// +// Success calls are ignored after TimeoutError call. func (ctx *RequestCtx) Success(contentType string, body []byte) { resp := &ctx.Response resp.Header.setStr(strContentType, contentType) resp.Body = append(resp.Body, body...) } +// Logger returns logger, which may be used for logging arbitrary +// request-specific messages inside RequestHandler. +// +// Each message logged via returned logger contains request-specific information +// such as request id, remote address, request method and request url. +// +// It is safe re-using returned logger for logging multiple messages. func (ctx *RequestCtx) Logger() Logger { return &ctx.logger } +// TimeoutError sets response status code to StatusRequestTimeout and sets +// body to the given msg. +// +// All response modifications after TimeoutError call are ignored. +// +// TimeoutError MUST be called before returning from RequestHandler if there are +// references to ctx and/or its members in other goroutines. func (ctx *RequestCtx) TimeoutError(msg string) { var shadow RequestCtx shadow.Request = Request{} @@ -159,10 +245,19 @@ func (ctx *RequestCtx) TimeoutError(msg string) { const defaultConcurrency = 64 * 1024 +// Serve serves incoming connections from the given listener. +// +// Serve blocks until the given listener returns permanent error. +// This error is returned from Serve. func (s *Server) Serve(ln net.Listener) error { return s.ServeConcurrency(ln, defaultConcurrency) } +// ServeConcurrency serves incoming connections from the given listener. +// It may serve maximum concurrency simultaneous connections. +// +// ServeConcurrency blocks until the given listener returns permanent error. +// This error is returned from ServeConcurrency. func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error { ch := make(chan net.Conn, 16*concurrency) stopCh := make(chan struct{}) @@ -188,6 +283,7 @@ func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error { func connWorkersMonitor(s *Server, ch <-chan net.Conn, maxWorkers int, stopCh <-chan struct{}) { workersCount := uint32(0) + var tc *time.Timer for { n := int(atomic.LoadUint32(&workersCount)) pendingConns := len(ch) @@ -201,13 +297,14 @@ func connWorkersMonitor(s *Server, ch <-chan net.Conn, maxWorkers int, stopCh <- } runtime.Gosched() } else { - tc := acquireTimer(100 * time.Millisecond) + tc = initTimer(tc, 100*time.Millisecond) select { case <-stopCh: + stopTimer(tc) return case <-tc.C: + stopTimer(tc) } - releaseTimer(tc) } } } @@ -225,18 +322,20 @@ func connWorker(s *Server, ch <-chan net.Conn) { } }() + var tc *time.Timer for { select { case c = <-ch: default: - tc := acquireTimer(time.Second) + tc = initTimer(tc, time.Second) select { case c = <-ch: + stopTimer(tc) case <-tc.C: + stopTimer(tc) s.releaseCtx(ctx) return } - releaseTimer(tc) } serveConn(s, c, &ctx) c.Close() @@ -279,6 +378,10 @@ func (s *Server) logger() Logger { return defaultLogger } +// ServeConn serves HTTP requests from the given connection. +// +// ServeConn returns nil if all requests from the c are successfully served. +// It returns non-nil error otherwise. func (s *Server) ServeConn(c io.ReadWriter) error { ctx := s.acquireCtx() err := s.serveConn(c, &ctx) diff --git a/server_test.go b/server_test.go index 0446814..0960790 100644 --- a/server_test.go +++ b/server_test.go @@ -11,6 +11,65 @@ import ( "time" ) +func TestTimeoutHandlerSuccess(t *testing.T) { + h := func(ctx *RequestCtx) { + ctx.Success("aaa/bbb", []byte("real response")) + } + s := &Server{ + Handler: TimeoutHandler(h, 100*time.Millisecond, "timeout!!!"), + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") +} + +func TestTimeoutHandlerTimeout(t *testing.T) { + h := func(ctx *RequestCtx) { + time.Sleep(time.Second) + ctx.Success("aaa/bbb", []byte("this shouldn't pass to client because of timeout")) + } + s := &Server{ + Handler: TimeoutHandler(h, 10*time.Millisecond, "timeout!!!"), + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!") +} + func TestServerTimeoutError(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { diff --git a/timer.go b/timer.go index 8227c7d..4693ecf 100644 --- a/timer.go +++ b/timer.go @@ -1,26 +1,20 @@ package fasthttp import ( - "sync" "time" ) -var timerPool sync.Pool - -func acquireTimer(timeout time.Duration) *time.Timer { - tv := timerPool.Get() - if tv == nil { +func initTimer(t *time.Timer, timeout time.Duration) *time.Timer { + if t == nil { return time.NewTimer(timeout) } - - t := tv.(*time.Timer) if t.Reset(timeout) { - panic("BUG: Active timer trapped into AcquireTimer()") + panic("BUG: active timer trapped into initTimer()") } return t } -func releaseTimer(t *time.Timer) { +func stopTimer(t *time.Timer) { if !t.Stop() { // Collect possibly added time from the channel // if timer has been stopped and nobody collected its' value. @@ -29,6 +23,4 @@ func releaseTimer(t *time.Timer) { default: } } - - timerPool.Put(t) }