diff --git a/chain.go b/chain.go index 8d304ba..1d2a3be 100644 --- a/chain.go +++ b/chain.go @@ -100,7 +100,7 @@ func (c *Chain) IsEmpty() bool { // If the chain is empty, it will use the net.Dial directly. func (c *Chain) Dial(addr string) (net.Conn, error) { if c.IsEmpty() { - return net.Dial("tcp", addr) + return net.DialTimeout("tcp", addr, DialTimeout) } route, err := c.selectRoute() @@ -108,7 +108,7 @@ func (c *Chain) Dial(addr string) (net.Conn, error) { return nil, err } - conn, err := c.getConn(route) + conn, err := route.getConn() if err != nil { return nil, err } @@ -128,16 +128,16 @@ func (c *Chain) Conn() (conn net.Conn, err error) { if err != nil { return nil, err } - conn, err = c.getConn(route) + conn, err = route.getConn() return } -func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) { - if route.IsEmpty() { +func (c *Chain) getConn() (conn net.Conn, err error) { + if c.IsEmpty() { err = ErrEmptyChain return } - nodes := route.Nodes() + nodes := c.Nodes() node := nodes[0] cn, err := node.Client.Dial(node.Addr, node.DialOptions...) @@ -206,7 +206,7 @@ func (c *Chain) selectRoute() (route *Chain, err error) { } func selectIP(node *Node) (string, error) { - s := node.IPSelector + s := node.Selector if s == nil { s = &RandomIPSelector{} } diff --git a/client.go b/client.go index e4a05f9..8a74998 100644 --- a/client.go +++ b/client.go @@ -116,6 +116,7 @@ func ChainDialOption(chain *Chain) DialOption { // HandshakeOptions describes the options for handshake. type HandshakeOptions struct { Addr string + Host string User *url.Userinfo Timeout time.Duration Interval time.Duration @@ -136,6 +137,13 @@ func AddrHandshakeOption(addr string) HandshakeOption { } } +// HostHandshakeOption specifies the hostname +func HostHandshakeOption(host string) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Host = host + } +} + // UserHandshakeOption specifies the user used by Transporter.Handshake func UserHandshakeOption(user *url.Userinfo) HandshakeOption { return func(opts *HandshakeOptions) { diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 0ecc3d0..8e69fdd 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -140,7 +140,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { node.IPs[i] = ip + ":" + sport } } - node.IPSelector = &gost.RoundRobinIPSelector{} + node.Selector = &gost.RoundRobinIPSelector{} users, err := parseUsers(node.Values.Get("secrets")) if err != nil { @@ -265,6 +265,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { retry, _ := strconv.Atoi(node.Values.Get("retry")) node.HandshakeOptions = append(node.HandshakeOptions, gost.AddrHandshakeOption(node.Addr), + gost.HostHandshakeOption(node.Host), gost.UserHandshakeOption(node.User), gost.TLSConfigHandshakeOption(tlsCfg), gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), diff --git a/node.go b/node.go index a0fa051..1266108 100644 --- a/node.go +++ b/node.go @@ -11,6 +11,7 @@ type Node struct { ID int Addr string IPs []string + Host string Protocol string Transport string Remote string // remote address, used by tcp/udp port forwarding @@ -19,7 +20,7 @@ type Node struct { DialOptions []DialOption HandshakeOptions []HandshakeOption Client *Client - IPSelector IPSelector + Selector IPSelector } // ParseNode parses the node info. @@ -40,6 +41,7 @@ func ParseNode(s string) (node Node, err error) { node = Node{ Addr: u.Host, + Host: u.Host, Remote: strings.Trim(u.EscapedPath(), "/"), Values: u.Query(), User: u.User, diff --git a/obfs.go b/obfs.go index 424ed58..2c7da09 100644 --- a/obfs.go +++ b/obfs.go @@ -35,7 +35,7 @@ func (tr *obfsHTTPTransporter) Handshake(conn net.Conn, options ...HandshakeOpti for _, option := range options { option(opts) } - return &obfsHTTPConn{Conn: conn}, nil + return &obfsHTTPConn{Conn: conn, host: opts.Host}, nil } type obfsHTTPListener struct { @@ -66,6 +66,7 @@ func (l *obfsHTTPListener) Accept() (net.Conn, error) { type obfsHTTPConn struct { net.Conn + host string request *http.Request response *http.Response rbuf []byte @@ -151,7 +152,7 @@ func (c *obfsHTTPConn) clientHandshake() (err error) { Method: http.MethodGet, ProtoMajor: 1, ProtoMinor: 1, - URL: &url.URL{Scheme: "http", Host: "www.baidu.com"}, + URL: &url.URL{Scheme: "http", Host: c.host}, Header: make(http.Header), } r.Header.Set("Connection", "keep-alive") diff --git a/ws.go b/ws.go index d77c709..fa9c043 100644 --- a/ws.go +++ b/ws.go @@ -129,7 +129,7 @@ func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (n if opts.WSOptions != nil { wsOptions = opts.WSOptions } - url := url.URL{Scheme: "ws", Host: opts.Addr, Path: "/ws"} + url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} return websocketClientConn(url.String(), conn, nil, wsOptions) } @@ -210,7 +210,7 @@ func (tr *mwsTransporter) initSession(addr string, conn net.Conn, opts *Handshak if opts.WSOptions != nil { wsOptions = opts.WSOptions } - url := url.URL{Scheme: "ws", Host: opts.Addr, Path: "/ws"} + url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} conn, err := websocketClientConn(url.String(), conn, nil, wsOptions) if err != nil { return nil, err @@ -252,7 +252,7 @@ func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( if opts.TLSConfig == nil { opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} } - url := url.URL{Scheme: "wss", Host: opts.Addr, Path: "/ws"} + url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) } @@ -337,7 +337,7 @@ func (tr *mwssTransporter) initSession(addr string, conn net.Conn, opts *Handsha if tlsConfig == nil { tlsConfig = &tls.Config{InsecureSkipVerify: true} } - url := url.URL{Scheme: "wss", Host: opts.Addr, Path: "/ws"} + url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions) if err != nil { return nil, err