This commit is contained in:
rui.zheng 2017-11-02 17:27:45 +08:00
parent c82f2d904d
commit e3120ca370
7 changed files with 104 additions and 66 deletions

View File

@ -3,6 +3,9 @@ package gost
import (
"errors"
"net"
"strings"
"github.com/go-log/log"
)
var (
@ -122,13 +125,18 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
if selector == nil {
selector = &defaultSelector{}
}
// select node from node group
node, err := selector.Select(groups[0].Nodes(), groups[0].Options...)
if err != nil {
return
}
nodes = append(nodes, node)
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
addr, err := selectIP(&node)
if err != nil {
return
}
cn, err := node.Client.Dial(addr, node.DialOptions...)
if err != nil {
return
}
@ -154,8 +162,13 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
}
nodes = append(nodes, node)
addr, err = selectIP(&node)
if err != nil {
return
}
var cc net.Conn
cc, err = preNode.Client.Connect(cn, node.Addr)
cc, err = preNode.Client.Connect(cn, addr)
if err != nil {
cn.Close()
return
@ -172,3 +185,29 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
conn = cn
return
}
func selectIP(node *Node) (string, error) {
addr := node.Addr
s := node.IPSelector
if s == nil {
s = &RandomIPSelector{}
}
// select IP from IP list
ip, err := s.Select(node.IPs)
if err != nil {
return "", err
}
if ip != "" {
if !strings.Contains(ip, ":") {
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return "", err
}
ip = ip + ":" + sport
}
addr = ip
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr))
}
log.Log("select IP:", node.Addr, node.IPs, addr)
return addr, nil
}

View File

@ -4,7 +4,6 @@ import (
"crypto/tls"
"net"
"net/url"
"sync/atomic"
"time"
)
@ -64,7 +63,6 @@ type Transporter interface {
}
type tcpTransporter struct {
count uint64
}
// TCPTransporter creates a transporter for TCP proxy client.
@ -78,16 +76,6 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er
option(opts)
}
if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}
if opts.Chain == nil {
return net.DialTimeout("tcp", addr, opts.Timeout)
}
@ -106,7 +94,7 @@ func (tr *tcpTransporter) Multiplex() bool {
type DialOptions struct {
Timeout time.Duration
Chain *Chain
IPs []string
// IPs []string
}
// DialOption allows a common way to set dial options.
@ -126,13 +114,6 @@ func ChainDialOption(chain *Chain) DialOption {
}
}
// IPDialOption specifies an IP list used by Transporter.Dial
func IPDialOption(ips ...string) DialOption {
return func(opts *DialOptions) {
opts.IPs = ips
}
}
// HandshakeOptions describes the options for handshake.
type HandshakeOptions struct {
Addr string

View File

@ -81,6 +81,10 @@ func initChain() (*gost.Chain, error) {
if err != nil {
return nil, err
}
node.IPs = parseIP(node.Values.Get("ip"))
node.IPSelector = &gost.RoundRobinIPSelector{}
users, err := parseUsers(node.Values.Get("secrets"))
if err != nil {
return nil, err
@ -201,7 +205,6 @@ func initChain() (*gost.Chain, error) {
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
node.DialOptions = append(node.DialOptions,
gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
gost.IPDialOption(parseIP(node.Values.Get("ip"))...),
)
interval, _ := strconv.Atoi(node.Values.Get("ping"))
@ -511,9 +514,11 @@ func parseIP(s string) (ips []string) {
if err != nil {
ss := strings.Split(s, ",")
for _, s := range ss {
if ip := net.ParseIP(s); ip != nil {
s = strings.TrimSpace(s)
if s != "" {
ips = append(ips, s)
}
}
return
}
@ -524,9 +529,7 @@ func parseIP(s string) (ips []string) {
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if ip := net.ParseIP(line); ip != nil {
ips = append(ips, line)
}
ips = append(ips, line)
}
return
}

View File

@ -8,6 +8,7 @@ import (
// Node is a proxy node, mainly used to construct a proxy chain.
type Node struct {
Addr string
IPs []string
Protocol string
Transport string
Remote string // remote address, used by tcp/udp port forwarding
@ -16,6 +17,7 @@ type Node struct {
DialOptions []DialOption
HandshakeOptions []HandshakeOption
Client *Client
IPSelector IPSelector
}
// ParseNode parses the node info.
@ -81,7 +83,7 @@ func ParseNode(s string) (node Node, err error) {
type NodeGroup struct {
nodes []Node
Options []SelectOption
Selector Selector
Selector NodeSelector
}
// NewNodeGroup creates a node group

View File

@ -1,6 +1,10 @@
package gost
import "errors"
import (
"errors"
"sync/atomic"
"time"
)
var (
// ErrNoneAvailable indicates there is no node available
@ -10,8 +14,8 @@ var (
// SelectOption used when making a select call
type SelectOption func(*SelectOptions)
// Selector as a mechanism to pick nodes and mark their status.
type Selector interface {
// NodeSelector as a mechanism to pick nodes and mark their status.
type NodeSelector interface {
Select(nodes []Node, opts ...SelectOption) (Node, error)
// Mark(node Node)
String() string
@ -71,3 +75,44 @@ func WithStrategy(s Strategy) SelectOption {
o.Strategy = s
}
}
// IPSelector as a mechanism to pick IPs and mark their status.
type IPSelector interface {
Select(ips []string) (string, error)
String() string
}
// RandomIPSelector is an IP Selector that selects an IP with random strategy.
type RandomIPSelector struct {
}
// Select selects an IP from ips list.
func (s *RandomIPSelector) Select(ips []string) (string, error) {
if len(ips) == 0 {
return "", nil
}
return ips[time.Now().Nanosecond()%len(ips)], nil
}
func (s *RandomIPSelector) String() string {
return "random"
}
// RoundRobinIPSelector is an IP Selector that selects an IP with round-robin strategy.
type RoundRobinIPSelector struct {
count uint64
}
// Select selects an IP from ips list.
func (s *RoundRobinIPSelector) Select(ips []string) (string, error) {
if len(ips) == 0 {
return "", nil
}
count := atomic.AddUint64(&s.count, 1)
return ips[int(count%uint64(len(ips)))], nil
}
func (s *RoundRobinIPSelector) String() string {
return "round"
}

13
tls.go
View File

@ -6,7 +6,6 @@ import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/go-log/log"
@ -53,20 +52,10 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
option(opts)
}
if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
session, ok := tr.sessions[addr]
if !ok {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)

25
ws.go
View File

@ -10,7 +10,6 @@ import (
"net/http"
"net/http/httputil"
"sync"
"sync/atomic"
"time"
"net/url"
@ -155,20 +154,10 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con
option(opts)
}
if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
session, ok := tr.sessions[addr]
if !ok {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
@ -288,20 +277,10 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co
option(opts)
}
if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
session, ok := tr.sessions[addr]
if !ok {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)