From eee784158a176a96fd1fa0014b4d9de510fd6d70 Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sat, 6 Jun 2026 17:24:52 +0800 Subject: [PATCH] bug: data race on pipeline client c.chR during worker drain (#2220) (#2272) --- client.go | 122 ++++++++++++++++++++++++++++--------------------- client_test.go | 109 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 51 deletions(-) diff --git a/client.go b/client.go index 9269f51..ca050c8 100644 --- a/client.go +++ b/client.go @@ -2511,10 +2511,9 @@ type pipelineConnClient struct { Dial DialFunc TLSConfig *tls.Config - chW chan *pipelineWork - chR chan *pipelineWork tlsConfig *tls.Config + chs *pipelineConnChannels Addr string Name string @@ -2536,6 +2535,12 @@ type pipelineConnClient struct { IsTLS bool } +type pipelineConnChannels struct { + chW chan *pipelineWork + chR chan *pipelineWork + users int +} + type pipelineWork struct { respCopy Response deadline time.Time @@ -2586,13 +2591,14 @@ func (c *PipelineClient) DoDeadline(req *Request, resp *Response, deadline time. } func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { - c.init() - timeout := time.Until(deadline) if timeout <= 0 { return ErrTimeout } + chs := c.acquirePipelineConnChannels() + defer c.releasePipelineConnChannels(chs) + if c.DisablePathNormalizing { req.URI().DisablePathNormalizing = true } @@ -2619,12 +2625,12 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t // Put the request to outgoing queue select { - case c.chW <- w: + case chs.chW <- w: // Fast path: len(c.ch) < cap(c.ch) default: // Slow path select { - case c.chW <- w: + case chs.chW <- w: case <-w.t.C: c.releasePipelineWork(w) return ErrTimeout @@ -2698,7 +2704,8 @@ func (c *PipelineClient) Do(req *Request, resp *Response) error { } func (c *pipelineConnClient) Do(req *Request, resp *Response) error { - c.init() + chs := c.acquirePipelineConnChannels() + defer c.releasePipelineConnChannels(chs) if c.DisablePathNormalizing { req.URI().DisablePathNormalizing = true @@ -2726,17 +2733,17 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error { // Put the request to outgoing queue select { - case c.chW <- w: + case chs.chW <- w: default: // Try substituting the oldest w with the current one. select { - case wOld := <-c.chW: + case wOld := <-chs.chW: wOld.err = ErrPipelineOverflow wOld.done <- struct{}{} default: } select { - case c.chW <- w: + case chs.chW <- w: default: c.releasePipelineWork(w) return ErrPipelineOverflow @@ -2824,46 +2831,58 @@ var ErrPipelineOverflow = errors.New("pipelined requests' queue has been overflo // for PipelineClient.MaxPendingRequests. const DefaultMaxPendingRequests = 1024 -func (c *pipelineConnClient) init() { +func (c *pipelineConnClient) acquirePipelineConnChannels() *pipelineConnChannels { c.chLock.Lock() - if c.chR == nil { + chs := c.chs + if chs == nil { maxPendingRequests := c.MaxPendingRequests if maxPendingRequests <= 0 { maxPendingRequests = DefaultMaxPendingRequests } - c.chR = make(chan *pipelineWork, maxPendingRequests) - if c.chW == nil { - c.chW = make(chan *pipelineWork, maxPendingRequests) + chs = &pipelineConnChannels{ + chR: make(chan *pipelineWork, maxPendingRequests), + chW: make(chan *pipelineWork, maxPendingRequests), } - go func() { - // Keep restarting the worker if it fails (connection errors for example). - for { - if err := c.worker(); err != nil { - c.logger().Printf("error in PipelineClient(%q): %v", c.Addr, err) - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - // Throttle client reconnections on timeout errors - time.Sleep(time.Second) - } - } else { - c.chLock.Lock() - stop := len(c.chR) == 0 && len(c.chW) == 0 - if !stop { - c.chR = nil - c.chW = nil - } - c.chLock.Unlock() - - if stop { - break - } - } - } - }() + c.chs = chs + go c.pipelineWorker(chs) } + chs.users++ + c.chLock.Unlock() + return chs +} + +func (c *pipelineConnClient) releasePipelineConnChannels(chs *pipelineConnChannels) { + c.chLock.Lock() + chs.users-- c.chLock.Unlock() } -func (c *pipelineConnClient) worker() error { +func (c *pipelineConnClient) pipelineWorker(chs *pipelineConnChannels) { + // Keep restarting the worker if it fails (connection errors for example). + for { + if err := c.worker(chs); err != nil { + c.logger().Printf("error in PipelineClient(%q): %v", c.Addr, err) + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Throttle client reconnections on timeout errors + time.Sleep(time.Second) + } + } else if c.tryRetirePipelineConnChannels(chs) { + return + } + } +} + +func (c *pipelineConnClient) tryRetirePipelineConnChannels(chs *pipelineConnChannels) bool { + c.chLock.Lock() + stop := c.chs == chs && chs.users == 0 && len(chs.chR) == 0 && len(chs.chW) == 0 + if stop { + c.chs = nil + } + c.chLock.Unlock() + return stop +} + +func (c *pipelineConnClient) worker(chs *pipelineConnChannels) error { tlsConfig := c.cachedTLSConfig() conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout) if err != nil { @@ -2874,12 +2893,12 @@ func (c *pipelineConnClient) worker() error { stopW := make(chan struct{}) doneW := make(chan error) go func() { - doneW <- c.writer(conn, stopW) + doneW <- c.writer(conn, stopW, chs) }() stopR := make(chan struct{}) doneR := make(chan error) go func() { - doneR <- c.reader(conn, stopR) + doneR <- c.reader(conn, stopR, chs) }() // Wait until reader and writer are stopped @@ -2895,8 +2914,8 @@ func (c *pipelineConnClient) worker() error { } // Notify pending readers - for len(c.chR) > 0 { - w := <-c.chR + for len(chs.chR) > 0 { + w := <-chs.chR w.err = errPipelineConnStopped w.done <- struct{}{} } @@ -2920,15 +2939,15 @@ func (c *pipelineConnClient) cachedTLSConfig() *tls.Config { return cfg } -func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}) error { +func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}, chs *pipelineConnChannels) error { writeBufferSize := c.WriteBufferSize if writeBufferSize <= 0 { writeBufferSize = defaultWriteBufferSize } bw := bufio.NewWriterSize(conn, writeBufferSize) defer bw.Flush() - chR := c.chR - chW := c.chW + chR := chs.chR + chW := chs.chW writeTimeout := c.WriteTimeout maxIdleConnDuration := c.MaxIdleConnDuration @@ -3027,13 +3046,13 @@ func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}) error } } -func (c *pipelineConnClient) reader(conn net.Conn, stopCh <-chan struct{}) error { +func (c *pipelineConnClient) reader(conn net.Conn, stopCh <-chan struct{}, chs *pipelineConnChannels) error { readBufferSize := c.ReadBufferSize if readBufferSize <= 0 { readBufferSize = defaultReadBufferSize } br := bufio.NewReaderSize(conn, readBufferSize) - chR := c.chR + chR := chs.chR readTimeout := c.ReadTimeout var ( @@ -3100,10 +3119,11 @@ func (c *PipelineClient) PendingRequests() int { } func (c *pipelineConnClient) PendingRequests() int { - c.init() + chs := c.acquirePipelineConnChannels() + defer c.releasePipelineConnChannels(chs) c.chLock.Lock() - n := len(c.chR) + len(c.chW) + n := len(chs.chR) + len(chs.chW) c.chLock.Unlock() return n } diff --git a/client_test.go b/client_test.go index 1350eba..45d2f3a 100644 --- a/client_test.go +++ b/client_test.go @@ -295,6 +295,115 @@ func TestPipelineClientIssue832(t *testing.T) { } } +func TestPipelineClientRestartsAfterIdle(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.WriteString("OK") //nolint:errcheck + }, + } + + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %v", err) + } + close(serverStopCh) + }() + + c := &PipelineClient{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + MaxIdleConnDuration: 10 * time.Millisecond, + MaxPendingRequests: 1, + Logger: &testLogger{}, + } + + testPipelineClientDoOnce(t, c) + time.Sleep(50 * time.Millisecond) + testPipelineClientDoOnce(t, c) + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + +func TestPipelineClientChannelLifecycleRace(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.WriteString("OK") //nolint:errcheck + }, + } + + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %v", err) + } + close(serverStopCh) + }() + + c := &PipelineClient{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + MaxIdleConnDuration: time.Millisecond, + MaxPendingRequests: 2, + Logger: &testLogger{}, + } + + var wg sync.WaitGroup + for range 8 { + wg.Go(func() { + for range 20 { + testPipelineClientDoOnce(t, c) + time.Sleep(time.Millisecond) + } + }) + } + wg.Wait() + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + +func testPipelineClientDoOnce(t *testing.T, c *PipelineClient) { + t.Helper() + + req := AcquireRequest() + req.SetRequestURI("http://foobar/baz") + resp := AcquireResponse() + defer ReleaseRequest(req) + defer ReleaseResponse(resp) + + if err := c.DoTimeout(req, resp, time.Second); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + if body := string(resp.Body()); body != "OK" { + t.Fatalf("unexpected body: %q. Expecting %q", body, "OK") + } +} + func TestClientInvalidURI(t *testing.T) { t.Parallel()