diff --git a/chain.go b/chain.go index 4a511cb..bedeb6f 100644 --- a/chain.go +++ b/chain.go @@ -240,14 +240,15 @@ func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) { nodes := c.Nodes() node := nodes[0] - cn, err := node.Client.Dial(node.Addr, node.DialOptions...) + cc, err := node.Client.Dial(node.Addr, node.DialOptions...) if err != nil { node.MarkDead() return } - cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) + cn, err := node.Client.Handshake(cc, node.HandshakeOptions...) if err != nil { + cc.Close() node.MarkDead() return } diff --git a/client.go b/client.go index 586d565..f18bcaf 100644 --- a/client.go +++ b/client.go @@ -114,6 +114,7 @@ type HandshakeOptions struct { WSOptions *WSOptions KCPConfig *KCPConfig QUICConfig *QUICConfig + SSHConfig *SSHConfig } // HandshakeOption allows a common way to set HandshakeOptions. @@ -189,6 +190,13 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption { } } +// SSHConfigHandshakeOption specifies the ssh config used by SSH client handshake. +func SSHConfigHandshakeOption(config *SSHConfig) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.SSHConfig = config + } +} + // ConnectOptions describes the options for Connector.Connect. type ConnectOptions struct { Addr string diff --git a/cmd/gost/route.go b/cmd/gost/route.go index f3a259e..979b39b 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -248,6 +248,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { if host == "" { host = node.Host } + + sshConfig := &gost.SSHConfig{} + if s := node.Get("ssh_key"); s != "" { + key, err := gost.ParseSSHKeyFile(s) + if err != nil { + return nil, err + } + sshConfig.Key = key + } handshakeOptions := []gost.HandshakeOption{ gost.AddrHandshakeOption(node.Addr), gost.HostHandshakeOption(host), @@ -256,7 +265,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { gost.IntervalHandshakeOption(node.GetDuration("ping")), gost.TimeoutHandshakeOption(timeout), gost.RetryHandshakeOption(node.GetInt("retry")), + gost.SSHConfigHandshakeOption(sshConfig), } + node.Client = &gost.Client{ Connector: connector, Transporter: tr, @@ -385,6 +396,20 @@ func (r *route) GenRouters() ([]router, error) { Authenticator: authenticator, TLSConfig: tlsCfg, } + if s := node.Get("ssh_key"); s != "" { + key, err := gost.ParseSSHKeyFile(s) + if err != nil { + return nil, err + } + config.Key = key + } + if s := node.Get("ssh_authorized_keys"); s != "" { + keys, err := gost.ParseSSHAuthorizedKeysFile(s) + if err != nil { + return nil, err + } + config.AuthorizedKeys = keys + } if node.Protocol == "forward" { ln, err = gost.TCPListener(node.Addr) } else { diff --git a/ssh.go b/ssh.go index 9fd7877..e4b1ab6 100644 --- a/ssh.go +++ b/ssh.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "io/ioutil" "net" "strconv" "strings" @@ -30,6 +31,32 @@ var ( errSessionDead = errors.New("session is dead") ) +func ParseSSHKeyFile(fp string) (ssh.Signer, error) { + key, err := ioutil.ReadFile(fp) + if err != nil { + return nil, err + } + return ssh.ParsePrivateKey(key) +} + +func ParseSSHAuthorizedKeysFile(fp string) (map[string]bool, error) { + authorizedKeysBytes, err := ioutil.ReadFile(fp) + if err != nil { + return nil, err + } + authorizedKeysMap := make(map[string]bool) + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, err + } + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + return authorizedKeysMap, nil +} + type sshDirectForwardConnector struct { } @@ -201,11 +228,15 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp } if opts.User != nil { config.User = opts.User.Username() - password, _ := opts.User.Password() - config.Auth = []ssh.AuthMethod{ - ssh.Password(password), + if password, _ := opts.User.Password(); password != "" { + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } } } + if opts.SSHConfig != nil && opts.SSHConfig.Key != nil { + config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key)) + } tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() @@ -217,6 +248,7 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp if !ok || session.client == nil { sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) if err != nil { + log.Log("ssh", err) conn.Close() delete(tr.sessions, opts.Addr) return nil, err @@ -305,16 +337,20 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt } config := ssh.ClientConfig{ + Timeout: timeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - // TODO: support pubkey auth. if opts.User != nil { config.User = opts.User.Username() - password, _ := opts.User.Password() - config.Auth = []ssh.AuthMethod{ - ssh.Password(password), + if password, _ := opts.User.Password(); password != "" { + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } } } + if opts.SSHConfig != nil && opts.SSHConfig.Key != nil { + config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key)) + } tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() @@ -682,8 +718,10 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque // SSHConfig holds the SSH tunnel server config type SSHConfig struct { - Authenticator Authenticator - TLSConfig *tls.Config + Authenticator Authenticator + TLSConfig *tls.Config + Key ssh.Signer + AuthorizedKeys map[string]bool } type sshTunnelListener struct { @@ -704,21 +742,22 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) { config = &SSHConfig{} } - sshConfig := &ssh.ServerConfig{} - sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Authenticator) - if config.Authenticator == nil { + sshConfig := &ssh.ServerConfig{ + PasswordCallback: defaultSSHPasswordCallback(config.Authenticator), + PublicKeyCallback: defaultSSHPublicKeyCallback(config.AuthorizedKeys), + } + + if config.Authenticator == nil && len(config.AuthorizedKeys) == 0 { sshConfig.NoClientAuth = true } - tlsConfig := config.TLSConfig - if tlsConfig == nil { - tlsConfig = DefaultTLSConfig - } - - signer, err := ssh.NewSignerFromKey(tlsConfig.Certificates[0].PrivateKey) - if err != nil { - ln.Close() - return nil, err + signer := config.Key + if signer == nil { + signer, err = ssh.NewSignerFromKey(DefaultTLSConfig.Certificates[0].PrivateKey) + if err != nil { + ln.Close() + return nil, err + } } sshConfig.AddHostKey(signer) @@ -823,9 +862,13 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { } // PasswordCallbackFunc is a callback function used by SSH server. +// It authenticates user using a password. type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc { + if au == nil { + return nil + } return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if au.Authenticate(conn.User(), string(password)) { return nil, nil @@ -835,6 +878,28 @@ func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc { } } +// PublicKeyCallbackFunc is a callback function used by SSH server. +// It offers a public key for authentication. +type PublicKeyCallbackFunc func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) + +func defaultSSHPublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc { + if len(keys) == 0 { + return nil + } + + return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + if keys[string(pubKey.Marshal())] { + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "pubkey-fp": ssh.FingerprintSHA256(pubKey), + }, + }, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + } +} + type sshNopConn struct { session *sshSession }