diff --git a/tcpdialer.go b/tcpdialer.go index d62bfe7..5c7531e 100644 --- a/tcpdialer.go +++ b/tcpdialer.go @@ -280,7 +280,8 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne go d.tcpAddrsClean() }) - addrs, idx, err := d.getTCPAddrs(addr, dualStack) + deadline := time.Now().Add(timeout) + addrs, idx, err := d.getTCPAddrs(addr, dualStack, deadline) if err != nil { return nil, err } @@ -291,7 +292,6 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne var conn net.Conn n := uint32(len(addrs)) - deadline := time.Now().Add(timeout) for n > 0 { conn, err = d.tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh) if err == nil { @@ -379,7 +379,7 @@ func (d *TCPDialer) tcpAddrsClean() { } } -func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uint32, error) { +func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool, deadline time.Time) ([]net.TCPAddr, uint32, error) { item, exist := d.tcpAddrsMap.Load(addr) e, ok := item.(*tcpAddrEntry) if exist && ok && e != nil && time.Since(e.resolveTime) > d.DNSCacheDuration { @@ -390,7 +390,7 @@ func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uin } if e == nil { - addrs, err := resolveTCPAddrs(addr, dualStack, d.Resolver) + addrs, err := resolveTCPAddrs(addr, dualStack, d.Resolver, deadline) if err != nil { item, exist := d.tcpAddrsMap.Load(addr) e, ok = item.(*tcpAddrEntry) @@ -412,7 +412,7 @@ func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uin return e.addrs, idx, nil } -func resolveTCPAddrs(addr string, dualStack bool, resolver Resolver) ([]net.TCPAddr, error) { +func resolveTCPAddrs(addr string, dualStack bool, resolver Resolver, deadline time.Time) ([]net.TCPAddr, error) { host, portS, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -426,7 +426,8 @@ func resolveTCPAddrs(addr string, dualStack bool, resolver Resolver) ([]net.TCPA resolver = net.DefaultResolver } - ctx := context.Background() + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() ipaddrs, err := resolver.LookupIPAddr(ctx, host) if err != nil { return nil, err