add ConnectOptions for Connector.Connect

This commit is contained in:
ginuerzh 2018-11-10 12:16:46 +08:00
parent 3ebf423e87
commit b16f878c39
12 changed files with 64 additions and 34 deletions

View File

@ -132,10 +132,10 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e
return nil, err return nil, err
} }
addr = c.resolve(addr, options.Resolver, options.Hosts) ipAddr := c.resolve(addr, options.Resolver, options.Hosts)
if route.IsEmpty() { if route.IsEmpty() {
return net.DialTimeout("tcp", addr, options.Timeout) return net.DialTimeout("tcp", ipAddr, options.Timeout)
} }
conn, err := route.getConn() conn, err := route.getConn()
@ -143,7 +143,7 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e
return nil, err return nil, err
} }
cc, err := route.LastNode().Client.Connect(conn, addr) cc, err := route.LastNode().Client.Connect(conn, addr, IPAddrConnectOption(ipAddr))
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err

View File

@ -27,8 +27,8 @@ func (c *Client) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn,
} }
// Connect connects to the address addr via the proxy over connection conn. // Connect connects to the address addr via the proxy over connection conn.
func (c *Client) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *Client) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return c.Connector.Connect(conn, addr) return c.Connector.Connect(conn, addr, options...)
} }
// DefaultClient is a standard HTTP proxy client. // DefaultClient is a standard HTTP proxy client.
@ -51,7 +51,7 @@ func Connect(conn net.Conn, addr string) (net.Conn, error) {
// Connector is responsible for connecting to the destination address. // Connector is responsible for connecting to the destination address.
type Connector interface { type Connector interface {
Connect(conn net.Conn, addr string) (net.Conn, error) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error)
} }
// Transporter is responsible for handshaking with the proxy server. // Transporter is responsible for handshaking with the proxy server.
@ -96,7 +96,7 @@ type DialOptions struct {
Chain *Chain Chain *Chain
} }
// DialOption allows a common way to set dial options. // DialOption allows a common way to set DialOptions.
type DialOption func(opts *DialOptions) type DialOption func(opts *DialOptions)
// TimeoutDialOption specifies the timeout used by Transporter.Dial // TimeoutDialOption specifies the timeout used by Transporter.Dial
@ -127,7 +127,7 @@ type HandshakeOptions struct {
QUICConfig *QUICConfig QUICConfig *QUICConfig
} }
// HandshakeOption allows a common way to set handshake options. // HandshakeOption allows a common way to set HandshakeOptions.
type HandshakeOption func(opts *HandshakeOptions) type HandshakeOption func(opts *HandshakeOptions)
// AddrHandshakeOption specifies the server address // AddrHandshakeOption specifies the server address
@ -199,3 +199,18 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption {
opts.QUICConfig = config opts.QUICConfig = config
} }
} }
// ConnectOptions describes the options for Connector.Connect.
type ConnectOptions struct {
IPAddr string
}
// ConnectOption allows a common way to set ConnectOptions.
type ConnectOption func(opts *ConnectOptions)
// IPAddrConnectOption specifies the corresponding IP:PORT of the connected target address.
func IPAddrConnectOption(ipAddr string) ConnectOption {
return func(opts *ConnectOptions) {
opts.IPAddr = ipAddr
}
}

View File

@ -22,7 +22,7 @@ func ForwardConnector() Connector {
return &forwardConnector{} return &forwardConnector{}
} }
func (c *forwardConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *forwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return conn, nil return conn, nil
} }

View File

@ -14,7 +14,7 @@ import (
) )
// Version is the gost version. // Version is the gost version.
const Version = "2.6" const Version = "2.7-dev"
// Debug is a flag that enables the debug log. // Debug is a flag that enables the debug log.
var Debug bool var Debug bool

View File

