mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-13 15:46:49 +03:00
Abstracts the RoundTripper interface and provides a default implement (#1602)
* Abstracts the RoundTripper interface and provides a default implementation for enhanced extensibility (#1601) * test: Add custom transport test case (#1601) * Make default RoundTripper implmention none public Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com> --------- Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
This commit is contained in:
@@ -628,8 +628,10 @@ type DialFunc func(addr string) (net.Conn, error)
|
||||
// Request argument passed to RetryIfFunc, if there are any request errors.
|
||||
type RetryIfFunc func(request *Request) bool
|
||||
|
||||
// TransportFunc wraps every request/response.
|
||||
type TransportFunc func(*Request, *Response) error
|
||||
// RoundTripper wraps every request/response.
|
||||
type RoundTripper interface {
|
||||
RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
|
||||
}
|
||||
|
||||
// ConnPoolStrategyType define strategy of connection pool enqueue/dequeue
|
||||
type ConnPoolStrategyType int
|
||||
@@ -791,7 +793,7 @@ type HostClient struct {
|
||||
RetryIf RetryIfFunc
|
||||
|
||||
// Transport defines a transport-like mechanism that wraps every request/response.
|
||||
Transport TransportFunc
|
||||
Transport RoundTripper
|
||||
|
||||
// Connection pool strategy. Can be either LIFO or FIFO (default).
|
||||
ConnPoolStrategy ConnPoolStrategyType
|
||||
@@ -1343,119 +1345,15 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
|
||||
req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
|
||||
}
|
||||
}
|
||||
if c.Transport != nil {
|
||||
err := c.Transport(req, resp)
|
||||
return err == nil, err
|
||||
}
|
||||
|
||||
var deadline time.Time
|
||||
if req.timeout > 0 {
|
||||
deadline = time.Now().Add(req.timeout)
|
||||
}
|
||||
return c.transport().RoundTrip(c, req, resp)
|
||||
}
|
||||
|
||||
cc, err := c.acquireConn(req.timeout, req.ConnectionClose())
|
||||
if err != nil {
|
||||
return false, err
|
||||
func (c *HostClient) transport() RoundTripper {
|
||||
if c.Transport == nil {
|
||||
return DefaultTransport
|
||||
}
|
||||
conn := cc.c
|
||||
|
||||
resp.parseNetConn(conn)
|
||||
|
||||
writeDeadline := deadline
|
||||
if c.WriteTimeout > 0 {
|
||||
tmpWriteDeadline := time.Now().Add(c.WriteTimeout)
|
||||
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
|
||||
writeDeadline = tmpWriteDeadline
|
||||
}
|
||||
}
|
||||
|
||||
if err = conn.SetWriteDeadline(writeDeadline); err != nil {
|
||||
c.closeConn(cc)
|
||||
return true, err
|
||||
}
|
||||
|
||||
resetConnection := false
|
||||
if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() {
|
||||
req.SetConnectionClose()
|
||||
resetConnection = true
|
||||
}
|
||||
|
||||
bw := c.acquireWriter(conn)
|
||||
err = req.Write(bw)
|
||||
|
||||
if resetConnection {
|
||||
req.Header.ResetConnectionClose()
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = bw.Flush()
|
||||
}
|
||||
c.releaseWriter(bw)
|
||||
|
||||
// Return ErrTimeout on any timeout.
|
||||
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
|
||||
err = ErrTimeout
|
||||
}
|
||||
|
||||
isConnRST := isConnectionReset(err)
|
||||
if err != nil && !isConnRST {
|
||||
c.closeConn(cc)
|
||||
return true, err
|
||||
}
|
||||
|
||||
readDeadline := deadline
|
||||
if c.ReadTimeout > 0 {
|
||||
tmpReadDeadline := time.Now().Add(c.ReadTimeout)
|
||||
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
|
||||
readDeadline = tmpReadDeadline
|
||||
}
|
||||
}
|
||||
|
||||
if err = conn.SetReadDeadline(readDeadline); err != nil {
|
||||
c.closeConn(cc)
|
||||
return true, err
|
||||
}
|
||||
|
||||
if customSkipBody || req.Header.IsHead() {
|
||||
resp.SkipBody = true
|
||||
}
|
||||
if c.DisableHeaderNamesNormalizing {
|
||||
resp.Header.DisableNormalizing()
|
||||
}
|
||||
|
||||
br := c.acquireReader(conn)
|
||||
err = resp.ReadLimitBody(br, c.MaxResponseBodySize)
|
||||
c.releaseReader(br)
|
||||
if err != nil {
|
||||
c.closeConn(cc)
|
||||
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
|
||||
retry := err != ErrBodyTooLarge
|
||||
return retry, err
|
||||
}
|
||||
|
||||
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
|
||||
if customStreamBody && resp.bodyStream != nil {
|
||||
rbs := resp.bodyStream
|
||||
resp.bodyStream = newCloseReader(rbs, func() error {
|
||||
if r, ok := rbs.(*requestStream); ok {
|
||||
releaseRequestStream(r)
|
||||
}
|
||||
if closeConn {
|
||||
c.closeConn(cc)
|
||||
} else {
|
||||
c.releaseConn(cc)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if closeConn {
|
||||
c.closeConn(cc)
|
||||
} else {
|
||||
c.releaseConn(cc)
|
||||
}
|
||||
return false, nil
|
||||
return c.Transport
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -2909,3 +2807,121 @@ func (c *pipelineConnClient) PendingRequests() int {
|
||||
}
|
||||
|
||||
var errPipelineConnStopped = errors.New("pipeline connection has been stopped")
|
||||
|
||||
var DefaultTransport RoundTripper = &transport{}
|
||||
|
||||
type transport struct{}
|
||||
|
||||
func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
|
||||
customSkipBody := resp.SkipBody
|
||||
customStreamBody := resp.StreamBody
|
||||
|
||||
var deadline time.Time
|
||||
if req.timeout > 0 {
|
||||
deadline = time.Now().Add(req.timeout)
|
||||
}
|
||||
|
||||
cc, err := hc.acquireConn(req.timeout, req.ConnectionClose())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
conn := cc.c
|
||||
|
||||
resp.parseNetConn(conn)
|
||||
|
||||
writeDeadline := deadline
|
||||
if hc.WriteTimeout > 0 {
|
||||
tmpWriteDeadline := time.Now().Add(hc.WriteTimeout)
|
||||
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
|
||||
writeDeadline = tmpWriteDeadline
|
||||
}
|
||||
}
|
||||
|
||||
if err = conn.SetWriteDeadline(writeDeadline); err != nil {
|
||||
hc.closeConn(cc)
|
||||
return true, err
|
||||
}
|
||||
|
||||
resetConnection := false
|
||||
if hc.MaxConnDuration > 0 && time.Since(cc.createdTime) > hc.MaxConnDuration && !req.ConnectionClose() {
|
||||
req.SetConnectionClose()
|
||||
resetConnection = true
|
||||
}
|
||||
|
||||
bw := hc.acquireWriter(conn)
|
||||
err = req.Write(bw)
|
||||
|
||||
if resetConnection {
|
||||
req.Header.ResetConnectionClose()
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = bw.Flush()
|
||||
}
|
||||
hc.releaseWriter(bw)
|
||||
|
||||
// Return ErrTimeout on any timeout.
|
||||
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
|
||||
err = ErrTimeout
|
||||
}
|
||||
|
||||
isConnRST := isConnectionReset(err)
|
||||
if err != nil && !isConnRST {
|
||||
hc.closeConn(cc)
|
||||
return true, err
|
||||
}
|
||||
|
||||
readDeadline := deadline
|
||||
if hc.ReadTimeout > 0 {
|
||||
tmpReadDeadline := time.Now().Add(hc.ReadTimeout)
|
||||
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
|
||||
readDeadline = tmpReadDeadline
|
||||
}
|
||||
}
|
||||
|
||||
if err = conn.SetReadDeadline(readDeadline); err != nil {
|
||||
hc.closeConn(cc)
|
||||
return true, err
|
||||
}
|
||||
|
||||
if customSkipBody || req.Header.IsHead() {
|
||||
resp.SkipBody = true
|
||||
}
|
||||
if hc.DisableHeaderNamesNormalizing {
|
||||
resp.Header.DisableNormalizing()
|
||||
}
|
||||
|
||||
br := hc.acquireReader(conn)
|
||||
err = resp.ReadLimitBody(br, hc.MaxResponseBodySize)
|
||||
hc.releaseReader(br)
|
||||
if err != nil {
|
||||
hc.closeConn(cc)
|
||||
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
|
||||
needRetry := err != ErrBodyTooLarge
|
||||
return needRetry, err
|
||||
}
|
||||
|
||||
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
|
||||
if customStreamBody && resp.bodyStream != nil {
|
||||
rbs := resp.bodyStream
|
||||
resp.bodyStream = newCloseReader(rbs, func() error {
|
||||
if r, ok := rbs.(*requestStream); ok {
|
||||
releaseRequestStream(r)
|
||||
}
|
||||
if closeConn {
|
||||
hc.closeConn(cc)
|
||||
} else {
|
||||
hc.releaseConn(cc)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if closeConn {
|
||||
hc.closeConn(cc)
|
||||
} else {
|
||||
hc.releaseConn(cc)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
+98
-16
@@ -2111,6 +2111,22 @@ func TestClientRetryRequestWithCustomDecider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type TransportDemo struct {
|
||||
br *bufio.Reader
|
||||
bw *bufio.Writer
|
||||
}
|
||||
|
||||
func (t TransportDemo) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
|
||||
if err = req.Write(t.bw); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err = t.bw.Flush(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
err = res.Read(t.br)
|
||||
return err != nil, err
|
||||
}
|
||||
|
||||
func TestHostClientTransport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2131,23 +2147,13 @@ func TestHostClientTransport(t *testing.T) {
|
||||
|
||||
c := &HostClient{
|
||||
Addr: "foobar",
|
||||
Transport: func() TransportFunc {
|
||||
Transport: func() RoundTripper {
|
||||
c, _ := ln.Dial()
|
||||
|
||||
br := bufio.NewReader(c)
|
||||
bw := bufio.NewWriter(c)
|
||||
|
||||
return func(req *Request, res *Response) error {
|
||||
if err := req.Write(bw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := bw.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return res.Read(br)
|
||||
}
|
||||
return TransportDemo{br: br, bw: bw}
|
||||
}(),
|
||||
}
|
||||
|
||||
@@ -3060,14 +3066,18 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type TransportEmpty struct{}
|
||||
|
||||
func (t TransportEmpty) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func TestHttpsRequestWithoutParsedURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := HostClient{
|
||||
IsTLS: true,
|
||||
Transport: func(r1 *Request, r2 *Response) error {
|
||||
return nil
|
||||
},
|
||||
IsTLS: true,
|
||||
Transport: TransportEmpty{},
|
||||
}
|
||||
|
||||
req := &Request{}
|
||||
@@ -3182,3 +3192,75 @@ func Test_AddMissingPort(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type TransportWrapper struct {
|
||||
base RoundTripper
|
||||
count *int
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (tw *TransportWrapper) RoundTrip(hc *HostClient, req *Request, resp *Response) (bool, error) {
|
||||
req.Header.Set("trace-id", "123")
|
||||
tw.assertRequestLog(req.String())
|
||||
retry, err := tw.transport().RoundTrip(hc, req, resp)
|
||||
resp.Header.Set("trace-id", "124")
|
||||
tw.assertResponseLog(resp.String())
|
||||
*tw.count++
|
||||
return retry, err
|
||||
}
|
||||
|
||||
func (tw *TransportWrapper) transport() RoundTripper {
|
||||
if tw.base == nil {
|
||||
return DefaultTransport
|
||||
}
|
||||
return tw.base
|
||||
}
|
||||
|
||||
func (tw *TransportWrapper) assertRequestLog(reqLog string) {
|
||||
if !strings.Contains(reqLog, "Trace-Id: 123") {
|
||||
tw.t.Errorf("request log should contains: %v", "Trace-Id: 123")
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TransportWrapper) assertResponseLog(respLog string) {
|
||||
if !strings.Contains(respLog, "Trace-Id: 124") {
|
||||
tw.t.Errorf("response log should contains: %v", "Trace-Id: 124")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientTransportEx(t *testing.T) {
|
||||
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
|
||||
defer sHTTP.Stop()
|
||||
|
||||
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
|
||||
defer sHTTPS.Stop()
|
||||
|
||||
count := 0
|
||||
c := &Client{
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
ConfigureClient: func(hc *HostClient) error {
|
||||
hc.Transport = &TransportWrapper{base: hc.Transport, count: &count, t: t}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
// test transport
|
||||
const loopCount = 4
|
||||
const getCount = 20
|
||||
const postCount = 10
|
||||
for i := 0; i < loopCount; i++ {
|
||||
addr := "http://" + sHTTP.Addr()
|
||||
if i&1 != 0 {
|
||||
addr = "https://" + sHTTPS.Addr()
|
||||
}
|
||||
// test get
|
||||
testClientGet(t, c, addr, getCount)
|
||||
// test post
|
||||
testClientPost(t, c, addr, postCount)
|
||||
}
|
||||
roundTripCount := loopCount * (getCount + postCount)
|
||||
if count != roundTripCount {
|
||||
t.Errorf("round trip count should be: %v", roundTripCount)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user