diff --git a/client.go b/client.go index 9d50924..fed4d16 100644 --- a/client.go +++ b/client.go @@ -1084,16 +1084,15 @@ func (c *HostClient) dialHostHard() (conn net.Conn, err error) { // It looks like c.addrs isn't initialized yet. n = 1 } + + startTime := time.Now() for n > 0 { conn, err = c.dialHost() if err == nil { return conn, nil } - if err == ErrDialTimeout { - // The function already has been blocked for DefaultDialTimeout, - // so let's return the error to the caller instead of waiting - // for unspecified time during dialing the remaining hosts. - return nil, err + if time.Since(startTime) > DefaultDialTimeout { + return nil, ErrDialTimeout } n-- } diff --git a/tcpdialer.go b/tcpdialer.go index 082deac..38bed87 100644 --- a/tcpdialer.go +++ b/tcpdialer.go @@ -20,8 +20,9 @@ var ( // // * It reduces load on DNS resolver by caching resolved TCP addressed // for one minute. -// * It uses real round-robin if the given addr is resolved to multiple -// TCP addresses. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. // * It returns ErrDialTimeout if connection cannot be established during // DefaultDialTimeout seconds. // @@ -46,8 +47,9 @@ func Dial(addr string) (net.Conn, error) { // // * It reduces load on DNS resolver by caching resolved TCP addressed // for one minute. -// * It uses real round-robin if the given addr is resolved to multiple -// TCP addresses. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. // * It returns ErrDialTimeout if connection cannot be established during // DefaultDialTimeout seconds. // @@ -82,7 +84,7 @@ func (d *tcpDialer) NewDial() DialFunc { go d.tcpAddrsClean() return func(addr string) (net.Conn, error) { - tcpAddr, err := d.getTCPAddr(addr) + addrs, idx, err := d.getTCPAddrs(addr) if err != nil { return nil, err } @@ -90,18 +92,37 @@ func (d *tcpDialer) NewDial() DialFunc { if d.DualStack { network = "tcp" } - ch := make(chan dialResult, 1) - go func() { - var dr dialResult - dr.conn, dr.err = net.DialTCP(network, nil, tcpAddr) - ch <- dr - }() - select { - case dr := <-ch: - return dr.conn, dr.err - case <-time.After(DefaultDialTimeout): - return nil, ErrDialTimeout + + var conn net.Conn + startTime := time.Now() + n := uint32(len(addrs)) + for n > 0 { + conn, err = tryDial(network, &addrs[idx%n]) + if err == nil { + return conn, nil + } + if time.Since(startTime) > DefaultDialTimeout { + return nil, ErrDialTimeout + } + idx++ + n-- } + return nil, err + } +} + +func tryDial(network string, addr *net.TCPAddr) (net.Conn, error) { + ch := make(chan dialResult, 1) + go func() { + var dr dialResult + dr.conn, dr.err = net.DialTCP(network, nil, addr) + ch <- dr + }() + select { + case dr := <-ch: + return dr.conn, dr.err + case <-time.After(DefaultDialTimeout): + return nil, ErrDialTimeout } } @@ -143,7 +164,7 @@ func (d *tcpDialer) tcpAddrsClean() { } } -func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) { +func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) { d.tcpAddrsLock.Lock() e := d.tcpAddrsMap[addr] if e != nil && !e.pending && time.Since(e.resolveTime) > tcpAddrsCacheDuration { @@ -153,7 +174,7 @@ func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) { d.tcpAddrsLock.Unlock() if e == nil { - tcpAddrs, err := resolveTCPAddrs(addr, d.DualStack) + addrs, err := resolveTCPAddrs(addr, d.DualStack) if err != nil { d.tcpAddrsLock.Lock() e = d.tcpAddrsMap[addr] @@ -161,11 +182,11 @@ func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) { e.pending = false } d.tcpAddrsLock.Unlock() - return nil, err + return nil, 0, err } e = &tcpAddrEntry{ - addrs: tcpAddrs, + addrs: addrs, resolveTime: time.Now(), } @@ -174,13 +195,11 @@ func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) { d.tcpAddrsLock.Unlock() } - tcpAddr := &e.addrs[0] - n := len(e.addrs) - if n > 1 { - n := atomic.AddUint32(&e.addrsIdx, 1) - tcpAddr = &e.addrs[n%uint32(n)] + idx := uint32(0) + if len(e.addrs) > 0 { + idx = atomic.AddUint32(&e.addrsIdx, 1) } - return tcpAddr, nil + return e.addrs, idx, nil } func resolveTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, error) {