fix #352: add pubkey auth support for ssh
This commit is contained in:
parent
0f8064470f
commit
3a63210845
5
chain.go
5
chain.go
@ -240,14 +240,15 @@ func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) {
|
||||
nodes := c.Nodes()
|
||||
node := nodes[0]
|
||||
|
||||
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
||||
cc, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
||||
if err != nil {
|
||||
node.MarkDead()
|
||||
return
|
||||
}
|
||||
|
||||
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
|
||||
cn, err := node.Client.Handshake(cc, node.HandshakeOptions...)
|
||||
if err != nil {
|
||||
cc.Close()
|
||||
node.MarkDead()
|
||||
return
|
||||
}
|
||||
|
@ -114,6 +114,7 @@ type HandshakeOptions struct {
|
||||
WSOptions *WSOptions
|
||||
KCPConfig *KCPConfig
|
||||
QUICConfig *QUICConfig
|
||||
SSHConfig *SSHConfig
|
||||
}
|
||||
|
||||
// HandshakeOption allows a common way to set HandshakeOptions.
|
||||
@ -189,6 +190,13 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption {
|
||||
}
|
||||
}
|
||||
|
||||
// SSHConfigHandshakeOption specifies the ssh config used by SSH client handshake.
|
||||
func SSHConfigHandshakeOption(config *SSHConfig) HandshakeOption {
|
||||
return func(opts *HandshakeOptions) {
|
||||
opts.SSHConfig = config
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectOptions describes the options for Connector.Connect.
|
||||
type ConnectOptions struct {
|
||||
Addr string
|
||||
|
@ -248,6 +248,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
|
||||
if host == "" {
|
||||
host = node.Host
|
||||
}
|
||||
|
||||
sshConfig := &gost.SSHConfig{}
|
||||
if s := node.Get("ssh_key"); s != "" {
|
||||
key, err := gost.ParseSSHKeyFile(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sshConfig.Key = key
|
||||
}
|
||||
handshakeOptions := []gost.HandshakeOption{
|
||||
gost.AddrHandshakeOption(node.Addr),
|
||||
gost.HostHandshakeOption(host),
|
||||
@ -256,7 +265,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
|
||||
gost.IntervalHandshakeOption(node.GetDuration("ping")),
|
||||
gost.TimeoutHandshakeOption(timeout),
|
||||
gost.RetryHandshakeOption(node.GetInt("retry")),
|
||||
gost.SSHConfigHandshakeOption(sshConfig),
|
||||
}
|
||||
|
||||
node.Client = &gost.Client{
|
||||
Connector: connector,
|
||||
Transporter: tr,
|
||||
@ -385,6 +396,20 @@ func (r *route) GenRouters() ([]router, error) {
|
||||
Authenticator: authenticator,
|
||||
TLSConfig: tlsCfg,
|
||||
}
|
||||
if s := node.Get("ssh_key"); s != "" {
|
||||
key, err := gost.ParseSSHKeyFile(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Key = key
|
||||
}
|
||||
if s := node.Get("ssh_authorized_keys"); s != "" {
|
||||
keys, err := gost.ParseSSHAuthorizedKeysFile(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.AuthorizedKeys = keys
|
||||
}
|
||||
if node.Protocol == "forward" {
|
||||
ln, err = gost.TCPListener(node.Addr)
|
||||
} else {
|
||||
|
91
ssh.go
91
ssh.go
@ -6,6 +6,7 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -30,6 +31,32 @@ var (
|
||||
errSessionDead = errors.New("session is dead")
|
||||
)
|
||||
|
||||
func ParseSSHKeyFile(fp string) (ssh.Signer, error) {
|
||||
key, err := ioutil.ReadFile(fp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.ParsePrivateKey(key)
|
||||
}
|
||||
|
||||
func ParseSSHAuthorizedKeysFile(fp string) (map[string]bool, error) {
|
||||
authorizedKeysBytes, err := ioutil.ReadFile(fp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authorizedKeysMap := make(map[string]bool)
|
||||
for len(authorizedKeysBytes) > 0 {
|
||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authorizedKeysMap[string(pubKey.Marshal())] = true
|
||||
authorizedKeysBytes = rest
|
||||
}
|
||||
|
||||
return authorizedKeysMap, nil
|
||||
}
|
||||
|
||||
type sshDirectForwardConnector struct {
|
||||
}
|
||||
|
||||
@ -201,11 +228,15 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
|
||||
}
|
||||
if opts.User != nil {
|
||||
config.User = opts.User.Username()
|
||||
password, _ := opts.User.Password()
|
||||
if password, _ := opts.User.Password(); password != "" {
|
||||
config.Auth = []ssh.AuthMethod{
|
||||
ssh.Password(password),
|
||||
}
|
||||
}
|
||||
}
|
||||
if opts.SSHConfig != nil && opts.SSHConfig.Key != nil {
|
||||
config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key))
|
||||
}
|
||||
|
||||
tr.sessionMutex.Lock()
|
||||
defer tr.sessionMutex.Unlock()
|
||||
@ -217,6 +248,7 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
|
||||
if !ok || session.client == nil {
|
||||
sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config)
|
||||
if err != nil {
|
||||
log.Log("ssh", err)
|
||||
conn.Close()
|
||||
delete(tr.sessions, opts.Addr)
|
||||
return nil, err
|
||||
@ -305,16 +337,20 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt
|
||||
}
|
||||
|
||||
config := ssh.ClientConfig{
|
||||
Timeout: timeout,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
// TODO: support pubkey auth.
|
||||
if opts.User != nil {
|
||||
config.User = opts.User.Username()
|
||||
password, _ := opts.User.Password()
|
||||
if password, _ := opts.User.Password(); password != "" {
|
||||
config.Auth = []ssh.AuthMethod{
|
||||
ssh.Password(password),
|
||||
}
|
||||
}
|
||||
}
|
||||
if opts.SSHConfig != nil && opts.SSHConfig.Key != nil {
|
||||
config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key))
|
||||
}
|
||||
|
||||
tr.sessionMutex.Lock()
|
||||
defer tr.sessionMutex.Unlock()
|
||||
@ -684,6 +720,8 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque
|
||||
type SSHConfig struct {
|
||||
Authenticator Authenticator
|
||||
TLSConfig *tls.Config
|
||||
Key ssh.Signer
|
||||
AuthorizedKeys map[string]bool
|
||||
}
|
||||
|
||||
type sshTunnelListener struct {
|
||||
@ -704,21 +742,22 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) {
|
||||
config = &SSHConfig{}
|
||||
}
|
||||
|
||||
sshConfig := &ssh.ServerConfig{}
|
||||
sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Authenticator)
|
||||
if config.Authenticator == nil {
|
||||
sshConfig.NoClientAuth = true
|
||||
}
|
||||
tlsConfig := config.TLSConfig
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = DefaultTLSConfig
|
||||
sshConfig := &ssh.ServerConfig{
|
||||
PasswordCallback: defaultSSHPasswordCallback(config.Authenticator),
|
||||
PublicKeyCallback: defaultSSHPublicKeyCallback(config.AuthorizedKeys),
|
||||
}
|
||||
|
||||
signer, err := ssh.NewSignerFromKey(tlsConfig.Certificates[0].PrivateKey)
|
||||
if config.Authenticator == nil && len(config.AuthorizedKeys) == 0 {
|
||||
sshConfig.NoClientAuth = true
|
||||
}
|
||||
|
||||
signer := config.Key
|
||||
if signer == nil {
|
||||
signer, err = ssh.NewSignerFromKey(DefaultTLSConfig.Certificates[0].PrivateKey)
|
||||
if err != nil {
|
||||
ln.Close()
|
||||
return nil, err
|
||||
|
||||
}
|
||||
}
|
||||
sshConfig.AddHostKey(signer)
|
||||
|
||||
@ -823,9 +862,13 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
|
||||
}
|
||||
|
||||
// PasswordCallbackFunc is a callback function used by SSH server.
|
||||
// It authenticates user using a password.
|
||||
type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
|
||||
|
||||
func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
|
||||
if au == nil {
|
||||
return nil
|
||||
}
|
||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if au.Authenticate(conn.User(), string(password)) {
|
||||
return nil, nil
|
||||
@ -835,6 +878,28 @@ func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// PublicKeyCallbackFunc is a callback function used by SSH server.
|
||||
// It offers a public key for authentication.
|
||||
type PublicKeyCallbackFunc func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error)
|
||||
|
||||
func defaultSSHPublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if keys[string(pubKey.Marshal())] {
|
||||
return &ssh.Permissions{
|
||||
// Record the public key used for authentication.
|
||||
Extensions: map[string]string{
|
||||
"pubkey-fp": ssh.FingerprintSHA256(pubKey),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unknown public key for %q", c.User())
|
||||
}
|
||||
}
|
||||
|
||||
type sshNopConn struct {
|
||||
session *sshSession
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user