mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-17 16:26:47 +03:00
@@ -16,6 +16,7 @@ var ErrInmemoryListenerClosed = errors.New("InmemoryListener is already closed:
|
||||
type InmemoryListener struct {
|
||||
listenerAddr net.Addr
|
||||
conns chan acceptConn
|
||||
done chan struct{}
|
||||
addrLock sync.RWMutex
|
||||
lock sync.Mutex
|
||||
closed bool
|
||||
@@ -30,6 +31,7 @@ type acceptConn struct {
|
||||
func NewInmemoryListener() *InmemoryListener {
|
||||
return &InmemoryListener{
|
||||
conns: make(chan acceptConn, 1024),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,12 +49,25 @@ func (ln *InmemoryListener) SetLocalAddr(localAddr net.Addr) {
|
||||
//
|
||||
// Accept returns new connection per each Dial call.
|
||||
func (ln *InmemoryListener) Accept() (net.Conn, error) {
|
||||
c, ok := <-ln.conns
|
||||
if !ok {
|
||||
select {
|
||||
case <-ln.done:
|
||||
return nil, ErrInmemoryListenerClosed
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case c := <-ln.conns:
|
||||
select {
|
||||
case <-ln.done:
|
||||
_ = c.conn.Close()
|
||||
return nil, ErrInmemoryListenerClosed
|
||||
default:
|
||||
}
|
||||
close(c.accepted)
|
||||
return c.conn, nil
|
||||
case <-ln.done:
|
||||
return nil, ErrInmemoryListenerClosed
|
||||
}
|
||||
close(c.accepted)
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
// Close implements net.Listener's Close.
|
||||
@@ -61,15 +76,29 @@ func (ln *InmemoryListener) Close() error {
|
||||
|
||||
ln.lock.Lock()
|
||||
if !ln.closed {
|
||||
close(ln.conns)
|
||||
close(ln.done)
|
||||
ln.closed = true
|
||||
} else {
|
||||
err = ErrInmemoryListenerClosed
|
||||
}
|
||||
ln.lock.Unlock()
|
||||
if err == nil {
|
||||
ln.closePendingConns()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ln *InmemoryListener) closePendingConns() {
|
||||
for {
|
||||
select {
|
||||
case c := <-ln.conns:
|
||||
_ = c.conn.Close()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type inmemoryAddr int
|
||||
|
||||
func (inmemoryAddr) Network() string {
|
||||
@@ -115,20 +144,50 @@ func (ln *InmemoryListener) DialWithLocalAddr(local net.Addr) (net.Conn, error)
|
||||
cConn := pc.Conn1()
|
||||
sConn := pc.Conn2()
|
||||
ln.lock.Lock()
|
||||
accepted := make(chan struct{})
|
||||
if !ln.closed {
|
||||
ln.conns <- acceptConn{conn: sConn, accepted: accepted}
|
||||
// Wait until the connection has been accepted.
|
||||
<-accepted
|
||||
} else {
|
||||
if ln.closed {
|
||||
ln.lock.Unlock()
|
||||
_ = sConn.Close()
|
||||
_ = cConn.Close()
|
||||
cConn = nil
|
||||
}
|
||||
ln.lock.Unlock()
|
||||
|
||||
if cConn == nil {
|
||||
return nil, ErrInmemoryListenerClosed
|
||||
}
|
||||
done := ln.done
|
||||
ln.lock.Unlock()
|
||||
|
||||
accepted := make(chan struct{})
|
||||
select {
|
||||
case <-done:
|
||||
_ = sConn.Close()
|
||||
_ = cConn.Close()
|
||||
return nil, ErrInmemoryListenerClosed
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case ln.conns <- acceptConn{conn: sConn, accepted: accepted}:
|
||||
case <-done:
|
||||
_ = sConn.Close()
|
||||
_ = cConn.Close()
|
||||
return nil, ErrInmemoryListenerClosed
|
||||
}
|
||||
|
||||
// Wait until the connection has been accepted.
|
||||
select {
|
||||
case <-accepted:
|
||||
return cConn, nil
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-accepted:
|
||||
case <-done:
|
||||
select {
|
||||
case <-accepted:
|
||||
return cConn, nil
|
||||
default:
|
||||
}
|
||||
_ = sConn.Close()
|
||||
_ = cConn.Close()
|
||||
return nil, ErrInmemoryListenerClosed
|
||||
}
|
||||
|
||||
return cConn, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package fasthttputil
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -96,6 +97,95 @@ func TestInmemoryListener(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInmemoryListenerCloseUnblocksPendingDial(t *testing.T) {
|
||||
ln := NewInmemoryListener()
|
||||
|
||||
dialCh := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := ln.Dial()
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
dialCh <- err
|
||||
}()
|
||||
|
||||
waitForPendingInmemoryDial(t, ln)
|
||||
|
||||
closeCh := make(chan error, 1)
|
||||
go func() {
|
||||
closeCh <- ln.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-closeCh:
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected close error: %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout waiting for Close")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-dialCh:
|
||||
if !errors.Is(err, ErrInmemoryListenerClosed) {
|
||||
t.Fatalf("unexpected dial error: %v. Expecting %v", err, ErrInmemoryListenerClosed)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout waiting for Dial")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInmemoryListenerCloseDropsQueuedDial(t *testing.T) {
|
||||
ln := NewInmemoryListener()
|
||||
|
||||
dialCh := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := ln.Dial()
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
dialCh <- err
|
||||
}()
|
||||
|
||||
waitForPendingInmemoryDial(t, ln)
|
||||
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Fatalf("unexpected close error: %v", err)
|
||||
}
|
||||
if queued := len(ln.conns); queued != 0 {
|
||||
t.Fatalf("unexpected queued conns after Close: %d. Expecting 0", queued)
|
||||
}
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
if !errors.Is(err, ErrInmemoryListenerClosed) {
|
||||
t.Fatalf("unexpected accept error: %v. Expecting %v", err, ErrInmemoryListenerClosed)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-dialCh:
|
||||
if !errors.Is(err, ErrInmemoryListenerClosed) {
|
||||
t.Fatalf("unexpected dial error: %v. Expecting %v", err, ErrInmemoryListenerClosed)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout waiting for Dial")
|
||||
}
|
||||
}
|
||||
|
||||
func waitForPendingInmemoryDial(t *testing.T, ln *InmemoryListener) {
|
||||
t.Helper()
|
||||
|
||||
for range 100 {
|
||||
if len(ln.conns) > 0 {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
t.Fatalf("timeout waiting for pending dial")
|
||||
}
|
||||
|
||||
// echoServerHandler implements http.Handler.
|
||||
type echoServerHandler struct {
|
||||
t *testing.T
|
||||
|
||||
Reference in New Issue
Block a user