Add ability to set timeout for handshake (#631)

* Fixed issue with handshake timeout
This commit is contained in:
Mike Faraponov
2019-08-18 15:34:47 +03:00
committed by Erik Dubbelboer
parent 2edabf3b76
commit ce02b85a9c
2 changed files with 83 additions and 4 deletions
+43 -4
View File
@@ -1537,7 +1537,7 @@ func (c *HostClient) dialHostHard() (conn net.Conn, err error) {
for n > 0 {
addr := c.nextAddr()
tlsConfig := c.cachedTLSConfig(addr)
conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig)
conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
if err == nil {
return conn, nil
}
@@ -1568,7 +1568,43 @@ func (c *HostClient) cachedTLSConfig(addr string) *tls.Config {
return cfg
}
func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config) (net.Conn, error) {
var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out")
var timeoutErrorChPool sync.Pool
func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
tc := AcquireTimer(timeout)
defer ReleaseTimer(tc)
var ch chan error
chv := timeoutErrorChPool.Get()
if chv == nil {
chv = make(chan error)
}
ch = chv.(chan error)
defer timeoutErrorChPool.Put(chv)
conn := tls.Client(rawConn, tlsConfig)
go func() {
ch <- conn.Handshake()
}()
select {
case <-tc.C:
rawConn.Close()
<-ch
return nil, ErrTLSHandshakeTimeout
case err := <-ch:
if err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
}
func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
if dial == nil {
if dialDualStack {
dial = DialDualStack
@@ -1585,7 +1621,10 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
panic("BUG: DialFunc returned (nil, nil)")
}
if isTLS {
conn = tls.Client(conn, tlsConfig)
if timeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, timeout)
}
return conn, nil
}
@@ -1992,7 +2031,7 @@ func (c *pipelineConnClient) init() {
func (c *pipelineConnClient) worker() error {
tlsConfig := c.cachedTLSConfig()
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig)
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
if err != nil {
return err
}
+40
View File
@@ -1773,3 +1773,43 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch
t: t,
}
}
func TestClientTLSHandshakeTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode")
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addr := listener.Addr().String()
defer listener.Close()
complete := make(chan bool)
defer close(complete)
go func() {
conn, err := listener.Accept()
if err != nil {
t.Error(err)
return
}
<-complete
conn.Close()
}()
client := Client{
WriteTimeout: 1 * time.Second,
ReadTimeout: 1 * time.Second,
}
_, _, err = client.Get(nil, "https://"+addr)
if err == nil {
t.Fatal("tlsClientHandshake completed successfully")
}
if err != ErrTLSHandshakeTimeout {
t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
}
}