diff --git a/chain.go b/chain.go index 4c6ea2a..db1fb7b 100644 --- a/chain.go +++ b/chain.go @@ -3,6 +3,9 @@ package gost import ( "errors" "net" + "strings" + + "github.com/go-log/log" ) var ( @@ -122,13 +125,18 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) { if selector == nil { selector = &defaultSelector{} } + // select node from node group node, err := selector.Select(groups[0].Nodes(), groups[0].Options...) if err != nil { return } nodes = append(nodes, node) - cn, err := node.Client.Dial(node.Addr, node.DialOptions...) + addr, err := selectIP(&node) + if err != nil { + return + } + cn, err := node.Client.Dial(addr, node.DialOptions...) if err != nil { return } @@ -154,8 +162,13 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) { } nodes = append(nodes, node) + addr, err = selectIP(&node) + if err != nil { + return + } + var cc net.Conn - cc, err = preNode.Client.Connect(cn, node.Addr) + cc, err = preNode.Client.Connect(cn, addr) if err != nil { cn.Close() return @@ -172,3 +185,29 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) { conn = cn return } + +func selectIP(node *Node) (string, error) { + addr := node.Addr + s := node.IPSelector + if s == nil { + s = &RandomIPSelector{} + } + // select IP from IP list + ip, err := s.Select(node.IPs) + if err != nil { + return "", err + } + if ip != "" { + if !strings.Contains(ip, ":") { + _, sport, err := net.SplitHostPort(addr) + if err != nil { + return "", err + } + ip = ip + ":" + sport + } + addr = ip + node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr)) + } + log.Log("select IP:", node.Addr, node.IPs, addr) + return addr, nil +} diff --git a/client.go b/client.go index 810b01e..c45d9cf 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "net" "net/url" - "sync/atomic" "time" ) @@ -64,7 +63,6 @@ type Transporter interface { } type tcpTransporter struct { - count uint64 } // TCPTransporter creates a transporter for TCP proxy client. @@ -78,16 +76,6 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er option(opts) } - if len(opts.IPs) > 0 { - count := atomic.AddUint64(&tr.count, 1) - _, sport, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - n := uint64(len(opts.IPs)) - addr = opts.IPs[int(count%n)] + ":" + sport - } - if opts.Chain == nil { return net.DialTimeout("tcp", addr, opts.Timeout) } @@ -106,7 +94,7 @@ func (tr *tcpTransporter) Multiplex() bool { type DialOptions struct { Timeout time.Duration Chain *Chain - IPs []string + // IPs []string } // DialOption allows a common way to set dial options. @@ -126,13 +114,6 @@ func ChainDialOption(chain *Chain) DialOption { } } -// IPDialOption specifies an IP list used by Transporter.Dial -func IPDialOption(ips ...string) DialOption { - return func(opts *DialOptions) { - opts.IPs = ips - } -} - // HandshakeOptions describes the options for handshake. type HandshakeOptions struct { Addr string diff --git a/cmd/gost/main.go b/cmd/gost/main.go index b8b7b20..a1c7e39 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -81,6 +81,10 @@ func initChain() (*gost.Chain, error) { if err != nil { return nil, err } + + node.IPs = parseIP(node.Values.Get("ip")) + node.IPSelector = &gost.RoundRobinIPSelector{} + users, err := parseUsers(node.Values.Get("secrets")) if err != nil { return nil, err @@ -201,7 +205,6 @@ func initChain() (*gost.Chain, error) { timeout, _ := strconv.Atoi(node.Values.Get("timeout")) node.DialOptions = append(node.DialOptions, gost.TimeoutDialOption(time.Duration(timeout)*time.Second), - gost.IPDialOption(parseIP(node.Values.Get("ip"))...), ) interval, _ := strconv.Atoi(node.Values.Get("ping")) @@ -511,9 +514,11 @@ func parseIP(s string) (ips []string) { if err != nil { ss := strings.Split(s, ",") for _, s := range ss { - if ip := net.ParseIP(s); ip != nil { + s = strings.TrimSpace(s) + if s != "" { ips = append(ips, s) } + } return } @@ -524,9 +529,7 @@ func parseIP(s string) (ips []string) { if line == "" || strings.HasPrefix(line, "#") { continue } - if ip := net.ParseIP(line); ip != nil { - ips = append(ips, line) - } + ips = append(ips, line) } return } diff --git a/node.go b/node.go index 61af6c2..4a2ef1a 100644 --- a/node.go +++ b/node.go @@ -8,6 +8,7 @@ import ( // Node is a proxy node, mainly used to construct a proxy chain. type Node struct { Addr string + IPs []string Protocol string Transport string Remote string // remote address, used by tcp/udp port forwarding @@ -16,6 +17,7 @@ type Node struct { DialOptions []DialOption HandshakeOptions []HandshakeOption Client *Client + IPSelector IPSelector } // ParseNode parses the node info. @@ -81,7 +83,7 @@ func ParseNode(s string) (node Node, err error) { type NodeGroup struct { nodes []Node Options []SelectOption - Selector Selector + Selector NodeSelector } // NewNodeGroup creates a node group diff --git a/selector.go b/selector.go index 5a48982..6f49987 100644 --- a/selector.go +++ b/selector.go @@ -1,6 +1,10 @@ package gost -import "errors" +import ( + "errors" + "sync/atomic" + "time" +) var ( // ErrNoneAvailable indicates there is no node available @@ -10,8 +14,8 @@ var ( // SelectOption used when making a select call type SelectOption func(*SelectOptions) -// Selector as a mechanism to pick nodes and mark their status. -type Selector interface { +// NodeSelector as a mechanism to pick nodes and mark their status. +type NodeSelector interface { Select(nodes []Node, opts ...SelectOption) (Node, error) // Mark(node Node) String() string @@ -71,3 +75,44 @@ func WithStrategy(s Strategy) SelectOption { o.Strategy = s } } + +// IPSelector as a mechanism to pick IPs and mark their status. +type IPSelector interface { + Select(ips []string) (string, error) + String() string +} + +// RandomIPSelector is an IP Selector that selects an IP with random strategy. +type RandomIPSelector struct { +} + +// Select selects an IP from ips list. +func (s *RandomIPSelector) Select(ips []string) (string, error) { + if len(ips) == 0 { + return "", nil + } + return ips[time.Now().Nanosecond()%len(ips)], nil +} + +func (s *RandomIPSelector) String() string { + return "random" +} + +// RoundRobinIPSelector is an IP Selector that selects an IP with round-robin strategy. +type RoundRobinIPSelector struct { + count uint64 +} + +// Select selects an IP from ips list. +func (s *RoundRobinIPSelector) Select(ips []string) (string, error) { + if len(ips) == 0 { + return "", nil + } + + count := atomic.AddUint64(&s.count, 1) + return ips[int(count%uint64(len(ips)))], nil +} + +func (s *RoundRobinIPSelector) String() string { + return "round" +} diff --git a/tls.go b/tls.go index edd0c82..f444330 100644 --- a/tls.go +++ b/tls.go @@ -6,7 +6,6 @@ import ( "errors" "net" "sync" - "sync/atomic" "time" "github.com/go-log/log" @@ -53,20 +52,10 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co option(opts) } - if len(opts.IPs) > 0 { - count := atomic.AddUint64(&tr.count, 1) - _, sport, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - n := uint64(len(opts.IPs)) - addr = opts.IPs[int(count%n)] + ":" + sport - } - tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() - session, ok := tr.sessions[addr] // TODO: the addr may be changed. + session, ok := tr.sessions[addr] if !ok { if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout) diff --git a/ws.go b/ws.go index 0583154..d77c709 100644 --- a/ws.go +++ b/ws.go @@ -10,7 +10,6 @@ import ( "net/http" "net/http/httputil" "sync" - "sync/atomic" "time" "net/url" @@ -155,20 +154,10 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con option(opts) } - if len(opts.IPs) > 0 { - count := atomic.AddUint64(&tr.count, 1) - _, sport, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - n := uint64(len(opts.IPs)) - addr = opts.IPs[int(count%n)] + ":" + sport - } - tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() - session, ok := tr.sessions[addr] // TODO: the addr may be changed. + session, ok := tr.sessions[addr] if !ok { if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout) @@ -288,20 +277,10 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co option(opts) } - if len(opts.IPs) > 0 { - count := atomic.AddUint64(&tr.count, 1) - _, sport, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - n := uint64(len(opts.IPs)) - addr = opts.IPs[int(count%n)] + ":" + sport - } - tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() - session, ok := tr.sessions[addr] // TODO: the addr may be changed. + session, ok := tr.sessions[addr] if !ok { if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout)