diff --git a/peripconn.go b/peripconn.go new file mode 100644 index 0000000..2568477 --- /dev/null +++ b/peripconn.go @@ -0,0 +1,76 @@ +package fasthttp + +import ( + "fmt" + "net" + "sync" +) + +type perIPConnCounter struct { + lock sync.Mutex + m map[uint32]int +} + +func (cc *perIPConnCounter) Register(ip uint32) int { + cc.lock.Lock() + if cc.m == nil { + cc.m = make(map[uint32]int) + } + n := cc.m[ip] + 1 + cc.m[ip] = n + cc.lock.Unlock() + return n +} + +func (cc *perIPConnCounter) Unregister(ip uint32) { + cc.lock.Lock() + if cc.m == nil { + cc.lock.Unlock() + panic("BUG: perIPConnCounter.Register() wasn't called") + } + n := cc.m[ip] - 1 + if n < 0 { + cc.lock.Unlock() + panic(fmt.Sprintf("BUG: negative per-ip counter=%d for ip=%d", n, ip)) + } + cc.m[ip] = n + cc.lock.Unlock() +} + +type perIPConn struct { + net.Conn + + ip uint32 + perIPConnCounter *perIPConnCounter +} + +func (c *perIPConn) Close() error { + err := c.Conn.Close() + c.perIPConnCounter.Unregister(c.ip) + return err +} + +func getUint32IP(c net.Conn) uint32 { + addr := c.RemoteAddr() + ipAddr, ok := addr.(*net.TCPAddr) + if !ok { + return 0 + } + return ip2uint32(ipAddr.IP.To4()) +} + +func ip2uint32(ip net.IP) uint32 { + if len(ip) != 4 { + return 0 + } + return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) +} + +func uint322ip(ip uint32) net.IP { + b := make([]byte, 4) + b[0] = byte(ip >> 24) + b[1] = byte(ip >> 16) + b[2] = byte(ip >> 8) + b[3] = byte(ip) + return b +} diff --git a/peripconn_test.go b/peripconn_test.go new file mode 100644 index 0000000..7cfb0d8 --- /dev/null +++ b/peripconn_test.go @@ -0,0 +1,59 @@ +package fasthttp + +import ( + "testing" +) + +func TestIPxUint32(t *testing.T) { + testIPxUint32(t, 0) + testIPxUint32(t, 10) + testIPxUint32(t, 0x12892392) +} + +func testIPxUint32(t *testing.T, n uint32) { + ip := uint322ip(n) + nn := ip2uint32(ip) + if n != nn { + t.Fatalf("Unexpected value=%d for ip=%s. Expected %d", nn, ip, n) + } +} + +func TestPerIPConnCounter(t *testing.T) { + var cc perIPConnCounter + + expectPanic(t, func() { cc.Unregister(123) }) + + for i := 1; i < 100; i++ { + if n := cc.Register(123); n != i { + t.Fatalf("Unexpected counter value=%d. Expected %d", n, i) + } + } + + n := cc.Register(456) + if n != 1 { + t.Fatalf("Unexpected counter value=%d. Expected 1", n) + } + + for i := 1; i < 100; i++ { + cc.Unregister(123) + } + cc.Unregister(456) + + expectPanic(t, func() { cc.Unregister(123) }) + expectPanic(t, func() { cc.Unregister(456) }) + + n = cc.Register(123) + if n != 1 { + t.Fatalf("Unexpected counter value=%d. Expected 1", n) + } + cc.Unregister(123) +} + +func expectPanic(t *testing.T, f func()) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("Expecting panic") + } + }() + f() +} diff --git a/server.go b/server.go index b2b146b..1933cca 100644 --- a/server.go +++ b/server.go @@ -55,13 +55,20 @@ type Server struct { // By default response write timeout is unlimited. WriteTimeout time.Duration + // Maximum number of concurrent client connections allowed per IP. + // + // By default unlimited number of concurrent connections + // may be established to the server from a single IP address. + MaxConnsPerIP int + // Logger, which is used by ServerCtx.Logger(). // // By default standard logger from log package is used. Logger Logger - serverName atomic.Value - ctxPool sync.Pool + perIPConnCounter perIPConnCounter + serverName atomic.Value + ctxPool sync.Pool } // TimeoutHandler creates RequestHandler, which returns StatusRequestTimeout @@ -182,8 +189,6 @@ func (ctx *RequestCtx) RemoteAddr() net.Addr { } // RemoteIP returns client ip for the given request. -// -// Nil is returned if client ip cannot be determined. func (ctx *RequestCtx) RemoteIP() net.IP { x, ok := ctx.RemoteAddr().(*net.TCPAddr) if !ok { @@ -261,8 +266,9 @@ func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error { stopCh := make(chan struct{}) go connWorkersMonitor(s, ch, concurrency, stopCh) var lastOverflowErrorTime time.Time + var lastPerIPErrorTime time.Time for { - c, err := acceptConn(s, ln) + c, err := acceptConn(s, ln, &lastPerIPErrorTime) if err != nil { close(stopCh) return err @@ -271,7 +277,7 @@ func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error { case ch <- c: default: c.Close() - if time.Since(lastOverflowErrorTime) > time.Second*10 { + if time.Since(lastOverflowErrorTime) > time.Minute { s.logger().Printf("The incoming connection cannot be served, because all %d workers are busy. "+ "Try increasing concurrency in Server.ServeWorkers()", concurrency) lastOverflowErrorTime = time.Now() @@ -342,7 +348,8 @@ func connWorker(s *Server, ch <-chan net.Conn) { } } -func acceptConn(s *Server, ln net.Listener) (net.Conn, error) { +func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) { + maxConnsPerIP := s.MaxConnsPerIP for { c, err := ln.Accept() if err != nil { @@ -356,6 +363,28 @@ func acceptConn(s *Server, ln net.Listener) (net.Conn, error) { } return nil, err } + if maxConnsPerIP > 0 { + ip := getUint32IP(c) + if ip == 0 { + return c, nil + } + n := s.perIPConnCounter.Register(ip) + if n > maxConnsPerIP { + if time.Since(*lastPerIPErrorTime) > time.Minute { + s.logger().Printf("Too many connections from ip %s: %d. MaxConnsPerIP=%d", + uint322ip(ip), n, maxConnsPerIP) + *lastPerIPErrorTime = time.Now() + } + c.Close() + s.perIPConnCounter.Unregister(ip) + continue + } + return &perIPConn{ + Conn: c, + ip: ip, + perIPConnCounter: &s.perIPConnCounter, + }, nil + } return c, nil } } diff --git a/server_timing_test.go b/server_timing_test.go index 634b4ea..8d68890 100644 --- a/server_timing_test.go +++ b/server_timing_test.go @@ -111,6 +111,22 @@ func BenchmarkNetHTTPServerGet10KReqPerConn1KClients(b *testing.B) { benchmarkNetHTTPServerGet(b, 1000, 10000) } +func BenchmarkServerMaxConnsPerIP(b *testing.B) { + clientsCount := 1000 + requestsPerConn := 10 + ch := make(chan struct{}, b.N) + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Success("foobar", []byte("123")) + registerServedRequest(b, ch) + }, + MaxConnsPerIP: clientsCount * 2, + } + req := "GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n" + benchmarkServer(b, &testServer{s, clientsCount}, clientsCount, requestsPerConn, req) + verifyRequestsServed(b, ch) +} + func BenchmarkServerTimeoutError(b *testing.B) { clientsCount := 1 requestsPerConn := 10 @@ -165,7 +181,10 @@ func (c *fakeServerConn) Write(b []byte) (int, error) { return len(b), nil } -var fakeAddr net.TCPAddr +var fakeAddr = net.TCPAddr{ + IP: []byte{1, 2, 3, 4}, + Port: 12345, +} func (c *fakeServerConn) RemoteAddr() net.Addr { return &fakeAddr