diff --git a/client.go b/client.go index 1fa072d..d72d83d 100644 --- a/client.go +++ b/client.go @@ -289,6 +289,11 @@ type Client struct { // By default will not waiting, return ErrNoFreeConns immediately MaxConnWaitTimeout time.Duration + // RetryIf controls whether a retry should be attempted after an error. + // + // By default will use isIdempotent function + RetryIf RetryIfFunc + mLock sync.Mutex m map[string]*HostClient ms map[string]*HostClient @@ -493,6 +498,7 @@ func (c *Client) Do(req *Request, resp *Response) error { DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing, DisablePathNormalizing: c.DisablePathNormalizing, MaxConnWaitTimeout: c.MaxConnWaitTimeout, + RetryIf: c.RetryIf, } m[string(host)] = hc if len(m) == 1 { @@ -560,6 +566,11 @@ const DefaultMaxIdemponentCallAttempts = 5 // - foobar.com:8080 type DialFunc func(addr string) (net.Conn, error) +// RetryIfFunc signature of retry if function +// +// Request argument passed to RetryIfFunc, if there are any request errors. +type RetryIfFunc func(request *Request) bool + // HostClient balances http requests among hosts listed in Addr. // // HostClient may be used for balancing load among multiple upstream hosts. @@ -698,6 +709,11 @@ type HostClient struct { // By default will not waiting, return ErrNoFreeConns immediately MaxConnWaitTimeout time.Duration + // RetryIf controls whether a retry should be attempted after an error. + // + // By default will use isIdempotent function + RetryIf RetryIfFunc + clientName atomic.Value lastUseTime uint32 @@ -1183,6 +1199,10 @@ func (c *HostClient) Do(req *Request, resp *Response) error { if maxAttempts <= 0 { maxAttempts = DefaultMaxIdemponentCallAttempts } + isRequestRetryable := isIdempotent + if c.RetryIf != nil { + isRequestRetryable = c.RetryIf + } attempts := 0 hasBodyStream := req.IsBodyStream() @@ -1196,7 +1216,7 @@ func (c *HostClient) Do(req *Request, resp *Response) error { if hasBodyStream { break } - if !isIdempotent(req) { + if !isRequestRetryable(req) { // Retry non-idempotent requests if the server closes // the connection before sending the response. // diff --git a/client_test.go b/client_test.go index f5c9de9..63c7f7e 100644 --- a/client_test.go +++ b/client_test.go @@ -1648,6 +1648,58 @@ func TestClientIdempotentRequest(t *testing.T) { } } +func TestClientRetryRequestWithCustomDecider(t *testing.T) { + t.Parallel() + + dialsCount := 0 + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + dialsCount++ + switch dialsCount { + case 1: + return &singleReadConn{ + s: "invalid response", + }, nil + case 2: + return &writeErrorConn{}, nil + case 3: + return &readErrorConn{}, nil + case 4: + return &singleReadConn{ + s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456", + }, nil + default: + t.Fatalf("unexpected number of dials: %d", dialsCount) + } + panic("unreachable") + }, + RetryIf: func(req *Request) bool { + return req.URI().String() == "http://foobar/a/b" + }, + } + + var args Args + + // Post must succeed for http://foobar/a/b uri. + statusCode, body, err := c.Post(nil, "http://foobar/a/b", &args) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if statusCode != 345 { + t.Fatalf("unexpected status code: %d. Expecting 345", statusCode) + } + if string(body) != "0123456" { + t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456") + } + + // POST must fail for http://foobar/a/b/c uri. + dialsCount = 0 + _, _, err = c.Post(nil, "http://foobar/a/b/c", &args) + if err == nil { + t.Fatalf("expecting error") + } +} + type writeErrorConn struct { net.Conn }