add timeout for TLS handshaking (#316)

This commit is contained in:
ginuerzh 2018-11-03 11:24:25 +08:00
parent bebb2f824a
commit 31257903a3
2 changed files with 13 additions and 5 deletions

View File

@ -114,7 +114,7 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn,
if err != nil { if err != nil {
return nil, err return nil, err
} }
return wrapTLSClient(conn, cfg) return wrapTLSClient(conn, cfg, opts.Timeout)
}, },
} }
client = &http.Client{ client = &http.Client{
@ -182,7 +182,7 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err
if tr.tlsConfig == nil { if tr.tlsConfig == nil {
return conn, nil return conn, nil
} }
return wrapTLSClient(conn, cfg) return wrapTLSClient(conn, cfg, opts.Timeout)
}, },
} }
client = &http.Client{ client = &http.Client{

14
tls.go
View File

@ -30,7 +30,7 @@ func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
if opts.TLSConfig == nil { if opts.TLSConfig == nil {
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
} }
return wrapTLSClient(conn, opts.TLSConfig) return wrapTLSClient(conn, opts.TLSConfig, opts.Timeout)
} }
type mtlsTransporter struct { type mtlsTransporter struct {
@ -113,7 +113,7 @@ func (tr *mtlsTransporter) initSession(addr string, conn net.Conn, opts *Handsha
if opts.TLSConfig == nil { if opts.TLSConfig == nil {
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
} }
conn, err := wrapTLSClient(conn, opts.TLSConfig) conn, err := wrapTLSClient(conn, opts.TLSConfig, opts.Timeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -248,7 +248,7 @@ func (l *mtlsListener) Close() error {
// //
// This code is taken from consul: // This code is taken from consul:
// https://github.com/hashicorp/consul/blob/master/tlsutil/config.go // https://github.com/hashicorp/consul/blob/master/tlsutil/config.go
func wrapTLSClient(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { func wrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
var err error var err error
var tlsConn *tls.Conn var tlsConn *tls.Conn
@ -264,6 +264,12 @@ func wrapTLSClient(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
return tlsConn, nil return tlsConn, nil
} }
if timeout <= 0 {
timeout = 10 * time.Second // default timeout
}
tlsConn.SetDeadline(time.Now().Add(timeout))
// Otherwise perform handshake, but don't verify the domain // Otherwise perform handshake, but don't verify the domain
// //
// The following is lightly-modified from the doFullHandshake // The following is lightly-modified from the doFullHandshake
@ -273,6 +279,8 @@ func wrapTLSClient(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
return nil, err return nil, err
} }
tlsConn.SetDeadline(time.Time{}) // clear timeout
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs, Roots: tlsConfig.RootCAs,
CurrentTime: time.Now(), CurrentTime: time.Now(),