diff --git a/server.go b/server.go index 20f289d..481a53a 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( "log" "mime/multipart" "net" + "net/http" "os" "strings" "sync" @@ -276,6 +277,11 @@ type Server struct { // value is explicitly provided during a request. NoDefaultServerHeader bool + // ConnState specifies an optional callback function that is + // called when a client connection changes state. See the + // ConnState type and associated constants for details. + ConnState func(net.Conn, http.ConnState) + // Logger, which is used by RequestCtx.Logger(). // // By default standard logger from log package is used. @@ -1372,6 +1378,7 @@ func (s *Server) Serve(ln net.Listener) error { MaxWorkersCount: maxWorkersCount, LogAllErrors: s.LogAllErrors, Logger: s.logger(), + connState: s.setState, } wp.Start() @@ -1383,12 +1390,14 @@ func (s *Server) Serve(ln net.Listener) error { } return err } + s.setState(c, http.StateNew) s.wg.Add(1) if !wp.Serve(c) { s.wg.Done() s.writeFastError(c, StatusServiceUnavailable, "The connection cannot be served because Server.Concurrency limit exceeded") c.Close() + s.setState(c, http.StateClosed) if time.Since(lastOverflowErrorTime) > time.Minute { s.logger().Printf("The incoming connection cannot be served, because %d concurrent connections are served. "+ "Try increasing Server.Concurrency", maxWorkersCount) @@ -1634,6 +1643,10 @@ func (s *Server) serveConn(c net.Conn) error { } // reading Headers and Body err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly) + if br.Buffered() > 0 { + // If we read any bytes off the wire, we're active. + s.setState(c, http.StateActive) + } if br.Buffered() == 0 || err != nil { releaseReader(s, br) br = nil @@ -1783,6 +1796,7 @@ func (s *Server) serveConn(c net.Conn) error { } currentTime = time.Now() + s.setState(c, http.StateIdle) } if br != nil { @@ -1795,6 +1809,12 @@ func (s *Server) serveConn(c net.Conn) error { return err } +func (s *Server) setState(nc net.Conn, state http.ConnState) { + if hook := s.ConnState; hook != nil { + hook(nc, state) + } +} + func (s *Server) updateReadDeadline(c net.Conn, ctx *RequestCtx, lastDeadlineTime time.Time) time.Time { readTimeout := s.ReadTimeout currentTime := ctx.time diff --git a/workerpool.go b/workerpool.go index 6ce2778..0e44a27 100644 --- a/workerpool.go +++ b/workerpool.go @@ -2,6 +2,7 @@ package fasthttp import ( "net" + "net/http" "runtime" "strings" "sync" @@ -35,6 +36,8 @@ type workerPool struct { stopCh chan struct{} workerChanPool sync.Pool + + connState func(net.Conn, http.ConnState) } type workerChan struct { @@ -216,8 +219,11 @@ func (wp *workerPool) workerFunc(ch *workerChan) { wp.Logger.Printf("error when serving connection %q<->%q: %s", c.LocalAddr(), c.RemoteAddr(), err) } } - if err != errHijacked { + if err == errHijacked { + wp.connState(c, http.StateHijacked) + } else { c.Close() + wp.connState(c, http.StateClosed) } c = nil