fix #352: add pubkey auth support for ssh

This commit is contained in:
ginuerzh 2020-03-02 19:42:27 +08:00
parent 0f8064470f
commit 3a63210845
4 changed files with 122 additions and 23 deletions

View File

@ -240,14 +240,15 @@ func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) {
nodes := c.Nodes() nodes := c.Nodes()
node := nodes[0] node := nodes[0]
cn, err := node.Client.Dial(node.Addr, node.DialOptions...) cc, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil { if err != nil {
node.MarkDead() node.MarkDead()
return return
} }
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) cn, err := node.Client.Handshake(cc, node.HandshakeOptions...)
if err != nil { if err != nil {
cc.Close()
node.MarkDead() node.MarkDead()
return return
} }

View File

@ -114,6 +114,7 @@ type HandshakeOptions struct {
WSOptions *WSOptions WSOptions *WSOptions
KCPConfig *KCPConfig KCPConfig *KCPConfig
QUICConfig *QUICConfig QUICConfig *QUICConfig
SSHConfig *SSHConfig
} }
// HandshakeOption allows a common way to set HandshakeOptions. // HandshakeOption allows a common way to set HandshakeOptions.
@ -189,6 +190,13 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption {
} }
} }
// SSHConfigHandshakeOption specifies the ssh config used by SSH client handshake.
func SSHConfigHandshakeOption(config *SSHConfig) HandshakeOption {
return func(opts *HandshakeOptions) {
opts.SSHConfig = config
}
}
// ConnectOptions describes the options for Connector.Connect. // ConnectOptions describes the options for Connector.Connect.
type ConnectOptions struct { type ConnectOptions struct {
Addr string Addr string

View File

@ -248,6 +248,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
if host == "" { if host == "" {
host = node.Host host = node.Host
} }
sshConfig := &gost.SSHConfig{}
if s := node.Get("ssh_key"); s != "" {
key, err := gost.ParseSSHKeyFile(s)
if err != nil {
return nil, err
}
sshConfig.Key = key
}
handshakeOptions := []gost.HandshakeOption{ handshakeOptions := []gost.HandshakeOption{
gost.AddrHandshakeOption(node.Addr), gost.AddrHandshakeOption(node.Addr),
gost.HostHandshakeOption(host), gost.HostHandshakeOption(host),
@ -256,7 +265,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
gost.IntervalHandshakeOption(node.GetDuration("ping")), gost.IntervalHandshakeOption(node.GetDuration("ping")),
gost.TimeoutHandshakeOption(timeout), gost.TimeoutHandshakeOption(timeout),
gost.RetryHandshakeOption(node.GetInt("retry")), gost.RetryHandshakeOption(node.GetInt("retry")),
gost.SSHConfigHandshakeOption(sshConfig),
} }
node.Client = &gost.Client{ node.Client = &gost.Client{
Connector: connector, Connector: connector,
Transporter: tr, Transporter: tr,
@ -385,6 +396,20 @@ func (r *route) GenRouters() ([]router, error) {
Authenticator: authenticator, Authenticator: authenticator,
TLSConfig: tlsCfg, TLSConfig: tlsCfg,
} }
if s := node.Get("ssh_key"); s != "" {
key, err := gost.ParseSSHKeyFile(s)
if err != nil {
return nil, err
}
config.Key = key
}
if s := node.Get("ssh_authorized_keys"); s != "" {
keys, err := gost.ParseSSHAuthorizedKeysFile(s)
if err != nil {
return nil, err
}
config.AuthorizedKeys = keys
}
if node.Protocol == "forward" { if node.Protocol == "forward" {
ln, err = gost.TCPListener(node.Addr) ln, err = gost.TCPListener(node.Addr)
} else { } else {

91
ssh.go
View File

@ -6,6 +6,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"strconv" "strconv"
"strings" "strings"
@ -30,6 +31,32 @@ var (
errSessionDead = errors.New("session is dead") errSessionDead = errors.New("session is dead")
) )
func ParseSSHKeyFile(fp string) (ssh.Signer, error) {
key, err := ioutil.ReadFile(fp)
if err != nil {
return nil, err
}
return ssh.ParsePrivateKey(key)
}
func ParseSSHAuthorizedKeysFile(fp string) (map[string]bool, error) {
authorizedKeysBytes, err := ioutil.ReadFile(fp)
if err != nil {
return nil, err
}
authorizedKeysMap := make(map[string]bool)
for len(authorizedKeysBytes) > 0 {
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
if err != nil {
return nil, err
}
authorizedKeysMap[string(pubKey.Marshal())] = true
authorizedKeysBytes = rest
}
return authorizedKeysMap, nil
}
type sshDirectForwardConnector struct { type sshDirectForwardConnector struct {
} }
@ -201,11 +228,15 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
} }
if opts.User != nil { if opts.User != nil {
config.User = opts.User.Username() config.User = opts.User.Username()
password, _ := opts.User.Password() if password, _ := opts.User.Password(); password != "" {
config.Auth = []ssh.AuthMethod{ config.Auth = []ssh.AuthMethod{
ssh.Password(password), ssh.Password(password),
} }
} }
}
if opts.SSHConfig != nil && opts.SSHConfig.Key != nil {
config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key))
}
tr.sessionMutex.Lock() tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock() defer tr.sessionMutex.Unlock()
@ -217,6 +248,7 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
if !ok || session.client == nil { if !ok || session.client == nil {
sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config)
if err != nil { if err != nil {
log.Log("ssh", err)
conn.Close() conn.Close()
delete(tr.sessions, opts.Addr) delete(tr.sessions, opts.Addr)
return nil, err return nil, err
@ -305,16 +337,20 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt
} }
config := ssh.ClientConfig{ config := ssh.ClientConfig{
Timeout: timeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), HostKeyCallback: ssh.InsecureIgnoreHostKey(),
} }
// TODO: support pubkey auth.
if opts.User != nil { if opts.User != nil {
config.User = opts.User.Username() config.User = opts.User.Username()
password, _ := opts.User.Password() if password, _ := opts.User.Password(); password != "" {
config.Auth = []ssh.AuthMethod{ config.Auth = []ssh.AuthMethod{
ssh.Password(password), ssh.Password(password),
} }
} }
}
if opts.SSHConfig != nil && opts.SSHConfig.Key != nil {
config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key))
}
tr.sessionMutex.Lock() tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock() defer tr.sessionMutex.Unlock()
@ -684,6 +720,8 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque
type SSHConfig struct { type SSHConfig struct {
Authenticator Authenticator Authenticator Authenticator
TLSConfig *tls.Config TLSConfig *tls.Config
Key ssh.Signer
AuthorizedKeys map[string]bool
} }
type sshTunnelListener struct { type sshTunnelListener struct {
@ -704,21 +742,22 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) {
config = &SSHConfig{} config = &SSHConfig{}
} }
sshConfig := &ssh.ServerConfig{} sshConfig := &ssh.ServerConfig{
sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Authenticator) PasswordCallback: defaultSSHPasswordCallback(config.Authenticator),
if config.Authenticator == nil { PublicKeyCallback: defaultSSHPublicKeyCallback(config.AuthorizedKeys),
sshConfig.NoClientAuth = true
}
tlsConfig := config.TLSConfig
if tlsConfig == nil {
tlsConfig = DefaultTLSConfig
} }
signer, err := ssh.NewSignerFromKey(tlsConfig.Certificates[0].PrivateKey) if config.Authenticator == nil && len(config.AuthorizedKeys) == 0 {
sshConfig.NoClientAuth = true
}
signer := config.Key
if signer == nil {
signer, err = ssh.NewSignerFromKey(DefaultTLSConfig.Certificates[0].PrivateKey)
if err != nil { if err != nil {
ln.Close() ln.Close()
return nil, err return nil, err
}
} }
sshConfig.AddHostKey(signer) sshConfig.AddHostKey(signer)
@ -823,9 +862,13 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
} }
// PasswordCallbackFunc is a callback function used by SSH server. // PasswordCallbackFunc is a callback function used by SSH server.
// It authenticates user using a password.
type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc { func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
if au == nil {
return nil
}
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if au.Authenticate(conn.User(), string(password)) { if au.Authenticate(conn.User(), string(password)) {
return nil, nil return nil, nil
@ -835,6 +878,28 @@ func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
} }
} }
// PublicKeyCallbackFunc is a callback function used by SSH server.
// It offers a public key for authentication.
type PublicKeyCallbackFunc func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error)
func defaultSSHPublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc {
if len(keys) == 0 {
return nil
}
return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
if keys[string(pubKey.Marshal())] {
return &ssh.Permissions{
// Record the public key used for authentication.
Extensions: map[string]string{
"pubkey-fp": ssh.FingerprintSHA256(pubKey),
},
}, nil
}
return nil, fmt.Errorf("unknown public key for %q", c.User())
}
}
type sshNopConn struct { type sshNopConn struct {
session *sshSession session *sshSession
} }