mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
This commit is contained in:
@@ -2511,10 +2511,9 @@ type pipelineConnClient struct {
|
||||
|
||||
Dial DialFunc
|
||||
TLSConfig *tls.Config
|
||||
chW chan *pipelineWork
|
||||
chR chan *pipelineWork
|
||||
|
||||
tlsConfig *tls.Config
|
||||
chs *pipelineConnChannels
|
||||
|
||||
Addr string
|
||||
Name string
|
||||
@@ -2536,6 +2535,12 @@ type pipelineConnClient struct {
|
||||
IsTLS bool
|
||||
}
|
||||
|
||||
type pipelineConnChannels struct {
|
||||
chW chan *pipelineWork
|
||||
chR chan *pipelineWork
|
||||
users int
|
||||
}
|
||||
|
||||
type pipelineWork struct {
|
||||
respCopy Response
|
||||
deadline time.Time
|
||||
@@ -2586,13 +2591,14 @@ func (c *PipelineClient) DoDeadline(req *Request, resp *Response, deadline time.
|
||||
}
|
||||
|
||||
func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
|
||||
c.init()
|
||||
|
||||
timeout := time.Until(deadline)
|
||||
if timeout <= 0 {
|
||||
return ErrTimeout
|
||||
}
|
||||
|
||||
chs := c.acquirePipelineConnChannels()
|
||||
defer c.releasePipelineConnChannels(chs)
|
||||
|
||||
if c.DisablePathNormalizing {
|
||||
req.URI().DisablePathNormalizing = true
|
||||
}
|
||||
@@ -2619,12 +2625,12 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t
|
||||
|
||||
// Put the request to outgoing queue
|
||||
select {
|
||||
case c.chW <- w:
|
||||
case chs.chW <- w:
|
||||
// Fast path: len(c.ch) < cap(c.ch)
|
||||
default:
|
||||
// Slow path
|
||||
select {
|
||||
case c.chW <- w:
|
||||
case chs.chW <- w:
|
||||
case <-w.t.C:
|
||||
c.releasePipelineWork(w)
|
||||
return ErrTimeout
|
||||
@@ -2698,7 +2704,8 @@ func (c *PipelineClient) Do(req *Request, resp *Response) error {
|
||||
}
|
||||
|
||||
func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
|
||||
c.init()
|
||||
chs := c.acquirePipelineConnChannels()
|
||||
defer c.releasePipelineConnChannels(chs)
|
||||
|
||||
if c.DisablePathNormalizing {
|
||||
req.URI().DisablePathNormalizing = true
|
||||
@@ -2726,17 +2733,17 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
|
||||
|
||||
// Put the request to outgoing queue
|
||||
select {
|
||||
case c.chW <- w:
|
||||
case chs.chW <- w:
|
||||
default:
|
||||
// Try substituting the oldest w with the current one.
|
||||
select {
|
||||
case wOld := <-c.chW:
|
||||
case wOld := <-chs.chW:
|
||||
wOld.err = ErrPipelineOverflow
|
||||
wOld.done <- struct{}{}
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case c.chW <- w:
|
||||
case chs.chW <- w:
|
||||
default:
|
||||
c.releasePipelineWork(w)
|
||||
return ErrPipelineOverflow
|
||||
@@ -2824,46 +2831,58 @@ var ErrPipelineOverflow = errors.New("pipelined requests' queue has been overflo
|
||||
// for PipelineClient.MaxPendingRequests.
|
||||
const DefaultMaxPendingRequests = 1024
|
||||
|
||||
func (c *pipelineConnClient) init() {
|
||||
func (c *pipelineConnClient) acquirePipelineConnChannels() *pipelineConnChannels {
|
||||
c.chLock.Lock()
|
||||
if c.chR == nil {
|
||||
chs := c.chs
|
||||
if chs == nil {
|
||||
maxPendingRequests := c.MaxPendingRequests
|
||||
if maxPendingRequests <= 0 {
|
||||
maxPendingRequests = DefaultMaxPendingRequests
|
||||
}
|
||||
c.chR = make(chan *pipelineWork, maxPendingRequests)
|
||||
if c.chW == nil {
|
||||
c.chW = make(chan *pipelineWork, maxPendingRequests)
|
||||
chs = &pipelineConnChannels{
|
||||
chR: make(chan *pipelineWork, maxPendingRequests),
|
||||
chW: make(chan *pipelineWork, maxPendingRequests),
|
||||
}
|
||||
go func() {
|
||||
// Keep restarting the worker if it fails (connection errors for example).
|
||||
for {
|
||||
if err := c.worker(); err != nil {
|
||||
c.logger().Printf("error in PipelineClient(%q): %v", c.Addr, err)
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Throttle client reconnections on timeout errors
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
} else {
|
||||
c.chLock.Lock()
|
||||
stop := len(c.chR) == 0 && len(c.chW) == 0
|
||||
if !stop {
|
||||
c.chR = nil
|
||||
c.chW = nil
|
||||
}
|
||||
c.chLock.Unlock()
|
||||
|
||||
if stop {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
c.chs = chs
|
||||
go c.pipelineWorker(chs)
|
||||
}
|
||||
chs.users++
|
||||
c.chLock.Unlock()
|
||||
return chs
|
||||
}
|
||||
|
||||
func (c *pipelineConnClient) releasePipelineConnChannels(chs *pipelineConnChannels) {
|
||||
c.chLock.Lock()
|
||||
chs.users--
|
||||
c.chLock.Unlock()
|
||||
}
|
||||
|
||||
func (c *pipelineConnClient) worker() error {
|
||||
func (c *pipelineConnClient) pipelineWorker(chs *pipelineConnChannels) {
|
||||
// Keep restarting the worker if it fails (connection errors for example).
|
||||
for {
|
||||
if err := c.worker(chs); err != nil {
|
||||
c.logger().Printf("error in PipelineClient(%q): %v", c.Addr, err)
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Throttle client reconnections on timeout errors
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
} else if c.tryRetirePipelineConnChannels(chs) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *pipelineConnClient) tryRetirePipelineConnChannels(chs *pipelineConnChannels) bool {
|
||||
c.chLock.Lock()
|
||||
stop := c.chs == chs && chs.users == 0 && len(chs.chR) == 0 && len(chs.chW) == 0
|
||||
if stop {
|
||||
c.chs = nil
|
||||
}
|
||||
c.chLock.Unlock()
|
||||
return stop
|
||||
}
|
||||
|
||||
func (c *pipelineConnClient) worker(chs *pipelineConnChannels) error {
|
||||
tlsConfig := c.cachedTLSConfig()
|
||||
conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout)
|
||||
if err != nil {
|
||||
@@ -2874,12 +2893,12 @@ func (c *pipelineConnClient) worker() error {
|
||||
stopW := make(chan struct{})
|
||||
doneW := make(chan error)
|
||||
go func() {
|
||||
doneW <- c.writer(conn, stopW)
|
||||
doneW <- c.writer(conn, stopW, chs)
|
||||
}()
|
||||
stopR := make(chan struct{})
|
||||
doneR := make(chan error)
|
||||
go func() {
|
||||
doneR <- c.reader(conn, stopR)
|
||||
doneR <- c.reader(conn, stopR, chs)
|
||||
}()
|
||||
|
||||
// Wait until reader and writer are stopped
|
||||
@@ -2895,8 +2914,8 @@ func (c *pipelineConnClient) worker() error {
|
||||
}
|
||||
|
||||
// Notify pending readers
|
||||
for len(c.chR) > 0 {
|
||||
w := <-c.chR
|
||||
for len(chs.chR) > 0 {
|
||||
w := <-chs.chR
|
||||
w.err = errPipelineConnStopped
|
||||
w.done <- struct{}{}
|
||||
}
|
||||
@@ -2920,15 +2939,15 @@ func (c *pipelineConnClient) cachedTLSConfig() *tls.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}) error {
|
||||
func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}, chs *pipelineConnChannels) error {
|
||||
writeBufferSize := c.WriteBufferSize
|
||||
if writeBufferSize <= 0 {
|
||||
writeBufferSize = defaultWriteBufferSize
|
||||
}
|
||||
bw := bufio.NewWriterSize(conn, writeBufferSize)
|
||||
defer bw.Flush()
|
||||
chR := c.chR
|
||||
chW := c.chW
|
||||
chR := chs.chR
|
||||
chW := chs.chW
|
||||
writeTimeout := c.WriteTimeout
|
||||
|
||||
maxIdleConnDuration := c.MaxIdleConnDuration
|
||||
@@ -3027,13 +3046,13 @@ func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}) error
|
||||
}
|
||||
}
|
||||
|
||||
func (c *pipelineConnClient) reader(conn net.Conn, stopCh <-chan struct{}) error {
|
||||
func (c *pipelineConnClient) reader(conn net.Conn, stopCh <-chan struct{}, chs *pipelineConnChannels) error {
|
||||
readBufferSize := c.ReadBufferSize
|
||||
if readBufferSize <= 0 {
|
||||
readBufferSize = defaultReadBufferSize
|
||||
}
|
||||
br := bufio.NewReaderSize(conn, readBufferSize)
|
||||
chR := c.chR
|
||||
chR := chs.chR
|
||||
readTimeout := c.ReadTimeout
|
||||
|
||||
var (
|
||||
@@ -3100,10 +3119,11 @@ func (c *PipelineClient) PendingRequests() int {
|
||||
}
|
||||
|
||||
func (c *pipelineConnClient) PendingRequests() int {
|
||||
c.init()
|
||||
chs := c.acquirePipelineConnChannels()
|
||||
defer c.releasePipelineConnChannels(chs)
|
||||
|
||||
c.chLock.Lock()
|
||||
n := len(c.chR) + len(c.chW)
|
||||
n := len(chs.chR) + len(chs.chW)
|
||||
c.chLock.Unlock()
|
||||
return n
|
||||
}
|
||||
|
||||
+109
@@ -295,6 +295,115 @@ func TestPipelineClientIssue832(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineClientRestartsAfterIdle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln := fasthttputil.NewInmemoryListener()
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
ctx.WriteString("OK") //nolint:errcheck
|
||||
},
|
||||
}
|
||||
|
||||
serverStopCh := make(chan struct{})
|
||||
go func() {
|
||||
if err := s.Serve(ln); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
close(serverStopCh)
|
||||
}()
|
||||
|
||||
c := &PipelineClient{
|
||||
Dial: func(addr string) (net.Conn, error) {
|
||||
return ln.Dial()
|
||||
},
|
||||
MaxIdleConnDuration: 10 * time.Millisecond,
|
||||
MaxPendingRequests: 1,
|
||||
Logger: &testLogger{},
|
||||
}
|
||||
|
||||
testPipelineClientDoOnce(t, c)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
testPipelineClientDoOnce(t, c)
|
||||
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-serverStopCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineClientChannelLifecycleRace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln := fasthttputil.NewInmemoryListener()
|
||||
s := &Server{
|
||||
Handler: func(ctx *RequestCtx) {
|
||||
ctx.WriteString("OK") //nolint:errcheck
|
||||
},
|
||||
}
|
||||
|
||||
serverStopCh := make(chan struct{})
|
||||
go func() {
|
||||
if err := s.Serve(ln); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
close(serverStopCh)
|
||||
}()
|
||||
|
||||
c := &PipelineClient{
|
||||
Dial: func(addr string) (net.Conn, error) {
|
||||
return ln.Dial()
|
||||
},
|
||||
MaxIdleConnDuration: time.Millisecond,
|
||||
MaxPendingRequests: 2,
|
||||
Logger: &testLogger{},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 8 {
|
||||
wg.Go(func() {
|
||||
for range 20 {
|
||||
testPipelineClientDoOnce(t, c)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-serverStopCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func testPipelineClientDoOnce(t *testing.T, c *PipelineClient) {
|
||||
t.Helper()
|
||||
|
||||
req := AcquireRequest()
|
||||
req.SetRequestURI("http://foobar/baz")
|
||||
resp := AcquireResponse()
|
||||
defer ReleaseRequest(req)
|
||||
defer ReleaseResponse(resp)
|
||||
|
||||
if err := c.DoTimeout(req, resp, time.Second); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != StatusOK {
|
||||
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
|
||||
}
|
||||
if body := string(resp.Body()); body != "OK" {
|
||||
t.Fatalf("unexpected body: %q. Expecting %q", body, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientInvalidURI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user