bug: InMemoryListener deadlock under high concurrency (#2245) (#2277)

This commit is contained in:
Erik Dubbelboer
2026-06-06 17:28:07 +08:00
committed by GitHub
parent da69ed9d3f
commit 4ef0547d4c
2 changed files with 165 additions and 16 deletions
+75 -16
View File
@@ -16,6 +16,7 @@ var ErrInmemoryListenerClosed = errors.New("InmemoryListener is already closed:
type InmemoryListener struct {
listenerAddr net.Addr
conns chan acceptConn
done chan struct{}
addrLock sync.RWMutex
lock sync.Mutex
closed bool
@@ -30,6 +31,7 @@ type acceptConn struct {
func NewInmemoryListener() *InmemoryListener {
return &InmemoryListener{
conns: make(chan acceptConn, 1024),
done: make(chan struct{}),
}
}
@@ -47,12 +49,25 @@ func (ln *InmemoryListener) SetLocalAddr(localAddr net.Addr) {
//
// Accept returns new connection per each Dial call.
func (ln *InmemoryListener) Accept() (net.Conn, error) {
c, ok := <-ln.conns
if !ok {
select {
case <-ln.done:
return nil, ErrInmemoryListenerClosed
default:
}
select {
case c := <-ln.conns:
select {
case <-ln.done:
_ = c.conn.Close()
return nil, ErrInmemoryListenerClosed
default:
}
close(c.accepted)
return c.conn, nil
case <-ln.done:
return nil, ErrInmemoryListenerClosed
}
close(c.accepted)
return c.conn, nil
}
// Close implements net.Listener's Close.
@@ -61,15 +76,29 @@ func (ln *InmemoryListener) Close() error {
ln.lock.Lock()
if !ln.closed {
close(ln.conns)
close(ln.done)
ln.closed = true
} else {
err = ErrInmemoryListenerClosed
}
ln.lock.Unlock()
if err == nil {
ln.closePendingConns()
}
return err
}
func (ln *InmemoryListener) closePendingConns() {
for {
select {
case c := <-ln.conns:
_ = c.conn.Close()
default:
return
}
}
}
type inmemoryAddr int
func (inmemoryAddr) Network() string {
@@ -115,20 +144,50 @@ func (ln *InmemoryListener) DialWithLocalAddr(local net.Addr) (net.Conn, error)
cConn := pc.Conn1()
sConn := pc.Conn2()
ln.lock.Lock()
accepted := make(chan struct{})
if !ln.closed {
ln.conns <- acceptConn{conn: sConn, accepted: accepted}
// Wait until the connection has been accepted.
<-accepted
} else {
if ln.closed {
ln.lock.Unlock()
_ = sConn.Close()
_ = cConn.Close()
cConn = nil
}
ln.lock.Unlock()
if cConn == nil {
return nil, ErrInmemoryListenerClosed
}
done := ln.done
ln.lock.Unlock()
accepted := make(chan struct{})
select {
case <-done:
_ = sConn.Close()
_ = cConn.Close()
return nil, ErrInmemoryListenerClosed
default:
}
select {
case ln.conns <- acceptConn{conn: sConn, accepted: accepted}:
case <-done:
_ = sConn.Close()
_ = cConn.Close()
return nil, ErrInmemoryListenerClosed
}
// Wait until the connection has been accepted.
select {
case <-accepted:
return cConn, nil
default:
}
select {
case <-accepted:
case <-done:
select {
case <-accepted:
return cConn, nil
default:
}
_ = sConn.Close()
_ = cConn.Close()
return nil, ErrInmemoryListenerClosed
}
return cConn, nil
}
+90
View File
@@ -3,6 +3,7 @@ package fasthttputil
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
@@ -96,6 +97,95 @@ func TestInmemoryListener(t *testing.T) {
}
}
func TestInmemoryListenerCloseUnblocksPendingDial(t *testing.T) {
ln := NewInmemoryListener()
dialCh := make(chan error, 1)
go func() {
conn, err := ln.Dial()
if conn != nil {
conn.Close()
}
dialCh <- err
}()
waitForPendingInmemoryDial(t, ln)
closeCh := make(chan error, 1)
go func() {
closeCh <- ln.Close()
}()
select {
case err := <-closeCh:
if err != nil {
t.Fatalf("unexpected close error: %v", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout waiting for Close")
}
select {
case err := <-dialCh:
if !errors.Is(err, ErrInmemoryListenerClosed) {
t.Fatalf("unexpected dial error: %v. Expecting %v", err, ErrInmemoryListenerClosed)
}
case <-time.After(time.Second):
t.Fatalf("timeout waiting for Dial")
}
}
func TestInmemoryListenerCloseDropsQueuedDial(t *testing.T) {
ln := NewInmemoryListener()
dialCh := make(chan error, 1)
go func() {
conn, err := ln.Dial()
if conn != nil {
conn.Close()
}
dialCh <- err
}()
waitForPendingInmemoryDial(t, ln)
if err := ln.Close(); err != nil {
t.Fatalf("unexpected close error: %v", err)
}
if queued := len(ln.conns); queued != 0 {
t.Fatalf("unexpected queued conns after Close: %d. Expecting 0", queued)
}
conn, err := ln.Accept()
if conn != nil {
conn.Close()
}
if !errors.Is(err, ErrInmemoryListenerClosed) {
t.Fatalf("unexpected accept error: %v. Expecting %v", err, ErrInmemoryListenerClosed)
}
select {
case err := <-dialCh:
if !errors.Is(err, ErrInmemoryListenerClosed) {
t.Fatalf("unexpected dial error: %v. Expecting %v", err, ErrInmemoryListenerClosed)
}
case <-time.After(time.Second):
t.Fatalf("timeout waiting for Dial")
}
}
func waitForPendingInmemoryDial(t *testing.T, ln *InmemoryListener) {
t.Helper()
for range 100 {
if len(ln.conns) > 0 {
return
}
time.Sleep(time.Millisecond)
}
t.Fatalf("timeout waiting for pending dial")
}
// echoServerHandler implements http.Handler.
type echoServerHandler struct {
t *testing.T