From 8aef785c1efe456220fada4b8d4d750adf8a4864 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Sat, 14 Nov 2015 23:18:52 +0200 Subject: [PATCH] Added DoTimeout() to client --- client.go | 105 +++++++++++++++++++++++++++++++++++++++++++++++-- client_test.go | 62 +++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 3acdead..93cbec2 100644 --- a/client.go +++ b/client.go @@ -32,6 +32,22 @@ func Do(req *Request, resp *Response) error { return defaultClient.Do(req, resp) } +// DoTimeout performs the given request and waits for response during +// the given timeout duration. +// +// Request must contain at least non-zero RequestURI with full url (including +// scheme and host) or non-zero Host header + RequestURI. +// +// Client determines the server to be requested in the following order: +// - from RequestURI if it contains full url with scheme and host; +// - from Host header otherwise. +// +// ErrTimeout is returned if the response wasn't returned during +// the given timeout. +func DoTimeout(req *Request, resp *Response, timeout time.Duration) error { + return defaultClient.DoTimeout(req, resp, timeout) +} + // Get fetches url contents into dst. // // Use Do for request customization. @@ -109,6 +125,22 @@ func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, b return clientPostURL(dst, url, postArgs, c) } +// DoTimeout performs the given request and waits for response during +// the given timeout duration. +// +// Request must contain at least non-zero RequestURI with full url (including +// scheme and host) or non-zero Host header + RequestURI. +// +// Client determines the server to be requested in the following order: +// - from RequestURI if it contains full url with scheme and host; +// - from Host header otherwise. +// +// ErrTimeout is returned if the response wasn't returned during +// the given timeout. +func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { + return clientDoTimeout(req, resp, timeout, c) +} + // Do performs the given http request and fills the given http response. // // Request must contain at least non-zero RequestURI with full url (including @@ -379,6 +411,68 @@ func releaseResponse(resp *Response) { responsePool.Put(resp) } +// DoTimeout performs the given request and waits for response during +// the given timeout duration. +// +// Request must contain at least non-zero RequestURI with full url (including +// scheme and host) or non-zero Host header + RequestURI. +// +// ErrTimeout is returned if the response wasn't returned during +// the given timeout. +func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { + return clientDoTimeout(req, resp, timeout, c) +} + +func clientDoTimeout(req *Request, resp *Response, timeout time.Duration, c clientDoer) error { + var ch chan error + chv := errorChPool.Get() + if chv == nil { + ch = make(chan error, 1) + } else { + ch = chv.(chan error) + } + + // make req and resp copies, since on timeout they no longer + // may accessed. + reqCopy := acquireRequest() + req.CopyTo(reqCopy) + respCopy := acquireResponse() + + go func() { + ch <- c.Do(reqCopy, respCopy) + }() + + var tc *time.Timer + tcv := timerPool.Get() + if tcv == nil { + tc = time.NewTimer(timeout) + } else { + tc = tcv.(*time.Timer) + initTimer(tc, timeout) + } + + var err error + select { + case err = <-ch: + resp.CopyTo(respCopy) + releaseResponse(respCopy) + releaseRequest(reqCopy) + errorChPool.Put(chv) + case <-tc.C: + err = ErrTimeout + } + + stopTimer(tc) + timerPool.Put(tcv) + + return err +} + +var ( + errorChPool sync.Pool + timerPool sync.Pool +) + // Do performs the given http request and sets the corresponding response. // // Request must contain at least non-zero RequestURI with full url (including @@ -463,9 +557,14 @@ func (c *HostClient) do(req *Request, resp *Response, newConn bool) (bool, error return false, err } -// ErrNoFreeConns is returned when no free connections available -// to the given host. -var ErrNoFreeConns = errors.New("no free connections available to host") +var ( + // ErrNoFreeConns is returned when no free connections available + // to the given host. + ErrNoFreeConns = errors.New("no free connections available to host") + + // ErrTimeout is returned from timed out calls. + ErrTimeout = errors.New("timeout") +) func (c *HostClient) acquireConn(newConn bool) (*clientConn, error) { var cc *clientConn diff --git a/client_test.go b/client_test.go index 8b22ba6..01817c8 100644 --- a/client_test.go +++ b/client_test.go @@ -12,6 +12,68 @@ import ( "time" ) +func TestClientDoTimeout(t *testing.T) { + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return &readTimeoutConn{t: time.Second}, nil + }, + } + + testClientDoTimeout(t, c, 100) +} + +func TestClientDoTimeoutConcurrent(t *testing.T) { + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return &readTimeoutConn{t: time.Second}, nil + }, + MaxConnsPerHost: 1000, + } + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + testClientDoTimeout(t, c, 100) + }() + } + wg.Wait() +} + +func testClientDoTimeout(t *testing.T, c *Client, n int) { + var req Request + var resp Response + req.SetRequestURI("http://foobar.com/baz") + for i := 0; i < 20; i++ { + err := c.DoTimeout(&req, &resp, time.Millisecond) + if err == nil { + t.Fatalf("expecting error") + } + if err != ErrTimeout { + t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout) + } + } +} + +type readTimeoutConn struct { + net.Conn + t time.Duration +} + +func (r *readTimeoutConn) Read(p []byte) (int, error) { + time.Sleep(r.t) + return 0, io.EOF +} + +func (r *readTimeoutConn) Write(p []byte) (int, error) { + return len(p), nil +} + +func (r *readTimeoutConn) Close() error { + return nil +} + func TestClientIdempotentRequest(t *testing.T) { dialsCount := 0 c := &Client{