diff --git a/client.go b/client.go index f63cff4..9b9ae8a 100644 --- a/client.go +++ b/client.go @@ -76,8 +76,12 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er option(opts) } + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } if opts.Chain == nil { - return net.DialTimeout("tcp", addr, opts.Timeout) + return net.DialTimeout("tcp", addr, timeout) } return opts.Chain.Dial(addr) } @@ -103,7 +107,13 @@ func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er for _, option := range options { option(opts) } - return net.DialTimeout("udp", addr, opts.Timeout) + + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + + return net.DialTimeout("udp", addr, timeout) } func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { diff --git a/forward.go b/forward.go index a185bf6..d5bc64e 100644 --- a/forward.go +++ b/forward.go @@ -662,9 +662,9 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) { go ln.listenLoop() - if err = <-ln.errChan; err != nil { - ln.Close() - } + // if err = <-ln.errChan; err != nil { + // ln.Close() + // } return ln, err } @@ -680,19 +680,22 @@ func (l *tcpRemoteForwardListener) isChainValid() bool { func (l *tcpRemoteForwardListener) listenLoop() { var tempDelay time.Duration - var once sync.Once + // var once sync.Once for { conn, err := l.accept() - once.Do(func() { - l.errChan <- err - close(l.errChan) - }) + // once.Do(func() { + // l.errChan <- err + // log.Log("once.Do error:", err) + // close(l.errChan) + // }) select { case <-l.closed: - conn.Close() + if conn != nil { + conn.Close() + } return default: } @@ -706,7 +709,7 @@ func (l *tcpRemoteForwardListener) listenLoop() { if max := 6 * time.Second; tempDelay > max { tempDelay = max } - log.Logf("[rtcp] Accept error: %v; retrying in %v", err, tempDelay) + log.Logf("[rtcp] accept error: %v; retrying in %v", err, tempDelay) time.Sleep(tempDelay) continue } diff --git a/gost.go b/gost.go index f3e9e8b..ab222b3 100644 --- a/gost.go +++ b/gost.go @@ -51,6 +51,8 @@ var ( KeepAliveTime = 180 * time.Second // DialTimeout is the timeout of dial. DialTimeout = 5 * time.Second + // HandshakeTimeout is the timeout of handshake. + HandshakeTimeout = 5 * time.Second // ReadTimeout is the timeout for reading. ReadTimeout = 5 * time.Second // WriteTimeout is the timeout for writing. diff --git a/http2.go b/http2.go index 0f5c88b..337afeb 100644 --- a/http2.go +++ b/http2.go @@ -126,6 +126,10 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn, } conn.Close() + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } transport := http2.Transport{ TLSClientConfig: tr.tlsConfig, DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { @@ -133,12 +137,12 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn, if err != nil { return nil, err } - return wrapTLSClient(conn, cfg, opts.Timeout) + return wrapTLSClient(conn, cfg, timeout) }, } client = &http.Client{ Transport: &transport, - Timeout: opts.Timeout, + Timeout: timeout, } tr.clients[addr] = client } @@ -190,6 +194,11 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err tr.clientMutex.Lock() client, ok := tr.clients[addr] if !ok { + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + transport := http2.Transport{ TLSClientConfig: tr.tlsConfig, DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { @@ -200,12 +209,12 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err if tr.tlsConfig == nil { return conn, nil } - return wrapTLSClient(conn, cfg, opts.Timeout) + return wrapTLSClient(conn, cfg, timeout) }, } client = &http.Client{ Transport: &transport, - Timeout: opts.Timeout, + Timeout: timeout, } tr.clients[addr] = client } diff --git a/kcp.go b/kcp.go index bcf72c2..0a2d18c 100644 --- a/kcp.go +++ b/kcp.go @@ -114,9 +114,9 @@ func KCPTransporter(config *KCPConfig) Transporter { } func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - uaddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return + opts := &DialOptions{} + for _, option := range options { + option(opts) } tr.sessionMutex.Lock() @@ -124,7 +124,11 @@ func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Con session, ok := tr.sessions[addr] if !ok { - conn, err = net.DialUDP("udp", nil, uaddr) + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + conn, err = net.DialTimeout("udp", addr, timeout) if err != nil { return } @@ -146,6 +150,13 @@ func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + session, ok := tr.sessions[opts.Addr] if !ok || session.session == nil { s, err := tr.initSession(opts.Addr, conn, config) diff --git a/quic.go b/quic.go index 0a362f2..1e7cf4d 100644 --- a/quic.go +++ b/quic.go @@ -54,6 +54,11 @@ func QUICTransporter(config *QUICConfig) Transporter { } func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() @@ -92,6 +97,13 @@ func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption) tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + session, ok := tr.sessions[opts.Addr] if session != nil && session.conn != conn { conn.Close() diff --git a/ssh.go b/ssh.go index 113d30d..c2dc5a5 100644 --- a/ssh.go +++ b/ssh.go @@ -126,10 +126,15 @@ func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + session, ok := tr.sessions[addr] if !ok || session.Closed() { if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + conn, err = net.DialTimeout("tcp", addr, timeout) } else { conn, err = opts.Chain.Dial(addr) } @@ -152,8 +157,13 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp option(opts) } + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + config := ssh.ClientConfig{ - Timeout: opts.Timeout, + Timeout: timeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } if opts.User != nil { @@ -222,10 +232,15 @@ func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn n tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + session, ok := tr.sessions[addr] if !ok || session.Closed() { if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + conn, err = net.DialTimeout("tcp", addr, timeout) } else { conn, err = opts.Chain.Dial(addr) } @@ -248,8 +263,13 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt option(opts) } + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + config := ssh.ClientConfig{ - Timeout: opts.Timeout, + Timeout: timeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } // TODO: support pubkey auth. @@ -318,7 +338,7 @@ func (s *sshSession) Ping(interval, timeout time.Duration, retries int) { return } if timeout <= 0 { - timeout = 10 * time.Second + timeout = PingTimeout } if retries == 0 { diff --git a/tls.go b/tls.go index 9ad4594..f8e45cc 100644 --- a/tls.go +++ b/tls.go @@ -30,7 +30,13 @@ func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( if opts.TLSConfig == nil { opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} } - return wrapTLSClient(conn, opts.TLSConfig, opts.Timeout) + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + + return wrapTLSClient(conn, opts.TLSConfig, timeout) } type mtlsTransporter struct { @@ -52,6 +58,11 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co option(opts) } + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() @@ -63,7 +74,7 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co } if !ok { if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + conn, err = net.DialTimeout("tcp", addr, timeout) } else { conn, err = opts.Chain.Dial(addr) } @@ -82,9 +93,17 @@ func (tr *mtlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) option(opts) } + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + session, ok := tr.sessions[opts.Addr] if !ok || session.session == nil { s, err := tr.initSession(opts.Addr, conn, opts) @@ -265,7 +284,7 @@ func wrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration) } if timeout <= 0 { - timeout = 10 * time.Second // default timeout + timeout = HandshakeTimeout // default timeout } tlsConn.SetDeadline(time.Now().Add(timeout)) diff --git a/ws.go b/ws.go index fdc97b4..24775f2 100644 --- a/ws.go +++ b/ws.go @@ -37,11 +37,17 @@ func websocketClientConn(url string, conn net.Conn, tlsConfig *tls.Config, optio if options == nil { options = &WSOptions{} } + + timeout := options.HandshakeTimeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + dialer := websocket.Dialer{ ReadBufferSize: options.ReadBufferSize, WriteBufferSize: options.WriteBufferSize, TLSClientConfig: tlsConfig, - HandshakeTimeout: options.HandshakeTimeout, + HandshakeTimeout: timeout, EnableCompression: options.EnableCompression, NetDial: func(net, addr string) (net.Conn, error) { return conn, nil @@ -154,6 +160,11 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con option(opts) } + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() @@ -165,7 +176,7 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con } if !ok { if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + conn, err = net.DialTimeout("tcp", addr, timeout) } else { conn, err = opts.Chain.Dial(addr) } @@ -184,9 +195,17 @@ func (tr *mwsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( option(opts) } + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + session, ok := tr.sessions[opts.Addr] if !ok || session.session == nil { s, err := tr.initSession(opts.Addr, conn, opts) @@ -283,6 +302,11 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co option(opts) } + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() @@ -294,7 +318,7 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co } if !ok { if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + conn, err = net.DialTimeout("tcp", addr, timeout) } else { conn, err = opts.Chain.Dial(addr) } @@ -313,9 +337,17 @@ func (tr *mwssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) option(opts) } + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + session, ok := tr.sessions[opts.Addr] if !ok || session.session == nil { s, err := tr.initSession(opts.Addr, conn, opts)