From dedd08530aba650790d4853b949ac817f17ba541 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Fri, 3 Nov 2017 17:42:05 +0800 Subject: [PATCH] #141: Add load balancing support for proxy chain --- chain.go | 72 +++++----- client.go | 1 - cmd/gost/main.go | 333 ++++++++++++++++++++++++++++------------------- gost.go | 26 ++-- quic.go | 1 + selector.go | 69 +++++++--- 6 files changed, 301 insertions(+), 201 deletions(-) diff --git a/chain.go b/chain.go index db1fb7b..9a409d9 100644 --- a/chain.go +++ b/chain.go @@ -95,12 +95,17 @@ func (c *Chain) Dial(addr string) (net.Conn, error) { return net.Dial("tcp", addr) } - conn, nodes, err := c.getConn() + route, err := c.selectRoute() if err != nil { return nil, err } - cc, err := nodes[len(nodes)-1].Client.Connect(conn, addr) + conn, err := c.getConn(route) + if err != nil { + return nil, err + } + + cc, err := route.LastNode().Client.Connect(conn, addr) if err != nil { conn.Close() return nil, err @@ -111,26 +116,44 @@ func (c *Chain) Dial(addr string) (net.Conn, error) { // Conn obtains a handshaked connection to the last node of the chain. // If the chain is empty, it returns an ErrEmptyChain error. func (c *Chain) Conn() (conn net.Conn, err error) { - conn, _, err = c.getConn() + route, err := c.selectRoute() + if err != nil { + return nil, err + } + conn, err = c.getConn(route) return } -func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) { - if c.IsEmpty() { +func (c *Chain) selectRoute() (route *Chain, err error) { + route = NewChain() + for _, group := range c.nodeGroups { + selector := group.Selector + if selector == nil { + selector = &defaultSelector{} + } + // select node from node group + node, err := selector.Select(group.Nodes(), group.Options...) + if err != nil { + return nil, err + } + if node.Client.Transporter.Multiplex() { + node.DialOptions = append(node.DialOptions, + ChainDialOption(route), + ) + route = NewChain() // cutoff the chain for multiplex + } + route.AddNode(node) + } + return +} + +func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { + if route.IsEmpty() { err = ErrEmptyChain return } - groups := c.nodeGroups - selector := groups[0].Selector - if selector == nil { - selector = &defaultSelector{} - } - // select node from node group - node, err := selector.Select(groups[0].Nodes(), groups[0].Options...) - if err != nil { - return - } - nodes = append(nodes, node) + nodes := route.Nodes() + node := nodes[0] addr, err := selectIP(&node) if err != nil { @@ -147,21 +170,7 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) { } preNode := node - for i := range groups { - if i == len(groups)-1 { - break - } - selector = groups[i+1].Selector - if selector == nil { - selector = &defaultSelector{} - } - node, err = selector.Select(groups[i+1].Nodes(), groups[i+1].Options...) - if err != nil { - cn.Close() - return - } - nodes = append(nodes, node) - + for _, node := range nodes[1:] { addr, err = selectIP(&node) if err != nil { return @@ -206,6 +215,7 @@ func selectIP(node *Node) (string, error) { ip = ip + ":" + sport } addr = ip + // override the original address node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr)) } log.Log("select IP:", node.Addr, node.IPs, addr) diff --git a/client.go b/client.go index c45d9cf..e4a05f9 100644 --- a/client.go +++ b/client.go @@ -94,7 +94,6 @@ func (tr *tcpTransporter) Multiplex() bool { type DialOptions struct { Timeout time.Duration Chain *Chain - // IPs []string } // DialOption allows a common way to set dial options. diff --git a/cmd/gost/main.go b/cmd/gost/main.go index a1c7e39..5ebbe66 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -62,6 +62,16 @@ func init() { } func main() { + // generate random self-signed certificate. + cert, err := gost.GenCertificate() + if err != nil { + log.Log(err) + os.Exit(1) + } + gost.DefaultTLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + chain, err := initChain() if err != nil { log.Log(err) @@ -71,162 +81,190 @@ func main() { log.Log(err) os.Exit(1) } + select {} } func initChain() (*gost.Chain, error) { chain := gost.NewChain() for _, ns := range options.ChainNodes { - node, err := gost.ParseNode(ns) + // parse the base node + node, err := parseChainNode(ns) if err != nil { return nil, err } - node.IPs = parseIP(node.Values.Get("ip")) - node.IPSelector = &gost.RoundRobinIPSelector{} + ngroup := gost.NewNodeGroup(node) - users, err := parseUsers(node.Values.Get("secrets")) + // parse node peers if exists + peerCfg, err := loadPeerConfig(node.Values.Get("peer")) if err != nil { - return nil, err + log.Log(err) } - if node.User == nil && len(users) > 0 { - node.User = users[0] - } - serverName, _, _ := net.SplitHostPort(node.Addr) - if serverName == "" { - serverName = "localhost" // default server name - } - - rootCAs, err := loadCA(node.Values.Get("ca")) - if err != nil { - return nil, err - } - tlsCfg := &tls.Config{ - ServerName: serverName, - InsecureSkipVerify: !toBool(node.Values.Get("secure")), - RootCAs: rootCAs, - } - wsOpts := &gost.WSOptions{} - wsOpts.EnableCompression = toBool(node.Values.Get("compression")) - wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) - wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) - wsOpts.UserAgent = node.Values.Get("agent") - - var tr gost.Transporter - switch node.Transport { - case "tls": - tr = gost.TLSTransporter() - case "mtls": - tr = gost.MTLSTransporter() - case "ws": - tr = gost.WSTransporter(wsOpts) - case "mws": - tr = gost.MWSTransporter(wsOpts) - case "wss": - tr = gost.WSSTransporter(wsOpts) - case "mwss": - tr = gost.MWSSTransporter(wsOpts) - case "kcp": - if !chain.IsEmpty() { - return nil, errors.New("KCP must be the first node in the proxy chain") - } - config, err := parseKCPConfig(node.Values.Get("c")) + ngroup.Options = append(ngroup.Options, + // gost.WithFilter(), + gost.WithStrategy(parseStrategy(peerCfg.Strategy)), + ) + for _, s := range peerCfg.Nodes { + node, err = parseChainNode(s) if err != nil { return nil, err } - tr = gost.KCPTransporter(config) - case "ssh": - if node.Protocol == "direct" || node.Protocol == "remote" { - tr = gost.SSHForwardTransporter() - } else { - tr = gost.SSHTunnelTransporter() - } - case "quic": - if !chain.IsEmpty() { - return nil, errors.New("QUIC must be the first node in the proxy chain") - } - config := &gost.QUICConfig{ - TLSConfig: tlsCfg, - KeepAlive: toBool(node.Values.Get("keepalive")), - } - tr = gost.QUICTransporter(config) - case "http2": - tr = gost.HTTP2Transporter(tlsCfg) - case "h2": - tr = gost.H2Transporter(tlsCfg) - case "h2c": - tr = gost.H2CTransporter() - - case "obfs4": - if err := gost.Obfs4Init(node, false); err != nil { - return nil, err - } - tr = gost.Obfs4Transporter() - case "ohttp": - tr = gost.ObfsHTTPTransporter() - default: - tr = gost.TCPTransporter() + ngroup.AddNode(node) } - if tr.Multiplex() { - node.DialOptions = append(node.DialOptions, - gost.ChainDialOption(chain), - ) - chain = gost.NewChain() // cutoff the chain for multiplex - } - - var connector gost.Connector - switch node.Protocol { - case "http2": - connector = gost.HTTP2Connector(node.User) - case "socks", "socks5": - connector = gost.SOCKS5Connector(node.User) - case "socks4": - connector = gost.SOCKS4Connector() - case "socks4a": - connector = gost.SOCKS4AConnector() - case "ss": - connector = gost.ShadowConnector(node.User) - case "direct": - connector = gost.SSHDirectForwardConnector() - case "remote": - connector = gost.SSHRemoteForwardConnector() - case "forward": - connector = gost.ForwardConnector() - case "sni": - connector = gost.SNIConnector(node.Values.Get("host")) - case "http": - fallthrough - default: - node.Protocol = "http" // default protocol is HTTP - connector = gost.HTTPConnector(node.User) - } - - timeout, _ := strconv.Atoi(node.Values.Get("timeout")) - node.DialOptions = append(node.DialOptions, - gost.TimeoutDialOption(time.Duration(timeout)*time.Second), - ) - - interval, _ := strconv.Atoi(node.Values.Get("ping")) - retry, _ := strconv.Atoi(node.Values.Get("retry")) - node.HandshakeOptions = append(node.HandshakeOptions, - gost.AddrHandshakeOption(node.Addr), - gost.UserHandshakeOption(node.User), - gost.TLSConfigHandshakeOption(tlsCfg), - gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), - gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second), - gost.RetryHandshakeOption(retry), - ) - node.Client = &gost.Client{ - Connector: connector, - Transporter: tr, - } - chain.AddNode(node) + chain.AddNodeGroup(ngroup) } return chain, nil } +func parseChainNode(ns string) (node gost.Node, err error) { + node, err = gost.ParseNode(ns) + if err != nil { + return + } + + node.IPs = parseIP(node.Values.Get("ip")) + node.IPSelector = &gost.RoundRobinIPSelector{} + + users, err := parseUsers(node.Values.Get("secrets")) + if err != nil { + return + } + if node.User == nil && len(users) > 0 { + node.User = users[0] + } + serverName, _, _ := net.SplitHostPort(node.Addr) + if serverName == "" { + serverName = "localhost" // default server name + } + + rootCAs, err := loadCA(node.Values.Get("ca")) + if err != nil { + return + } + tlsCfg := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: !toBool(node.Values.Get("secure")), + RootCAs: rootCAs, + } + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = toBool(node.Values.Get("compression")) + wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) + wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) + wsOpts.UserAgent = node.Values.Get("agent") + + var tr gost.Transporter + switch node.Transport { + case "tls": + tr = gost.TLSTransporter() + case "mtls": + tr = gost.MTLSTransporter() + case "ws": + tr = gost.WSTransporter(wsOpts) + case "mws": + tr = gost.MWSTransporter(wsOpts) + case "wss": + tr = gost.WSSTransporter(wsOpts) + case "mwss": + tr = gost.MWSSTransporter(wsOpts) + case "kcp": + /* + if !chain.IsEmpty() { + return nil, errors.New("KCP must be the first node in the proxy chain") + } + */ + config, err := parseKCPConfig(node.Values.Get("c")) + if err != nil { + return node, err + } + tr = gost.KCPTransporter(config) + case "ssh": + if node.Protocol == "direct" || node.Protocol == "remote" { + tr = gost.SSHForwardTransporter() + } else { + tr = gost.SSHTunnelTransporter() + } + case "quic": + /* + if !chain.IsEmpty() { + return nil, errors.New("QUIC must be the first node in the proxy chain") + } + */ + config := &gost.QUICConfig{ + TLSConfig: tlsCfg, + KeepAlive: toBool(node.Values.Get("keepalive")), + } + tr = gost.QUICTransporter(config) + case "http2": + tr = gost.HTTP2Transporter(tlsCfg) + case "h2": + tr = gost.H2Transporter(tlsCfg) + case "h2c": + tr = gost.H2CTransporter() + + case "obfs4": + if err := gost.Obfs4Init(node, false); err != nil { + return node, err + } + tr = gost.Obfs4Transporter() + case "ohttp": + tr = gost.ObfsHTTPTransporter() + default: + tr = gost.TCPTransporter() + } + + var connector gost.Connector + switch node.Protocol { + case "http2": + connector = gost.HTTP2Connector(node.User) + case "socks", "socks5": + connector = gost.SOCKS5Connector(node.User) + case "socks4": + connector = gost.SOCKS4Connector() + case "socks4a": + connector = gost.SOCKS4AConnector() + case "ss": + connector = gost.ShadowConnector(node.User) + case "direct": + connector = gost.SSHDirectForwardConnector() + case "remote": + connector = gost.SSHRemoteForwardConnector() + case "forward": + connector = gost.ForwardConnector() + case "sni": + connector = gost.SNIConnector(node.Values.Get("host")) + case "http": + fallthrough + default: + node.Protocol = "http" // default protocol is HTTP + connector = gost.HTTPConnector(node.User) + } + + timeout, _ := strconv.Atoi(node.Values.Get("timeout")) + node.DialOptions = append(node.DialOptions, + gost.TimeoutDialOption(time.Duration(timeout)*time.Second), + ) + + interval, _ := strconv.Atoi(node.Values.Get("ping")) + retry, _ := strconv.Atoi(node.Values.Get("retry")) + node.HandshakeOptions = append(node.HandshakeOptions, + gost.AddrHandshakeOption(node.Addr), + gost.UserHandshakeOption(node.User), + gost.TLSConfigHandshakeOption(tlsCfg), + gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), + gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second), + gost.RetryHandshakeOption(retry), + ) + node.Client = &gost.Client{ + Connector: connector, + Transporter: tr, + } + + return +} + func serve(chain *gost.Chain) error { for _, ns := range options.ServeNodes { node, err := gost.ParseNode(ns) @@ -533,3 +571,32 @@ func parseIP(s string) (ips []string) { } return } + +type peerConfig struct { + Strategy string `json:"strategy"` + Filters []string `json:"filters"` + Nodes []string `json:"nodes"` +} + +func loadPeerConfig(peer string) (config peerConfig, err error) { + if peer == "" { + return + } + content, err := ioutil.ReadFile(peer) + if err != nil { + return + } + err = json.Unmarshal(content, &config) + return +} + +func parseStrategy(s string) gost.Strategy { + switch s { + case "round": + return &gost.RoundStrategy{} + case "random": + fallthrough + default: + return &gost.RandomStrategy{} + } +} diff --git a/gost.go b/gost.go index 057a9f2..39b818f 100644 --- a/gost.go +++ b/gost.go @@ -38,7 +38,7 @@ var ( // PingTimeout is the timeout for pinging. PingTimeout = 30 * time.Second // PingRetries is the reties of ping. - PingRetries = 3 + PingRetries = 1 // default udp node TTL in second for udp port forwarding. defaultTTL = 60 * time.Second ) @@ -51,27 +51,19 @@ var ( DefaultUserAgent = "Chrome/60.0.3112.90" ) -func init() { - rawCert, rawKey, err := generateKeyPair() - if err != nil { - panic(err) - } - cert, err := tls.X509KeyPair(rawCert, rawKey) - if err != nil { - panic(err) - } - DefaultTLSConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - - // log.DefaultLogger = &LogLogger{} -} - // SetLogger sets a new logger for internal log system func SetLogger(logger log.Logger) { log.DefaultLogger = logger } +func GenCertificate() (cert tls.Certificate, err error) { + rawCert, rawKey, err := generateKeyPair() + if err != nil { + return + } + return tls.X509KeyPair(rawCert, rawKey) +} + func generateKeyPair() (rawCert, rawKey []byte, err error) { // Create private key and self-signed certificate // Adapted from https://golang.org/src/crypto/tls/generate_cert.go diff --git a/quic.go b/quic.go index 8e218f0..56a61c8 100644 --- a/quic.go +++ b/quic.go @@ -194,6 +194,7 @@ func (l *quicListener) sessionLoop(session quic.Session) { stream, err := session.AcceptStream() if err != nil { log.Log("[quic] accept stream:", err) + session.Close(err) return } diff --git a/selector.go b/selector.go index 6f49987..675a254 100644 --- a/selector.go +++ b/selector.go @@ -11,14 +11,10 @@ var ( ErrNoneAvailable = errors.New("none available") ) -// SelectOption used when making a select call -type SelectOption func(*SelectOptions) - // NodeSelector as a mechanism to pick nodes and mark their status. type NodeSelector interface { Select(nodes []Node, opts ...SelectOption) (Node, error) // Mark(node Node) - String() string } type defaultSelector struct { @@ -26,35 +22,70 @@ type defaultSelector struct { func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, error) { sopts := SelectOptions{ - Strategy: defaultStrategy, + Strategy: &RoundStrategy{}, } for _, opt := range opts { opt(&sopts) } for _, filter := range sopts.Filters { - nodes = filter(nodes) + nodes = filter.Filter(nodes) } if len(nodes) == 0 { return Node{}, ErrNoneAvailable } - return sopts.Strategy(nodes), nil -} - -func (s *defaultSelector) String() string { - return "default" + return sopts.Strategy.Apply(nodes), nil } // Filter is used to filter a node during the selection process -type Filter func([]Node) []Node +type Filter interface { + Filter([]Node) []Node +} // Strategy is a selection strategy e.g random, round robin -type Strategy func([]Node) Node - -func defaultStrategy(nodes []Node) Node { - return nodes[0] +type Strategy interface { + Apply([]Node) Node + String() string } +// RoundStrategy is a strategy for node selector +type RoundStrategy struct { + count uint64 +} + +// Apply applies the round robin strategy for the nodes +func (s *RoundStrategy) Apply(nodes []Node) Node { + if len(nodes) == 0 { + return Node{} + } + old := s.count + atomic.AddUint64(&s.count, 1) + return nodes[int(old%uint64(len(nodes)))] +} + +func (s *RoundStrategy) String() string { + return "round" +} + +// RandomStrategy is a strategy for node selector +type RandomStrategy struct{} + +// Apply applies the random strategy for the nodes +func (s *RandomStrategy) Apply(nodes []Node) Node { + if len(nodes) == 0 { + return Node{} + } + + return nodes[time.Now().Nanosecond()%len(nodes)] +} + +func (s *RandomStrategy) String() string { + return "random" +} + +// SelectOption used when making a select call +type SelectOption func(*SelectOptions) + // SelectOptions is the options for node selection type SelectOptions struct { Filters []Filter @@ -108,9 +139,9 @@ func (s *RoundRobinIPSelector) Select(ips []string) (string, error) { if len(ips) == 0 { return "", nil } - - count := atomic.AddUint64(&s.count, 1) - return ips[int(count%uint64(len(ips)))], nil + old := s.count + atomic.AddUint64(&s.count, 1) + return ips[int(old%uint64(len(ips)))], nil } func (s *RoundRobinIPSelector) String() string {