Added ability to limit the number of concurrent client connections per ip

This commit is contained in:
Aliaksandr Valialkin
2015-11-02 15:09:45 +02:00
parent 7c83bade48
commit 3eaecd9c6c
4 changed files with 191 additions and 8 deletions
+76
View File
@@ -0,0 +1,76 @@
package fasthttp
import (
"fmt"
"net"
"sync"
)
type perIPConnCounter struct {
lock sync.Mutex
m map[uint32]int
}
func (cc *perIPConnCounter) Register(ip uint32) int {
cc.lock.Lock()
if cc.m == nil {
cc.m = make(map[uint32]int)
}
n := cc.m[ip] + 1
cc.m[ip] = n
cc.lock.Unlock()
return n
}
func (cc *perIPConnCounter) Unregister(ip uint32) {
cc.lock.Lock()
if cc.m == nil {
cc.lock.Unlock()
panic("BUG: perIPConnCounter.Register() wasn't called")
}
n := cc.m[ip] - 1
if n < 0 {
cc.lock.Unlock()
panic(fmt.Sprintf("BUG: negative per-ip counter=%d for ip=%d", n, ip))
}
cc.m[ip] = n
cc.lock.Unlock()
}
type perIPConn struct {
net.Conn
ip uint32
perIPConnCounter *perIPConnCounter
}
func (c *perIPConn) Close() error {
err := c.Conn.Close()
c.perIPConnCounter.Unregister(c.ip)
return err
}
func getUint32IP(c net.Conn) uint32 {
addr := c.RemoteAddr()
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return 0
}
return ip2uint32(ipAddr.IP.To4())
}
func ip2uint32(ip net.IP) uint32 {
if len(ip) != 4 {
return 0
}
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}
func uint322ip(ip uint32) net.IP {
b := make([]byte, 4)
b[0] = byte(ip >> 24)
b[1] = byte(ip >> 16)
b[2] = byte(ip >> 8)
b[3] = byte(ip)
return b
}
+59
View File
@@ -0,0 +1,59 @@
package fasthttp
import (
"testing"
)
func TestIPxUint32(t *testing.T) {
testIPxUint32(t, 0)
testIPxUint32(t, 10)
testIPxUint32(t, 0x12892392)
}
func testIPxUint32(t *testing.T, n uint32) {
ip := uint322ip(n)
nn := ip2uint32(ip)
if n != nn {
t.Fatalf("Unexpected value=%d for ip=%s. Expected %d", nn, ip, n)
}
}
func TestPerIPConnCounter(t *testing.T) {
var cc perIPConnCounter
expectPanic(t, func() { cc.Unregister(123) })
for i := 1; i < 100; i++ {
if n := cc.Register(123); n != i {
t.Fatalf("Unexpected counter value=%d. Expected %d", n, i)
}
}
n := cc.Register(456)
if n != 1 {
t.Fatalf("Unexpected counter value=%d. Expected 1", n)
}
for i := 1; i < 100; i++ {
cc.Unregister(123)
}
cc.Unregister(456)
expectPanic(t, func() { cc.Unregister(123) })
expectPanic(t, func() { cc.Unregister(456) })
n = cc.Register(123)
if n != 1 {
t.Fatalf("Unexpected counter value=%d. Expected 1", n)
}
cc.Unregister(123)
}
func expectPanic(t *testing.T, f func()) {
defer func() {
if r := recover(); r == nil {
t.Fatalf("Expecting panic")
}
}()
f()
}
+36 -7
View File
@@ -55,13 +55,20 @@ type Server struct {
// By default response write timeout is unlimited.
WriteTimeout time.Duration
// Maximum number of concurrent client connections allowed per IP.
//
// By default unlimited number of concurrent connections
// may be established to the server from a single IP address.
MaxConnsPerIP int
// Logger, which is used by ServerCtx.Logger().
//
// By default standard logger from log package is used.
Logger Logger
serverName atomic.Value
ctxPool sync.Pool
perIPConnCounter perIPConnCounter
serverName atomic.Value
ctxPool sync.Pool
}
// TimeoutHandler creates RequestHandler, which returns StatusRequestTimeout
@@ -182,8 +189,6 @@ func (ctx *RequestCtx) RemoteAddr() net.Addr {
}
// RemoteIP returns client ip for the given request.
//
// Nil is returned if client ip cannot be determined.
func (ctx *RequestCtx) RemoteIP() net.IP {
x, ok := ctx.RemoteAddr().(*net.TCPAddr)
if !ok {
@@ -261,8 +266,9 @@ func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error {
stopCh := make(chan struct{})
go connWorkersMonitor(s, ch, concurrency, stopCh)
var lastOverflowErrorTime time.Time
var lastPerIPErrorTime time.Time
for {
c, err := acceptConn(s, ln)
c, err := acceptConn(s, ln, &lastPerIPErrorTime)
if err != nil {
close(stopCh)
return err
@@ -271,7 +277,7 @@ func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error {
case ch <- c:
default:
c.Close()
if time.Since(lastOverflowErrorTime) > time.Second*10 {
if time.Since(lastOverflowErrorTime) > time.Minute {
s.logger().Printf("The incoming connection cannot be served, because all %d workers are busy. "+
"Try increasing concurrency in Server.ServeWorkers()", concurrency)
lastOverflowErrorTime = time.Now()
@@ -342,7 +348,8 @@ func connWorker(s *Server, ch <-chan net.Conn) {
}
}
func acceptConn(s *Server, ln net.Listener) (net.Conn, error) {
func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
maxConnsPerIP := s.MaxConnsPerIP
for {
c, err := ln.Accept()
if err != nil {
@@ -356,6 +363,28 @@ func acceptConn(s *Server, ln net.Listener) (net.Conn, error) {
}
return nil, err
}
if maxConnsPerIP > 0 {
ip := getUint32IP(c)
if ip == 0 {
return c, nil
}
n := s.perIPConnCounter.Register(ip)
if n > maxConnsPerIP {
if time.Since(*lastPerIPErrorTime) > time.Minute {
s.logger().Printf("Too many connections from ip %s: %d. MaxConnsPerIP=%d",
uint322ip(ip), n, maxConnsPerIP)
*lastPerIPErrorTime = time.Now()
}
c.Close()
s.perIPConnCounter.Unregister(ip)
continue
}
return &perIPConn{
Conn: c,
ip: ip,
perIPConnCounter: &s.perIPConnCounter,
}, nil
}
return c, nil
}
}
+20 -1
View File
@@ -111,6 +111,22 @@ func BenchmarkNetHTTPServerGet10KReqPerConn1KClients(b *testing.B) {
benchmarkNetHTTPServerGet(b, 1000, 10000)
}
func BenchmarkServerMaxConnsPerIP(b *testing.B) {
clientsCount := 1000
requestsPerConn := 10
ch := make(chan struct{}, b.N)
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("foobar", []byte("123"))
registerServedRequest(b, ch)
},
MaxConnsPerIP: clientsCount * 2,
}
req := "GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n"
benchmarkServer(b, &testServer{s, clientsCount}, clientsCount, requestsPerConn, req)
verifyRequestsServed(b, ch)
}
func BenchmarkServerTimeoutError(b *testing.B) {
clientsCount := 1
requestsPerConn := 10
@@ -165,7 +181,10 @@ func (c *fakeServerConn) Write(b []byte) (int, error) {
return len(b), nil
}
var fakeAddr net.TCPAddr
var fakeAddr = net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 12345,
}
func (c *fakeServerConn) RemoteAddr() net.Addr {
return &fakeAddr