fix: The client write operation did not immediately return upon encountering an RST packet. (#1849)

The current client implementation does not immediately return when encountering an RST packet while sending a request, but instead ignores it. This behavior is inconsistent with the net/http package and does not make logical sense.
This commit is contained in:
newacorn
2024-08-31 20:52:13 +08:00
committed by GitHub
parent d31f4ef7d5
commit d5c7d8953d
6 changed files with 71 additions and 165 deletions
+2 -3
View File
@@ -2989,8 +2989,7 @@ func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (ret
err = ErrTimeout
}
isConnRST := isConnectionReset(err)
if err != nil && !isConnRST {
if err != nil {
hc.closeConn(cc)
return true, err
}
@@ -3025,7 +3024,7 @@ func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (ret
return needRetry, err
}
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose()
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReaderWithError(rbs, func(wErr error) error {
+69
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"regexp"
@@ -15,6 +16,7 @@ import (
"strings"
"sync"
"sync/atomic"
"syscall"
"testing"
"time"
@@ -3531,3 +3533,70 @@ func TestClientHeadWithBody(t *testing.T) {
t.Error(err)
}
}
func TestRevertPull1233(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
t.Parallel()
const expectedStatus = http.StatusTeapot
ln, err := net.Listen("tcp", "127.0.0.1:8089")
defer func() { ln.Close() }()
if err != nil {
t.Fatal(err.Error())
}
go func() {
for {
conn, err := ln.Accept()
if err != nil {
if !strings.Contains(err.Error(), "closed") {
t.Errorf(err.Error())
}
return
}
_, err = conn.Write([]byte("HTTP/1.1 418 Teapot\r\n\r\n"))
if err != nil {
t.Error(err)
}
err = conn.(*net.TCPConn).SetLinger(0)
if err != nil {
t.Errorf(err.Error())
}
conn.Close()
}
}()
reqURL := "http://" + ln.Addr().String()
reqStrBody := "hello 2323 23323 2323 2323 232323 323 2323 2333333 hello 2323 23323 2323 2323 232323 323 2323 2333333 hello 2323 23323 2323 2323 232323 323 2323 2333333 hello 2323 23323 2323 2323 232323 323 2323 2333333 hello 2323 23323 2323 2323 232323 323 2323 2333333"
req2 := AcquireRequest()
resp2 := AcquireResponse()
defer func() {
ReleaseRequest(req2)
ReleaseResponse(resp2)
}()
req2.SetRequestURI(reqURL)
req2.SetBodyStream(F{strings.NewReader(reqStrBody)}, -1)
cli2 := Client{}
err = cli2.Do(req2, resp2)
if !errors.Is(err, syscall.EPIPE) && !errors.Is(err, syscall.ECONNRESET) {
t.Errorf("expected error %v or %v, but got nil", syscall.EPIPE, syscall.ECONNRESET)
}
if expectedStatus == resp2.StatusCode() {
t.Errorf("Not Expected status code %d", resp2.StatusCode())
}
}
type F struct {
*strings.Reader
}
func (f F) Read(p []byte) (n int, err error) {
if len(p) > 10 {
p = p[:10]
}
// Ensure that subsequent segments can see the RST packet caused by sending previous
// segments to a closed connection.
time.Sleep(500 * time.Microsecond)
return f.Reader.Read(p)
}
-134
View File
@@ -1,134 +0,0 @@
//go:build !windows
package fasthttp
import (
"io"
"net"
"net/http"
"strings"
"testing"
)
// See issue #1232.
func TestRstConnResponseWhileSending(t *testing.T) {
const expectedStatus = http.StatusTeapot
const payload = "payload"
srv, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer srv.Close()
go func() {
for {
conn, err := srv.Accept()
if err != nil {
return
}
// Read at least one byte of the header
// Otherwise we would have an unsolicited response
_, err = io.ReadAll(io.LimitReader(conn, 1))
if err != nil {
t.Error(err)
}
// Respond
_, err = conn.Write([]byte("HTTP/1.1 418 Teapot\r\n\r\n"))
if err != nil {
t.Error(err)
}
// Forcefully close connection
err = conn.(*net.TCPConn).SetLinger(0)
if err != nil {
t.Error(err)
}
conn.Close()
}
}()
srvURL := "http://" + srv.Addr().String()
client := HostClient{Addr: srv.Addr().String()}
for i := 0; i < 100; i++ {
req := AcquireRequest()
defer ReleaseRequest(req)
resp := AcquireResponse()
defer ReleaseResponse(resp)
req.Header.SetMethod("POST")
req.SetBodyStream(strings.NewReader(payload), len(payload))
req.SetRequestURI(srvURL)
err = client.Do(req, resp)
if err != nil {
t.Fatal(err)
}
if expectedStatus != resp.StatusCode() {
t.Fatalf("Expected %d status code, but got %d", expectedStatus, resp.StatusCode())
}
}
}
// See issue #1232.
func TestRstConnClosedWithoutResponse(t *testing.T) {
const payload = "payload"
srv, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer srv.Close()
go func() {
for {
conn, err := srv.Accept()
if err != nil {
return
}
// Read at least one byte of the header
// Otherwise we would have an unsolicited response
_, err = io.ReadAll(io.LimitReader(conn, 1))
if err != nil {
t.Error(err)
}
// Respond with incomplete header
_, err = conn.Write([]byte("Http"))
if err != nil {
t.Error(err)
}
// Forcefully close connection
err = conn.(*net.TCPConn).SetLinger(0)
if err != nil {
t.Error(err)
}
conn.Close()
}
}()
srvURL := "http://" + srv.Addr().String()
client := HostClient{Addr: srv.Addr().String()}
for i := 0; i < 100; i++ {
req := AcquireRequest()
defer ReleaseRequest(req)
resp := AcquireResponse()
defer ReleaseResponse(resp)
req.Header.SetMethod("POST")
req.SetBodyStream(strings.NewReader(payload), len(payload))
req.SetRequestURI(srvURL)
err = client.Do(req, resp)
if !isConnectionReset(err) {
t.Fatal("Expected connection reset error")
}
}
}
-6
View File
@@ -1439,9 +1439,6 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
if !resp.mustSkipBody() {
err = resp.ReadBody(r, maxBodySize)
if err != nil {
if isConnectionReset(err) {
return nil
}
return err
}
}
@@ -1450,9 +1447,6 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
if resp.Header.ContentLength() == -1 && !resp.StreamBody && !resp.mustSkipBody() {
err = resp.Header.ReadTrailer(r)
if err != nil && err != io.EOF {
if isConnectionReset(err) {
return nil
}
return err
}
}
-12
View File
@@ -1,12 +0,0 @@
//go:build !windows
package fasthttp
import (
"errors"
"syscall"
)
func isConnectionReset(err error) bool {
return errors.Is(err, syscall.ECONNRESET)
}
-10
View File
@@ -1,10 +0,0 @@
package fasthttp
import (
"errors"
"syscall"
)
func isConnectionReset(err error) bool {
return errors.Is(err, syscall.WSAECONNRESET)
}