Abstracts the RoundTripper interface and provides a default implement (#1602)

* Abstracts the RoundTripper interface and provides a default implementation for enhanced extensibility (#1601)

* test: Add custom transport test case (#1601)

* Make default RoundTripper implmention none public

Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>

---------

Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
This commit is contained in:
Tim
2023-08-10 15:43:26 +08:00
committed by GitHub
parent e181af17c7
commit 54fdc7a73c
2 changed files with 227 additions and 129 deletions
+129 -113
View File
@@ -628,8 +628,10 @@ type DialFunc func(addr string) (net.Conn, error)
// Request argument passed to RetryIfFunc, if there are any request errors.
type RetryIfFunc func(request *Request) bool
// TransportFunc wraps every request/response.
type TransportFunc func(*Request, *Response) error
// RoundTripper wraps every request/response.
type RoundTripper interface {
RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
}
// ConnPoolStrategyType define strategy of connection pool enqueue/dequeue
type ConnPoolStrategyType int
@@ -791,7 +793,7 @@ type HostClient struct {
RetryIf RetryIfFunc
// Transport defines a transport-like mechanism that wraps every request/response.
Transport TransportFunc
Transport RoundTripper
// Connection pool strategy. Can be either LIFO or FIFO (default).
ConnPoolStrategy ConnPoolStrategyType
@@ -1343,119 +1345,15 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
}
}
if c.Transport != nil {
err := c.Transport(req, resp)
return err == nil, err
}
var deadline time.Time
if req.timeout > 0 {
deadline = time.Now().Add(req.timeout)
}
return c.transport().RoundTrip(c, req, resp)
}
cc, err := c.acquireConn(req.timeout, req.ConnectionClose())
if err != nil {
return false, err
func (c *HostClient) transport() RoundTripper {
if c.Transport == nil {
return DefaultTransport
}
conn := cc.c
resp.parseNetConn(conn)
writeDeadline := deadline
if c.WriteTimeout > 0 {
tmpWriteDeadline := time.Now().Add(c.WriteTimeout)
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
writeDeadline = tmpWriteDeadline
}
}
if err = conn.SetWriteDeadline(writeDeadline); err != nil {
c.closeConn(cc)
return true, err
}
resetConnection := false
if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() {
req.SetConnectionClose()
resetConnection = true
}
bw := c.acquireWriter(conn)
err = req.Write(bw)
if resetConnection {
req.Header.ResetConnectionClose()
}
if err == nil {
err = bw.Flush()
}
c.releaseWriter(bw)
// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
err = ErrTimeout
}
isConnRST := isConnectionReset(err)
if err != nil && !isConnRST {
c.closeConn(cc)
return true, err
}
readDeadline := deadline
if c.ReadTimeout > 0 {
tmpReadDeadline := time.Now().Add(c.ReadTimeout)
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
readDeadline = tmpReadDeadline
}
}
if err = conn.SetReadDeadline(readDeadline); err != nil {
c.closeConn(cc)
return true, err
}
if customSkipBody || req.Header.IsHead() {
resp.SkipBody = true
}
if c.DisableHeaderNamesNormalizing {
resp.Header.DisableNormalizing()
}
br := c.acquireReader(conn)
err = resp.ReadLimitBody(br, c.MaxResponseBodySize)
c.releaseReader(br)
if err != nil {
c.closeConn(cc)
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
retry := err != ErrBodyTooLarge
return retry, err
}
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReader(rbs, func() error {
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
if closeConn {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}
return nil
})
return false, nil
}
if closeConn {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}
return false, nil
return c.Transport
}
var (
@@ -2909,3 +2807,121 @@ func (c *pipelineConnClient) PendingRequests() int {
}
var errPipelineConnStopped = errors.New("pipeline connection has been stopped")
var DefaultTransport RoundTripper = &transport{}
type transport struct{}
func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
customSkipBody := resp.SkipBody
customStreamBody := resp.StreamBody
var deadline time.Time
if req.timeout > 0 {
deadline = time.Now().Add(req.timeout)
}
cc, err := hc.acquireConn(req.timeout, req.ConnectionClose())
if err != nil {
return false, err
}
conn := cc.c
resp.parseNetConn(conn)
writeDeadline := deadline
if hc.WriteTimeout > 0 {
tmpWriteDeadline := time.Now().Add(hc.WriteTimeout)
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
writeDeadline = tmpWriteDeadline
}
}
if err = conn.SetWriteDeadline(writeDeadline); err != nil {
hc.closeConn(cc)
return true, err
}
resetConnection := false
if hc.MaxConnDuration > 0 && time.Since(cc.createdTime) > hc.MaxConnDuration && !req.ConnectionClose() {
req.SetConnectionClose()
resetConnection = true
}
bw := hc.acquireWriter(conn)
err = req.Write(bw)
if resetConnection {
req.Header.ResetConnectionClose()
}
if err == nil {
err = bw.Flush()
}
hc.releaseWriter(bw)
// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
err = ErrTimeout
}
isConnRST := isConnectionReset(err)
if err != nil && !isConnRST {
hc.closeConn(cc)
return true, err
}
readDeadline := deadline
if hc.ReadTimeout > 0 {
tmpReadDeadline := time.Now().Add(hc.ReadTimeout)
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
readDeadline = tmpReadDeadline
}
}
if err = conn.SetReadDeadline(readDeadline); err != nil {
hc.closeConn(cc)
return true, err
}
if customSkipBody || req.Header.IsHead() {
resp.SkipBody = true
}
if hc.DisableHeaderNamesNormalizing {
resp.Header.DisableNormalizing()
}
br := hc.acquireReader(conn)
err = resp.ReadLimitBody(br, hc.MaxResponseBodySize)
hc.releaseReader(br)
if err != nil {
hc.closeConn(cc)
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
needRetry := err != ErrBodyTooLarge
return needRetry, err
}
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReader(rbs, func() error {
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
if closeConn {
hc.closeConn(cc)
} else {
hc.releaseConn(cc)
}
return nil
})
return false, nil
}
if closeConn {
hc.closeConn(cc)
} else {
hc.releaseConn(cc)
}
return false, nil
}
+98 -16
View File
@@ -2111,6 +2111,22 @@ func TestClientRetryRequestWithCustomDecider(t *testing.T) {
}
}
type TransportDemo struct {
br *bufio.Reader
bw *bufio.Writer
}
func (t TransportDemo) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
if err = req.Write(t.bw); err != nil {
return false, err
}
if err = t.bw.Flush(); err != nil {
return false, err
}
err = res.Read(t.br)
return err != nil, err
}
func TestHostClientTransport(t *testing.T) {
t.Parallel()
@@ -2131,23 +2147,13 @@ func TestHostClientTransport(t *testing.T) {
c := &HostClient{
Addr: "foobar",
Transport: func() TransportFunc {
Transport: func() RoundTripper {
c, _ := ln.Dial()
br := bufio.NewReader(c)
bw := bufio.NewWriter(c)
return func(req *Request, res *Response) error {
if err := req.Write(bw); err != nil {
return err
}
if err := bw.Flush(); err != nil {
return err
}
return res.Read(br)
}
return TransportDemo{br: br, bw: bw}
}(),
}
@@ -3060,14 +3066,18 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
}
}
type TransportEmpty struct{}
func (t TransportEmpty) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
return false, nil
}
func TestHttpsRequestWithoutParsedURL(t *testing.T) {
t.Parallel()
client := HostClient{
IsTLS: true,
Transport: func(r1 *Request, r2 *Response) error {
return nil
},
IsTLS: true,
Transport: TransportEmpty{},
}
req := &Request{}
@@ -3182,3 +3192,75 @@ func Test_AddMissingPort(t *testing.T) {
})
}
}
type TransportWrapper struct {
base RoundTripper
count *int
t *testing.T
}
func (tw *TransportWrapper) RoundTrip(hc *HostClient, req *Request, resp *Response) (bool, error) {
req.Header.Set("trace-id", "123")
tw.assertRequestLog(req.String())
retry, err := tw.transport().RoundTrip(hc, req, resp)
resp.Header.Set("trace-id", "124")
tw.assertResponseLog(resp.String())
*tw.count++
return retry, err
}
func (tw *TransportWrapper) transport() RoundTripper {
if tw.base == nil {
return DefaultTransport
}
return tw.base
}
func (tw *TransportWrapper) assertRequestLog(reqLog string) {
if !strings.Contains(reqLog, "Trace-Id: 123") {
tw.t.Errorf("request log should contains: %v", "Trace-Id: 123")
}
}
func (tw *TransportWrapper) assertResponseLog(respLog string) {
if !strings.Contains(respLog, "Trace-Id: 124") {
tw.t.Errorf("response log should contains: %v", "Trace-Id: 124")
}
}
func TestClientTransportEx(t *testing.T) {
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
defer sHTTP.Stop()
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
defer sHTTPS.Stop()
count := 0
c := &Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
ConfigureClient: func(hc *HostClient) error {
hc.Transport = &TransportWrapper{base: hc.Transport, count: &count, t: t}
return nil
},
}
// test transport
const loopCount = 4
const getCount = 20
const postCount = 10
for i := 0; i < loopCount; i++ {
addr := "http://" + sHTTP.Addr()
if i&1 != 0 {
addr = "https://" + sHTTPS.Addr()
}
// test get
testClientGet(t, c, addr, getCount)
// test post
testClientPost(t, c, addr, postCount)
}
roundTripCount := loopCount * (getCount + postCount)
if count != roundTripCount {
t.Errorf("round trip count should be: %v", roundTripCount)
}
}