Enforce MaxConnsPerIP limit to connections served via Server.ServeConn()

This commit is contained in:
Aliaksandr Valialkin
2015-11-05 10:56:04 +02:00
parent 4a823fa707
commit a7fdc68be0
3 changed files with 74 additions and 31 deletions
+6 -2
View File
@@ -51,12 +51,16 @@ func (c *perIPConn) Close() error {
}
func getUint32IP(c net.Conn) uint32 {
return ip2uint32(getConnIP4(c))
}
func getConnIP4(c net.Conn) net.IP {
addr := c.RemoteAddr()
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return 0
return net.IPv4zero
}
return ip2uint32(ipAddr.IP.To4())
return ipAddr.IP.To4()
}
func ip2uint32(ip net.IP) uint32 {
+59 -28
View File
@@ -2,6 +2,7 @@ package fasthttp
import (
"bufio"
"errors"
"fmt"
"io"
"log"
@@ -131,7 +132,7 @@ type RequestCtx struct {
logger ctxLogger
s *Server
c io.ReadWriter
c io.ReadWriteCloser
fbr firstByteReader
// shadow is set by TimeoutError().
@@ -144,7 +145,7 @@ type RequestCtx struct {
}
type firstByteReader struct {
c io.ReadWriter
c io.ReadWriteCloser
ch byte
byteRead bool
}
@@ -285,7 +286,7 @@ func (s *Server) Serve(ln net.Listener) error {
// ServeConcurrency blocks until the given listener returns permanent error.
// This error is returned from ServeConcurrency.
func (s *Server) ServeConcurrency(ln net.Listener, concurrency int) error {
ch := make(chan net.Conn, 16*concurrency)
ch := make(chan net.Conn, 2*concurrency)
stopCh := make(chan struct{})
go connWorkersMonitor(s, ch, concurrency, stopCh)
var lastOverflowErrorTime time.Time
@@ -338,16 +339,13 @@ func connWorkersMonitor(s *Server, ch <-chan net.Conn, maxWorkers int, stopCh <-
}
func connWorker(s *Server, ch <-chan net.Conn) {
var c net.Conn
defer func() {
if r := recover(); r != nil {
s.logger().Printf("panic: %s\nStack trace:\n%s", r, debug.Stack())
}
if c != nil {
c.Close()
}
}()
var c net.Conn
var tc *time.Timer
for {
select {
@@ -362,14 +360,12 @@ func connWorker(s *Server, ch <-chan net.Conn) {
return
}
}
s.ServeConn(c)
c.Close()
s.serveConn(c)
c = nil
}
}
func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
maxConnsPerIP := s.MaxConnsPerIP
for {
c, err := ln.Accept()
if err != nil {
@@ -383,32 +379,40 @@ func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.
}
return nil, err
}
if maxConnsPerIP > 0 {
ip := getUint32IP(c)
if ip == 0 {
return c, nil
}
n := s.perIPConnCounter.Register(ip)
if n > maxConnsPerIP {
if s.MaxConnsPerIP > 0 {
pic := wrapPerIPConn(s, c)
if pic == nil {
c.Close()
if time.Since(*lastPerIPErrorTime) > time.Minute {
s.logger().Printf("Too many connections from ip %s: %d. MaxConnsPerIP=%d",
uint322ip(ip), n, maxConnsPerIP)
s.logger().Printf("The number of connections from %s exceeds MaxConnsPerIP=%d",
getConnIP4(c), s.MaxConnsPerIP)
*lastPerIPErrorTime = time.Now()
}
c.Close()
s.perIPConnCounter.Unregister(ip)
continue
}
return &perIPConn{
Conn: c,
ip: ip,
perIPConnCounter: &s.perIPConnCounter,
}, nil
return pic, nil
}
return c, nil
}
}
func wrapPerIPConn(s *Server, c net.Conn) net.Conn {
ip := getUint32IP(c)
if ip == 0 {
return c
}
n := s.perIPConnCounter.Register(ip)
if n > s.MaxConnsPerIP {
s.perIPConnCounter.Unregister(ip)
return nil
}
return &perIPConn{
Conn: c,
ip: ip,
perIPConnCounter: &s.perIPConnCounter,
}
}
var defaultLogger = log.New(os.Stderr, "", log.LstdFlags)
func (s *Server) logger() Logger {
@@ -418,11 +422,33 @@ func (s *Server) logger() Logger {
return defaultLogger
}
// ErrPerIPConnLimit may be returned from ServeConn if the number of connections
// per ip exceeds Server.MaxConnsPerIP.
var ErrPerIPConnLimit = errors.New("too many connections per ip")
// ServeConn serves HTTP requests from the given connection.
//
// ServeConn returns nil if all requests from the c are successfully served.
// It returns non-nil error otherwise.
func (s *Server) ServeConn(c io.ReadWriter) error {
//
// Connection c must immediately propagate all the data passed to Write()
// to the client. Otherwise requests' processing may hang.
//
// ServeConn closes c before returning.
func (s *Server) ServeConn(c io.ReadWriteCloser) error {
conn, ok := c.(net.Conn)
if ok {
pic := wrapPerIPConn(s, conn)
if pic == nil {
c.Close()
return ErrPerIPConnLimit
}
c = pic
}
return s.serveConn(c)
}
func (s *Server) serveConn(c io.ReadWriteCloser) error {
var rd readDeadliner
readTimeout := s.ReadTimeout
if readTimeout > 0 {
@@ -553,6 +579,11 @@ func (s *Server) ServeConn(c io.ReadWriter) error {
ctx.Logger().Printf("Error when serving network connection: %s", err)
}
s.releaseCtx(ctx)
err1 := c.Close()
if err == nil {
err = err1
}
return err
}
@@ -671,7 +702,7 @@ func releaseWriter(ctx *RequestCtx, w *bufio.Writer) {
var globalCtxID uint64
func (s *Server) acquireCtx(c io.ReadWriter) *RequestCtx {
func (s *Server) acquireCtx(c io.ReadWriteCloser) *RequestCtx {
v := s.ctxPool.Get()
var ctx *RequestCtx
if v == nil {
+9 -1
View File
@@ -279,10 +279,14 @@ func TestServerRemoteAddr(t *testing.T) {
}
type readWriterRemoteAddr struct {
rw io.ReadWriter
rw io.ReadWriteCloser
addr net.Addr
}
func (rw *readWriterRemoteAddr) Close() error {
return rw.rw.Close()
}
func (rw *readWriterRemoteAddr) Read(b []byte) (int, error) {
return rw.rw.Read(b)
}
@@ -414,6 +418,10 @@ type readWriter struct {
w bytes.Buffer
}
func (rw *readWriter) Close() error {
return nil
}
func (rw *readWriter) Read(b []byte) (int, error) {
return rw.r.Read(b)
}