fix issue #173
This commit is contained in:
parent
c82f2d904d
commit
e3120ca370
43
chain.go
43
chain.go
@ -3,6 +3,9 @@ package gost
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-log/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -122,13 +125,18 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
|
|||||||
if selector == nil {
|
if selector == nil {
|
||||||
selector = &defaultSelector{}
|
selector = &defaultSelector{}
|
||||||
}
|
}
|
||||||
|
// select node from node group
|
||||||
node, err := selector.Select(groups[0].Nodes(), groups[0].Options...)
|
node, err := selector.Select(groups[0].Nodes(), groups[0].Options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
nodes = append(nodes, node)
|
nodes = append(nodes, node)
|
||||||
|
|
||||||
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
addr, err := selectIP(&node)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cn, err := node.Client.Dial(addr, node.DialOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -154,8 +162,13 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
|
|||||||
}
|
}
|
||||||
nodes = append(nodes, node)
|
nodes = append(nodes, node)
|
||||||
|
|
||||||
|
addr, err = selectIP(&node)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var cc net.Conn
|
var cc net.Conn
|
||||||
cc, err = preNode.Client.Connect(cn, node.Addr)
|
cc, err = preNode.Client.Connect(cn, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cn.Close()
|
cn.Close()
|
||||||
return
|
return
|
||||||
@ -172,3 +185,29 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
|
|||||||
conn = cn
|
conn = cn
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func selectIP(node *Node) (string, error) {
|
||||||
|
addr := node.Addr
|
||||||
|
s := node.IPSelector
|
||||||
|
if s == nil {
|
||||||
|
s = &RandomIPSelector{}
|
||||||
|
}
|
||||||
|
// select IP from IP list
|
||||||
|
ip, err := s.Select(node.IPs)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if ip != "" {
|
||||||
|
if !strings.Contains(ip, ":") {
|
||||||
|
_, sport, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
ip = ip + ":" + sport
|
||||||
|
}
|
||||||
|
addr = ip
|
||||||
|
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr))
|
||||||
|
}
|
||||||
|
log.Log("select IP:", node.Addr, node.IPs, addr)
|
||||||
|
return addr, nil
|
||||||
|
}
|
||||||
|
21
client.go
21
client.go
@ -4,7 +4,6 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -64,7 +63,6 @@ type Transporter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type tcpTransporter struct {
|
type tcpTransporter struct {
|
||||||
count uint64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TCPTransporter creates a transporter for TCP proxy client.
|
// TCPTransporter creates a transporter for TCP proxy client.
|
||||||
@ -78,16 +76,6 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er
|
|||||||
option(opts)
|
option(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.IPs) > 0 {
|
|
||||||
count := atomic.AddUint64(&tr.count, 1)
|
|
||||||
_, sport, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
n := uint64(len(opts.IPs))
|
|
||||||
addr = opts.IPs[int(count%n)] + ":" + sport
|
|
||||||
}
|
|
||||||
|
|
||||||
if opts.Chain == nil {
|
if opts.Chain == nil {
|
||||||
return net.DialTimeout("tcp", addr, opts.Timeout)
|
return net.DialTimeout("tcp", addr, opts.Timeout)
|
||||||
}
|
}
|
||||||
@ -106,7 +94,7 @@ func (tr *tcpTransporter) Multiplex() bool {
|
|||||||
type DialOptions struct {
|
type DialOptions struct {
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
Chain *Chain
|
Chain *Chain
|
||||||
IPs []string
|
// IPs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialOption allows a common way to set dial options.
|
// DialOption allows a common way to set dial options.
|
||||||
@ -126,13 +114,6 @@ func ChainDialOption(chain *Chain) DialOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPDialOption specifies an IP list used by Transporter.Dial
|
|
||||||
func IPDialOption(ips ...string) DialOption {
|
|
||||||
return func(opts *DialOptions) {
|
|
||||||
opts.IPs = ips
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandshakeOptions describes the options for handshake.
|
// HandshakeOptions describes the options for handshake.
|
||||||
type HandshakeOptions struct {
|
type HandshakeOptions struct {
|
||||||
Addr string
|
Addr string
|
||||||
|
@ -81,6 +81,10 @@ func initChain() (*gost.Chain, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
node.IPs = parseIP(node.Values.Get("ip"))
|
||||||
|
node.IPSelector = &gost.RoundRobinIPSelector{}
|
||||||
|
|
||||||
users, err := parseUsers(node.Values.Get("secrets"))
|
users, err := parseUsers(node.Values.Get("secrets"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -201,7 +205,6 @@ func initChain() (*gost.Chain, error) {
|
|||||||
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
||||||
node.DialOptions = append(node.DialOptions,
|
node.DialOptions = append(node.DialOptions,
|
||||||
gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
|
gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
|
||||||
gost.IPDialOption(parseIP(node.Values.Get("ip"))...),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
interval, _ := strconv.Atoi(node.Values.Get("ping"))
|
interval, _ := strconv.Atoi(node.Values.Get("ping"))
|
||||||
@ -511,9 +514,11 @@ func parseIP(s string) (ips []string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
ss := strings.Split(s, ",")
|
ss := strings.Split(s, ",")
|
||||||
for _, s := range ss {
|
for _, s := range ss {
|
||||||
if ip := net.ParseIP(s); ip != nil {
|
s = strings.TrimSpace(s)
|
||||||
|
if s != "" {
|
||||||
ips = append(ips, s)
|
ips = append(ips, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -524,9 +529,7 @@ func parseIP(s string) (ips []string) {
|
|||||||
if line == "" || strings.HasPrefix(line, "#") {
|
if line == "" || strings.HasPrefix(line, "#") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if ip := net.ParseIP(line); ip != nil {
|
ips = append(ips, line)
|
||||||
ips = append(ips, line)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
4
node.go
4
node.go
@ -8,6 +8,7 @@ import (
|
|||||||
// Node is a proxy node, mainly used to construct a proxy chain.
|
// Node is a proxy node, mainly used to construct a proxy chain.
|
||||||
type Node struct {
|
type Node struct {
|
||||||
Addr string
|
Addr string
|
||||||
|
IPs []string
|
||||||
Protocol string
|
Protocol string
|
||||||
Transport string
|
Transport string
|
||||||
Remote string // remote address, used by tcp/udp port forwarding
|
Remote string // remote address, used by tcp/udp port forwarding
|
||||||
@ -16,6 +17,7 @@ type Node struct {
|
|||||||
DialOptions []DialOption
|
DialOptions []DialOption
|
||||||
HandshakeOptions []HandshakeOption
|
HandshakeOptions []HandshakeOption
|
||||||
Client *Client
|
Client *Client
|
||||||
|
IPSelector IPSelector
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseNode parses the node info.
|
// ParseNode parses the node info.
|
||||||
@ -81,7 +83,7 @@ func ParseNode(s string) (node Node, err error) {
|
|||||||
type NodeGroup struct {
|
type NodeGroup struct {
|
||||||
nodes []Node
|
nodes []Node
|
||||||
Options []SelectOption
|
Options []SelectOption
|
||||||
Selector Selector
|
Selector NodeSelector
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNodeGroup creates a node group
|
// NewNodeGroup creates a node group
|
||||||
|
51
selector.go
51
selector.go
@ -1,6 +1,10 @@
|
|||||||
package gost
|
package gost
|
||||||
|
|
||||||
import "errors"
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrNoneAvailable indicates there is no node available
|
// ErrNoneAvailable indicates there is no node available
|
||||||
@ -10,8 +14,8 @@ var (
|
|||||||
// SelectOption used when making a select call
|
// SelectOption used when making a select call
|
||||||
type SelectOption func(*SelectOptions)
|
type SelectOption func(*SelectOptions)
|
||||||
|
|
||||||
// Selector as a mechanism to pick nodes and mark their status.
|
// NodeSelector as a mechanism to pick nodes and mark their status.
|
||||||
type Selector interface {
|
type NodeSelector interface {
|
||||||
Select(nodes []Node, opts ...SelectOption) (Node, error)
|
Select(nodes []Node, opts ...SelectOption) (Node, error)
|
||||||
// Mark(node Node)
|
// Mark(node Node)
|
||||||
String() string
|
String() string
|
||||||
@ -71,3 +75,44 @@ func WithStrategy(s Strategy) SelectOption {
|
|||||||
o.Strategy = s
|
o.Strategy = s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IPSelector as a mechanism to pick IPs and mark their status.
|
||||||
|
type IPSelector interface {
|
||||||
|
Select(ips []string) (string, error)
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomIPSelector is an IP Selector that selects an IP with random strategy.
|
||||||
|
type RandomIPSelector struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select selects an IP from ips list.
|
||||||
|
func (s *RandomIPSelector) Select(ips []string) (string, error) {
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return ips[time.Now().Nanosecond()%len(ips)], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RandomIPSelector) String() string {
|
||||||
|
return "random"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundRobinIPSelector is an IP Selector that selects an IP with round-robin strategy.
|
||||||
|
type RoundRobinIPSelector struct {
|
||||||
|
count uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select selects an IP from ips list.
|
||||||
|
func (s *RoundRobinIPSelector) Select(ips []string) (string, error) {
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
count := atomic.AddUint64(&s.count, 1)
|
||||||
|
return ips[int(count%uint64(len(ips)))], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RoundRobinIPSelector) String() string {
|
||||||
|
return "round"
|
||||||
|
}
|
||||||
|
13
tls.go
13
tls.go
@ -6,7 +6,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-log/log"
|
"github.com/go-log/log"
|
||||||
@ -53,20 +52,10 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
|||||||
option(opts)
|
option(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.IPs) > 0 {
|
|
||||||
count := atomic.AddUint64(&tr.count, 1)
|
|
||||||
_, sport, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
n := uint64(len(opts.IPs))
|
|
||||||
addr = opts.IPs[int(count%n)] + ":" + sport
|
|
||||||
}
|
|
||||||
|
|
||||||
tr.sessionMutex.Lock()
|
tr.sessionMutex.Lock()
|
||||||
defer tr.sessionMutex.Unlock()
|
defer tr.sessionMutex.Unlock()
|
||||||
|
|
||||||
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
|
session, ok := tr.sessions[addr]
|
||||||
if !ok {
|
if !ok {
|
||||||
if opts.Chain == nil {
|
if opts.Chain == nil {
|
||||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||||
|
25
ws.go
25
ws.go
@ -10,7 +10,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -155,20 +154,10 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con
|
|||||||
option(opts)
|
option(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.IPs) > 0 {
|
|
||||||
count := atomic.AddUint64(&tr.count, 1)
|
|
||||||
_, sport, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
n := uint64(len(opts.IPs))
|
|
||||||
addr = opts.IPs[int(count%n)] + ":" + sport
|
|
||||||
}
|
|
||||||
|
|
||||||
tr.sessionMutex.Lock()
|
tr.sessionMutex.Lock()
|
||||||
defer tr.sessionMutex.Unlock()
|
defer tr.sessionMutex.Unlock()
|
||||||
|
|
||||||
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
|
session, ok := tr.sessions[addr]
|
||||||
if !ok {
|
if !ok {
|
||||||
if opts.Chain == nil {
|
if opts.Chain == nil {
|
||||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||||
@ -288,20 +277,10 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
|||||||
option(opts)
|
option(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.IPs) > 0 {
|
|
||||||
count := atomic.AddUint64(&tr.count, 1)
|
|
||||||
_, sport, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
n := uint64(len(opts.IPs))
|
|
||||||
addr = opts.IPs[int(count%n)] + ":" + sport
|
|
||||||
}
|
|
||||||
|
|
||||||
tr.sessionMutex.Lock()
|
tr.sessionMutex.Lock()
|
||||||
defer tr.sessionMutex.Unlock()
|
defer tr.sessionMutex.Unlock()
|
||||||
|
|
||||||
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
|
session, ok := tr.sessions[addr]
|
||||||
if !ok {
|
if !ok {
|
||||||
if opts.Chain == nil {
|
if opts.Chain == nil {
|
||||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||||
|
Loading…
Reference in New Issue
Block a user