@ -145,7 +145,7 @@ func (h *autoHandler) Handle(conn net.Conn) {
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
b, err := br.Peek(1) b, err := br.Peek(1)
if err != nil { if err != nil {
log.Log(err) log.Logf("[auto] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
conn.Close() conn.Close()
return return
} }

View File

@ -24,7 +24,7 @@ func HTTPConnector(user *url.Userinfo) Connector {
return &httpConnector{User: user} return &httpConnector{User: user}
} }
func (c *httpConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
req := &http.Request{ req := &http.Request{
Method: http.MethodConnect, Method: http.MethodConnect,
URL: &url.URL{Host: addr}, URL: &url.URL{Host: addr},

View File

@ -28,7 +28,12 @@ func HTTP2Connector(user *url.Userinfo) Connector {
return &http2Connector{User: user} return &http2Connector{User: user}
} }
func (c *http2Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
var cOpts ConnectOptions
for _, opt := range options {
opt(&cOpts)
}
cc, ok := conn.(*http2ClientConn) cc, ok := conn.(*http2ClientConn)
if !ok { if !ok {
return nil, errors.New("wrong connection type") return nil, errors.New("wrong connection type")
@ -75,6 +80,10 @@ func (c *http2Connector) Connect(conn net.Conn, addr string) (net.Conn, error) {
w: pw, w: pw,
closed: make(chan struct{}), closed: make(chan struct{}),
} }
if cOpts.IPAddr != "" {
addr = cOpts.IPAddr
}
hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr) hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr)
hc.localAddr, _ = net.ResolveTCPAddr("tcp", cc.addr) hc.localAddr, _ = net.ResolveTCPAddr("tcp", cc.addr)
@ -526,7 +535,7 @@ func H2CListener(addr string) (Listener, error) {
l := &h2Listener{ l := &h2Listener{
Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, Listener: tcpKeepAliveListener{ln.(*net.TCPListener)},
server: &http2.Server{ server: &http2.Server{
// MaxConcurrentStreams: 1000, // MaxConcurrentStreams: 1000,
}, },
connChan: make(chan net.Conn, 1024), connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1), errChan: make(chan error, 1),

View File

@ -30,9 +30,7 @@ func PeriodReload(r Reloader, configFile string) error {
} }
mt := finfo.ModTime() mt := finfo.ModTime()
if !mt.Equal(lastMod) { if !mt.Equal(lastMod) {
if Debug { log.Log("[reload]", configFile)
log.Log("[reload]", configFile)
}
r.Reload(f) r.Reload(f)
lastMod = mt lastMod = mt
} }

2
sni.go
View File

@ -28,7 +28,7 @@ func SNIConnector(host string) Connector {
return &sniConnector{host: host} return &sniConnector{host: host}
} }
func (c *sniConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *sniConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return &sniClientConn{addr: addr, host: c.host, Conn: conn}, nil return &sniClientConn{addr: addr, host: c.host, Conn: conn}, nil
} }

View File

@ -148,11 +148,11 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
req, err := gosocks5.ReadUserPassRequest(conn) req, err := gosocks5.ReadUserPassRequest(conn)
if err != nil { if err != nil {
log.Log("[socks5]", err) log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return nil, err return nil, err
} }
if Debug { if Debug {
log.Log("[socks5]", req.String()) log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String())
} }
valid := false valid := false
for _, user := range selector.Users { for _, user := range selector.Users {
@ -168,23 +168,23 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
if len(selector.Users) > 0 && !valid { if len(selector.Users) > 0 && !valid {
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure)
if err := resp.Write(conn); err != nil { if err := resp.Write(conn); err != nil {
log.Log("[socks5]", err) log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return nil, err return nil, err
} }
if Debug { if Debug {
log.Log("[socks5]", resp) log.Log("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), resp)
} }
log.Log("[socks5] proxy authentication required") log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr())
return nil, gosocks5.ErrAuthFailure return nil, gosocks5.ErrAuthFailure
} }
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded)
if err := resp.Write(conn); err != nil { if err := resp.Write(conn); err != nil {
log.Log("[socks5]", err) log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return nil, err return nil, err
} }
if Debug { if Debug {
log.Log("[socks5]", resp) log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), resp)
} }
case gosocks5.MethodNoAcceptable: case gosocks5.MethodNoAcceptable:
return nil, gosocks5.ErrBadMethod return nil, gosocks5.ErrBadMethod
@ -203,7 +203,7 @@ func SOCKS5Connector(user *url.Userinfo) Connector {
return &socks5Connector{User: user} return &socks5Connector{User: user}
} }
func (c *socks5Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
selector := &clientSelector{ selector := &clientSelector{
TLSConfig: &tls.Config{InsecureSkipVerify: true}, TLSConfig: &tls.Config{InsecureSkipVerify: true},
User: c.User, User: c.User,
@ -261,7 +261,15 @@ func SOCKS4Connector() Connector {
return &socks4Connector{} return &socks4Connector{}
} }
func (c *socks4Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *socks4Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
var cOpts ConnectOptions
for _, opt := range options {
opt(&cOpts)
}
if cOpts.IPAddr != "" {
addr = cOpts.IPAddr
}
taddr, err := net.ResolveTCPAddr("tcp4", addr) taddr, err := net.ResolveTCPAddr("tcp4", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -308,7 +316,7 @@ func SOCKS4AConnector() Connector {
return &socks4aConnector{} return &socks4aConnector{}
} }
func (c *socks4aConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *socks4aConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -322,7 +330,7 @@ func (c *socks4aConnector) Connect(conn net.Conn, addr string) (net.Conn, error)
} }
if Debug { if Debug {
log.Logf("[socks4] %s", req) log.Logf("[socks4a] %s", req)
} }
reply, err := gosocks4.ReadReply(conn) reply, err := gosocks4.ReadReply(conn)
@ -331,11 +339,11 @@ func (c *socks4aConnector) Connect(conn net.Conn, addr string) (net.Conn, error)
} }
if Debug { if Debug {
log.Logf("[socks4] %s", reply) log.Logf("[socks4a] %s", reply)
} }
if reply.Code != gosocks4.Granted { if reply.Code != gosocks4.Granted {
return nil, fmt.Errorf("[socks4] %d", reply.Code) return nil, fmt.Errorf("[socks4a] %d", reply.Code)
} }
return conn, nil return conn, nil

2
ss.go
View File

@ -67,7 +67,7 @@ func ShadowConnector(cipher *url.Userinfo) Connector {
return &shadowConnector{Cipher: cipher} return &shadowConnector{Cipher: cipher}
} }
func (c *shadowConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
rawaddr, err := ss.RawAddr(addr) rawaddr, err := ss.RawAddr(addr)
if err != nil { if err != nil {
return nil, err return nil, err

4
ssh.go
View File

@ -39,7 +39,7 @@ func SSHDirectForwardConnector() Connector {
return &sshDirectForwardConnector{} return &sshDirectForwardConnector{}
} }
func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string) (net.Conn, error) { func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string, options ...ConnectOption) (net.Conn, error) {
cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution.
if !ok { if !ok {
return nil, errors.New("ssh: wrong connection type") return nil, errors.New("ssh: wrong connection type")
@ -60,7 +60,7 @@ func SSHRemoteForwardConnector() Connector {
return &sshRemoteForwardConnector{} return &sshRemoteForwardConnector{}
} }
func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution.
if !ok { if !ok {
return nil, errors.New("ssh: wrong connection type") return nil, errors.New("ssh: wrong connection type")