From e42d27c368d6c6c2aafa286ead0b0d5d0dde7552 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Mon, 6 Nov 2017 17:59:33 +0800 Subject: [PATCH 1/9] fix #165 #179 --- ss.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ss.go b/ss.go index 3cc702d..e180cb6 100644 --- a/ss.go +++ b/ss.go @@ -124,11 +124,15 @@ func (h *shadowHandler) Handle(conn net.Conn) { log.Logf("[ss] %s - %s", conn.RemoteAddr(), conn.LocalAddr()) + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) addr, err := h.getRequest(conn) if err != nil { log.Logf("[ss] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) return } + // clear timer + conn.SetReadDeadline(time.Time{}) + log.Logf("[ss] %s -> %s", conn.RemoteAddr(), addr) if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { @@ -165,19 +169,16 @@ const ( ) // This function is copied from shadowsocks library with some modification. -func (h *shadowHandler) getRequest(conn net.Conn) (host string, err error) { +func (h *shadowHandler) getRequest(r io.Reader) (host string, err error) { // buf size should at least have the same size with the largest possible // request size (when addrType is 3, domain name has at most 256 bytes) // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) buf := make([]byte, smallBufferSize) // read till we get possible domain length field - conn.SetReadDeadline(time.Now().Add(ReadTimeout)) - if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil { + if _, err = io.ReadFull(r, buf[:idType+1]); err != nil { return } - // clear timer - conn.SetReadDeadline(time.Time{}) var reqStart, reqEnd int addrType := buf[idType] @@ -187,16 +188,16 @@ func (h *shadowHandler) getRequest(conn net.Conn) (host string, err error) { case typeIPv6: reqStart, reqEnd = idIP0, idIP0+lenIPv6 case typeDm: - if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil { + if _, err = io.ReadFull(r, buf[idType+1:idDmLen+1]); err != nil { return } - reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase) + reqStart, reqEnd = idDm0, idDm0+int(buf[idDmLen])+lenDmBase default: err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask) return } - if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil { + if _, err = io.ReadFull(r, buf[reqStart:reqEnd]); err != nil { return } @@ -209,7 +210,7 @@ func (h *shadowHandler) getRequest(conn net.Conn) (host string, err error) { case typeIPv6: host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String() case typeDm: - host = string(buf[idDm0 : idDm0+buf[idDmLen]]) + host = string(buf[idDm0 : idDm0+int(buf[idDmLen])]) } // parse port port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd]) From 8cb22691599984909d92021f01fd92407680d067 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Tue, 7 Nov 2017 17:55:18 +0800 Subject: [PATCH 2/9] add node ID --- chain.go | 84 +++++++++++++++++++++--------------------------- cmd/gost/main.go | 21 ++++++++++-- gost.go | 1 + http.go | 7 ++++ node.go | 16 ++++++--- sni.go | 16 ++++++--- 6 files changed, 87 insertions(+), 58 deletions(-) diff --git a/chain.go b/chain.go index 9a409d9..d328f32 100644 --- a/chain.go +++ b/chain.go @@ -1,9 +1,10 @@ package gost import ( + "bytes" "errors" + "fmt" "net" - "strings" "github.com/go-log/log" ) @@ -124,29 +125,6 @@ func (c *Chain) Conn() (conn net.Conn, err error) { return } -func (c *Chain) selectRoute() (route *Chain, err error) { - route = NewChain() - for _, group := range c.nodeGroups { - selector := group.Selector - if selector == nil { - selector = &defaultSelector{} - } - // select node from node group - node, err := selector.Select(group.Nodes(), group.Options...) - if err != nil { - return nil, err - } - if node.Client.Transporter.Multiplex() { - node.DialOptions = append(node.DialOptions, - ChainDialOption(route), - ) - route = NewChain() // cutoff the chain for multiplex - } - route.AddNode(node) - } - return -} - func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { if route.IsEmpty() { err = ErrEmptyChain @@ -155,11 +133,7 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { nodes := route.Nodes() node := nodes[0] - addr, err := selectIP(&node) - if err != nil { - return - } - cn, err := node.Client.Dial(addr, node.DialOptions...) + cn, err := node.Client.Dial(node.Addr, node.DialOptions...) if err != nil { return } @@ -171,13 +145,8 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { preNode := node for _, node := range nodes[1:] { - addr, err = selectIP(&node) - if err != nil { - return - } - var cc net.Conn - cc, err = preNode.Client.Connect(cn, addr) + cc, err = preNode.Client.Connect(cn, node.Addr) if err != nil { cn.Close() return @@ -195,8 +164,37 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { return } +func (c *Chain) selectRoute() (route *Chain, err error) { + buf := bytes.Buffer{} + route = NewChain() + for _, group := range c.nodeGroups { + selector := group.Selector + if selector == nil { + selector = &defaultSelector{} + } + // select node from node group + node, err := selector.Select(group.Nodes(), group.Options...) + if err != nil { + return nil, err + } + if _, err := selectIP(&node); err != nil { + return nil, err + } + buf.WriteString(fmt.Sprintf("%d@%s -> ", node.ID, node.Addr)) + + if node.Client.Transporter.Multiplex() { + node.DialOptions = append(node.DialOptions, + ChainDialOption(route), + ) + route = NewChain() // cutoff the chain for multiplex + } + route.AddNode(node) + } + log.Log("select route:", buf.String()) + return +} + func selectIP(node *Node) (string, error) { - addr := node.Addr s := node.IPSelector if s == nil { s = &RandomIPSelector{} @@ -207,17 +205,9 @@ func selectIP(node *Node) (string, error) { return "", err } if ip != "" { - if !strings.Contains(ip, ":") { - _, sport, err := net.SplitHostPort(addr) - if err != nil { - return "", err - } - ip = ip + ":" + sport - } - addr = ip // override the original address - node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr)) + node.Addr = ip + node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(node.Addr)) } - log.Log("select IP:", node.Addr, node.IPs, addr) - return addr, nil + return node.Addr, nil } diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 5ebbe66..0ecc3d0 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -94,6 +94,9 @@ func initChain() (*gost.Chain, error) { return nil, err } + id := 1 // start from 1 + + node.ID = id ngroup := gost.NewNodeGroup(node) // parse node peers if exists @@ -110,6 +113,8 @@ func initChain() (*gost.Chain, error) { if err != nil { return nil, err } + id++ + node.ID = id ngroup.AddNode(node) } @@ -126,6 +131,15 @@ func parseChainNode(ns string) (node gost.Node, err error) { } node.IPs = parseIP(node.Values.Get("ip")) + for i, ip := range node.IPs { + if !strings.Contains(ip, ":") { + _, sport, _ := net.SplitHostPort(node.Addr) + if sport == "" { + sport = "8080" // default port + } + node.IPs[i] = ip + ":" + sport + } + } node.IPSelector = &gost.RoundRobinIPSelector{} users, err := parseUsers(node.Values.Get("secrets")) @@ -592,11 +606,12 @@ func loadPeerConfig(peer string) (config peerConfig, err error) { func parseStrategy(s string) gost.Strategy { switch s { - case "round": - return &gost.RoundStrategy{} case "random": + return &gost.RandomStrategy{} + case "round": fallthrough default: - return &gost.RandomStrategy{} + return &gost.RoundStrategy{} + } } diff --git a/gost.go b/gost.go index 39b818f..2a8baeb 100644 --- a/gost.go +++ b/gost.go @@ -56,6 +56,7 @@ func SetLogger(logger log.Logger) { log.DefaultLogger = logger } +// GenCertificate generates a random TLS certificate func GenCertificate() (cert tls.Certificate, err error) { rawCert, rawKey, err := generateKeyPair() if err != nil { diff --git a/http.go b/http.go index 0612f50..b88942d 100644 --- a/http.go +++ b/http.go @@ -93,6 +93,13 @@ func (h *httpHandler) Handle(conn net.Conn) { return } + h.handleRequest(conn, req) +} + +func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { + if req == nil { + return + } if Debug { dump, _ := httputil.DumpRequest(req, false) log.Logf("[http] %s -> %s\n%s", conn.RemoteAddr(), req.Host, string(dump)) diff --git a/node.go b/node.go index 4a2ef1a..a0fa051 100644 --- a/node.go +++ b/node.go @@ -3,10 +3,12 @@ package gost import ( "net/url" "strings" + "sync" ) // Node is a proxy node, mainly used to construct a proxy chain. type Node struct { + ID int Addr string IPs []string Protocol string @@ -81,15 +83,21 @@ func ParseNode(s string) (node Node, err error) { // NodeGroup is a group of nodes. type NodeGroup struct { - nodes []Node - Options []SelectOption - Selector NodeSelector + nodes []Node + Options []SelectOption + Selector NodeSelector + mutex sync.Mutex + mFails map[string]int // node -> fail count + MaxFails int + FailTimeout int + Retries int } // NewNodeGroup creates a node group func NewNodeGroup(nodes ...Node) *NodeGroup { return &NodeGroup{ - nodes: nodes, + nodes: nodes, + mFails: make(map[string]int), } } diff --git a/sni.go b/sni.go index d4ec608..2cb9800 100644 --- a/sni.go +++ b/sni.go @@ -11,6 +11,7 @@ import ( "hash/crc32" "io" "net" + "net/http" "strings" "sync" @@ -53,15 +54,22 @@ func (h *sniHandler) Handle(conn net.Conn) { return } conn = &bufferdConn{br: br, Conn: conn} + defer conn.Close() if hdr[0] != dissector.Handshake { - // We assume that it is HTTP request - HTTPHandler(h.options...).Handle(conn) + // We assume it is an HTTP request + req, err := http.ReadRequest(bufio.NewReader(conn)) + if !req.URL.IsAbs() { + req.URL.Scheme = "http" // make sure that the URL is absolute + } + if err != nil { + log.Logf("[sni] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + HTTPHandler(h.options...).(*httpHandler).handleRequest(conn, req) return } - defer conn.Close() - b, host, err := readClientHelloRecord(conn, "", false) if err != nil { log.Log("[sni]", err) From 11b727d72b478f16c5139d0e611a4d18e4ec4fc9 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Thu, 9 Nov 2017 13:25:35 +0800 Subject: [PATCH 3/9] fix chain route --- chain.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/chain.go b/chain.go index d328f32..8d304ba 100644 --- a/chain.go +++ b/chain.go @@ -16,6 +16,7 @@ var ( // Chain is a proxy chain that holds a list of proxy nodes. type Chain struct { + isRoute bool nodeGroups []*NodeGroup } @@ -28,6 +29,12 @@ func NewChain(nodes ...Node) *Chain { return chain } +func newRoute(nodes ...Node) *Chain { + chain := NewChain(nodes...) + chain.isRoute = true + return chain +} + // Nodes returns the proxy nodes that the chain holds. // If a node is a node group, the first node in the group will be returned. func (c *Chain) Nodes() (nodes []Node) { @@ -165,8 +172,12 @@ func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { } func (c *Chain) selectRoute() (route *Chain, err error) { + if c.isRoute { + return c, nil + } + buf := bytes.Buffer{} - route = NewChain() + route = newRoute() for _, group := range c.nodeGroups { selector := group.Selector if selector == nil { @@ -186,7 +197,7 @@ func (c *Chain) selectRoute() (route *Chain, err error) { node.DialOptions = append(node.DialOptions, ChainDialOption(route), ) - route = NewChain() // cutoff the chain for multiplex + route = newRoute() // cutoff the chain for multiplex } route.AddNode(node) } From 637e6423412335b187c39603c47d2a14eeb67e8c Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Thu, 9 Nov 2017 14:05:53 +0800 Subject: [PATCH 4/9] fix ssh deadline --- ssh.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ssh.go b/ssh.go index 91c1bf3..58c0cbc 100644 --- a/ssh.go +++ b/ssh.go @@ -830,13 +830,13 @@ func (c *sshConn) RemoteAddr() net.Addr { } func (c *sshConn) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} + return nil } func (c *sshConn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} + return nil } func (c *sshConn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} + return nil } From da30584df18594098fb2508801d7282fb0223fd6 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Fri, 10 Nov 2017 11:00:35 +0800 Subject: [PATCH 5/9] add host for client handshake --- chain.go | 14 +++++++------- client.go | 8 ++++++++ cmd/gost/main.go | 3 ++- node.go | 4 +++- obfs.go | 5 +++-- ws.go | 8 ++++---- 6 files changed, 27 insertions(+), 15 deletions(-) diff --git a/chain.go b/chain.go index 8d304ba..1d2a3be 100644 --- a/chain.go +++ b/chain.go @@ -100,7 +100,7 @@ func (c *Chain) IsEmpty() bool { // If the chain is empty, it will use the net.Dial directly. func (c *Chain) Dial(addr string) (net.Conn, error) { if c.IsEmpty() { - return net.Dial("tcp", addr) + return net.DialTimeout("tcp", addr, DialTimeout) } route, err := c.selectRoute() @@ -108,7 +108,7 @@ func (c *Chain) Dial(addr string) (net.Conn, error) { return nil, err } - conn, err := c.getConn(route) + conn, err := route.getConn() if err != nil { return nil, err } @@ -128,16 +128,16 @@ func (c *Chain) Conn() (conn net.Conn, err error) { if err != nil { return nil, err } - conn, err = c.getConn(route) + conn, err = route.getConn() return } -func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { - if route.IsEmpty() { +func (c *Chain) getConn() (conn net.Conn, err error) { + if c.IsEmpty() { err = ErrEmptyChain return } - nodes := route.Nodes() + nodes := c.Nodes() node := nodes[0] cn, err := node.Client.Dial(node.Addr, node.DialOptions...) @@ -206,7 +206,7 @@ func (c *Chain) selectRoute() (route *Chain, err error) { } func selectIP(node *Node) (string, error) { - s := node.IPSelector + s := node.Selector if s == nil { s = &RandomIPSelector{} } diff --git a/client.go b/client.go index e4a05f9..8a74998 100644 --- a/client.go +++ b/client.go @@ -116,6 +116,7 @@ func ChainDialOption(chain *Chain) DialOption { // HandshakeOptions describes the options for handshake. type HandshakeOptions struct { Addr string + Host string User *url.Userinfo Timeout time.Duration Interval time.Duration @@ -136,6 +137,13 @@ func AddrHandshakeOption(addr string) HandshakeOption { } } +// HostHandshakeOption specifies the hostname +func HostHandshakeOption(host string) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Host = host + } +} + // UserHandshakeOption specifies the user used by Transporter.Handshake func UserHandshakeOption(user *url.Userinfo) HandshakeOption { return func(opts *HandshakeOptions) { diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 0ecc3d0..8e69fdd 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -140,7 +140,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { node.IPs[i] = ip + ":" + sport } } - node.IPSelector = &gost.RoundRobinIPSelector{} + node.Selector = &gost.RoundRobinIPSelector{} users, err := parseUsers(node.Values.Get("secrets")) if err != nil { @@ -265,6 +265,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { retry, _ := strconv.Atoi(node.Values.Get("retry")) node.HandshakeOptions = append(node.HandshakeOptions, gost.AddrHandshakeOption(node.Addr), + gost.HostHandshakeOption(node.Host), gost.UserHandshakeOption(node.User), gost.TLSConfigHandshakeOption(tlsCfg), gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), diff --git a/node.go b/node.go index a0fa051..1266108 100644 --- a/node.go +++ b/node.go @@ -11,6 +11,7 @@ type Node struct { ID int Addr string IPs []string + Host string Protocol string Transport string Remote string // remote address, used by tcp/udp port forwarding @@ -19,7 +20,7 @@ type Node struct { DialOptions []DialOption HandshakeOptions []HandshakeOption Client *Client - IPSelector IPSelector + Selector IPSelector } // ParseNode parses the node info. @@ -40,6 +41,7 @@ func ParseNode(s string) (node Node, err error) { node = Node{ Addr: u.Host, + Host: u.Host, Remote: strings.Trim(u.EscapedPath(), "/"), Values: u.Query(), User: u.User, diff --git a/obfs.go b/obfs.go index 424ed58..2c7da09 100644 --- a/obfs.go +++ b/obfs.go @@ -35,7 +35,7 @@ func (tr *obfsHTTPTransporter) Handshake(conn net.Conn, options ...HandshakeOpti for _, option := range options { option(opts) } - return &obfsHTTPConn{Conn: conn}, nil + return &obfsHTTPConn{Conn: conn, host: opts.Host}, nil } type obfsHTTPListener struct { @@ -66,6 +66,7 @@ func (l *obfsHTTPListener) Accept() (net.Conn, error) { type obfsHTTPConn struct { net.Conn + host string request *http.Request response *http.Response rbuf []byte @@ -151,7 +152,7 @@ func (c *obfsHTTPConn) clientHandshake() (err error) { Method: http.MethodGet, ProtoMajor: 1, ProtoMinor: 1, - URL: &url.URL{Scheme: "http", Host: "www.baidu.com"}, + URL: &url.URL{Scheme: "http", Host: c.host}, Header: make(http.Header), } r.Header.Set("Connection", "keep-alive") diff --git a/ws.go b/ws.go index d77c709..fa9c043 100644 --- a/ws.go +++ b/ws.go @@ -129,7 +129,7 @@ func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (n if opts.WSOptions != nil { wsOptions = opts.WSOptions } - url := url.URL{Scheme: "ws", Host: opts.Addr, Path: "/ws"} + url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} return websocketClientConn(url.String(), conn, nil, wsOptions) } @@ -210,7 +210,7 @@ func (tr *mwsTransporter) initSession(addr string, conn net.Conn, opts *Handshak if opts.WSOptions != nil { wsOptions = opts.WSOptions } - url := url.URL{Scheme: "ws", Host: opts.Addr, Path: "/ws"} + url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} conn, err := websocketClientConn(url.String(), conn, nil, wsOptions) if err != nil { return nil, err @@ -252,7 +252,7 @@ func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( if opts.TLSConfig == nil { opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} } - url := url.URL{Scheme: "wss", Host: opts.Addr, Path: "/ws"} + url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) } @@ -337,7 +337,7 @@ func (tr *mwssTransporter) initSession(addr string, conn net.Conn, opts *Handsha if tlsConfig == nil { tlsConfig = &tls.Config{InsecureSkipVerify: true} } - url := url.URL{Scheme: "wss", Host: opts.Addr, Path: "/ws"} + url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions) if err != nil { return nil, err From 3befde01282a3b62a39b47f1183bcd8aeeff961f Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Tue, 14 Nov 2017 17:50:02 +0800 Subject: [PATCH 6/9] add node fail filter --- chain.go | 40 +++++----------- cmd/gost/main.go | 112 ++++++++++++++++++++++++++++++--------------- node.go | 92 +++++++++++++++++++++++++++++-------- selector.go | 117 +++++++++++++++++++++++------------------------ 4 files changed, 217 insertions(+), 144 deletions(-) diff --git a/chain.go b/chain.go index 1d2a3be..a09e7e0 100644 --- a/chain.go +++ b/chain.go @@ -142,13 +142,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) { cn, err := node.Client.Dial(node.Addr, node.DialOptions...) if err != nil { + node.MarkDead() return } cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) if err != nil { + node.MarkDead() return } + node.ResetDead() preNode := node for _, node := range nodes[1:] { @@ -156,13 +159,17 @@ func (c *Chain) getConn() (conn net.Conn, err error) { cc, err = preNode.Client.Connect(cn, node.Addr) if err != nil { cn.Close() + node.MarkDead() return } cc, err = node.Client.Handshake(cc, node.HandshakeOptions...) if err != nil { cn.Close() + node.MarkDead() return } + node.ResetDead() + cn = cc preNode = node } @@ -179,46 +186,21 @@ func (c *Chain) selectRoute() (route *Chain, err error) { buf := bytes.Buffer{} route = newRoute() for _, group := range c.nodeGroups { - selector := group.Selector - if selector == nil { - selector = &defaultSelector{} - } - // select node from node group - node, err := selector.Select(group.Nodes(), group.Options...) + node, err := group.Next() if err != nil { return nil, err } - if _, err := selectIP(&node); err != nil { - return nil, err - } - buf.WriteString(fmt.Sprintf("%d@%s -> ", node.ID, node.Addr)) + buf.WriteString(fmt.Sprintf("%s -> ", node.String())) if node.Client.Transporter.Multiplex() { node.DialOptions = append(node.DialOptions, ChainDialOption(route), ) - route = newRoute() // cutoff the chain for multiplex + route = newRoute() // cutoff the chain for multiplex. } + route.AddNode(node) } log.Log("select route:", buf.String()) return } - -func selectIP(node *Node) (string, error) { - s := node.Selector - if s == nil { - s = &RandomIPSelector{} - } - // select IP from IP list - ip, err := s.Select(node.IPs) - if err != nil { - return "", err - } - if ip != "" { - // override the original address - node.Addr = ip - node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(node.Addr)) - } - return node.Addr, nil -} diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 8e69fdd..d15cabd 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -87,35 +87,53 @@ func main() { func initChain() (*gost.Chain, error) { chain := gost.NewChain() + gid := 1 // group ID + for _, ns := range options.ChainNodes { + ngroup := gost.NewNodeGroup() + ngroup.ID = gid + gid++ + // parse the base node - node, err := parseChainNode(ns) + nodes, err := parseChainNode(ns) if err != nil { return nil, err } - id := 1 // start from 1 + nid := 1 // node ID - node.ID = id - ngroup := gost.NewNodeGroup(node) + for i := range nodes { + nodes[i].ID = nid + nid++ + } + ngroup.AddNode(nodes...) - // parse node peers if exists - peerCfg, err := loadPeerConfig(node.Values.Get("peer")) + // parse peer nodes if exists + peerCfg, err := loadPeerConfig(nodes[0].Values.Get("peer")) if err != nil { log.Log(err) } + peerCfg.Validate() ngroup.Options = append(ngroup.Options, - // gost.WithFilter(), + gost.WithFilter(&gost.FailFilter{ + MaxFails: peerCfg.MaxFails, + FailTimeout: time.Duration(peerCfg.FailTimeout) * time.Second, + }), gost.WithStrategy(parseStrategy(peerCfg.Strategy)), ) + for _, s := range peerCfg.Nodes { - node, err = parseChainNode(s) + nodes, err = parseChainNode(s) if err != nil { return nil, err } - id++ - node.ID = id - ngroup.AddNode(node) + + for i := range nodes { + nodes[i].ID = nid + nid++ + } + + ngroup.AddNode(nodes...) } chain.AddNodeGroup(ngroup) @@ -124,24 +142,12 @@ func initChain() (*gost.Chain, error) { return chain, nil } -func parseChainNode(ns string) (node gost.Node, err error) { - node, err = gost.ParseNode(ns) +func parseChainNode(ns string) (nodes []gost.Node, err error) { + node, err := gost.ParseNode(ns) if err != nil { return } - node.IPs = parseIP(node.Values.Get("ip")) - for i, ip := range node.IPs { - if !strings.Contains(ip, ":") { - _, sport, _ := net.SplitHostPort(node.Addr) - if sport == "" { - sport = "8080" // default port - } - node.IPs[i] = ip + ":" + sport - } - } - node.Selector = &gost.RoundRobinIPSelector{} - users, err := parseUsers(node.Values.Get("secrets")) if err != nil { return @@ -149,7 +155,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { if node.User == nil && len(users) > 0 { node.User = users[0] } - serverName, _, _ := net.SplitHostPort(node.Addr) + serverName, sport, _ := net.SplitHostPort(node.Addr) if serverName == "" { serverName = "localhost" // default server name } @@ -191,7 +197,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { */ config, err := parseKCPConfig(node.Values.Get("c")) if err != nil { - return node, err + return nil, err } tr = gost.KCPTransporter(config) case "ssh": @@ -220,7 +226,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { case "obfs4": if err := gost.Obfs4Init(node, false); err != nil { - return node, err + return nil, err } tr = gost.Obfs4Transporter() case "ohttp": @@ -263,20 +269,31 @@ func parseChainNode(ns string) (node gost.Node, err error) { interval, _ := strconv.Atoi(node.Values.Get("ping")) retry, _ := strconv.Atoi(node.Values.Get("retry")) - node.HandshakeOptions = append(node.HandshakeOptions, + handshakeOptions := []gost.HandshakeOption{ gost.AddrHandshakeOption(node.Addr), gost.HostHandshakeOption(node.Host), gost.UserHandshakeOption(node.User), gost.TLSConfigHandshakeOption(tlsCfg), - gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), - gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second), + gost.IntervalHandshakeOption(time.Duration(interval) * time.Second), + gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second), gost.RetryHandshakeOption(retry), - ) + } node.Client = &gost.Client{ Connector: connector, Transporter: tr, } + ips := parseIP(node.Values.Get("ip"), sport) + for _, ip := range ips { + node.Addr = ip + node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip)) + nodes = append(nodes, node) + } + if len(ips) == 0 { + node.HandshakeOptions = handshakeOptions + nodes = []gost.Node{node} + } + return } @@ -559,16 +576,23 @@ func parseUsers(authFile string) (users []*url.Userinfo, err error) { return } -func parseIP(s string) (ips []string) { +func parseIP(s string, port string) (ips []string) { if s == "" { - return nil + return } + if port == "" { + port = "8080" // default port + } + file, err := os.Open(s) if err != nil { ss := strings.Split(s, ",") for _, s := range ss { s = strings.TrimSpace(s) if s != "" { + if !strings.Contains(s, ":") { + s = s + ":" + port + } ips = append(ips, s) } @@ -582,15 +606,20 @@ func parseIP(s string) (ips []string) { if line == "" || strings.HasPrefix(line, "#") { continue } + if !strings.Contains(line, ":") { + line = line + ":" + port + } ips = append(ips, line) } return } type peerConfig struct { - Strategy string `json:"strategy"` - Filters []string `json:"filters"` - Nodes []string `json:"nodes"` + Strategy string `json:"strategy"` + Filters []string `json:"filters"` + MaxFails int `json:"max_fails"` + FailTimeout int `json:"fail_timeout"` + Nodes []string `json:"nodes"` } func loadPeerConfig(peer string) (config peerConfig, err error) { @@ -605,6 +634,15 @@ func loadPeerConfig(peer string) (config peerConfig, err error) { return } +func (cfg *peerConfig) Validate() { + if cfg.MaxFails <= 0 { + cfg.MaxFails = 3 + } + if cfg.FailTimeout <= 0 { + cfg.FailTimeout = 30 // seconds + } +} + func parseStrategy(s string) gost.Strategy { switch s { case "random": diff --git a/node.go b/node.go index 1266108..799b6e4 100644 --- a/node.go +++ b/node.go @@ -1,16 +1,17 @@ package gost import ( + "fmt" "net/url" "strings" - "sync" + "sync/atomic" + "time" ) // Node is a proxy node, mainly used to construct a proxy chain. type Node struct { ID int Addr string - IPs []string Host string Protocol string Transport string @@ -20,7 +21,9 @@ type Node struct { DialOptions []DialOption HandshakeOptions []HandshakeOption Client *Client - Selector IPSelector + group *NodeGroup + failCount uint32 + failTime time.Time } // ParseNode parses the node info. @@ -83,38 +86,89 @@ func ParseNode(s string) (node Node, err error) { return } +// MarkDead marks the node fail status. +func (node *Node) MarkDead() { + atomic.AddUint32(&node.failCount, 1) + node.failTime = time.Now() + + if node.group == nil { + return + } + for i := range node.group.nodes { + if node.group.nodes[i].ID == node.ID { + atomic.AddUint32(&node.group.nodes[i].failCount, 1) + node.group.nodes[i].failTime = time.Now() + break + } + } +} + +// ResetDead resets the node fail status. +func (node *Node) ResetDead() { + atomic.StoreUint32(&node.failCount, 0) + node.failTime = time.Time{} + + if node.group == nil { + return + } + + for i := range node.group.nodes { + if node.group.nodes[i].ID == node.ID { + atomic.StoreUint32(&node.group.nodes[i].failCount, 0) + node.group.nodes[i].failTime = time.Time{} + break + } + } +} + +func (node *Node) String() string { + return fmt.Sprintf("%d@%s", node.ID, node.Addr) +} + // NodeGroup is a group of nodes. type NodeGroup struct { - nodes []Node - Options []SelectOption - Selector NodeSelector - mutex sync.Mutex - mFails map[string]int // node -> fail count - MaxFails int - FailTimeout int - Retries int + ID int + nodes []Node + Options []SelectOption + Selector NodeSelector } // NewNodeGroup creates a node group func NewNodeGroup(nodes ...Node) *NodeGroup { return &NodeGroup{ - nodes: nodes, - mFails: make(map[string]int), + nodes: nodes, } } // AddNode adds node or node list into group -func (ng *NodeGroup) AddNode(node ...Node) { - if ng == nil { +func (group *NodeGroup) AddNode(node ...Node) { + if group == nil { return } - ng.nodes = append(ng.nodes, node...) + group.nodes = append(group.nodes, node...) } // Nodes returns node list in the group -func (ng *NodeGroup) Nodes() []Node { - if ng == nil { +func (group *NodeGroup) Nodes() []Node { + if group == nil { return nil } - return ng.nodes + return group.nodes +} + +// Next selects the next node from group. +// It also selects IP if the IP list exists. +func (group *NodeGroup) Next() (node Node, err error) { + selector := group.Selector + if selector == nil { + selector = &defaultSelector{} + } + // select node from node group + node, err = selector.Select(group.Nodes(), group.Options...) + if err != nil { + return + } + node.group = group + + return } diff --git a/selector.go b/selector.go index 675a254..345b265 100644 --- a/selector.go +++ b/selector.go @@ -2,6 +2,8 @@ package gost import ( "errors" + "math/rand" + "sync" "sync/atomic" "time" ) @@ -37,9 +39,28 @@ func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, erro return sopts.Strategy.Apply(nodes), nil } -// Filter is used to filter a node during the selection process -type Filter interface { - Filter([]Node) []Node +// SelectOption used when making a select call +type SelectOption func(*SelectOptions) + +// SelectOptions is the options for node selection +type SelectOptions struct { + Filters []Filter + Strategy Strategy +} + +// WithFilter adds a filter function to the list of filters +// used during the Select call. +func WithFilter(f ...Filter) SelectOption { + return func(o *SelectOptions) { + o.Filters = append(o.Filters, f...) + } +} + +// WithStrategy sets the selector strategy +func WithStrategy(s Strategy) SelectOption { + return func(o *SelectOptions) { + o.Strategy = s + } } // Strategy is a selection strategy e.g random, round robin @@ -68,82 +89,60 @@ func (s *RoundStrategy) String() string { } // RandomStrategy is a strategy for node selector -type RandomStrategy struct{} +type RandomStrategy struct { + Seed int64 + rand *rand.Rand + once sync.Once +} // Apply applies the random strategy for the nodes func (s *RandomStrategy) Apply(nodes []Node) Node { + s.once.Do(func() { + seed := s.Seed + if seed == 0 { + seed = time.Now().UnixNano() + } + s.rand = rand.New(rand.NewSource(seed)) + }) if len(nodes) == 0 { return Node{} } - return nodes[time.Now().Nanosecond()%len(nodes)] + return nodes[s.rand.Int()%len(nodes)] } func (s *RandomStrategy) String() string { return "random" } -// SelectOption used when making a select call -type SelectOption func(*SelectOptions) - -// SelectOptions is the options for node selection -type SelectOptions struct { - Filters []Filter - Strategy Strategy -} - -// WithFilter adds a filter function to the list of filters -// used during the Select call. -func WithFilter(f ...Filter) SelectOption { - return func(o *SelectOptions) { - o.Filters = append(o.Filters, f...) - } -} - -// WithStrategy sets the selector strategy -func WithStrategy(s Strategy) SelectOption { - return func(o *SelectOptions) { - o.Strategy = s - } -} - -// IPSelector as a mechanism to pick IPs and mark their status. -type IPSelector interface { - Select(ips []string) (string, error) +// Filter is used to filter a node during the selection process +type Filter interface { + Filter([]Node) []Node String() string } -// RandomIPSelector is an IP Selector that selects an IP with random strategy. -type RandomIPSelector struct { +// FailFilter filters the dead node. +// A node is marked as dead if its failed count is greater than MaxFails. +type FailFilter struct { + MaxFails int + FailTimeout time.Duration } -// Select selects an IP from ips list. -func (s *RandomIPSelector) Select(ips []string) (string, error) { - if len(ips) == 0 { - return "", nil +// Filter filters nodes. +func (f *FailFilter) Filter(nodes []Node) []Node { + if f.MaxFails <= 0 { + return nodes } - return ips[time.Now().Nanosecond()%len(ips)], nil -} - -func (s *RandomIPSelector) String() string { - return "random" -} - -// RoundRobinIPSelector is an IP Selector that selects an IP with round-robin strategy. -type RoundRobinIPSelector struct { - count uint64 -} - -// Select selects an IP from ips list. -func (s *RoundRobinIPSelector) Select(ips []string) (string, error) { - if len(ips) == 0 { - return "", nil + nl := []Node{} + for _, node := range nodes { + if node.failCount < uint32(f.MaxFails) || + time.Since(node.failTime) >= f.FailTimeout { + nl = append(nl, node) + } } - old := s.count - atomic.AddUint64(&s.count, 1) - return ips[int(old%uint64(len(ips)))], nil + return nl } -func (s *RoundRobinIPSelector) String() string { - return "round" +func (f *FailFilter) String() string { + return "fail" } From 003c5a5085fdad0b1aa81aa60ee493955e91c41a Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Wed, 15 Nov 2017 14:39:35 +0800 Subject: [PATCH 7/9] enhancement configuration --- chain.go | 8 ++++-- cmd/gost/gost.json | 20 ++++++++++++++ cmd/gost/main.go | 66 +++++++++++++++++++++++++++++++++------------- handler.go | 3 ++- node.go | 30 +++++++++++++++++---- selector.go | 10 +++---- 6 files changed, 105 insertions(+), 32 deletions(-) diff --git a/chain.go b/chain.go index a09e7e0..9ebd14b 100644 --- a/chain.go +++ b/chain.go @@ -17,6 +17,7 @@ var ( // Chain is a proxy chain that holds a list of proxy nodes. type Chain struct { isRoute bool + Retries int nodeGroups []*NodeGroup } @@ -58,8 +59,8 @@ func (c *Chain) LastNode() Node { if c.IsEmpty() { return Node{} } - last := c.nodeGroups[len(c.nodeGroups)-1] - return last.nodes[0] + group := c.nodeGroups[len(c.nodeGroups)-1] + return group.nodes[0].Clone() } // LastNodeGroup returns the last group of the group list. @@ -185,6 +186,8 @@ func (c *Chain) selectRoute() (route *Chain, err error) { buf := bytes.Buffer{} route = newRoute() + route.Retries = c.Retries + for _, group := range c.nodeGroups { node, err := group.Next() if err != nil { @@ -197,6 +200,7 @@ func (c *Chain) selectRoute() (route *Chain, err error) { ChainDialOption(route), ) route = newRoute() // cutoff the chain for multiplex. + route.Retries = c.Retries } route.AddNode(node) diff --git a/cmd/gost/gost.json b/cmd/gost/gost.json index 8602a51..e7217c9 100644 --- a/cmd/gost/gost.json +++ b/cmd/gost/gost.json @@ -1,4 +1,6 @@ { + "Debug": false, + "Retries": 1, "ServeNodes": [ ":8080", "ss://chacha20:12345678@:8338" @@ -6,5 +8,23 @@ "ChainNodes": [ "http://192.168.1.1:8080", "https://10.0.2.1:443" + ], + + "Routes": [ + { + "Retries": 1, + "ServeNodes": [ + "ws://:1443" + ], + "ChainNodes": [ + "socks://:192.168.1.1:1080" + ] + }, + { + "Retries": 3, + "ServeNodes": [ + "quic://:443" + ] + } ] } \ No newline at end of file diff --git a/cmd/gost/main.go b/cmd/gost/main.go index d15cabd..b075cfa 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -22,10 +22,8 @@ import ( ) var ( - options struct { - ChainNodes, ServeNodes stringList - Debug bool - } + options route + routes []route ) func init() { @@ -43,12 +41,17 @@ func init() { flag.BoolVar(&printVersion, "V", false, "print version") flag.Parse() + if len(options.ServeNodes) > 0 { + routes = append(routes, options) + } + gost.Debug = options.Debug + if err := loadConfigureFile(configureFile); err != nil { log.Log(err) os.Exit(1) } - if flag.NFlag() == 0 { + if flag.NFlag() == 0 || len(routes) == 0 { flag.PrintDefaults() os.Exit(0) } @@ -58,7 +61,7 @@ func init() { os.Exit(0) } - gost.Debug = options.Debug + log.Log("Debug:", gost.Debug) } func main() { @@ -72,24 +75,29 @@ func main() { Certificates: []tls.Certificate{cert}, } - chain, err := initChain() - if err != nil { - log.Log(err) - os.Exit(1) - } - if err := serve(chain); err != nil { - log.Log(err) - os.Exit(1) + for _, route := range routes { + if err := route.serve(); err != nil { + log.Log(err) + os.Exit(1) + } } select {} } -func initChain() (*gost.Chain, error) { +type route struct { + ChainNodes, ServeNodes stringList + Retries int + Debug bool +} + +func (r *route) initChain() (*gost.Chain, error) { chain := gost.NewChain() + chain.Retries = r.Retries + gid := 1 // group ID - for _, ns := range options.ChainNodes { + for _, ns := range r.ChainNodes { ngroup := gost.NewNodeGroup() ngroup.ID = gid gid++ @@ -297,8 +305,13 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { return } -func serve(chain *gost.Chain) error { - for _, ns := range options.ServeNodes { +func (r *route) serve() error { + chain, err := r.initChain() + if err != nil { + return err + } + + for _, ns := range r.ServeNodes { node, err := gost.ParseNode(ns) if err != nil { return err @@ -507,9 +520,24 @@ func loadConfigureFile(configureFile string) error { if err != nil { return err } - if err := json.Unmarshal(content, &options); err != nil { + var cfg struct { + route + Routes []route + } + if err := json.Unmarshal(content, &cfg); err != nil { return err } + + if len(cfg.route.ServeNodes) > 0 { + routes = append(routes, cfg.route) + } + for _, route := range cfg.Routes { + if len(cfg.route.ServeNodes) > 0 { + routes = append(routes, route) + } + } + gost.Debug = cfg.Debug + return nil } diff --git a/handler.go b/handler.go index 2b51e4e..84162b3 100644 --- a/handler.go +++ b/handler.go @@ -95,7 +95,8 @@ func (h *autoHandler) Handle(conn net.Conn) { cc := &bufferdConn{Conn: conn, br: br} switch b[0] { case gosocks4.Ver4: - return // SOCKS4(a) does not suppport authentication method, so we ignore it. + cc.Close() + return // SOCKS4(a) does not suppport authentication method, so we ignore it for security reason. case gosocks5.Ver5: SOCKS5Handler(h.options...).Handle(cc) default: // http diff --git a/node.go b/node.go index 799b6e4..14d5b17 100644 --- a/node.go +++ b/node.go @@ -23,7 +23,7 @@ type Node struct { Client *Client group *NodeGroup failCount uint32 - failTime time.Time + failTime int64 } // ParseNode parses the node info. @@ -89,7 +89,7 @@ func ParseNode(s string) (node Node, err error) { // MarkDead marks the node fail status. func (node *Node) MarkDead() { atomic.AddUint32(&node.failCount, 1) - node.failTime = time.Now() + atomic.StoreInt64(&node.failTime, time.Now().Unix()) if node.group == nil { return @@ -97,7 +97,7 @@ func (node *Node) MarkDead() { for i := range node.group.nodes { if node.group.nodes[i].ID == node.ID { atomic.AddUint32(&node.group.nodes[i].failCount, 1) - node.group.nodes[i].failTime = time.Now() + atomic.StoreInt64(&node.group.nodes[i].failTime, time.Now().Unix()) break } } @@ -106,7 +106,7 @@ func (node *Node) MarkDead() { // ResetDead resets the node fail status. func (node *Node) ResetDead() { atomic.StoreUint32(&node.failCount, 0) - node.failTime = time.Time{} + atomic.StoreInt64(&node.failTime, 0) if node.group == nil { return @@ -115,12 +115,32 @@ func (node *Node) ResetDead() { for i := range node.group.nodes { if node.group.nodes[i].ID == node.ID { atomic.StoreUint32(&node.group.nodes[i].failCount, 0) - node.group.nodes[i].failTime = time.Time{} + atomic.StoreInt64(&node.group.nodes[i].failTime, 0) break } } } +// Clone clones the node, it will prevent data race. +func (node *Node) Clone() Node { + return Node{ + ID: node.ID, + Addr: node.Addr, + Host: node.Host, + Protocol: node.Protocol, + Transport: node.Transport, + Remote: node.Remote, + User: node.User, + Values: node.Values, + DialOptions: node.DialOptions, + HandshakeOptions: node.HandshakeOptions, + Client: node.Client, + group: node.group, + failCount: atomic.LoadUint32(&node.failCount), + failTime: atomic.LoadInt64(&node.failTime), + } +} + func (node *Node) String() string { return fmt.Sprintf("%d@%s", node.ID, node.Addr) } diff --git a/selector.go b/selector.go index 345b265..bea2c65 100644 --- a/selector.go +++ b/selector.go @@ -79,7 +79,7 @@ func (s *RoundStrategy) Apply(nodes []Node) Node { if len(nodes) == 0 { return Node{} } - old := s.count + old := atomic.LoadUint64(&s.count) atomic.AddUint64(&s.count, 1) return nodes[int(old%uint64(len(nodes)))] } @@ -134,10 +134,10 @@ func (f *FailFilter) Filter(nodes []Node) []Node { return nodes } nl := []Node{} - for _, node := range nodes { - if node.failCount < uint32(f.MaxFails) || - time.Since(node.failTime) >= f.FailTimeout { - nl = append(nl, node) + for i := range nodes { + if atomic.LoadUint32(&nodes[i].failCount) < uint32(f.MaxFails) || + time.Since(time.Unix(atomic.LoadInt64(&nodes[i].failTime), 0)) >= f.FailTimeout { + nl = append(nl, nodes[i].Clone()) } } return nl From cd26bcb81db693d2557d7cd9bf0c1f0539ea149b Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Wed, 15 Nov 2017 15:54:29 +0800 Subject: [PATCH 8/9] add retry mechanism for chain --- cmd/gost/main.go | 2 -- cmd/gost/peer.json | 10 ++++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) create mode 100644 cmd/gost/peer.json diff --git a/cmd/gost/main.go b/cmd/gost/main.go index b075cfa..14071cb 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -60,8 +60,6 @@ func init() { fmt.Fprintf(os.Stderr, "gost %s (%s)\n", gost.Version, runtime.Version()) os.Exit(0) } - - log.Log("Debug:", gost.Debug) } func main() { diff --git a/cmd/gost/peer.json b/cmd/gost/peer.json new file mode 100644 index 0000000..a4eee76 --- /dev/null +++ b/cmd/gost/peer.json @@ -0,0 +1,10 @@ +{ + "strategy": "round", + "max_fails": 3, + "fail_timeout": 30, + "nodes":[ + "socks5://:1081", + "socks://:1082", + "socks4a://:1083" + ] +} \ No newline at end of file From 6f935fc397e4abfa949c21c8f5f3e30d7675be88 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Wed, 15 Nov 2017 15:55:15 +0800 Subject: [PATCH 9/9] add retry mechanism for chain --- chain.go | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/chain.go b/chain.go index 9ebd14b..ca7d15b 100644 --- a/chain.go +++ b/chain.go @@ -99,11 +99,21 @@ func (c *Chain) IsEmpty() bool { // Dial connects to the target address addr through the chain. // If the chain is empty, it will use the net.Dial directly. -func (c *Chain) Dial(addr string) (net.Conn, error) { +func (c *Chain) Dial(addr string) (conn net.Conn, err error) { if c.IsEmpty() { return net.DialTimeout("tcp", addr, DialTimeout) } + for i := 0; i < c.Retries+1; i++ { + conn, err = c.dial(addr) + if err == nil { + break + } + } + return +} + +func (c *Chain) dial(addr string) (net.Conn, error) { route, err := c.selectRoute() if err != nil { return nil, err @@ -125,11 +135,19 @@ func (c *Chain) Dial(addr string) (net.Conn, error) { // Conn obtains a handshaked connection to the last node of the chain. // If the chain is empty, it returns an ErrEmptyChain error. func (c *Chain) Conn() (conn net.Conn, err error) { - route, err := c.selectRoute() - if err != nil { - return nil, err + for i := 0; i < c.Retries+1; i++ { + var route *Chain + route, err = c.selectRoute() + if err != nil { + continue + } + conn, err = route.getConn() + if err != nil { + continue + } + + break } - conn, err = route.getConn() return } @@ -205,6 +223,8 @@ func (c *Chain) selectRoute() (route *Chain, err error) { route.AddNode(node) } - log.Log("select route:", buf.String()) + if Debug { + log.Log("select route:", buf.String()) + } return }