From 4ef0547d4cf16f6d679dcc82023f4cfd68db6c48 Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sat, 6 Jun 2026 17:28:07 +0800 Subject: [PATCH] bug: InMemoryListener deadlock under high concurrency (#2245) (#2277) --- fasthttputil/inmemory_listener.go | 91 +++++++++++++++++++++----- fasthttputil/inmemory_listener_test.go | 90 +++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 16 deletions(-) diff --git a/fasthttputil/inmemory_listener.go b/fasthttputil/inmemory_listener.go index 2df4664..3d345bd 100644 --- a/fasthttputil/inmemory_listener.go +++ b/fasthttputil/inmemory_listener.go @@ -16,6 +16,7 @@ var ErrInmemoryListenerClosed = errors.New("InmemoryListener is already closed: type InmemoryListener struct { listenerAddr net.Addr conns chan acceptConn + done chan struct{} addrLock sync.RWMutex lock sync.Mutex closed bool @@ -30,6 +31,7 @@ type acceptConn struct { func NewInmemoryListener() *InmemoryListener { return &InmemoryListener{ conns: make(chan acceptConn, 1024), + done: make(chan struct{}), } } @@ -47,12 +49,25 @@ func (ln *InmemoryListener) SetLocalAddr(localAddr net.Addr) { // // Accept returns new connection per each Dial call. func (ln *InmemoryListener) Accept() (net.Conn, error) { - c, ok := <-ln.conns - if !ok { + select { + case <-ln.done: + return nil, ErrInmemoryListenerClosed + default: + } + + select { + case c := <-ln.conns: + select { + case <-ln.done: + _ = c.conn.Close() + return nil, ErrInmemoryListenerClosed + default: + } + close(c.accepted) + return c.conn, nil + case <-ln.done: return nil, ErrInmemoryListenerClosed } - close(c.accepted) - return c.conn, nil } // Close implements net.Listener's Close. @@ -61,15 +76,29 @@ func (ln *InmemoryListener) Close() error { ln.lock.Lock() if !ln.closed { - close(ln.conns) + close(ln.done) ln.closed = true } else { err = ErrInmemoryListenerClosed } ln.lock.Unlock() + if err == nil { + ln.closePendingConns() + } return err } +func (ln *InmemoryListener) closePendingConns() { + for { + select { + case c := <-ln.conns: + _ = c.conn.Close() + default: + return + } + } +} + type inmemoryAddr int func (inmemoryAddr) Network() string { @@ -115,20 +144,50 @@ func (ln *InmemoryListener) DialWithLocalAddr(local net.Addr) (net.Conn, error) cConn := pc.Conn1() sConn := pc.Conn2() ln.lock.Lock() - accepted := make(chan struct{}) - if !ln.closed { - ln.conns <- acceptConn{conn: sConn, accepted: accepted} - // Wait until the connection has been accepted. - <-accepted - } else { + if ln.closed { + ln.lock.Unlock() _ = sConn.Close() _ = cConn.Close() - cConn = nil - } - ln.lock.Unlock() - - if cConn == nil { return nil, ErrInmemoryListenerClosed } + done := ln.done + ln.lock.Unlock() + + accepted := make(chan struct{}) + select { + case <-done: + _ = sConn.Close() + _ = cConn.Close() + return nil, ErrInmemoryListenerClosed + default: + } + + select { + case ln.conns <- acceptConn{conn: sConn, accepted: accepted}: + case <-done: + _ = sConn.Close() + _ = cConn.Close() + return nil, ErrInmemoryListenerClosed + } + + // Wait until the connection has been accepted. + select { + case <-accepted: + return cConn, nil + default: + } + select { + case <-accepted: + case <-done: + select { + case <-accepted: + return cConn, nil + default: + } + _ = sConn.Close() + _ = cConn.Close() + return nil, ErrInmemoryListenerClosed + } + return cConn, nil } diff --git a/fasthttputil/inmemory_listener_test.go b/fasthttputil/inmemory_listener_test.go index 6e4b1f1..943b329 100644 --- a/fasthttputil/inmemory_listener_test.go +++ b/fasthttputil/inmemory_listener_test.go @@ -3,6 +3,7 @@ package fasthttputil import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -96,6 +97,95 @@ func TestInmemoryListener(t *testing.T) { } } +func TestInmemoryListenerCloseUnblocksPendingDial(t *testing.T) { + ln := NewInmemoryListener() + + dialCh := make(chan error, 1) + go func() { + conn, err := ln.Dial() + if conn != nil { + conn.Close() + } + dialCh <- err + }() + + waitForPendingInmemoryDial(t, ln) + + closeCh := make(chan error, 1) + go func() { + closeCh <- ln.Close() + }() + + select { + case err := <-closeCh: + if err != nil { + t.Fatalf("unexpected close error: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("timeout waiting for Close") + } + + select { + case err := <-dialCh: + if !errors.Is(err, ErrInmemoryListenerClosed) { + t.Fatalf("unexpected dial error: %v. Expecting %v", err, ErrInmemoryListenerClosed) + } + case <-time.After(time.Second): + t.Fatalf("timeout waiting for Dial") + } +} + +func TestInmemoryListenerCloseDropsQueuedDial(t *testing.T) { + ln := NewInmemoryListener() + + dialCh := make(chan error, 1) + go func() { + conn, err := ln.Dial() + if conn != nil { + conn.Close() + } + dialCh <- err + }() + + waitForPendingInmemoryDial(t, ln) + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected close error: %v", err) + } + if queued := len(ln.conns); queued != 0 { + t.Fatalf("unexpected queued conns after Close: %d. Expecting 0", queued) + } + + conn, err := ln.Accept() + if conn != nil { + conn.Close() + } + if !errors.Is(err, ErrInmemoryListenerClosed) { + t.Fatalf("unexpected accept error: %v. Expecting %v", err, ErrInmemoryListenerClosed) + } + + select { + case err := <-dialCh: + if !errors.Is(err, ErrInmemoryListenerClosed) { + t.Fatalf("unexpected dial error: %v. Expecting %v", err, ErrInmemoryListenerClosed) + } + case <-time.After(time.Second): + t.Fatalf("timeout waiting for Dial") + } +} + +func waitForPendingInmemoryDial(t *testing.T, ln *InmemoryListener) { + t.Helper() + + for range 100 { + if len(ln.conns) > 0 { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("timeout waiting for pending dial") +} + // echoServerHandler implements http.Handler. type echoServerHandler struct { t *testing.T