fix #352: add pubkey auth support for ssh

This commit is contained in:
ginuerzh 2020-03-02 19:42:27 +08:00
parent 0f8064470f
commit 3a63210845
4 changed files with 122 additions and 23 deletions

View File

@ -240,14 +240,15 @@ func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) {
nodes := c.Nodes()
node := nodes[0]
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
cc, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil {
node.MarkDead()
return
}
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
cn, err := node.Client.Handshake(cc, node.HandshakeOptions...)
if err != nil {
cc.Close()
node.MarkDead()
return
}

View File

@ -114,6 +114,7 @@ type HandshakeOptions struct {
WSOptions *WSOptions
KCPConfig *KCPConfig
QUICConfig *QUICConfig
SSHConfig *SSHConfig
}
// HandshakeOption allows a common way to set HandshakeOptions.
@ -189,6 +190,13 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption {
}
}
// SSHConfigHandshakeOption specifies the ssh config used by SSH client handshake.
func SSHConfigHandshakeOption(config *SSHConfig) HandshakeOption {
return func(opts *HandshakeOptions) {
opts.SSHConfig = config
}
}
// ConnectOptions describes the options for Connector.Connect.
type ConnectOptions struct {
Addr string

View File

@ -248,6 +248,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
if host == "" {
host = node.Host
}
sshConfig := &gost.SSHConfig{}
if s := node.Get("ssh_key"); s != "" {
key, err := gost.ParseSSHKeyFile(s)
if err != nil {
return nil, err
}
sshConfig.Key = key
}
handshakeOptions := []gost.HandshakeOption{
gost.AddrHandshakeOption(node.Addr),
gost.HostHandshakeOption(host),
@ -256,7 +265,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
gost.IntervalHandshakeOption(node.GetDuration("ping")),
gost.TimeoutHandshakeOption(timeout),
gost.RetryHandshakeOption(node.GetInt("retry")),
gost.SSHConfigHandshakeOption(sshConfig),
}
node.Client = &gost.Client{
Connector: connector,
Transporter: tr,
@ -385,6 +396,20 @@ func (r *route) GenRouters() ([]router, error) {
Authenticator: authenticator,
TLSConfig: tlsCfg,
}
if s := node.Get("ssh_key"); s != "" {
key, err := gost.ParseSSHKeyFile(s)
if err != nil {
return nil, err
}
config.Key = key
}
if s := node.Get("ssh_authorized_keys"); s != "" {
keys, err := gost.ParseSSHAuthorizedKeysFile(s)
if err != nil {
return nil, err
}
config.AuthorizedKeys = keys
}
if node.Protocol == "forward" {
ln, err = gost.TCPListener(node.Addr)
} else {

91
ssh.go
View File

@ -6,6 +6,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"io/ioutil"
"net"
"strconv"
"strings"
@ -30,6 +31,32 @@ var (
errSessionDead = errors.New("session is dead")
)
func ParseSSHKeyFile(fp string) (ssh.Signer, error) {
key, err := ioutil.ReadFile(fp)
if err != nil {
return nil, err
}
return ssh.ParsePrivateKey(key)
}
func ParseSSHAuthorizedKeysFile(fp string) (map[string]bool, error) {
authorizedKeysBytes, err := ioutil.ReadFile(fp)
if err != nil {
return nil, err
}
authorizedKeysMap := make(map[string]bool)
for len(authorizedKeysBytes) > 0 {
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
if err != nil {
return nil, err
}
authorizedKeysMap[string(pubKey.Marshal())] = true
authorizedKeysBytes = rest
}
return authorizedKeysMap, nil
}
type sshDirectForwardConnector struct {
}
@ -201,11 +228,15 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
}
if opts.User != nil {
config.User = opts.User.Username()
password, _ := opts.User.Password()
if password, _ := opts.User.Password(); password != "" {
config.Auth = []ssh.AuthMethod{
ssh.Password(password),
}
}
}
if opts.SSHConfig != nil && opts.SSHConfig.Key != nil {
config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key))
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
@ -217,6 +248,7 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
if !ok || session.client == nil {
sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config)
if err != nil {
log.Log("ssh", err)
conn.Close()
delete(tr.sessions, opts.Addr)
return nil, err
@ -305,16 +337,20 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt
}
config := ssh.ClientConfig{
Timeout: timeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
// TODO: support pubkey auth.
if opts.User != nil {
config.User = opts.User.Username()
password, _ := opts.User.Password()
if password, _ := opts.User.Password(); password != "" {
config.Auth = []ssh.AuthMethod{
ssh.Password(password),
}
}
}
if opts.SSHConfig != nil && opts.SSHConfig.Key != nil {
config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key))
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
@ -684,6 +720,8 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque
type SSHConfig struct {
Authenticator Authenticator
TLSConfig *tls.Config
Key ssh.Signer
AuthorizedKeys map[string]bool
}
type sshTunnelListener struct {
@ -704,21 +742,22 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) {
config = &SSHConfig{}
}
sshConfig := &ssh.ServerConfig{}
sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Authenticator)
if config.Authenticator == nil {
sshConfig.NoClientAuth = true
}
tlsConfig := config.TLSConfig
if tlsConfig == nil {
tlsConfig = DefaultTLSConfig
sshConfig := &ssh.ServerConfig{
PasswordCallback: defaultSSHPasswordCallback(config.Authenticator),
PublicKeyCallback: defaultSSHPublicKeyCallback(config.AuthorizedKeys),
}
signer, err := ssh.NewSignerFromKey(tlsConfig.Certificates[0].PrivateKey)
if config.Authenticator == nil && len(config.AuthorizedKeys) == 0 {
sshConfig.NoClientAuth = true
}
signer := config.Key
if signer == nil {
signer, err = ssh.NewSignerFromKey(DefaultTLSConfig.Certificates[0].PrivateKey)
if err != nil {
ln.Close()
return nil, err
}
}
sshConfig.AddHostKey(signer)
@ -823,9 +862,13 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
}
// PasswordCallbackFunc is a callback function used by SSH server.
// It authenticates user using a password.
type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
if au == nil {
return nil
}
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if au.Authenticate(conn.User(), string(password)) {
return nil, nil
@ -835,6 +878,28 @@ func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
}
}
// PublicKeyCallbackFunc is a callback function used by SSH server.
// It offers a public key for authentication.
type PublicKeyCallbackFunc func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error)
func defaultSSHPublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc {
if len(keys) == 0 {
return nil
}
return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
if keys[string(pubKey.Marshal())] {
return &ssh.Permissions{
// Record the public key used for authentication.
Extensions: map[string]string{
"pubkey-fp": ssh.FingerprintSHA256(pubKey),
},
}, nil
}
return nil, fmt.Errorf("unknown public key for %q", c.User())
}
}
type sshNopConn struct {
session *sshSession
}