fix issue #173
This commit is contained in:
parent
c82f2d904d
commit
e3120ca370
43
chain.go
43
chain.go
@ -3,6 +3,9 @@ package gost
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/go-log/log"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -122,13 +125,18 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
|
||||
if selector == nil {
|
||||
selector = &defaultSelector{}
|
||||
}
|
||||
// select node from node group
|
||||
node, err := selector.Select(groups[0].Nodes(), groups[0].Options...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
nodes = append(nodes, node)
|
||||
|
||||
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
||||
addr, err := selectIP(&node)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cn, err := node.Client.Dial(addr, node.DialOptions...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -154,8 +162,13 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
|
||||
}
|
||||
nodes = append(nodes, node)
|
||||
|
||||
addr, err = selectIP(&node)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var cc net.Conn
|
||||
cc, err = preNode.Client.Connect(cn, node.Addr)
|
||||
cc, err = preNode.Client.Connect(cn, addr)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
return
|
||||
@ -172,3 +185,29 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
|
||||
conn = cn
|
||||
return
|
||||
}
|
||||
|
||||
func selectIP(node *Node) (string, error) {
|
||||
addr := node.Addr
|
||||
s := node.IPSelector
|
||||
if s == nil {
|
||||
s = &RandomIPSelector{}
|
||||
}
|
||||
// select IP from IP list
|
||||
ip, err := s.Select(node.IPs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ip != "" {
|
||||
if !strings.Contains(ip, ":") {
|
||||
_, sport, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ip = ip + ":" + sport
|
||||
}
|
||||
addr = ip
|
||||
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr))
|
||||
}
|
||||
log.Log("select IP:", node.Addr, node.IPs, addr)
|
||||
return addr, nil
|
||||
}
|
||||
|
21
client.go
21
client.go
@ -4,7 +4,6 @@ import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -64,7 +63,6 @@ type Transporter interface {
|
||||
}
|
||||
|
||||
type tcpTransporter struct {
|
||||
count uint64
|
||||
}
|
||||
|
||||
// TCPTransporter creates a transporter for TCP proxy client.
|
||||
@ -78,16 +76,6 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er
|
||||
option(opts)
|
||||
}
|
||||
|
||||
if len(opts.IPs) > 0 {
|
||||
count := atomic.AddUint64(&tr.count, 1)
|
||||
_, sport, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := uint64(len(opts.IPs))
|
||||
addr = opts.IPs[int(count%n)] + ":" + sport
|
||||
}
|
||||
|
||||
if opts.Chain == nil {
|
||||
return net.DialTimeout("tcp", addr, opts.Timeout)
|
||||
}
|
||||
@ -106,7 +94,7 @@ func (tr *tcpTransporter) Multiplex() bool {
|
||||
type DialOptions struct {
|
||||
Timeout time.Duration
|
||||
Chain *Chain
|
||||
IPs []string
|
||||
// IPs []string
|
||||
}
|
||||
|
||||
// DialOption allows a common way to set dial options.
|
||||
@ -126,13 +114,6 @@ func ChainDialOption(chain *Chain) DialOption {
|
||||
}
|
||||
}
|
||||
|
||||
// IPDialOption specifies an IP list used by Transporter.Dial
|
||||
func IPDialOption(ips ...string) DialOption {
|
||||
return func(opts *DialOptions) {
|
||||
opts.IPs = ips
|
||||
}
|
||||
}
|
||||
|
||||
// HandshakeOptions describes the options for handshake.
|
||||
type HandshakeOptions struct {
|
||||
Addr string
|
||||
|
@ -81,6 +81,10 @@ func initChain() (*gost.Chain, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node.IPs = parseIP(node.Values.Get("ip"))
|
||||
node.IPSelector = &gost.RoundRobinIPSelector{}
|
||||
|
||||
users, err := parseUsers(node.Values.Get("secrets"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -201,7 +205,6 @@ func initChain() (*gost.Chain, error) {
|
||||
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
|
||||
node.DialOptions = append(node.DialOptions,
|
||||
gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
|
||||
gost.IPDialOption(parseIP(node.Values.Get("ip"))...),
|
||||
)
|
||||
|
||||
interval, _ := strconv.Atoi(node.Values.Get("ping"))
|
||||
@ -511,9 +514,11 @@ func parseIP(s string) (ips []string) {
|
||||
if err != nil {
|
||||
ss := strings.Split(s, ",")
|
||||
for _, s := range ss {
|
||||
if ip := net.ParseIP(s); ip != nil {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
ips = append(ips, s)
|
||||
}
|
||||
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -524,9 +529,7 @@ func parseIP(s string) (ips []string) {
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
if ip := net.ParseIP(line); ip != nil {
|
||||
ips = append(ips, line)
|
||||
}
|
||||
ips = append(ips, line)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
4
node.go
4
node.go
@ -8,6 +8,7 @@ import (
|
||||
// Node is a proxy node, mainly used to construct a proxy chain.
|
||||
type Node struct {
|
||||
Addr string
|
||||
IPs []string
|
||||
Protocol string
|
||||
Transport string
|
||||
Remote string // remote address, used by tcp/udp port forwarding
|
||||
@ -16,6 +17,7 @@ type Node struct {
|
||||
DialOptions []DialOption
|
||||
HandshakeOptions []HandshakeOption
|
||||
Client *Client
|
||||
IPSelector IPSelector
|
||||
}
|
||||
|
||||
// ParseNode parses the node info.
|
||||
@ -81,7 +83,7 @@ func ParseNode(s string) (node Node, err error) {
|
||||
type NodeGroup struct {
|
||||
nodes []Node
|
||||
Options []SelectOption
|
||||
Selector Selector
|
||||
Selector NodeSelector
|
||||
}
|
||||
|
||||
// NewNodeGroup creates a node group
|
||||
|
51
selector.go
51
selector.go
@ -1,6 +1,10 @@
|
||||
package gost
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoneAvailable indicates there is no node available
|
||||
@ -10,8 +14,8 @@ var (
|
||||
// SelectOption used when making a select call
|
||||
type SelectOption func(*SelectOptions)
|
||||
|
||||
// Selector as a mechanism to pick nodes and mark their status.
|
||||
type Selector interface {
|
||||
// NodeSelector as a mechanism to pick nodes and mark their status.
|
||||
type NodeSelector interface {
|
||||
Select(nodes []Node, opts ...SelectOption) (Node, error)
|
||||
// Mark(node Node)
|
||||
String() string
|
||||
@ -71,3 +75,44 @@ func WithStrategy(s Strategy) SelectOption {
|
||||
o.Strategy = s
|
||||
}
|
||||
}
|
||||
|
||||
// IPSelector as a mechanism to pick IPs and mark their status.
|
||||
type IPSelector interface {
|
||||
Select(ips []string) (string, error)
|
||||
String() string
|
||||
}
|
||||
|
||||
// RandomIPSelector is an IP Selector that selects an IP with random strategy.
|
||||
type RandomIPSelector struct {
|
||||
}
|
||||
|
||||
// Select selects an IP from ips list.
|
||||
func (s *RandomIPSelector) Select(ips []string) (string, error) {
|
||||
if len(ips) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return ips[time.Now().Nanosecond()%len(ips)], nil
|
||||
}
|
||||
|
||||
func (s *RandomIPSelector) String() string {
|
||||
return "random"
|
||||
}
|
||||
|
||||
// RoundRobinIPSelector is an IP Selector that selects an IP with round-robin strategy.
|
||||
type RoundRobinIPSelector struct {
|
||||
count uint64
|
||||
}
|
||||
|
||||
// Select selects an IP from ips list.
|
||||
func (s *RoundRobinIPSelector) Select(ips []string) (string, error) {
|
||||
if len(ips) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
count := atomic.AddUint64(&s.count, 1)
|
||||
return ips[int(count%uint64(len(ips)))], nil
|
||||
}
|
||||
|
||||
func (s *RoundRobinIPSelector) String() string {
|
||||
return "round"
|
||||
}
|
||||
|
13
tls.go
13
tls.go
@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-log/log"
|
||||
@ -53,20 +52,10 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
||||
option(opts)
|
||||
}
|
||||
|
||||
if len(opts.IPs) > 0 {
|
||||
count := atomic.AddUint64(&tr.count, 1)
|
||||
_, sport, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := uint64(len(opts.IPs))
|
||||
addr = opts.IPs[int(count%n)] + ":" + sport
|
||||
}
|
||||
|
||||
tr.sessionMutex.Lock()
|
||||
defer tr.sessionMutex.Unlock()
|
||||
|
||||
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
|
||||
session, ok := tr.sessions[addr]
|
||||
if !ok {
|
||||
if opts.Chain == nil {
|
||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||
|
25
ws.go
25
ws.go
@ -10,7 +10,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"net/url"
|
||||
@ -155,20 +154,10 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con
|
||||
option(opts)
|
||||
}
|
||||
|
||||
if len(opts.IPs) > 0 {
|
||||
count := atomic.AddUint64(&tr.count, 1)
|
||||
_, sport, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := uint64(len(opts.IPs))
|
||||
addr = opts.IPs[int(count%n)] + ":" + sport
|
||||
}
|
||||
|
||||
tr.sessionMutex.Lock()
|
||||
defer tr.sessionMutex.Unlock()
|
||||
|
||||
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
|
||||
session, ok := tr.sessions[addr]
|
||||
if !ok {
|
||||
if opts.Chain == nil {
|
||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||
@ -288,20 +277,10 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co
|
||||
option(opts)
|
||||
}
|
||||
|
||||
if len(opts.IPs) > 0 {
|
||||
count := atomic.AddUint64(&tr.count, 1)
|
||||
_, sport, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := uint64(len(opts.IPs))
|
||||
addr = opts.IPs[int(count%n)] + ":" + sport
|
||||
}
|
||||
|
||||
tr.sessionMutex.Lock()
|
||||
defer tr.sessionMutex.Unlock()
|
||||
|
||||
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
|
||||
session, ok := tr.sessions[addr]
|
||||
if !ok {
|
||||
if opts.Chain == nil {
|
||||
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||
|
Loading…
Reference in New Issue
Block a user