mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-15 16:07:51 +03:00
Dial all the resolved TCP addresses in round-robin manner until the connection is established
This commit is contained in:
@@ -1084,16 +1084,15 @@ func (c *HostClient) dialHostHard() (conn net.Conn, err error) {
|
||||
// It looks like c.addrs isn't initialized yet.
|
||||
n = 1
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
for n > 0 {
|
||||
conn, err = c.dialHost()
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
if err == ErrDialTimeout {
|
||||
// The function already has been blocked for DefaultDialTimeout,
|
||||
// so let's return the error to the caller instead of waiting
|
||||
// for unspecified time during dialing the remaining hosts.
|
||||
return nil, err
|
||||
if time.Since(startTime) > DefaultDialTimeout {
|
||||
return nil, ErrDialTimeout
|
||||
}
|
||||
n--
|
||||
}
|
||||
|
||||
+45
-26
@@ -20,8 +20,9 @@ var (
|
||||
//
|
||||
// * It reduces load on DNS resolver by caching resolved TCP addressed
|
||||
// for one minute.
|
||||
// * It uses real round-robin if the given addr is resolved to multiple
|
||||
// TCP addresses.
|
||||
// * It dials all the resolved TCP addresses in round-robin manner until
|
||||
// connection is established. This may be useful if certain addresses
|
||||
// are temporarily unreachable.
|
||||
// * It returns ErrDialTimeout if connection cannot be established during
|
||||
// DefaultDialTimeout seconds.
|
||||
//
|
||||
@@ -46,8 +47,9 @@ func Dial(addr string) (net.Conn, error) {
|
||||
//
|
||||
// * It reduces load on DNS resolver by caching resolved TCP addressed
|
||||
// for one minute.
|
||||
// * It uses real round-robin if the given addr is resolved to multiple
|
||||
// TCP addresses.
|
||||
// * It dials all the resolved TCP addresses in round-robin manner until
|
||||
// connection is established. This may be useful if certain addresses
|
||||
// are temporarily unreachable.
|
||||
// * It returns ErrDialTimeout if connection cannot be established during
|
||||
// DefaultDialTimeout seconds.
|
||||
//
|
||||
@@ -82,7 +84,7 @@ func (d *tcpDialer) NewDial() DialFunc {
|
||||
go d.tcpAddrsClean()
|
||||
|
||||
return func(addr string) (net.Conn, error) {
|
||||
tcpAddr, err := d.getTCPAddr(addr)
|
||||
addrs, idx, err := d.getTCPAddrs(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -90,18 +92,37 @@ func (d *tcpDialer) NewDial() DialFunc {
|
||||
if d.DualStack {
|
||||
network = "tcp"
|
||||
}
|
||||
ch := make(chan dialResult, 1)
|
||||
go func() {
|
||||
var dr dialResult
|
||||
dr.conn, dr.err = net.DialTCP(network, nil, tcpAddr)
|
||||
ch <- dr
|
||||
}()
|
||||
select {
|
||||
case dr := <-ch:
|
||||
return dr.conn, dr.err
|
||||
case <-time.After(DefaultDialTimeout):
|
||||
return nil, ErrDialTimeout
|
||||
|
||||
var conn net.Conn
|
||||
startTime := time.Now()
|
||||
n := uint32(len(addrs))
|
||||
for n > 0 {
|
||||
conn, err = tryDial(network, &addrs[idx%n])
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
if time.Since(startTime) > DefaultDialTimeout {
|
||||
return nil, ErrDialTimeout
|
||||
}
|
||||
idx++
|
||||
n--
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func tryDial(network string, addr *net.TCPAddr) (net.Conn, error) {
|
||||
ch := make(chan dialResult, 1)
|
||||
go func() {
|
||||
var dr dialResult
|
||||
dr.conn, dr.err = net.DialTCP(network, nil, addr)
|
||||
ch <- dr
|
||||
}()
|
||||
select {
|
||||
case dr := <-ch:
|
||||
return dr.conn, dr.err
|
||||
case <-time.After(DefaultDialTimeout):
|
||||
return nil, ErrDialTimeout
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,7 +164,7 @@ func (d *tcpDialer) tcpAddrsClean() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) {
|
||||
func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) {
|
||||
d.tcpAddrsLock.Lock()
|
||||
e := d.tcpAddrsMap[addr]
|
||||
if e != nil && !e.pending && time.Since(e.resolveTime) > tcpAddrsCacheDuration {
|
||||
@@ -153,7 +174,7 @@ func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) {
|
||||
d.tcpAddrsLock.Unlock()
|
||||
|
||||
if e == nil {
|
||||
tcpAddrs, err := resolveTCPAddrs(addr, d.DualStack)
|
||||
addrs, err := resolveTCPAddrs(addr, d.DualStack)
|
||||
if err != nil {
|
||||
d.tcpAddrsLock.Lock()
|
||||
e = d.tcpAddrsMap[addr]
|
||||
@@ -161,11 +182,11 @@ func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) {
|
||||
e.pending = false
|
||||
}
|
||||
d.tcpAddrsLock.Unlock()
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
e = &tcpAddrEntry{
|
||||
addrs: tcpAddrs,
|
||||
addrs: addrs,
|
||||
resolveTime: time.Now(),
|
||||
}
|
||||
|
||||
@@ -174,13 +195,11 @@ func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) {
|
||||
d.tcpAddrsLock.Unlock()
|
||||
}
|
||||
|
||||
tcpAddr := &e.addrs[0]
|
||||
n := len(e.addrs)
|
||||
if n > 1 {
|
||||
n := atomic.AddUint32(&e.addrsIdx, 1)
|
||||
tcpAddr = &e.addrs[n%uint32(n)]
|
||||
idx := uint32(0)
|
||||
if len(e.addrs) > 0 {
|
||||
idx = atomic.AddUint32(&e.addrsIdx, 1)
|
||||
}
|
||||
return tcpAddr, nil
|
||||
return e.addrs, idx, nil
|
||||
}
|
||||
|
||||
func resolveTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, error) {
|
||||
|
||||
Reference in New Issue
Block a user