diff --git a/client.go b/client.go index 456ff12..1029f1a 100644 --- a/client.go +++ b/client.go @@ -1162,6 +1162,7 @@ func (c *HostClient) acquireConn() (*clientConn, error) { } else { n-- cc = c.conns[n] + c.conns[n] = nil c.conns = c.conns[:n] } c.connsLock.Unlock() diff --git a/http.go b/http.go index 3472d4e..ba23ce0 100644 --- a/http.go +++ b/http.go @@ -473,6 +473,52 @@ func (req *Request) ReleaseBody(size int) { } } +// SwapBody swaps response body with the given body and returns +// the previous response body. +// +// It is forbidden to use the body passed to SwapBody after +// the function returns. +func (resp *Response) SwapBody(body []byte) []byte { + bb := resp.bodyBuffer() + + if resp.bodyStream != nil { + bb.Reset() + _, err := copyZeroAlloc(bb, resp.bodyStream) + resp.closeBodyStream() + if err != nil { + bb.Reset() + bb.SetString(err.Error()) + } + } + + oldBody := bb.B + bb.B = body + return oldBody +} + +// SwapBody swaps request body with the given body and returns +// the previous request body. +// +// It is forbidden to use the body passed to SwapBody after +// the function returns. +func (req *Request) SwapBody(body []byte) []byte { + bb := req.bodyBuffer() + + if req.bodyStream != nil { + bb.Reset() + _, err := copyZeroAlloc(bb, req.bodyStream) + req.closeBodyStream() + if err != nil { + bb.Reset() + bb.SetString(err.Error()) + } + } + + oldBody := bb.B + bb.B = body + return oldBody +} + // Body returns request body. func (req *Request) Body() []byte { if req.bodyStream != nil { diff --git a/http_test.go b/http_test.go index 5323b26..2a8507d 100644 --- a/http_test.go +++ b/http_test.go @@ -12,6 +12,106 @@ import ( "time" ) +func TestResponseSwapBodySerial(t *testing.T) { + testResponseSwapBody(t) +} + +func TestResponseSwapBodyConcurrent(t *testing.T) { + ch := make(chan struct{}) + for i := 0; i < 10; i++ { + go func() { + testResponseSwapBody(t) + ch <- struct{}{} + }() + } + + for i := 0; i < 10; i++ { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + } +} + +func testResponseSwapBody(t *testing.T) { + var b []byte + r := AcquireResponse() + for i := 0; i < 20; i++ { + bOrig := r.Body() + b = r.SwapBody(b) + if !bytes.Equal(bOrig, b) { + t.Fatalf("unexpected body returned: %q. Expecting %q", b, bOrig) + } + r.AppendBodyString("foobar") + } + + s := "aaaabbbbcccc" + b = b[:0] + for i := 0; i < 10; i++ { + r.SetBodyStream(bytes.NewBufferString(s), len(s)) + b = r.SwapBody(b) + if string(b) != s { + t.Fatalf("unexpected body returned: %q. Expecting %q", b, s) + } + b = r.SwapBody(b) + if len(b) > 0 { + t.Fatalf("unexpected body with non-zero size returned: %q", b) + } + } + ReleaseResponse(r) +} + +func TestRequestSwapBodySerial(t *testing.T) { + testRequestSwapBody(t) +} + +func TestRequestSwapBodyConcurrent(t *testing.T) { + ch := make(chan struct{}) + for i := 0; i < 10; i++ { + go func() { + testRequestSwapBody(t) + ch <- struct{}{} + }() + } + + for i := 0; i < 10; i++ { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + } +} + +func testRequestSwapBody(t *testing.T) { + var b []byte + r := AcquireRequest() + for i := 0; i < 20; i++ { + bOrig := r.Body() + b = r.SwapBody(b) + if !bytes.Equal(bOrig, b) { + t.Fatalf("unexpected body returned: %q. Expecting %q", b, bOrig) + } + r.AppendBodyString("foobar") + } + + s := "aaaabbbbcccc" + b = b[:0] + for i := 0; i < 10; i++ { + r.SetBodyStream(bytes.NewBufferString(s), len(s)) + b = r.SwapBody(b) + if string(b) != s { + t.Fatalf("unexpected body returned: %q. Expecting %q", b, s) + } + b = r.SwapBody(b) + if len(b) > 0 { + t.Fatalf("unexpected body with non-zero size returned: %q", b) + } + } + ReleaseRequest(r) +} + func TestRequestHostFromRequestURI(t *testing.T) { hExpected := "foobar.com" var req Request