mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-25 17:45:28 +03:00
Enforce MaxConnsPerIP limit to connections served via Server.ServeConn()
This commit is contained in:
+6
-2
@@ -51,12 +51,16 @@ func (c *perIPConn) Close() error {
|
||||
}
|
||||
|
||||
func getUint32IP(c net.Conn) uint32 {
|
||||
return ip2uint32(getConnIP4(c))
|
||||
}
|
||||
|
||||
func getConnIP4(c net.Conn) net.IP {
|
||||
addr := c.RemoteAddr()
|
||||
ipAddr, ok := addr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return 0
|
||||
return net.IPv4zero
|
||||
}
|
||||
return ip2uint32(ipAddr.IP.To4())
|
||||
return ipAddr.IP.To4()
|
||||
}
|
||||
|
||||
func ip2uint32(ip net.IP) uint32 {
|
||||
|
||||
@@ -2,6 +2,7 @@ package fasthttp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -131,7 +132,7 @@ type RequestCtx struct {
|
||||
|
||||
logger ctxLogger
|
||||
s *Server
|
||||
c io.ReadWriter
|
||||
c io.ReadWriteCloser
|
||||
fbr firstByteReader
|
||||
|
||||
// shadow is set by TimeoutError().
|
||||
@@ -144,7 +145,7 @@ type RequestCtx struct {
|
||||
}
|
||||
|
||||
type firstByteReader struct {
|
||||
c io.ReadWriter
|
||||
c io.ReadWriteCloser
|
||||
ch byte
|
||||
byteRead bool
|
||||
}
|
||||
@@ -285,7 +286,7 @@ func (s *Server) Serve(ln net.Listener) error {
|
||||
// ServeConcurrency blocks until the given listener returns permanent error.
|
||||
// This error is returned from ServeConcurrency.
|
||||
func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error {
|
||||
ch := make(chan net.Conn, 16*concurrency)
|
||||
ch := make(chan net.Conn, 2*concurrency)
|
||||
stopCh := make(chan struct{})
|
||||
go connWorkersMonitor(s, ch, concurrency, stopCh)
|
||||
var lastOverflowErrorTime time.Time
|
||||
@@ -338,16 +339,13 @@ func connWorkersMonitor(s *Server, ch <-chan net.Conn, maxWorkers int, stopCh <-
|
||||
}
|
||||
|
||||
func connWorker(s *Server, ch <-chan net.Conn) {
|
||||
var c net.Conn
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
s.logger().Printf("panic: %s\nStack trace:\n%s", r, debug.Stack())
|
||||
}
|
||||
if c != nil {
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
var c net.Conn
|
||||
var tc *time.Timer
|
||||
for {
|
||||
select {
|
||||
@@ -362,14 +360,12 @@ func connWorker(s *Server, ch <-chan net.Conn) {
|
||||
return
|
||||
}
|
||||
}
|
||||
s.ServeConn(c)
|
||||
c.Close()
|
||||
s.serveConn(c)
|
||||
c = nil
|
||||
}
|
||||
}
|
||||
|
||||
func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
|
||||
maxConnsPerIP := s.MaxConnsPerIP
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
@@ -383,32 +379,40 @@ func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if maxConnsPerIP > 0 {
|
||||
ip := getUint32IP(c)
|
||||
if ip == 0 {
|
||||
return c, nil
|
||||
}
|
||||
n := s.perIPConnCounter.Register(ip)
|
||||
if n > maxConnsPerIP {
|
||||
if s.MaxConnsPerIP > 0 {
|
||||
pic := wrapPerIPConn(s, c)
|
||||
if pic == nil {
|
||||
c.Close()
|
||||
if time.Since(*lastPerIPErrorTime) > time.Minute {
|
||||
s.logger().Printf("Too many connections from ip %s: %d. MaxConnsPerIP=%d",
|
||||
uint322ip(ip), n, maxConnsPerIP)
|
||||
s.logger().Printf("The number of connections from %s exceeds MaxConnsPerIP=%d",
|
||||
getConnIP4(c), s.MaxConnsPerIP)
|
||||
*lastPerIPErrorTime = time.Now()
|
||||
}
|
||||
c.Close()
|
||||
s.perIPConnCounter.Unregister(ip)
|
||||
continue
|
||||
}
|
||||
return &perIPConn{
|
||||
Conn: c,
|
||||
ip: ip,
|
||||
perIPConnCounter: &s.perIPConnCounter,
|
||||
}, nil
|
||||
return pic, nil
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
func wrapPerIPConn(s *Server, c net.Conn) net.Conn {
|
||||
ip := getUint32IP(c)
|
||||
if ip == 0 {
|
||||
return c
|
||||
}
|
||||
n := s.perIPConnCounter.Register(ip)
|
||||
if n > s.MaxConnsPerIP {
|
||||
s.perIPConnCounter.Unregister(ip)
|
||||
return nil
|
||||
}
|
||||
return &perIPConn{
|
||||
Conn: c,
|
||||
ip: ip,
|
||||
perIPConnCounter: &s.perIPConnCounter,
|
||||
}
|
||||
}
|
||||
|
||||
var defaultLogger = log.New(os.Stderr, "", log.LstdFlags)
|
||||
|
||||
func (s *Server) logger() Logger {
|
||||
@@ -418,11 +422,33 @@ func (s *Server) logger() Logger {
|
||||
return defaultLogger
|
||||
}
|
||||
|
||||
// ErrPerIPConnLimit may be returned from ServeConn if the number of connections
|
||||
// per ip exceeds Server.MaxConnsPerIP.
|
||||
var ErrPerIPConnLimit = errors.New("too many connections per ip")
|
||||
|
||||
// ServeConn serves HTTP requests from the given connection.
|
||||
//
|
||||
// ServeConn returns nil if all requests from the c are successfully served.
|
||||
// It returns non-nil error otherwise.
|
||||
func (s *Server) ServeConn(c io.ReadWriter) error {
|
||||
//
|
||||
// Connection c must immediately propagate all the data passed to Write()
|
||||
// to the client. Otherwise requests' processing may hang.
|
||||
//
|
||||
// ServeConn closes c before returning.
|
||||
func (s *Server) ServeConn(c io.ReadWriteCloser) error {
|
||||
conn, ok := c.(net.Conn)
|
||||
if ok {
|
||||
pic := wrapPerIPConn(s, conn)
|
||||
if pic == nil {
|
||||
c.Close()
|
||||
return ErrPerIPConnLimit
|
||||
}
|
||||
c = pic
|
||||
}
|
||||
return s.serveConn(c)
|
||||
}
|
||||
|
||||
func (s *Server) serveConn(c io.ReadWriteCloser) error {
|
||||
var rd readDeadliner
|
||||
readTimeout := s.ReadTimeout
|
||||
if readTimeout > 0 {
|
||||
@@ -553,6 +579,11 @@ func (s *Server) ServeConn(c io.ReadWriter) error {
|
||||
ctx.Logger().Printf("Error when serving network connection: %s", err)
|
||||
}
|
||||
s.releaseCtx(ctx)
|
||||
|
||||
err1 := c.Close()
|
||||
if err == nil {
|
||||
err = err1
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -671,7 +702,7 @@ func releaseWriter(ctx *RequestCtx, w *bufio.Writer) {
|
||||
|
||||
var globalCtxID uint64
|
||||
|
||||
func (s *Server) acquireCtx(c io.ReadWriter) *RequestCtx {
|
||||
func (s *Server) acquireCtx(c io.ReadWriteCloser) *RequestCtx {
|
||||
v := s.ctxPool.Get()
|
||||
var ctx *RequestCtx
|
||||
if v == nil {
|
||||
|
||||
+9
-1
@@ -279,10 +279,14 @@ func TestServerRemoteAddr(t *testing.T) {
|
||||
}
|
||||
|
||||
type readWriterRemoteAddr struct {
|
||||
rw io.ReadWriter
|
||||
rw io.ReadWriteCloser
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func (rw *readWriterRemoteAddr) Close() error {
|
||||
return rw.rw.Close()
|
||||
}
|
||||
|
||||
func (rw *readWriterRemoteAddr) Read(b []byte) (int, error) {
|
||||
return rw.rw.Read(b)
|
||||
}
|
||||
@@ -414,6 +418,10 @@ type readWriter struct {
|
||||
w bytes.Buffer
|
||||
}
|
||||
|
||||
func (rw *readWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *readWriter) Read(b []byte) (int, error) {
|
||||
return rw.r.Read(b)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user