diff --git a/peripconn.go b/peripconn.go index 2568477..cf8e986 100644 --- a/peripconn.go +++ b/peripconn.go @@ -51,12 +51,16 @@ func (c *perIPConn) Close() error { } func getUint32IP(c net.Conn) uint32 { + return ip2uint32(getConnIP4(c)) +} + +func getConnIP4(c net.Conn) net.IP { addr := c.RemoteAddr() ipAddr, ok := addr.(*net.TCPAddr) if !ok { - return 0 + return net.IPv4zero } - return ip2uint32(ipAddr.IP.To4()) + return ipAddr.IP.To4() } func ip2uint32(ip net.IP) uint32 { diff --git a/server.go b/server.go index afc2147..f21fec3 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package fasthttp import ( "bufio" + "errors" "fmt" "io" "log" @@ -131,7 +132,7 @@ type RequestCtx struct { logger ctxLogger s *Server - c io.ReadWriter + c io.ReadWriteCloser fbr firstByteReader // shadow is set by TimeoutError(). @@ -144,7 +145,7 @@ type RequestCtx struct { } type firstByteReader struct { - c io.ReadWriter + c io.ReadWriteCloser ch byte byteRead bool } @@ -285,7 +286,7 @@ func (s *Server) Serve(ln net.Listener) error { // ServeConcurrency blocks until the given listener returns permanent error. // This error is returned from ServeConcurrency. func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error { - ch := make(chan net.Conn, 16*concurrency) + ch := make(chan net.Conn, 2*concurrency) stopCh := make(chan struct{}) go connWorkersMonitor(s, ch, concurrency, stopCh) var lastOverflowErrorTime time.Time @@ -338,16 +339,13 @@ func connWorkersMonitor(s *Server, ch <-chan net.Conn, maxWorkers int, stopCh <- } func connWorker(s *Server, ch <-chan net.Conn) { - var c net.Conn defer func() { if r := recover(); r != nil { s.logger().Printf("panic: %s\nStack trace:\n%s", r, debug.Stack()) } - if c != nil { - c.Close() - } }() + var c net.Conn var tc *time.Timer for { select { @@ -362,14 +360,12 @@ func connWorker(s *Server, ch <-chan net.Conn) { return } } - s.ServeConn(c) - c.Close() + s.serveConn(c) c = nil } } func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) { - maxConnsPerIP := s.MaxConnsPerIP for { c, err := ln.Accept() if err != nil { @@ -383,32 +379,40 @@ func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net. } return nil, err } - if maxConnsPerIP > 0 { - ip := getUint32IP(c) - if ip == 0 { - return c, nil - } - n := s.perIPConnCounter.Register(ip) - if n > maxConnsPerIP { + if s.MaxConnsPerIP > 0 { + pic := wrapPerIPConn(s, c) + if pic == nil { + c.Close() if time.Since(*lastPerIPErrorTime) > time.Minute { - s.logger().Printf("Too many connections from ip %s: %d. MaxConnsPerIP=%d", - uint322ip(ip), n, maxConnsPerIP) + s.logger().Printf("The number of connections from %s exceeds MaxConnsPerIP=%d", + getConnIP4(c), s.MaxConnsPerIP) *lastPerIPErrorTime = time.Now() } - c.Close() - s.perIPConnCounter.Unregister(ip) continue } - return &perIPConn{ - Conn: c, - ip: ip, - perIPConnCounter: &s.perIPConnCounter, - }, nil + return pic, nil } return c, nil } } +func wrapPerIPConn(s *Server, c net.Conn) net.Conn { + ip := getUint32IP(c) + if ip == 0 { + return c + } + n := s.perIPConnCounter.Register(ip) + if n > s.MaxConnsPerIP { + s.perIPConnCounter.Unregister(ip) + return nil + } + return &perIPConn{ + Conn: c, + ip: ip, + perIPConnCounter: &s.perIPConnCounter, + } +} + var defaultLogger = log.New(os.Stderr, "", log.LstdFlags) func (s *Server) logger() Logger { @@ -418,11 +422,33 @@ func (s *Server) logger() Logger { return defaultLogger } +// ErrPerIPConnLimit may be returned from ServeConn if the number of connections +// per ip exceeds Server.MaxConnsPerIP. +var ErrPerIPConnLimit = errors.New("too many connections per ip") + // ServeConn serves HTTP requests from the given connection. // // ServeConn returns nil if all requests from the c are successfully served. // It returns non-nil error otherwise. -func (s *Server) ServeConn(c io.ReadWriter) error { +// +// Connection c must immediately propagate all the data passed to Write() +// to the client. Otherwise requests' processing may hang. +// +// ServeConn closes c before returning. +func (s *Server) ServeConn(c io.ReadWriteCloser) error { + conn, ok := c.(net.Conn) + if ok { + pic := wrapPerIPConn(s, conn) + if pic == nil { + c.Close() + return ErrPerIPConnLimit + } + c = pic + } + return s.serveConn(c) +} + +func (s *Server) serveConn(c io.ReadWriteCloser) error { var rd readDeadliner readTimeout := s.ReadTimeout if readTimeout > 0 { @@ -553,6 +579,11 @@ func (s *Server) ServeConn(c io.ReadWriter) error { ctx.Logger().Printf("Error when serving network connection: %s", err) } s.releaseCtx(ctx) + + err1 := c.Close() + if err == nil { + err = err1 + } return err } @@ -671,7 +702,7 @@ func releaseWriter(ctx *RequestCtx, w *bufio.Writer) { var globalCtxID uint64 -func (s *Server) acquireCtx(c io.ReadWriter) *RequestCtx { +func (s *Server) acquireCtx(c io.ReadWriteCloser) *RequestCtx { v := s.ctxPool.Get() var ctx *RequestCtx if v == nil { diff --git a/server_test.go b/server_test.go index 68994eb..28a2c8d 100644 --- a/server_test.go +++ b/server_test.go @@ -279,10 +279,14 @@ func TestServerRemoteAddr(t *testing.T) { } type readWriterRemoteAddr struct { - rw io.ReadWriter + rw io.ReadWriteCloser addr net.Addr } +func (rw *readWriterRemoteAddr) Close() error { + return rw.rw.Close() +} + func (rw *readWriterRemoteAddr) Read(b []byte) (int, error) { return rw.rw.Read(b) } @@ -414,6 +418,10 @@ type readWriter struct { w bytes.Buffer } +func (rw *readWriter) Close() error { + return nil +} + func (rw *readWriter) Read(b []byte) (int, error) { return rw.r.Read(b) }