Dial all the resolved TCP addresses in round-robin manner until the connection is established

This commit is contained in:
Aliaksandr Valialkin
2016-01-15 19:41:09 +02:00
parent 576ba8868b
commit e3369ec00b
2 changed files with 49 additions and 31 deletions
+4 -5
View File
@@ -1084,16 +1084,15 @@ func (c *HostClient) dialHostHard() (conn net.Conn, err error) {
// It looks like c.addrs isn't initialized yet.
n = 1
}
startTime := time.Now()
for n > 0 {
conn, err = c.dialHost()
if err == nil {
return conn, nil
}
if err == ErrDialTimeout {
// The function already has been blocked for DefaultDialTimeout,
// so let's return the error to the caller instead of waiting
// for unspecified time during dialing the remaining hosts.
return nil, err
if time.Since(startTime) > DefaultDialTimeout {
return nil, ErrDialTimeout
}
n--
}
+45 -26
View File
@@ -20,8 +20,9 @@ var (
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for one minute.
// * It uses real round-robin if the given addr is resolved to multiple
// TCP addresses.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// * It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds.
//
@@ -46,8 +47,9 @@ func Dial(addr string) (net.Conn, error) {
//
// * It reduces load on DNS resolver by caching resolved TCP addressed
// for one minute.
// * It uses real round-robin if the given addr is resolved to multiple
// TCP addresses.
// * It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// * It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds.
//
@@ -82,7 +84,7 @@ func (d *tcpDialer) NewDial() DialFunc {
go d.tcpAddrsClean()
return func(addr string) (net.Conn, error) {
tcpAddr, err := d.getTCPAddr(addr)
addrs, idx, err := d.getTCPAddrs(addr)
if err != nil {
return nil, err
}
@@ -90,18 +92,37 @@ func (d *tcpDialer) NewDial() DialFunc {
if d.DualStack {
network = "tcp"
}
ch := make(chan dialResult, 1)
go func() {
var dr dialResult
dr.conn, dr.err = net.DialTCP(network, nil, tcpAddr)
ch <- dr
}()
select {
case dr := <-ch:
return dr.conn, dr.err
case <-time.After(DefaultDialTimeout):
return nil, ErrDialTimeout
var conn net.Conn
startTime := time.Now()
n := uint32(len(addrs))
for n > 0 {
conn, err = tryDial(network, &addrs[idx%n])
if err == nil {
return conn, nil
}
if time.Since(startTime) > DefaultDialTimeout {
return nil, ErrDialTimeout
}
idx++
n--
}
return nil, err
}
}
func tryDial(network string, addr *net.TCPAddr) (net.Conn, error) {
ch := make(chan dialResult, 1)
go func() {
var dr dialResult
dr.conn, dr.err = net.DialTCP(network, nil, addr)
ch <- dr
}()
select {
case dr := <-ch:
return dr.conn, dr.err
case <-time.After(DefaultDialTimeout):
return nil, ErrDialTimeout
}
}
@@ -143,7 +164,7 @@ func (d *tcpDialer) tcpAddrsClean() {
}
}
func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) {
func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) {
d.tcpAddrsLock.Lock()
e := d.tcpAddrsMap[addr]
if e != nil && !e.pending && time.Since(e.resolveTime) > tcpAddrsCacheDuration {
@@ -153,7 +174,7 @@ func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) {
d.tcpAddrsLock.Unlock()
if e == nil {
tcpAddrs, err := resolveTCPAddrs(addr, d.DualStack)
addrs, err := resolveTCPAddrs(addr, d.DualStack)
if err != nil {
d.tcpAddrsLock.Lock()
e = d.tcpAddrsMap[addr]
@@ -161,11 +182,11 @@ func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) {
e.pending = false
}
d.tcpAddrsLock.Unlock()
return nil, err
return nil, 0, err
}
e = &tcpAddrEntry{
addrs: tcpAddrs,
addrs: addrs,
resolveTime: time.Now(),
}
@@ -174,13 +195,11 @@ func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) {
d.tcpAddrsLock.Unlock()
}
tcpAddr := &e.addrs[0]
n := len(e.addrs)
if n > 1 {
n := atomic.AddUint32(&e.addrsIdx, 1)
tcpAddr = &e.addrs[n%uint32(n)]
idx := uint32(0)
if len(e.addrs) > 0 {
idx = atomic.AddUint32(&e.addrsIdx, 1)
}
return tcpAddr, nil
return e.addrs, idx, nil
}
func resolveTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, error) {