From 0f07c8d5ddb4c2a4de5c4e2331b0b1750fc95b84 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Fri, 11 Aug 2017 16:49:10 +0800 Subject: [PATCH] add IP list support for TCP transport --- client.go | 24 +++++++++++++----------- cmd/gost/main.go | 6 +----- ws.go | 18 ------------------ 3 files changed, 14 insertions(+), 34 deletions(-) diff --git a/client.go b/client.go index 03ea503..5f43c50 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net" "net/url" + "sync/atomic" "time" ) @@ -63,6 +64,7 @@ type Transporter interface { } type tcpTransporter struct { + count uint64 } // TCPTransporter creates a transporter for TCP proxy client. @@ -75,6 +77,17 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er for _, option := range options { option(opts) } + + if len(opts.IPs) > 0 { + count := atomic.AddUint64(&tr.count, 1) + _, sport, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + n := uint64(len(opts.IPs)) + addr = opts.IPs[int(count%n)] + ":" + sport + } + if opts.Chain == nil { return net.DialTimeout("tcp", addr, opts.Timeout) } @@ -94,17 +107,6 @@ type DialOptions struct { Timeout time.Duration Chain *Chain IPs []string - // count uint32 -} - -func (o *DialOptions) getIP() string { - n := len(o.IPs) - if n == 0 { - return "" - } - return o.IPs[int(time.Now().Nanosecond())%n] - // count := atomic.AddUint32(&o.count, 1) - //return o.IPs[int(count)%n] } // DialOption allows a common way to set dial options. diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 38eb325..8740ee8 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -112,11 +112,6 @@ func initChain() (*gost.Chain, error) { wsOpts.UserAgent = node.Values.Get("agent") tr = gost.WSTransporter(wsOpts) case "wss": - ips := strings.Split(node.Values.Get("ip"), ",") - node.DialOptions = append(node.DialOptions, - gost.IPDialOption(ips...), - ) - wsOpts := &gost.WSOptions{} wsOpts.EnableCompression = toBool(node.Values.Get("compression")) wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) @@ -204,6 +199,7 @@ func initChain() (*gost.Chain, error) { timeout, _ := strconv.Atoi(node.Values.Get("timeout")) node.DialOptions = append(node.DialOptions, gost.TimeoutDialOption(time.Duration(timeout)*time.Second), + gost.IPDialOption(strings.Split(node.Values.Get("ip"), ",")...), ) interval, _ := strconv.Atoi(node.Values.Get("ping")) diff --git a/ws.go b/ws.go index fda58a7..69be7dc 100644 --- a/ws.go +++ b/ws.go @@ -139,24 +139,6 @@ func WSSTransporter(opts *WSOptions) Transporter { } } -func (tr *wssTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - if ip := opts.getIP(); ip != "" { - _, sport, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - addr = ip + ":" + sport - } - if opts.Chain == nil { - return net.DialTimeout("tcp", addr, opts.Timeout) - } - return opts.Chain.Dial(addr) -} - func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { opts := &HandshakeOptions{} for _, option := range options {