Files
fasthttp/peripconn.go
T
2026-03-20 07:27:24 +01:00

144 lines
2.4 KiB
Go

package fasthttp
import (
"crypto/tls"
"encoding/binary"
"net"
"sync"
)
type perIPConnCounter struct {
perIPConnPool sync.Pool
perIPTLSConnPool sync.Pool
m map[uint32]int
lock sync.Mutex
}
func (cc *perIPConnCounter) Register(ip uint32) int {
cc.lock.Lock()
if cc.m == nil {
cc.m = make(map[uint32]int)
}
n := cc.m[ip] + 1
cc.m[ip] = n
cc.lock.Unlock()
return n
}
func (cc *perIPConnCounter) Unregister(ip uint32) {
cc.lock.Lock()
defer cc.lock.Unlock()
if cc.m == nil {
// developer safeguard
panic("BUG: perIPConnCounter.Register() wasn't called")
}
n := max(cc.m[ip]-1, 0)
cc.m[ip] = n
}
type perIPConn struct {
net.Conn
perIPConnCounter *perIPConnCounter
ip uint32
lock sync.Mutex
}
type perIPTLSConn struct {
*tls.Conn
perIPConnCounter *perIPConnCounter
ip uint32
lock sync.Mutex
}
func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) net.Conn {
if tlsConn, ok := conn.(*tls.Conn); ok {
v := counter.perIPTLSConnPool.Get()
if v == nil {
return &perIPTLSConn{
perIPConnCounter: counter,
Conn: tlsConn,
ip: ip,
}
}
c := v.(*perIPTLSConn)
c.Conn = tlsConn
c.ip = ip
return c
}
v := counter.perIPConnPool.Get()
if v == nil {
return &perIPConn{
perIPConnCounter: counter,
Conn: conn,
ip: ip,
}
}
c := v.(*perIPConn)
c.Conn = conn
c.ip = ip
return c
}
func (c *perIPConn) Close() error {
c.lock.Lock()
cc := c.Conn
c.Conn = nil
c.lock.Unlock()
if cc == nil {
return nil
}
err := cc.Close()
c.perIPConnCounter.Unregister(c.ip)
c.perIPConnCounter.perIPConnPool.Put(c)
return err
}
func (c *perIPTLSConn) Close() error {
c.lock.Lock()
cc := c.Conn
c.Conn = nil
c.lock.Unlock()
if cc == nil {
return nil
}
err := cc.Close()
c.perIPConnCounter.Unregister(c.ip)
c.perIPConnCounter.perIPTLSConnPool.Put(c)
return err
}
func getUint32IP(c net.Conn) uint32 {
return ip2uint32(getConnIP4(c))
}
func getConnIP4(c net.Conn) net.IP {
addr := c.RemoteAddr()
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return net.IPv4zero
}
return ipAddr.IP.To4()
}
func ip2uint32(ip net.IP) uint32 {
if len(ip) != 4 {
return 0
}
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}
func uint322ip(ip uint32) net.IP {
b := make(net.IP, net.IPv4len)
binary.BigEndian.PutUint32(b, ip)
return b
}