mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-13 15:46:49 +03:00
Added TimeoutHandler
This commit is contained in:
@@ -21,7 +21,8 @@ type Request struct {
|
||||
PostArgs Args
|
||||
parsedPostArgs bool
|
||||
|
||||
timeoutCh chan error
|
||||
timeoutCh chan error
|
||||
timeoutTimer *time.Timer
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
@@ -32,7 +33,8 @@ type Response struct {
|
||||
// Use it for HEAD requests.
|
||||
SkipBody bool
|
||||
|
||||
timeoutCh chan error
|
||||
timeoutCh chan error
|
||||
timeoutTimer *time.Timer
|
||||
}
|
||||
|
||||
func (req *Request) ParseURI() {
|
||||
@@ -94,14 +96,14 @@ func (req *Request) ReadTimeout(r *bufio.Reader, timeout time.Duration) error {
|
||||
}()
|
||||
|
||||
var err error
|
||||
tc := acquireTimer(timeout)
|
||||
req.timeoutTimer = initTimer(req.timeoutTimer, timeout)
|
||||
select {
|
||||
case err = <-ch:
|
||||
case <-tc.C:
|
||||
case <-req.timeoutTimer.C:
|
||||
req.timeoutCh = nil
|
||||
err = ErrReadTimeout
|
||||
}
|
||||
releaseTimer(tc)
|
||||
stopTimer(req.timeoutTimer)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -123,14 +125,14 @@ func (resp *Response) ReadTimeout(r *bufio.Reader, timeout time.Duration) error
|
||||
}()
|
||||
|
||||
var err error
|
||||
tc := acquireTimer(timeout)
|
||||
resp.timeoutTimer = initTimer(resp.timeoutTimer, timeout)
|
||||
select {
|
||||
case err = <-ch:
|
||||
case <-tc.C:
|
||||
case <-resp.timeoutTimer.C:
|
||||
resp.timeoutCh = nil
|
||||
err = ErrReadTimeout
|
||||
}
|
||||
releaseTimer(tc)
|
||||
stopTimer(resp.timeoutTimer)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -16,18 +16,34 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// RequestHandler must process incoming requests.
|
||||
//
|
||||
// ResponseHandler must call ctx.TimeoutError() before return
|
||||
// if it keeps references to ctx and/or its' members after the return.
|
||||
type RequestHandler func(ctx *RequestCtx)
|
||||
|
||||
// Server implements HTTP server.
|
||||
//
|
||||
// It is forbidden copying Server instances. Create new Server instances
|
||||
// instead.
|
||||
type Server struct {
|
||||
// Request handler
|
||||
// Handler for processing incoming requests.
|
||||
Handler RequestHandler
|
||||
|
||||
// Server name for sending in response headers.
|
||||
//
|
||||
// Default server name is used if left blank.
|
||||
Name string
|
||||
|
||||
// Per-connection buffer size for requests' reading.
|
||||
// This also limits the maximum header size.
|
||||
//
|
||||
// Default buffer size is used if 0.
|
||||
ReadBufferSize int
|
||||
|
||||
// Per-connection buffer size for responses' writing.
|
||||
//
|
||||
// Default buffer size is used if 0.
|
||||
WriteBufferSize int
|
||||
|
||||
// Maximum duration for full request reading (including body).
|
||||
@@ -40,17 +56,57 @@ type Server struct {
|
||||
// By default response write timeout is unlimited.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// Logger.
|
||||
// Logger, which is used by ServerCtx.Logger().
|
||||
//
|
||||
// By default standard logger from log package is used.
|
||||
Logger Logger
|
||||
|
||||
serverName atomic.Value
|
||||
ctxPool sync.Pool
|
||||
}
|
||||
|
||||
type RequestHandler func(ctx *RequestCtx)
|
||||
// TimeoutHandler returns StatusRequestTimeout error with the given msg
|
||||
// in the body to the client if h didn't return during the given duration.
|
||||
func TimeoutHandler(h RequestHandler, timeout time.Duration, msg string) RequestHandler {
|
||||
if timeout <= 0 {
|
||||
return h
|
||||
}
|
||||
|
||||
return func(ctx *RequestCtx) {
|
||||
ch := ctx.timeoutCh
|
||||
if ch == nil {
|
||||
ch = make(chan struct{}, 1)
|
||||
ctx.timeoutCh = ch
|
||||
}
|
||||
go func() {
|
||||
h(ctx)
|
||||
ch <- struct{}{}
|
||||
}()
|
||||
ctx.timeoutTimer = initTimer(ctx.timeoutTimer, timeout)
|
||||
select {
|
||||
case <-ch:
|
||||
case <-ctx.timeoutTimer.C:
|
||||
ctx.TimeoutError(msg)
|
||||
}
|
||||
stopTimer(ctx.timeoutTimer)
|
||||
}
|
||||
}
|
||||
|
||||
// RequestCtx contains incoming request and manages outgoing response.
|
||||
//
|
||||
// It is forbidden copying RequestCtx instances.
|
||||
//
|
||||
// RequestHandler should avoid holding references to incoming RequestCtx and/or
|
||||
// its' members after the return.
|
||||
// If holding RequestCtx references after the return is unavoidable
|
||||
// (for instance, ctx is passed to a separate goroutine and we cannot control
|
||||
// ctx lifetime in this goroutine), then the RequestHandler MUST call
|
||||
// ctx.TimeoutError() before return.
|
||||
type RequestCtx struct {
|
||||
Request Request
|
||||
// Incoming request.
|
||||
Request Request
|
||||
|
||||
// Outgoing response.
|
||||
Response Response
|
||||
|
||||
// Unique id of the request.
|
||||
@@ -70,6 +126,9 @@ type RequestCtx struct {
|
||||
w *bufio.Writer
|
||||
shadow unsafe.Pointer
|
||||
|
||||
timeoutCh chan struct{}
|
||||
timeoutTimer *time.Timer
|
||||
|
||||
v interface{}
|
||||
}
|
||||
|
||||
@@ -85,7 +144,9 @@ type writeDeadliner interface {
|
||||
SetWriteDeadline(time.Time) error
|
||||
}
|
||||
|
||||
// Logger is used for logging formatted messages.
|
||||
type Logger interface {
|
||||
// Printf must have the same semantics as log.Printf.
|
||||
Printf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
@@ -105,6 +166,7 @@ func (cl *ctxLogger) Printf(format string, args ...interface{}) {
|
||||
ctxLoggerLock.Unlock()
|
||||
}
|
||||
|
||||
// RemoteAddr returns client address for the given request.
|
||||
func (ctx *RequestCtx) RemoteAddr() string {
|
||||
x, ok := ctx.c.(remoteAddrer)
|
||||
if !ok {
|
||||
@@ -113,6 +175,7 @@ func (ctx *RequestCtx) RemoteAddr() string {
|
||||
return x.RemoteAddr().String()
|
||||
}
|
||||
|
||||
// RemoteIP returns client ip for the given request.
|
||||
func (ctx *RequestCtx) RemoteIP() string {
|
||||
addr := ctx.RemoteAddr()
|
||||
n := strings.LastIndexByte(addr, ':')
|
||||
@@ -122,6 +185,10 @@ func (ctx *RequestCtx) RemoteIP() string {
|
||||
return addr[:n]
|
||||
}
|
||||
|
||||
// Error sets response status code to the given value and sets response body
|
||||
// to the given message.
|
||||
//
|
||||
// Error calls are ignored after TimeoutError call.
|
||||
func (ctx *RequestCtx) Error(msg string, statusCode int) {
|
||||
resp := &ctx.Response
|
||||
resp.Clear()
|
||||
@@ -130,16 +197,35 @@ func (ctx *RequestCtx) Error(msg string, statusCode int) {
|
||||
resp.Body = append(resp.Body, []byte(msg)...)
|
||||
}
|
||||
|
||||
// Success sets response Content-Type and body to the given values.
|
||||
//
|
||||
// It is safe modifying body buffer after the Success() call.
|
||||
//
|
||||
// Success calls are ignored after TimeoutError call.
|
||||
func (ctx *RequestCtx) Success(contentType string, body []byte) {
|
||||
resp := &ctx.Response
|
||||
resp.Header.setStr(strContentType, contentType)
|
||||
resp.Body = append(resp.Body, body...)
|
||||
}
|
||||
|
||||
// Logger returns logger, which may be used for logging arbitrary
|
||||
// request-specific messages inside RequestHandler.
|
||||
//
|
||||
// Each message logged via returned logger contains request-specific information
|
||||
// such as request id, remote address, request method and request url.
|
||||
//
|
||||
// It is safe re-using returned logger for logging multiple messages.
|
||||
func (ctx *RequestCtx) Logger() Logger {
|
||||
return &ctx.logger
|
||||
}
|
||||
|
||||
// TimeoutError sets response status code to StatusRequestTimeout and sets
|
||||
// body to the given msg.
|
||||
//
|
||||
// All response modifications after TimeoutError call are ignored.
|
||||
//
|
||||
// TimeoutError MUST be called before returning from RequestHandler if there are
|
||||
// references to ctx and/or its members in other goroutines.
|
||||
func (ctx *RequestCtx) TimeoutError(msg string) {
|
||||
var shadow RequestCtx
|
||||
shadow.Request = Request{}
|
||||
@@ -159,10 +245,19 @@ func (ctx *RequestCtx) TimeoutError(msg string) {
|
||||
|
||||
const defaultConcurrency = 64 * 1024
|
||||
|
||||
// Serve serves incoming connections from the given listener.
|
||||
//
|
||||
// Serve blocks until the given listener returns permanent error.
|
||||
// This error is returned from Serve.
|
||||
func (s *Server) Serve(ln net.Listener) error {
|
||||
return s.ServeConcurrency(ln, defaultConcurrency)
|
||||
}
|
||||
|
||||
// ServeConcurrency serves incoming connections from the given listener.
|
||||
// It may serve maximum concurrency simultaneous connections.
|
||||
//
|
||||
// ServeConcurrency blocks until the given listener returns permanent error.
|
||||
// This error is returned from ServeConcurrency.
|
||||
func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error {
|
||||
ch := make(chan net.Conn, 16*concurrency)
|
||||
stopCh := make(chan struct{})
|
||||
@@ -188,6 +283,7 @@ func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error {
|
||||
|
||||
func connWorkersMonitor(s *Server, ch <-chan net.Conn, maxWorkers int, stopCh <-chan struct{}) {
|
||||
workersCount := uint32(0)
|
||||
var tc *time.Timer
|
||||
for {
|
||||
n := int(atomic.LoadUint32(&workersCount))
|
||||
pendingConns := len(ch)
|
||||
@@ -201,13 +297,14 @@ func connWorkersMonitor(s *Server, ch <-chan net.Conn, maxWorkers int, stopCh <-
|
||||
}
|
||||
runtime.Gosched()
|
||||
} else {
|
||||
tc := acquireTimer(100 * time.Millisecond)
|
||||
tc = initTimer(tc, 100*time.Millisecond)
|
||||
select {
|
||||
case <-stopCh:
|
||||
stopTimer(tc)
|
||||
return
|
||||
case <-tc.C:
|
||||
stopTimer(tc)
|
||||
}
|
||||
releaseTimer(tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -225,18 +322,20 @@ func connWorker(s *Server, ch <-chan net.Conn) {
|
||||
}
|
||||
}()
|
||||
|
||||
var tc *time.Timer
|
||||
for {
|
||||
select {
|
||||
case c = <-ch:
|
||||
default:
|
||||
tc := acquireTimer(time.Second)
|
||||
tc = initTimer(tc, time.Second)
|
||||
select {
|
||||
case c = <-ch:
|
||||
stopTimer(tc)
|
||||
case <-tc.C:
|
||||
stopTimer(tc)
|
||||
s.releaseCtx(ctx)
|
||||
return
|
||||
}
|
||||
releaseTimer(tc)
|
||||
}
|
||||
serveConn(s, c, &ctx)
|
||||
c.Close()
|
||||
@@ -279,6 +378,10 @@ func (s *Server) logger() Logger {
|
||||
return defaultLogger
|
||||
}
|
||||
|
||||
// ServeConn serves HTTP requests from the given connection.
|
||||
//
|
||||
// ServeConn returns nil if all requests from the c are successfully served.
|
||||
// It returns non-nil error otherwise.
|
||||
func (s *Server) ServeConn(c io.ReadWriter) error {
|
||||
ctx := s.acquireCtx()
|
||||
err := s.serveConn(c, &ctx)
|
||||
|
||||
@@ -11,6 +11,65 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTimeoutHandlerSuccess(t *testing.T) {
|
||||
h := func(ctx *RequestCtx) {
|
||||
ctx.Success("aaa/bbb", []byte("real response"))
|
||||
}
|
||||
s := &Server{
|
||||
Handler: TimeoutHandler(h, 100*time.Millisecond, "timeout!!!"),
|
||||
}
|
||||
|
||||
rw := &readWriter{}
|
||||
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
|
||||
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
ch <- s.ServeConn(rw)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error from serveConn: %s", err)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("timeout")
|
||||
}
|
||||
|
||||
br := bufio.NewReader(&rw.w)
|
||||
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
|
||||
}
|
||||
|
||||
func TestTimeoutHandlerTimeout(t *testing.T) {
|
||||
h := func(ctx *RequestCtx) {
|
||||
time.Sleep(time.Second)
|
||||
ctx.Success("aaa/bbb", []byte("this shouldn't pass to client because of timeout"))
|
||||
}
|
||||
s := &Server{
|
||||
Handler: TimeoutHandler(h, 10*time.Millisecond, "timeout!!!"),
|
||||
}
|
||||
|
||||
rw := &readWriter{}
|
||||
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
|
||||
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
ch <- s.ServeConn(rw)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error from serveConn: %s", err)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("timeout")
|
||||
}
|
||||
|
||||
br := bufio.NewReader(&rw.w)
|
||||
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!")
|
||||
}
|
||||
|
||||
func TestServerTimeoutError(t *testing.T) {
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
package fasthttp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var timerPool sync.Pool
|
||||
|
||||
func acquireTimer(timeout time.Duration) *time.Timer {
|
||||
tv := timerPool.Get()
|
||||
if tv == nil {
|
||||
func initTimer(t *time.Timer, timeout time.Duration) *time.Timer {
|
||||
if t == nil {
|
||||
return time.NewTimer(timeout)
|
||||
}
|
||||
|
||||
t := tv.(*time.Timer)
|
||||
if t.Reset(timeout) {
|
||||
panic("BUG: Active timer trapped into AcquireTimer()")
|
||||
panic("BUG: active timer trapped into initTimer()")
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func releaseTimer(t *time.Timer) {
|
||||
func stopTimer(t *time.Timer) {
|
||||
if !t.Stop() {
|
||||
// Collect possibly added time from the channel
|
||||
// if timer has been stopped and nobody collected its' value.
|
||||
@@ -29,6 +23,4 @@ func releaseTimer(t *time.Timer) {
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
timerPool.Put(t)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user