diff --git a/client.go b/client.go index 898fa05..52f1d80 100644 --- a/client.go +++ b/client.go @@ -51,14 +51,28 @@ func DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return defaultClient.DoTimeout(req, resp, timeout) } -// Get fetches url contents into dst and returns it as body. +// Get appends url contents to dst and returns it as body. +// +// New body buffer is allocated if dst is nil. func Get(dst []byte, url string) (statusCode int, body []byte, err error) { return defaultClient.Get(dst, url) } +// GetTimeout appends url contents to dst and returns it as body. +// +// New body buffer is allocated if dst is nil. +// +// ErrTimeout error is returned if url contents couldn't be fetched +// during the given timeout. +func GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) { + return defaultClient.GetTimeout(dst, url, timeout) +} + // Post sends POST request to the given url with the given POST arguments. // -// Response body is written to dst, which is returned as body. +// Response body is appended to dst, which is returned as body. +// +// New body buffer is allocated if dst is nil. // // Empty POST body is sent if postArgs is nil. func Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) { @@ -120,14 +134,28 @@ type Client struct { ms map[string]*HostClient } -// Get fetches url contents into dst and returns it as body. +// Get appends url contents to dst and returns it as body. +// +// New body buffer is allocated if dst is nil. func (c *Client) Get(dst []byte, url string) (statusCode int, body []byte, err error) { return clientGetURL(dst, url, c) } +// GetTimeout appends url contents to dst and returns it as body. +// +// New body buffer is allocated if dst is nil. +// +// ErrTimeout error is returned if url contents couldn't be fetched +// during the given timeout. +func (c *Client) GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) { + return clientGetURLTimeout(dst, url, timeout, c) +} + // Post sends POST request to the given url with the given POST arguments. // -// Response body is written to dst, which is returned as body. +// Response body is appended to dst, which is returned as body. +// +// New body buffer is allocated if dst is nil. // // Empty POST body is sent if postArgs is nil. func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) { @@ -352,14 +380,28 @@ func (c *HostClient) LastUseTime() time.Time { return time.Unix(int64(n), 0) } -// Get fetches url contents into dst and returns it as body. +// Get appends url contents to dst and returns it as body. +// +// New body buffer is allocated if dst is nil. func (c *HostClient) Get(dst []byte, url string) (statusCode int, body []byte, err error) { return clientGetURL(dst, url, c) } +// GetTimeout appends url contents to dst and returns it as body. +// +// New body buffer is allocated if dst is nil. +// +// ErrTimeout error is returned if url contents couldn't be fetched +// during the given timeout. +func (c *HostClient) GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) { + return clientGetURLTimeout(dst, url, timeout, c) +} + // Post sends POST request to the given url with the given POST arguments. // -// Response body is written to dst, which is returned as body. +// Response body is appended to dst, which is returned as body. +// +// New body buffer is allocated if dst is nil. // // Empty POST body is sent if postArgs is nil. func (c *HostClient) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) { @@ -379,6 +421,85 @@ func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body [] return statusCode, body, err } +func clientGetURLTimeout(dst []byte, url string, timeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) { + if timeout <= 0 { + return 0, dst, ErrTimeout + } + + deadline := time.Now().Add(timeout) + for { + statusCode, body, err = clientGetURLTimeoutFreeConn(dst, url, timeout, c) + if err != ErrNoFreeConns { + return statusCode, body, err + } + timeout = -time.Since(deadline) + if timeout <= 0 { + return 0, dst, ErrTimeout + } + sleepTime := (10 + time.Duration(rand.Intn(100))) * time.Millisecond + if sleepTime > timeout { + sleepTime = timeout + } + time.Sleep(sleepTime) + timeout = -time.Since(deadline) + if timeout <= 0 { + return 0, dst, ErrTimeout + } + } +} + +func clientGetURLTimeoutFreeConn(dst []byte, url string, timeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) { + var ch chan error + chv := errorChPool.Get() + if chv == nil { + ch = make(chan error, 1) + } else { + ch = chv.(chan error) + } + + req := acquireRequest() + + // Note that the request continues execution on ErrTimeout until + // client-specific ReadTimeout exceeds. This helps limiting load + // on slow hosts by MaxConns* concurrent requests. + // + // Without this 'hack' the load on slow host could exceed MaxConns* + // concurrent requests, since timed out requests on client side + // usually continue execution on the host. + var statusCodeCopy int + var bodyCopy []byte + go func() { + var errCopy error + statusCodeCopy, bodyCopy, errCopy = doRequest(req, dst, url, c) + ch <- errCopy + }() + + var tc *time.Timer + tcv := timerPool.Get() + if tcv == nil { + tc = time.NewTimer(timeout) + } else { + tc = tcv.(*time.Timer) + initTimer(tc, timeout) + } + + select { + case err = <-ch: + releaseRequest(req) + errorChPool.Put(chv) + statusCode = statusCodeCopy + body = bodyCopy + case <-tc.C: + body = dst + err = ErrTimeout + } + + stopTimer(tc) + timerPool.Put(tcv) + + return statusCode, body, err +} + func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (statusCode int, body []byte, err error) { req := acquireRequest() req.Header.SetMethodBytes(strPost) @@ -400,7 +521,7 @@ func doRequest(req *Request, dst []byte, url string, c clientDoer) (statusCode i oldBody := resp.body resp.body = dst if err = c.Do(req, resp); err != nil { - return 0, nil, err + return 0, dst, err } statusCode = resp.Header.StatusCode() body = resp.body @@ -489,12 +610,19 @@ func clientDoTimeoutFreeConn(req *Request, resp *Response, timeout time.Duration ch = chv.(chan error) } - // make req and resp copies, since on timeout they no longer - // may accessed. + // Make req and resp copies, since on timeout they no longer + // may be accessed. reqCopy := acquireRequest() req.CopyTo(reqCopy) respCopy := acquireResponse() + // Note that the request continues execution on ErrTimeout until + // client-specific ReadTimeout exceeds. This helps limiting load + // on slow hosts by MaxConns* concurrent requests. + // + // Without this 'hack' the load on slow host could exceed MaxConns* + // concurrent requests, since timed out requests on client side + // usually continue execution on the host. go func() { ch <- c.Do(reqCopy, respCopy) }() diff --git a/client_test.go b/client_test.go index d6069d7..6b81355 100644 --- a/client_test.go +++ b/client_test.go @@ -12,17 +12,43 @@ import ( "time" ) -func TestClientDoTimeout(t *testing.T) { +func TestClientGetTimeoutSuccess(t *testing.T) { + addr := "127.0.0.1:56889" + s := startEchoServer(t, "tcp", addr) + defer s.Stop() + + addr = "http://" + addr + testClientGetTimeoutSuccess(t, &defaultClient, addr, 100) +} + +func TestClientGetTimeoutSuccessConcurrent(t *testing.T) { + addr := "127.0.0.1:56989" + s := startEchoServer(t, "tcp", addr) + defer s.Stop() + + addr = "http://" + addr + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + testClientGetTimeoutSuccess(t, &defaultClient, addr, 100) + }() + } + wg.Wait() +} + +func TestClientGetTimeoutError(t *testing.T) { c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil }, } - testClientDoTimeout(t, c, 100) + testClientGetTimeoutError(t, c, 100) } -func TestClientDoTimeoutConcurrent(t *testing.T) { +func TestClientGetTimeoutErrorConcurrent(t *testing.T) { c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil @@ -35,17 +61,46 @@ func TestClientDoTimeoutConcurrent(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - testClientDoTimeout(t, c, 100) + testClientGetTimeoutError(t, c, 100) }() } wg.Wait() } -func testClientDoTimeout(t *testing.T, c *Client, n int) { +func TestClientDoTimeoutError(t *testing.T) { + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return &readTimeoutConn{t: time.Second}, nil + }, + } + + testClientDoTimeoutError(t, c, 100) +} + +func TestClientDoTimeoutErrorConcurrent(t *testing.T) { + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return &readTimeoutConn{t: time.Second}, nil + }, + MaxConnsPerHost: 1000, + } + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + testClientDoTimeoutError(t, c, 100) + }() + } + wg.Wait() +} + +func testClientDoTimeoutError(t *testing.T, c *Client, n int) { var req Request var resp Response req.SetRequestURI("http://foobar.com/baz") - for i := 0; i < 20; i++ { + for i := 0; i < n; i++ { err := c.DoTimeout(&req, &resp, time.Millisecond) if err == nil { t.Fatalf("expecting error") @@ -56,6 +111,25 @@ func testClientDoTimeout(t *testing.T, c *Client, n int) { } } +func testClientGetTimeoutError(t *testing.T, c *Client, n int) { + buf := make([]byte, 10) + for i := 0; i < n; i++ { + statusCode, body, err := c.GetTimeout(buf, "http://foobar.com/baz", time.Millisecond) + if err == nil { + t.Fatalf("expecting error") + } + if err != ErrTimeout { + t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout) + } + if statusCode != 0 { + t.Fatalf("unexpected statusCode=%d. Expecting %d", statusCode, 0) + } + if body == nil { + t.Fatalf("body must be non-nil") + } + } +} + type readTimeoutConn struct { net.Conn t time.Duration @@ -225,7 +299,7 @@ func TestClientPost(t *testing.T) { } func TestClientConcurrent(t *testing.T) { - addr := "127.0.0.1:56780" + addr := "127.0.0.1:55780" s := startEchoServer(t, "tcp", addr) defer s.Stop() @@ -300,6 +374,28 @@ func testClientGet(t *testing.T, c clientGetter, addr string, n int) { } } +func testClientGetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) { + var buf []byte + for i := 0; i < n; i++ { + uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i) + statusCode, body, err := c.GetTimeout(buf, uri, time.Second) + buf = body + if err != nil { + t.Fatalf("unexpected error when doing http request: %s", err) + } + if statusCode != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK) + } + resultURI := string(body) + if strings.HasPrefix(uri, "https") { + resultURI = uri[:5] + resultURI[4:] + } + if resultURI != uri { + t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri) + } + } +} + func testClientPost(t *testing.T, c clientPoster, addr string, n int) { var buf []byte var args Args