bug: data race on pipeline client c.chR during worker drain (#2220) (#2272)

This commit is contained in:
Erik Dubbelboer
2026-06-06 17:24:52 +08:00
committed by GitHub
parent 289229aad3
commit eee784158a
2 changed files with 180 additions and 51 deletions
+71 -51
View File
@@ -2511,10 +2511,9 @@ type pipelineConnClient struct {
Dial DialFunc
TLSConfig *tls.Config
chW chan *pipelineWork
chR chan *pipelineWork
tlsConfig *tls.Config
chs *pipelineConnChannels
Addr string
Name string
@@ -2536,6 +2535,12 @@ type pipelineConnClient struct {
IsTLS bool
}
type pipelineConnChannels struct {
chW chan *pipelineWork
chR chan *pipelineWork
users int
}
type pipelineWork struct {
respCopy Response
deadline time.Time
@@ -2586,13 +2591,14 @@ func (c *PipelineClient) DoDeadline(req *Request, resp *Response, deadline time.
}
func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
c.init()
timeout := time.Until(deadline)
if timeout <= 0 {
return ErrTimeout
}
chs := c.acquirePipelineConnChannels()
defer c.releasePipelineConnChannels(chs)
if c.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
}
@@ -2619,12 +2625,12 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t
// Put the request to outgoing queue
select {
case c.chW <- w:
case chs.chW <- w:
// Fast path: len(c.ch) < cap(c.ch)
default:
// Slow path
select {
case c.chW <- w:
case chs.chW <- w:
case <-w.t.C:
c.releasePipelineWork(w)
return ErrTimeout
@@ -2698,7 +2704,8 @@ func (c *PipelineClient) Do(req *Request, resp *Response) error {
}
func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
c.init()
chs := c.acquirePipelineConnChannels()
defer c.releasePipelineConnChannels(chs)
if c.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
@@ -2726,17 +2733,17 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
// Put the request to outgoing queue
select {
case c.chW <- w:
case chs.chW <- w:
default:
// Try substituting the oldest w with the current one.
select {
case wOld := <-c.chW:
case wOld := <-chs.chW:
wOld.err = ErrPipelineOverflow
wOld.done <- struct{}{}
default:
}
select {
case c.chW <- w:
case chs.chW <- w:
default:
c.releasePipelineWork(w)
return ErrPipelineOverflow
@@ -2824,46 +2831,58 @@ var ErrPipelineOverflow = errors.New("pipelined requests' queue has been overflo
// for PipelineClient.MaxPendingRequests.
const DefaultMaxPendingRequests = 1024
func (c *pipelineConnClient) init() {
func (c *pipelineConnClient) acquirePipelineConnChannels() *pipelineConnChannels {
c.chLock.Lock()
if c.chR == nil {
chs := c.chs
if chs == nil {
maxPendingRequests := c.MaxPendingRequests
if maxPendingRequests <= 0 {
maxPendingRequests = DefaultMaxPendingRequests
}
c.chR = make(chan *pipelineWork, maxPendingRequests)
if c.chW == nil {
c.chW = make(chan *pipelineWork, maxPendingRequests)
chs = &pipelineConnChannels{
chR: make(chan *pipelineWork, maxPendingRequests),
chW: make(chan *pipelineWork, maxPendingRequests),
}
go func() {
// Keep restarting the worker if it fails (connection errors for example).
for {
if err := c.worker(); err != nil {
c.logger().Printf("error in PipelineClient(%q): %v", c.Addr, err)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Throttle client reconnections on timeout errors
time.Sleep(time.Second)
}
} else {
c.chLock.Lock()
stop := len(c.chR) == 0 && len(c.chW) == 0
if !stop {
c.chR = nil
c.chW = nil
}
c.chLock.Unlock()
if stop {
break
}
}
}
}()
c.chs = chs
go c.pipelineWorker(chs)
}
chs.users++
c.chLock.Unlock()
return chs
}
func (c *pipelineConnClient) releasePipelineConnChannels(chs *pipelineConnChannels) {
c.chLock.Lock()
chs.users--
c.chLock.Unlock()
}
func (c *pipelineConnClient) worker() error {
func (c *pipelineConnClient) pipelineWorker(chs *pipelineConnChannels) {
// Keep restarting the worker if it fails (connection errors for example).
for {
if err := c.worker(chs); err != nil {
c.logger().Printf("error in PipelineClient(%q): %v", c.Addr, err)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Throttle client reconnections on timeout errors
time.Sleep(time.Second)
}
} else if c.tryRetirePipelineConnChannels(chs) {
return
}
}
}
func (c *pipelineConnClient) tryRetirePipelineConnChannels(chs *pipelineConnChannels) bool {
c.chLock.Lock()
stop := c.chs == chs && chs.users == 0 && len(chs.chR) == 0 && len(chs.chW) == 0
if stop {
c.chs = nil
}
c.chLock.Unlock()
return stop
}
func (c *pipelineConnClient) worker(chs *pipelineConnChannels) error {
tlsConfig := c.cachedTLSConfig()
conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout)
if err != nil {
@@ -2874,12 +2893,12 @@ func (c *pipelineConnClient) worker() error {
stopW := make(chan struct{})
doneW := make(chan error)
go func() {
doneW <- c.writer(conn, stopW)
doneW <- c.writer(conn, stopW, chs)
}()
stopR := make(chan struct{})
doneR := make(chan error)
go func() {
doneR <- c.reader(conn, stopR)
doneR <- c.reader(conn, stopR, chs)
}()
// Wait until reader and writer are stopped
@@ -2895,8 +2914,8 @@ func (c *pipelineConnClient) worker() error {
}
// Notify pending readers
for len(c.chR) > 0 {
w := <-c.chR
for len(chs.chR) > 0 {
w := <-chs.chR
w.err = errPipelineConnStopped
w.done <- struct{}{}
}
@@ -2920,15 +2939,15 @@ func (c *pipelineConnClient) cachedTLSConfig() *tls.Config {
return cfg
}
func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}) error {
func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}, chs *pipelineConnChannels) error {
writeBufferSize := c.WriteBufferSize
if writeBufferSize <= 0 {
writeBufferSize = defaultWriteBufferSize
}
bw := bufio.NewWriterSize(conn, writeBufferSize)
defer bw.Flush()
chR := c.chR
chW := c.chW
chR := chs.chR
chW := chs.chW
writeTimeout := c.WriteTimeout
maxIdleConnDuration := c.MaxIdleConnDuration
@@ -3027,13 +3046,13 @@ func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}) error
}
}
func (c *pipelineConnClient) reader(conn net.Conn, stopCh <-chan struct{}) error {
func (c *pipelineConnClient) reader(conn net.Conn, stopCh <-chan struct{}, chs *pipelineConnChannels) error {
readBufferSize := c.ReadBufferSize
if readBufferSize <= 0 {
readBufferSize = defaultReadBufferSize
}
br := bufio.NewReaderSize(conn, readBufferSize)
chR := c.chR
chR := chs.chR
readTimeout := c.ReadTimeout
var (
@@ -3100,10 +3119,11 @@ func (c *PipelineClient) PendingRequests() int {
}
func (c *pipelineConnClient) PendingRequests() int {
c.init()
chs := c.acquirePipelineConnChannels()
defer c.releasePipelineConnChannels(chs)
c.chLock.Lock()
n := len(c.chR) + len(c.chW)
n := len(chs.chR) + len(chs.chW)
c.chLock.Unlock()
return n
}
+109
View File
@@ -295,6 +295,115 @@ func TestPipelineClientIssue832(t *testing.T) {
}
}
func TestPipelineClientRestartsAfterIdle(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxIdleConnDuration: 10 * time.Millisecond,
MaxPendingRequests: 1,
Logger: &testLogger{},
}
testPipelineClientDoOnce(t, c)
time.Sleep(50 * time.Millisecond)
testPipelineClientDoOnce(t, c)
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestPipelineClientChannelLifecycleRace(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxIdleConnDuration: time.Millisecond,
MaxPendingRequests: 2,
Logger: &testLogger{},
}
var wg sync.WaitGroup
for range 8 {
wg.Go(func() {
for range 20 {
testPipelineClientDoOnce(t, c)
time.Sleep(time.Millisecond)
}
})
}
wg.Wait()
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func testPipelineClientDoOnce(t *testing.T, c *PipelineClient) {
t.Helper()
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
resp := AcquireResponse()
defer ReleaseRequest(req)
defer ReleaseResponse(resp)
if err := c.DoTimeout(req, resp, time.Second); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if body := string(resp.Body()); body != "OK" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "OK")
}
}
func TestClientInvalidURI(t *testing.T) {
t.Parallel()