fix #352: add pubkey auth support for ssh
This commit is contained in:
parent
0f8064470f
commit
3a63210845
5
chain.go
5
chain.go
@ -240,14 +240,15 @@ func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) {
|
|||||||
nodes := c.Nodes()
|
nodes := c.Nodes()
|
||||||
node := nodes[0]
|
node := nodes[0]
|
||||||
|
|
||||||
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
cc, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
node.MarkDead()
|
node.MarkDead()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
|
cn, err := node.Client.Handshake(cc, node.HandshakeOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
cc.Close()
|
||||||
node.MarkDead()
|
node.MarkDead()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -114,6 +114,7 @@ type HandshakeOptions struct {
|
|||||||
WSOptions *WSOptions
|
WSOptions *WSOptions
|
||||||
KCPConfig *KCPConfig
|
KCPConfig *KCPConfig
|
||||||
QUICConfig *QUICConfig
|
QUICConfig *QUICConfig
|
||||||
|
SSHConfig *SSHConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandshakeOption allows a common way to set HandshakeOptions.
|
// HandshakeOption allows a common way to set HandshakeOptions.
|
||||||
@ -189,6 +190,13 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSHConfigHandshakeOption specifies the ssh config used by SSH client handshake.
|
||||||
|
func SSHConfigHandshakeOption(config *SSHConfig) HandshakeOption {
|
||||||
|
return func(opts *HandshakeOptions) {
|
||||||
|
opts.SSHConfig = config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ConnectOptions describes the options for Connector.Connect.
|
// ConnectOptions describes the options for Connector.Connect.
|
||||||
type ConnectOptions struct {
|
type ConnectOptions struct {
|
||||||
Addr string
|
Addr string
|
||||||
|
@ -248,6 +248,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
|
|||||||
if host == "" {
|
if host == "" {
|
||||||
host = node.Host
|
host = node.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sshConfig := &gost.SSHConfig{}
|
||||||
|
if s := node.Get("ssh_key"); s != "" {
|
||||||
|
key, err := gost.ParseSSHKeyFile(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sshConfig.Key = key
|
||||||
|
}
|
||||||
handshakeOptions := []gost.HandshakeOption{
|
handshakeOptions := []gost.HandshakeOption{
|
||||||
gost.AddrHandshakeOption(node.Addr),
|
gost.AddrHandshakeOption(node.Addr),
|
||||||
gost.HostHandshakeOption(host),
|
gost.HostHandshakeOption(host),
|
||||||
@ -256,7 +265,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
|
|||||||
gost.IntervalHandshakeOption(node.GetDuration("ping")),
|
gost.IntervalHandshakeOption(node.GetDuration("ping")),
|
||||||
gost.TimeoutHandshakeOption(timeout),
|
gost.TimeoutHandshakeOption(timeout),
|
||||||
gost.RetryHandshakeOption(node.GetInt("retry")),
|
gost.RetryHandshakeOption(node.GetInt("retry")),
|
||||||
|
gost.SSHConfigHandshakeOption(sshConfig),
|
||||||
}
|
}
|
||||||
|
|
||||||
node.Client = &gost.Client{
|
node.Client = &gost.Client{
|
||||||
Connector: connector,
|
Connector: connector,
|
||||||
Transporter: tr,
|
Transporter: tr,
|
||||||
@ -385,6 +396,20 @@ func (r *route) GenRouters() ([]router, error) {
|
|||||||
Authenticator: authenticator,
|
Authenticator: authenticator,
|
||||||
TLSConfig: tlsCfg,
|
TLSConfig: tlsCfg,
|
||||||
}
|
}
|
||||||
|
if s := node.Get("ssh_key"); s != "" {
|
||||||
|
key, err := gost.ParseSSHKeyFile(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
config.Key = key
|
||||||
|
}
|
||||||
|
if s := node.Get("ssh_authorized_keys"); s != "" {
|
||||||
|
keys, err := gost.ParseSSHAuthorizedKeysFile(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
config.AuthorizedKeys = keys
|
||||||
|
}
|
||||||
if node.Protocol == "forward" {
|
if node.Protocol == "forward" {
|
||||||
ln, err = gost.TCPListener(node.Addr)
|
ln, err = gost.TCPListener(node.Addr)
|
||||||
} else {
|
} else {
|
||||||
|
91
ssh.go
91
ssh.go
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -30,6 +31,32 @@ var (
|
|||||||
errSessionDead = errors.New("session is dead")
|
errSessionDead = errors.New("session is dead")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ParseSSHKeyFile(fp string) (ssh.Signer, error) {
|
||||||
|
key, err := ioutil.ReadFile(fp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ssh.ParsePrivateKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseSSHAuthorizedKeysFile(fp string) (map[string]bool, error) {
|
||||||
|
authorizedKeysBytes, err := ioutil.ReadFile(fp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
authorizedKeysMap := make(map[string]bool)
|
||||||
|
for len(authorizedKeysBytes) > 0 {
|
||||||
|
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
authorizedKeysMap[string(pubKey.Marshal())] = true
|
||||||
|
authorizedKeysBytes = rest
|
||||||
|
}
|
||||||
|
|
||||||
|
return authorizedKeysMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
type sshDirectForwardConnector struct {
|
type sshDirectForwardConnector struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,11 +228,15 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
|
|||||||
}
|
}
|
||||||
if opts.User != nil {
|
if opts.User != nil {
|
||||||
config.User = opts.User.Username()
|
config.User = opts.User.Username()
|
||||||
password, _ := opts.User.Password()
|
if password, _ := opts.User.Password(); password != "" {
|
||||||
config.Auth = []ssh.AuthMethod{
|
config.Auth = []ssh.AuthMethod{
|
||||||
ssh.Password(password),
|
ssh.Password(password),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if opts.SSHConfig != nil && opts.SSHConfig.Key != nil {
|
||||||
|
config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key))
|
||||||
|
}
|
||||||
|
|
||||||
tr.sessionMutex.Lock()
|
tr.sessionMutex.Lock()
|
||||||
defer tr.sessionMutex.Unlock()
|
defer tr.sessionMutex.Unlock()
|
||||||
@ -217,6 +248,7 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
|
|||||||
if !ok || session.client == nil {
|
if !ok || session.client == nil {
|
||||||
sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config)
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Log("ssh", err)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
delete(tr.sessions, opts.Addr)
|
delete(tr.sessions, opts.Addr)
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -305,16 +337,20 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt
|
|||||||
}
|
}
|
||||||
|
|
||||||
config := ssh.ClientConfig{
|
config := ssh.ClientConfig{
|
||||||
|
Timeout: timeout,
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
}
|
}
|
||||||
// TODO: support pubkey auth.
|
|
||||||
if opts.User != nil {
|
if opts.User != nil {
|
||||||
config.User = opts.User.Username()
|
config.User = opts.User.Username()
|
||||||
password, _ := opts.User.Password()
|
if password, _ := opts.User.Password(); password != "" {
|
||||||
config.Auth = []ssh.AuthMethod{
|
config.Auth = []ssh.AuthMethod{
|
||||||
ssh.Password(password),
|
ssh.Password(password),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if opts.SSHConfig != nil && opts.SSHConfig.Key != nil {
|
||||||
|
config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key))
|
||||||
|
}
|
||||||
|
|
||||||
tr.sessionMutex.Lock()
|
tr.sessionMutex.Lock()
|
||||||
defer tr.sessionMutex.Unlock()
|
defer tr.sessionMutex.Unlock()
|
||||||
@ -684,6 +720,8 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque
|
|||||||
type SSHConfig struct {
|
type SSHConfig struct {
|
||||||
Authenticator Authenticator
|
Authenticator Authenticator
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
|
Key ssh.Signer
|
||||||
|
AuthorizedKeys map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type sshTunnelListener struct {
|
type sshTunnelListener struct {
|
||||||
@ -704,21 +742,22 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) {
|
|||||||
config = &SSHConfig{}
|
config = &SSHConfig{}
|
||||||
}
|
}
|
||||||
|
|
||||||
sshConfig := &ssh.ServerConfig{}
|
sshConfig := &ssh.ServerConfig{
|
||||||
sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Authenticator)
|
PasswordCallback: defaultSSHPasswordCallback(config.Authenticator),
|
||||||
if config.Authenticator == nil {
|
PublicKeyCallback: defaultSSHPublicKeyCallback(config.AuthorizedKeys),
|
||||||
sshConfig.NoClientAuth = true
|
|
||||||
}
|
|
||||||
tlsConfig := config.TLSConfig
|
|
||||||
if tlsConfig == nil {
|
|
||||||
tlsConfig = DefaultTLSConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
signer, err := ssh.NewSignerFromKey(tlsConfig.Certificates[0].PrivateKey)
|
if config.Authenticator == nil && len(config.AuthorizedKeys) == 0 {
|
||||||
|
sshConfig.NoClientAuth = true
|
||||||
|
}
|
||||||
|
|
||||||
|
signer := config.Key
|
||||||
|
if signer == nil {
|
||||||
|
signer, err = ssh.NewSignerFromKey(DefaultTLSConfig.Certificates[0].PrivateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ln.Close()
|
ln.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sshConfig.AddHostKey(signer)
|
sshConfig.AddHostKey(signer)
|
||||||
|
|
||||||
@ -823,9 +862,13 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PasswordCallbackFunc is a callback function used by SSH server.
|
// PasswordCallbackFunc is a callback function used by SSH server.
|
||||||
|
// It authenticates user using a password.
|
||||||
type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
|
type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
|
||||||
|
|
||||||
func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
|
func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
|
||||||
|
if au == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||||
if au.Authenticate(conn.User(), string(password)) {
|
if au.Authenticate(conn.User(), string(password)) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -835,6 +878,28 @@ func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PublicKeyCallbackFunc is a callback function used by SSH server.
|
||||||
|
// It offers a public key for authentication.
|
||||||
|
type PublicKeyCallbackFunc func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error)
|
||||||
|
|
||||||
|
func defaultSSHPublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
|
if keys[string(pubKey.Marshal())] {
|
||||||
|
return &ssh.Permissions{
|
||||||
|
// Record the public key used for authentication.
|
||||||
|
Extensions: map[string]string{
|
||||||
|
"pubkey-fp": ssh.FingerprintSHA256(pubKey),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unknown public key for %q", c.User())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type sshNopConn struct {
|
type sshNopConn struct {
|
||||||
session *sshSession
|
session *sshSession
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user