diff --git a/chain.go b/chain.go index d9d6682..934e5db 100644 --- a/chain.go +++ b/chain.go @@ -106,7 +106,7 @@ func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error opt(options) } - retries := 1 + retries := 10 //maximum retry if c != nil && c.Retries > 0 { retries = c.Retries } @@ -119,6 +119,9 @@ func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error if err == nil { break } + if err == ErrNoneAvailable { + break + } } return } @@ -128,6 +131,7 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e options = &ChainOptions{} } route, err := c.selectRouteFor(addr) + //log.Log("Connecting", addr, "using", route.Nodes()[0].Addr, "failcount", route.Nodes()[0].failCount) if err != nil { return nil, err } @@ -138,17 +142,12 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e return net.DialTimeout("tcp", addr, options.Timeout) } - conn, err := route.getConn() + conn, err := route.getConn(addr) if err != nil { return nil, err } - cc, err := route.LastNode().Client.Connect(conn, addr) - if err != nil { - conn.Close() - return nil, err - } - return cc, nil + return conn, nil } func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { @@ -194,7 +193,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { if err != nil { continue } - conn, err = route.getConn() + conn, err = route.getConn("") if err != nil { log.Log(err) continue @@ -206,7 +205,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { } // getConn obtains a connection to the last node of the chain. -func (c *Chain) getConn() (conn net.Conn, err error) { +func (c *Chain) getConn(addr string) (conn net.Conn, err error) { if c.IsEmpty() { err = ErrEmptyChain return @@ -225,7 +224,10 @@ func (c *Chain) getConn() (conn net.Conn, err error) { node.MarkDead() return } - node.ResetDead() + + if len(nodes) > 1 { + node.ResetDead() // don't reset the last node as we are going to check if it will connect successfully. + } preNode := node for _, node := range nodes[1:] { @@ -242,13 +244,27 @@ func (c *Chain) getConn() (conn net.Conn, err error) { node.MarkDead() return } - node.ResetDead() - + if len(nodes) > 1 { + node.ResetDead() + } cn = cc preNode = node } conn = cn + if addr != "" { + var cc net.Conn + cc, err = node.Client.Connect(conn, addr) + if err != nil { + if _, ok := err.(*net.OpError); ok { + node.MarkDead() + } + conn.Close() + return + } + conn = cc + } + node.ResetDead() return } diff --git a/socks.go b/socks.go index b1bb24a..62cd162 100644 --- a/socks.go +++ b/socks.go @@ -215,9 +215,13 @@ func (c *socks5Connector) Connect(conn net.Conn, addr string) (net.Conn, error) ) cc := gosocks5.ClientConn(conn, selector) + + conn.SetDeadline(time.Now().Add(time.Second * 5)) if err := cc.Handleshake(); err != nil { + conn.SetDeadline(time.Time{}) return nil, err } + conn.SetDeadline(time.Time{}) conn = cc host, port, err := net.SplitHostPort(addr)