add IP list support for TCP transport

This commit is contained in:
rui.zheng 2017-08-11 16:49:10 +08:00
parent 2446f5d0a2
commit 0f07c8d5dd
3 changed files with 14 additions and 34 deletions

View File

@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"net" "net"
"net/url" "net/url"
"sync/atomic"
"time" "time"
) )
@ -63,6 +64,7 @@ type Transporter interface {
} }
type tcpTransporter struct { type tcpTransporter struct {
count uint64
} }
// TCPTransporter creates a transporter for TCP proxy client. // TCPTransporter creates a transporter for TCP proxy client.
@ -75,6 +77,17 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er
for _, option := range options { for _, option := range options {
option(opts) option(opts)
} }
if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}
if opts.Chain == nil { if opts.Chain == nil {
return net.DialTimeout("tcp", addr, opts.Timeout) return net.DialTimeout("tcp", addr, opts.Timeout)
} }
@ -94,17 +107,6 @@ type DialOptions struct {
Timeout time.Duration Timeout time.Duration
Chain *Chain Chain *Chain
IPs []string IPs []string
// count uint32
}
func (o *DialOptions) getIP() string {
n := len(o.IPs)
if n == 0 {
return ""
}
return o.IPs[int(time.Now().Nanosecond())%n]
// count := atomic.AddUint32(&o.count, 1)
//return o.IPs[int(count)%n]
} }
// DialOption allows a common way to set dial options. // DialOption allows a common way to set dial options.

View File

@ -112,11 +112,6 @@ func initChain() (*gost.Chain, error) {
wsOpts.UserAgent = node.Values.Get("agent") wsOpts.UserAgent = node.Values.Get("agent")
tr = gost.WSTransporter(wsOpts) tr = gost.WSTransporter(wsOpts)
case "wss": case "wss":
ips := strings.Split(node.Values.Get("ip"), ",")
node.DialOptions = append(node.DialOptions,
gost.IPDialOption(ips...),
)
wsOpts := &gost.WSOptions{} wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression")) wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
@ -204,6 +199,7 @@ func initChain() (*gost.Chain, error) {
timeout, _ := strconv.Atoi(node.Values.Get("timeout")) timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
node.DialOptions = append(node.DialOptions, node.DialOptions = append(node.DialOptions,
gost.TimeoutDialOption(time.Duration(timeout)*time.Second), gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
gost.IPDialOption(strings.Split(node.Values.Get("ip"), ",")...),
) )
interval, _ := strconv.Atoi(node.Values.Get("ping")) interval, _ := strconv.Atoi(node.Values.Get("ping"))

18
ws.go
View File

@ -139,24 +139,6 @@ func WSSTransporter(opts *WSOptions) Transporter {
} }
} }
func (tr *wssTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
opts := &DialOptions{}
for _, option := range options {
option(opts)
}
if ip := opts.getIP(); ip != "" {
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
addr = ip + ":" + sport
}
if opts.Chain == nil {
return net.DialTimeout("tcp", addr, opts.Timeout)
}
return opts.Chain.Dial(addr)
}
func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
opts := &HandshakeOptions{} opts := &HandshakeOptions{}
for _, option := range options { for _, option := range options {