From 92b182c4b17757fe8b6dc71aedb76e38eb45017b Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Wed, 13 Jan 2016 14:06:05 +0200 Subject: [PATCH] Added ability to limit the maximum connection duration in HostClient --- client.go | 32 ++++++++++++++++++++-------- client_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index c63dfee..b9cb5f2 100644 --- a/client.go +++ b/client.go @@ -366,6 +366,11 @@ type HostClient struct { // DefaultMaxConnsPerHost is used if not set. MaxConns int + // Maximum duration for each keep-alive connection before closing. + // + // By default connection duration is unlimited. + MaxConnDuration time.Duration + // Per-connection buffer size for responses' reading. // This also limits the maximum header size. // @@ -411,8 +416,9 @@ type HostClient struct { } type clientConn struct { - t time.Time - c net.Conn + c net.Conn + createdTime time.Time + lastUseTime time.Time } var startTimeUnix = time.Now().Unix() @@ -812,6 +818,12 @@ func (c *HostClient) do(req *Request, resp *Response, newConn bool) (bool, error } } + resetConnection := false + if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() { + req.SetConnectionClose() + resetConnection = true + } + userAgentOld := req.Header.UserAgent() if len(userAgentOld) == 0 { req.Header.userAgent = c.getClientName() @@ -822,6 +834,10 @@ func (c *HostClient) do(req *Request, resp *Response, newConn bool) (bool, error req.Header.userAgent = userAgentOld } + if resetConnection { + req.Header.ResetConnectionClose() + } + if err != nil { c.releaseWriter(bw) c.closeConn(cc) @@ -865,7 +881,7 @@ func (c *HostClient) do(req *Request, resp *Response, newConn bool) (bool, error } c.releaseReader(br) - if req.Header.ConnectionClose() || resp.Header.ConnectionClose() { + if resetConnection || req.ConnectionClose() || resp.ConnectionClose() { c.closeConn(cc) } else { c.releaseConn(cc) @@ -939,7 +955,7 @@ func (c *HostClient) connsCleaner() { t := time.Now() c.connsLock.Lock() conns := c.conns - for len(conns) > 0 && t.Sub(conns[0].t) > 10*time.Second { + for len(conns) > 0 && t.Sub(conns[0].lastUseTime) > 10*time.Second { cc := conns[0] c.connsCount-- cc.c.Close() @@ -980,13 +996,11 @@ func (c *HostClient) decConnsCount() { func acquireClientConn(conn net.Conn) *clientConn { v := clientConnPool.Get() if v == nil { - cc := &clientConn{ - c: conn, - } - return cc + v = &clientConn{} } cc := v.(*clientConn) cc.c = conn + cc.createdTime = time.Now() return cc } @@ -998,7 +1012,7 @@ func releaseClientConn(cc *clientConn) { var clientConnPool sync.Pool func (c *HostClient) releaseConn(cc *clientConn) { - cc.t = time.Now() + cc.lastUseTime = time.Now() c.connsLock.Lock() c.conns = append(c.conns, cc) c.connsLock.Unlock() diff --git a/client_test.go b/client_test.go index 6dd2868..11dcf00 100644 --- a/client_test.go +++ b/client_test.go @@ -8,12 +8,69 @@ import ( "os" "strings" "sync" + "sync/atomic" "testing" "time" "github.com/valyala/fasthttp/fasthttputil" ) +func TestHostClientMaxConnDuration(t *testing.T) { + ln := fasthttputil.NewInmemoryListener() + + connectionCloseCount := uint32(0) + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.WriteString("abcd") + if ctx.Request.Header.ConnectionCloseReal() { + atomic.AddUint32(&connectionCloseCount, 1) + } + }, + } + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Fatalf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + c := &HostClient{ + Addr: "foobar", + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + MaxConnDuration: 10 * time.Millisecond, + } + + for i := 0; i < 5; i++ { + statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if statusCode != StatusOK { + t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK) + } + if string(body) != "abcd" { + t.Fatalf("unexpected body %q. Expecting %q", body, "abcd") + } + time.Sleep(c.MaxConnDuration) + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + if connectionCloseCount == 0 { + t.Fatalf("expecting at least one 'Connection: close' request header") + } +} + func TestHostClientMultipleAddrs(t *testing.T) { ln := fasthttputil.NewInmemoryListener()