diff --git a/peripconn.go b/peripconn.go index 123c55e..46bddbf 100644 --- a/peripconn.go +++ b/peripconn.go @@ -1,14 +1,16 @@ package fasthttp import ( + "crypto/tls" "net" "sync" ) type perIPConnCounter struct { - pool sync.Pool - lock sync.Mutex - m map[uint32]int + perIPConnPool sync.Pool + perIPTLSConnPool sync.Pool + lock sync.Mutex + m map[uint32]int } func (cc *perIPConnCounter) Register(ip uint32) int { @@ -43,8 +45,30 @@ type perIPConn struct { perIPConnCounter *perIPConnCounter } -func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perIPConn { - v := counter.pool.Get() +type perIPTLSConn struct { + *tls.Conn + + ip uint32 + perIPConnCounter *perIPConnCounter +} + +func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) net.Conn { + if tlcConn, ok := conn.(*tls.Conn); ok { + v := counter.perIPTLSConnPool.Get() + if v == nil { + return &perIPTLSConn{ + perIPConnCounter: counter, + Conn: tlcConn, + ip: ip, + } + } + c := v.(*perIPConn) + c.Conn = conn + c.ip = ip + return c + } + + v := counter.perIPConnPool.Get() if v == nil { return &perIPConn{ perIPConnCounter: counter, @@ -58,15 +82,19 @@ func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perI return c } -func releasePerIPConn(c *perIPConn) { - c.Conn = nil - c.perIPConnCounter.pool.Put(c) -} - func (c *perIPConn) Close() error { err := c.Conn.Close() c.perIPConnCounter.Unregister(c.ip) - releasePerIPConn(c) + c.Conn = nil + c.perIPConnCounter.perIPConnPool.Put(c) + return err +} + +func (c *perIPTLSConn) Close() error { + err := c.Conn.Close() + c.perIPConnCounter.Unregister(c.ip) + c.Conn = nil + c.perIPConnCounter.perIPTLSConnPool.Put(c) return err } diff --git a/peripconn_test.go b/peripconn_test.go index 5571654..6bfccf1 100644 --- a/peripconn_test.go +++ b/peripconn_test.go @@ -4,6 +4,8 @@ import ( "testing" ) +var _ connTLSer = &perIPTLSConn{} + func TestIPxUint32(t *testing.T) { t.Parallel()