mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
Added ability to limit the number of concurrent client connections per ip
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
package fasthttp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type perIPConnCounter struct {
|
||||
lock sync.Mutex
|
||||
m map[uint32]int
|
||||
}
|
||||
|
||||
func (cc *perIPConnCounter) Register(ip uint32) int {
|
||||
cc.lock.Lock()
|
||||
if cc.m == nil {
|
||||
cc.m = make(map[uint32]int)
|
||||
}
|
||||
n := cc.m[ip] + 1
|
||||
cc.m[ip] = n
|
||||
cc.lock.Unlock()
|
||||
return n
|
||||
}
|
||||
|
||||
func (cc *perIPConnCounter) Unregister(ip uint32) {
|
||||
cc.lock.Lock()
|
||||
if cc.m == nil {
|
||||
cc.lock.Unlock()
|
||||
panic("BUG: perIPConnCounter.Register() wasn't called")
|
||||
}
|
||||
n := cc.m[ip] - 1
|
||||
if n < 0 {
|
||||
cc.lock.Unlock()
|
||||
panic(fmt.Sprintf("BUG: negative per-ip counter=%d for ip=%d", n, ip))
|
||||
}
|
||||
cc.m[ip] = n
|
||||
cc.lock.Unlock()
|
||||
}
|
||||
|
||||
type perIPConn struct {
|
||||
net.Conn
|
||||
|
||||
ip uint32
|
||||
perIPConnCounter *perIPConnCounter
|
||||
}
|
||||
|
||||
func (c *perIPConn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
c.perIPConnCounter.Unregister(c.ip)
|
||||
return err
|
||||
}
|
||||
|
||||
func getUint32IP(c net.Conn) uint32 {
|
||||
addr := c.RemoteAddr()
|
||||
ipAddr, ok := addr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return ip2uint32(ipAddr.IP.To4())
|
||||
}
|
||||
|
||||
func ip2uint32(ip net.IP) uint32 {
|
||||
if len(ip) != 4 {
|
||||
return 0
|
||||
}
|
||||
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
|
||||
}
|
||||
|
||||
func uint322ip(ip uint32) net.IP {
|
||||
b := make([]byte, 4)
|
||||
b[0] = byte(ip >> 24)
|
||||
b[1] = byte(ip >> 16)
|
||||
b[2] = byte(ip >> 8)
|
||||
b[3] = byte(ip)
|
||||
return b
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package fasthttp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPxUint32(t *testing.T) {
|
||||
testIPxUint32(t, 0)
|
||||
testIPxUint32(t, 10)
|
||||
testIPxUint32(t, 0x12892392)
|
||||
}
|
||||
|
||||
func testIPxUint32(t *testing.T, n uint32) {
|
||||
ip := uint322ip(n)
|
||||
nn := ip2uint32(ip)
|
||||
if n != nn {
|
||||
t.Fatalf("Unexpected value=%d for ip=%s. Expected %d", nn, ip, n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerIPConnCounter(t *testing.T) {
|
||||
var cc perIPConnCounter
|
||||
|
||||
expectPanic(t, func() { cc.Unregister(123) })
|
||||
|
||||
for i := 1; i < 100; i++ {
|
||||
if n := cc.Register(123); n != i {
|
||||
t.Fatalf("Unexpected counter value=%d. Expected %d", n, i)
|
||||
}
|
||||
}
|
||||
|
||||
n := cc.Register(456)
|
||||
if n != 1 {
|
||||
t.Fatalf("Unexpected counter value=%d. Expected 1", n)
|
||||
}
|
||||
|
||||
for i := 1; i < 100; i++ {
|
||||
cc.Unregister(123)
|
||||
}
|
||||
cc.Unregister(456)
|
||||
|
||||
expectPanic(t, func() { cc.Unregister(123) })
|
||||
expectPanic(t, func() { cc.Unregister(456) })
|
||||
|
||||
n = cc.Register(123)
|
||||
if n != 1 {
|
||||
t.Fatalf("Unexpected counter value=%d. Expected 1", n)
|
||||
}
|
||||
cc.Unregister(123)
|
||||
}
|
||||
|
||||
func expectPanic(t *testing.T, f func()) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatalf("Expecting panic")
|
||||
}
|
||||
}()
|
||||
f()
|
||||
}
|
||||
@@ -55,13 +55,20 @@ type Server struct {
|
||||
// By default response write timeout is unlimited.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// Maximum number of concurrent client connections allowed per IP.
|
||||
//
|
||||
// By default unlimited number of concurrent connections
|
||||
// may be established to the server from a single IP address.
|
||||
MaxConnsPerIP int
|
||||
|
||||
// Logger, which is used by ServerCtx.Logger().
|
||||
//
|
||||
// By default standard logger from log package is used.
|
||||
Logger Logger
|
||||
|
||||
serverName atomic.Value
|
||||
ctxPool sync.Pool
|
||||
perIPConnCounter perIPConnCounter
|
||||
serverName atomic.Value
|
||||
ctxPool sync.Pool
|
||||
}
|
||||
|
||||
// TimeoutHandler creates RequestHandler, which returns StatusRequestTimeout
|
||||
@@ -182,8 +189,6 @@ func (ctx *RequestCtx) RemoteAddr() net.Addr {
|
||||
}
|
||||
|
||||
// RemoteIP returns client ip for the given request.
|
||||
//
|
||||
// Nil is returned if client ip cannot be determined.
|
||||
func (ctx *RequestCtx) RemoteIP() net.IP {
|
||||
x, ok := ctx.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
@@ -261,8 +266,9 @@ func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error {
|
||||
stopCh := make(chan struct{})
|
||||
go connWorkersMonitor(s, ch, concurrency, stopCh)
|
||||
var lastOverflowErrorTime time.Time
|
||||
var lastPerIPErrorTime time.Time
|
||||
for {
|
||||
c, err := acceptConn(s, ln)
|
||||
c, err := acceptConn(s, ln, &lastPerIPErrorTime)
|
||||
if err != nil {
|
||||
close(stopCh)
|
||||
return err
|
||||
@@ -271,7 +277,7 @@ func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error {
|
||||
case ch <- c:
|
||||
default:
|
||||
c.Close()
|
||||
if time.Since(lastOverflowErrorTime) > time.Second*10 {
|
||||
if time.Since(lastOverflowErrorTime) > time.Minute {
|
||||
s.logger().Printf("The incoming connection cannot be served, because all %d workers are busy. "+
|
||||
"Try increasing concurrency in Server.ServeWorkers()", concurrency)
|
||||
lastOverflowErrorTime = time.Now()
|
||||
@@ -342,7 +348,8 @@ func connWorker(s *Server, ch <-chan net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func acceptConn(s *Server, ln net.Listener) (net.Conn, error) {
|
||||
func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
|
||||
maxConnsPerIP := s.MaxConnsPerIP
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
@@ -356,6 +363,28 @@ func acceptConn(s *Server, ln net.Listener) (net.Conn, error) {
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if maxConnsPerIP > 0 {
|
||||
ip := getUint32IP(c)
|
||||
if ip == 0 {
|
||||
return c, nil
|
||||
}
|
||||
n := s.perIPConnCounter.Register(ip)
|
||||
if n > maxConnsPerIP {
|
||||
if time.Since(*lastPerIPErrorTime) > time.Minute {
|
||||
s.logger().Printf("Too many connections from ip %s: %d. MaxConnsPerIP=%d",
|
||||
uint322ip(ip), n, maxConnsPerIP)
|
||||
*lastPerIPErrorTime = time.Now()
|
||||
}
|
||||
c.Close()
|
||||
s.perIPConnCounter.Unregister(ip)
|
||||
continue
|
||||
}
|
||||
return &perIPConn{
|
||||
Conn: c,
|
||||
ip: ip,
|
||||
perIPConnCounter: &s.perIPConnCounter,
|
||||
}, nil
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
+20
-1
@@ -111,6 +111,22 @@ func BenchmarkNetHTTPServerGet10KReqPerConn1KClients(b *testing.B) {
|
||||
benchmarkNetHTTPServerGet(b, 1000, 10000)
|
||||
}
|
||||
|
||||
func BenchmarkServerMaxConnsPerIP(b *testing.B) {
|
||||
clientsCount := 1000
|
||||
requestsPerConn := 10
|
||||
ch := make(chan struct{}, b.N)
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
ctx.Success("foobar", []byte("123"))
|
||||
registerServedRequest(b, ch)
|
||||
},
|
||||
MaxConnsPerIP: clientsCount * 2,
|
||||
}
|
||||
req := "GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n"
|
||||
benchmarkServer(b, &testServer{s, clientsCount}, clientsCount, requestsPerConn, req)
|
||||
verifyRequestsServed(b, ch)
|
||||
}
|
||||
|
||||
func BenchmarkServerTimeoutError(b *testing.B) {
|
||||
clientsCount := 1
|
||||
requestsPerConn := 10
|
||||
@@ -165,7 +181,10 @@ func (c *fakeServerConn) Write(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
var fakeAddr net.TCPAddr
|
||||
var fakeAddr = net.TCPAddr{
|
||||
IP: []byte{1, 2, 3, 4},
|
||||
Port: 12345,
|
||||
}
|
||||
|
||||
func (c *fakeServerConn) RemoteAddr() net.Addr {
|
||||
return &fakeAddr
|
||||
|
||||
Reference in New Issue
Block a user