From c8f577c7f1d8408ecb95c45a6b2ce7d14259b843 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Tue, 12 Jan 2016 11:08:24 +0200 Subject: [PATCH] Added ability to balance requests among multiple upstream hosts via HostClient --- TODO | 1 - client.go | 44 +++++++++++++++++++++++------------- client_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 16 deletions(-) diff --git a/TODO b/TODO index ce1e2bb..399593d 100644 --- a/TODO +++ b/TODO @@ -1,5 +1,4 @@ - SessionClient with referer and cookies support. -- Load balancing client for multiple upstream hosts. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . diff --git a/client.go b/client.go index e751b55..c63dfee 100644 --- a/client.go +++ b/client.go @@ -99,7 +99,7 @@ type Client struct { // Callback for establishing new connections to hosts. // - // Default TCPDialer is used if not set. + // Default Dial is used if not set. Dial DialFunc // Attempt to connect to both ipv4 and ipv6 addresses if set to true. @@ -301,7 +301,7 @@ func (c *Client) mCleaner(m map[string]*HostClient) { // Maximum number of concurrent connections http client may establish per host // by default. -const DefaultMaxConnsPerHost = 4 * 1024 +const DefaultMaxConnsPerHost = 512 // DialFunc must establish connection to addr. // @@ -317,16 +317,18 @@ const DefaultMaxConnsPerHost = 4 * 1024 // - foobar.com:8080 type DialFunc func(addr string) (net.Conn, error) -// HostClient is a single-host http client. It can make http requests -// to the given Addr only. +// HostClient balances http requests among hosts listed in Addr. +// +// HostClient may be used for balancing load among multiple upstream hosts. // // It is forbidden copying HostClient instances. Create new instances instead. // // It is safe calling HostClient methods from concurrently running goroutines. type HostClient struct { - // HTTP server host address, which is passed to Dial. + // Comma-separated list of upstream HTTP server host addresses, + // which are passed to Dial in round-robin manner. // - // The address may contain port if default dialer is used. + // Each address may contain port if default dialer is used. // For example, // // - foobar.com:80 @@ -339,7 +341,7 @@ type HostClient struct { // Callback for establishing new connection to the host. // - // Default TCPDialer is used if not set. + // Default Dial is used if not set. Dial DialFunc // Attempt to connect to both ipv4 and ipv6 host addresses @@ -358,7 +360,8 @@ type HostClient struct { // Optional TLS config. TLSConfig *tls.Config - // Maximum number of connections to the host which may be established. + // Maximum number of connections which may be established to all hosts + // listed in Addr. // // DefaultMaxConnsPerHost is used if not set. MaxConns int @@ -399,12 +402,9 @@ type HostClient struct { connsCount int conns []*clientConn - // dns caching stuff for default dialer. - tcpAddrsLock sync.Mutex - tcpAddrs []net.TCPAddr - tcpAddrsPending bool - tcpAddrsResolveTime time.Time - tcpAddrsIdx uint32 + addrsLock sync.Mutex + addrs []string + addrIdx uint32 readerPool sync.Pool writerPool sync.Pool @@ -1044,9 +1044,23 @@ var defaultTLSConfig = &tls.Config{ InsecureSkipVerify: true, } +func (c *HostClient) nextAddr() string { + c.addrsLock.Lock() + if c.addrs == nil { + c.addrs = strings.Split(c.Addr, ",") + } + addr := c.addrs[0] + if len(c.addrs) > 1 { + addr = c.addrs[c.addrIdx%uint32(len(c.addrs))] + c.addrIdx++ + } + c.addrsLock.Unlock() + return addr +} + func (c *HostClient) dialHost() (net.Conn, error) { dial := c.Dial - addr := c.Addr + addr := c.nextAddr() if dial == nil { if c.DialDualStack { dial = DialDualStack diff --git a/client_test.go b/client_test.go index f0d89de..6dd2868 100644 --- a/client_test.go +++ b/client_test.go @@ -10,8 +10,68 @@ import ( "sync" "testing" "time" + + "github.com/valyala/fasthttp/fasthttputil" ) +func TestHostClientMultipleAddrs(t *testing.T) { + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Write(ctx.Host()) + ctx.SetConnectionClose() + }, + } + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Fatalf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + dialsCount := make(map[string]int) + c := &HostClient{ + Addr: "foo,bar,baz", + Dial: func(addr string) (net.Conn, error) { + dialsCount[addr]++ + return ln.Dial() + }, + } + + for i := 0; i < 9; i++ { + statusCode, body, err := c.Get(nil, "http://foobar/baz/aaa?bbb=ddd") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if statusCode != StatusOK { + t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK) + } + if string(body) != "foobar" { + t.Fatalf("unexpected body %q. Expecting %q", body, "foobar") + } + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + if len(dialsCount) != 3 { + t.Fatalf("unexpected dialsCount size %d. Expecting 3", len(dialsCount)) + } + for _, k := range []string{"foo", "bar", "baz"} { + if dialsCount[k] != 3 { + t.Fatalf("unexpected dialsCount for %q. Expecting 3", k) + } + } +} + func TestClientFollowRedirects(t *testing.T) { addr := "127.0.0.1:55234" s := &Server{