mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
Add ability to set timeout for handshake (#631)
* Fixed issue with handshake timeout
This commit is contained in:
committed by
Erik Dubbelboer
parent
2edabf3b76
commit
ce02b85a9c
@@ -1537,7 +1537,7 @@ func (c *HostClient) dialHostHard() (conn net.Conn, err error) {
|
||||
for n > 0 {
|
||||
addr := c.nextAddr()
|
||||
tlsConfig := c.cachedTLSConfig(addr)
|
||||
conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig)
|
||||
conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
@@ -1568,7 +1568,43 @@ func (c *HostClient) cachedTLSConfig(addr string) *tls.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out")
|
||||
|
||||
var timeoutErrorChPool sync.Pool
|
||||
|
||||
func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
|
||||
tc := AcquireTimer(timeout)
|
||||
defer ReleaseTimer(tc)
|
||||
|
||||
var ch chan error
|
||||
chv := timeoutErrorChPool.Get()
|
||||
if chv == nil {
|
||||
chv = make(chan error)
|
||||
}
|
||||
ch = chv.(chan error)
|
||||
defer timeoutErrorChPool.Put(chv)
|
||||
|
||||
conn := tls.Client(rawConn, tlsConfig)
|
||||
|
||||
go func() {
|
||||
ch <- conn.Handshake()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-tc.C:
|
||||
rawConn.Close()
|
||||
<-ch
|
||||
return nil, ErrTLSHandshakeTimeout
|
||||
case err := <-ch:
|
||||
if err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
|
||||
if dial == nil {
|
||||
if dialDualStack {
|
||||
dial = DialDualStack
|
||||
@@ -1585,7 +1621,10 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
|
||||
panic("BUG: DialFunc returned (nil, nil)")
|
||||
}
|
||||
if isTLS {
|
||||
conn = tls.Client(conn, tlsConfig)
|
||||
if timeout == 0 {
|
||||
return tls.Client(conn, tlsConfig), nil
|
||||
}
|
||||
return tlsClientHandshake(conn, tlsConfig, timeout)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
@@ -1992,7 +2031,7 @@ func (c *pipelineConnClient) init() {
|
||||
|
||||
func (c *pipelineConnClient) worker() error {
|
||||
tlsConfig := c.cachedTLSConfig()
|
||||
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig)
|
||||
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1773,3 +1773,43 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientTLSHandshakeTimeout(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode")
|
||||
}
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := listener.Addr().String()
|
||||
defer listener.Close()
|
||||
|
||||
complete := make(chan bool)
|
||||
defer close(complete)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
<-complete
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
client := Client{
|
||||
WriteTimeout: 1 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
_, _, err = client.Get(nil, "https://"+addr)
|
||||
if err == nil {
|
||||
t.Fatal("tlsClientHandshake completed successfully")
|
||||
}
|
||||
|
||||
if err != ErrTLSHandshakeTimeout {
|
||||
t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user