From 302476f481e09d035853072d2d7e6294810894d9 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Thu, 27 Jul 2017 18:07:27 +0800 Subject: [PATCH] add ssh tunnel support --- gost/client.go | 15 + gost/{ => examples}/cli/cli.go | 21 +- gost/{ => examples}/srv/cert.pem | 0 gost/{ => examples}/srv/key.pem | 0 gost/{ => examples}/srv/srv.go | 34 ++ gost/examples/ssh/sshc.go | 47 ++ gost/examples/ssh/sshd.go | 101 +++++ gost/handler.go | 8 + gost/http2.go | 8 +- gost/kcp.go | 7 +- gost/ssh.go | 734 ++++++++++++++++++++++++------- gost/tls.go | 4 - gost/ws.go | 8 - 13 files changed, 813 insertions(+), 174 deletions(-) rename gost/{ => examples}/cli/cli.go (90%) rename gost/{ => examples}/srv/cert.pem (100%) rename gost/{ => examples}/srv/key.pem (100%) rename gost/{ => examples}/srv/srv.go (86%) create mode 100644 gost/examples/ssh/sshc.go create mode 100644 gost/examples/ssh/sshd.go diff --git a/gost/client.go b/gost/client.go index 9553488..4d54a3e 100644 --- a/gost/client.go +++ b/gost/client.go @@ -3,6 +3,7 @@ package gost import ( "crypto/tls" "net" + "net/url" "time" ) @@ -120,7 +121,9 @@ func ChainDialOption(chain *Chain) DialOption { // HandshakeOptions describes the options for handshake. type HandshakeOptions struct { Addr string + User *url.Userinfo Timeout time.Duration + Interval time.Duration TLSConfig *tls.Config WSOptions *WSOptions KCPConfig *KCPConfig @@ -135,12 +138,24 @@ func AddrHandshakeOption(addr string) HandshakeOption { } } +func UserHandshakeOption(user *url.Userinfo) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.User = user + } +} + func TimeoutHandshakeOption(timeout time.Duration) HandshakeOption { return func(opts *HandshakeOptions) { opts.Timeout = timeout } } +func IntervalHandshakeOption(interval time.Duration) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Interval = interval + } +} + func TLSConfigHandshakeOption(config *tls.Config) HandshakeOption { return func(opts *HandshakeOptions) { opts.TLSConfig = config diff --git a/gost/cli/cli.go b/gost/examples/cli/cli.go similarity index 90% rename from gost/cli/cli.go rename to gost/examples/cli/cli.go index f366fbc..16334cf 100644 --- a/gost/cli/cli.go +++ b/gost/examples/cli/cli.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "net/http/httputil" + "net/url" "sync" "time" @@ -116,12 +117,23 @@ func main() { }, */ - // http+kcp + /* + // http+kcp + gost.Node{ + Addr: "127.0.0.1:18388", + Client: gost.NewClient( + gost.HTTPConnector(nil), + gost.KCPTransporter(nil), + ), + }, + */ + + // http+ssh gost.Node{ - Addr: "127.0.0.1:18388", + Addr: "127.0.0.1:12222", Client: gost.NewClient( - gost.HTTPConnector(nil), - gost.KCPTransporter(nil), + gost.HTTPConnector(url.UserPassword("admin", "123456")), + gost.SSHTunnelTransporter(), ), }, ) @@ -143,6 +155,7 @@ func main() { duration := time.Since(start) total += concurrency log.Printf("%d/%d/%d requests done (%v/%v)", total, requests, concurrency, duration, duration/time.Duration(concurrency)) + time.Sleep(500 * time.Millisecond) } } diff --git a/gost/srv/cert.pem b/gost/examples/srv/cert.pem similarity index 100% rename from gost/srv/cert.pem rename to gost/examples/srv/cert.pem diff --git a/gost/srv/key.pem b/gost/examples/srv/key.pem similarity index 100% rename from gost/srv/key.pem rename to gost/examples/srv/key.pem diff --git a/gost/srv/srv.go b/gost/examples/srv/srv.go similarity index 86% rename from gost/srv/srv.go rename to gost/examples/srv/srv.go index 0aac7a7..5017275 100644 --- a/gost/srv/srv.go +++ b/gost/examples/srv/srv.go @@ -37,6 +37,8 @@ func main() { // go rtcpForwardServer() // go rudpForwardServer() // go tcpRedirectServer() + // go sshForwardServer() + go sshTunnelServer() // go http2Server() select {} @@ -193,6 +195,38 @@ func tcpRedirectServer() { log.Fatal(s.Serve(ln)) } +func sshForwardServer() { + s := &gost.Server{} + s.Handle( + gost.SSHForwardHandler( + gost.AddrHandlerOption(":1222"), + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + gost.TLSConfigHandlerOption(tlsConfig()), + ), + ) + + ln, err := gost.TCPListener(":1222") + if err != nil { + log.Fatal(err) + } + log.Fatal(s.Serve(ln)) +} + +func sshTunnelServer() { + s := &gost.Server{} + s.Handle( + gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + ) + + ln, err := gost.SSHTunnelListener(":12222", &gost.SSHConfig{TLSConfig: tlsConfig()}) + if err != nil { + log.Fatal(err) + } + log.Fatal(s.Serve(ln)) +} + func http2Server() { // http2.VerboseLogs = true diff --git a/gost/examples/ssh/sshc.go b/gost/examples/ssh/sshc.go new file mode 100644 index 0000000..ac8ba19 --- /dev/null +++ b/gost/examples/ssh/sshc.go @@ -0,0 +1,47 @@ +package main + +import ( + "flag" + "log" + + "github.com/ginuerzh/gost/gost" +) + +var ( + laddr, faddr string + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":18080", "listen address") + flag.StringVar(&faddr, "F", ":12222", "forward address") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} + +func main() { + chain := gost.NewChain( + gost.Node{ + Addr: faddr, + Client: gost.NewClient( + gost.HTTPConnector(nil), + gost.SSHTunnelTransporter(), + ), + }, + ) + + s := &gost.Server{} + s.Handle(gost.HTTPHandler(gost.ChainHandlerOption(chain))) + ln, err := gost.TCPListener(laddr) + if err != nil { + log.Fatal(err) + } + log.Fatal(s.Serve(ln)) +} diff --git a/gost/examples/ssh/sshd.go b/gost/examples/ssh/sshd.go new file mode 100644 index 0000000..da6189e --- /dev/null +++ b/gost/examples/ssh/sshd.go @@ -0,0 +1,101 @@ +package main + +import ( + "crypto/tls" + "flag" + "log" + + "github.com/ginuerzh/gost/gost" +) + +var ( + laddr string + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":12222", "listen address") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} + +func main() { + sshTunnelServer() +} + +func sshTunnelServer() { + s := &gost.Server{} + s.Handle( + gost.HTTPHandler(), + ) + + ln, err := gost.SSHTunnelListener(laddr, &gost.SSHConfig{TLSConfig: tlsConfig()}) + if err != nil { + log.Fatal(err) + } + log.Fatal(s.Serve(ln)) +} + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} diff --git a/gost/handler.go b/gost/handler.go index 78282c2..1188787 100644 --- a/gost/handler.go +++ b/gost/handler.go @@ -13,6 +13,7 @@ type Handler interface { // HandlerOptions describes the options for Handler. type HandlerOptions struct { + Addr string Chain *Chain Users []*url.Userinfo TLSConfig *tls.Config @@ -21,6 +22,13 @@ type HandlerOptions struct { // HandlerOption allows a common way to set handler options. type HandlerOption func(opts *HandlerOptions) +// AddrHandlerOption sets the Addr option of HandlerOptions. +func AddrHandlerOption(addr string) HandlerOption { + return func(opts *HandlerOptions) { + opts.Addr = addr + } +} + // ChainHandlerOption sets the Chain option of HandlerOptions. func ChainHandlerOption(chain *Chain) HandlerOption { return func(opts *HandlerOptions) { diff --git a/gost/http2.go b/gost/http2.go index 365ef53..717d620 100644 --- a/gost/http2.go +++ b/gost/http2.go @@ -32,7 +32,7 @@ func HTTP2Connector(user *url.Userinfo) Connector { func (c *http2Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { cc, ok := conn.(*http2DummyConn) if !ok { - return nil, errors.New("conn must be a conn wrapper") + return nil, errors.New("wrong connection type") } pr, pw := io.Pipe() @@ -155,10 +155,8 @@ type http2Handler struct { // HTTP2Handler creates a server Handler for HTTP2 proxy server. func HTTP2Handler(opts ...HandlerOption) Handler { h := &http2Handler{ - server: new(http2.Server), - options: &HandlerOptions{ - Chain: new(Chain), - }, + server: new(http2.Server), + options: new(HandlerOptions), } for _, opt := range opts { opt(h.options) diff --git a/gost/kcp.go b/gost/kcp.go index 585d093..ecda95e 100644 --- a/gost/kcp.go +++ b/gost/kcp.go @@ -214,10 +214,15 @@ func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( defer tr.sessionMutex.Unlock() session, ok := tr.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("kcp: unrecognized connection") + } if !ok || session.session == nil { s, err := tr.initSession(opts.Addr, conn, config) if err != nil { conn.Close() + delete(tr.sessions, opts.Addr) return nil, err } session = s @@ -236,7 +241,7 @@ func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPConfig) (*kcpSession, error) { udpConn, ok := conn.(*net.UDPConn) if !ok { - return nil, errors.New("wrong connection type") + return nil, errors.New("kcp: wrong connection type") } kcpconn, err := kcp.NewConn(addr, diff --git a/gost/ssh.go b/gost/ssh.go index d4930d2..349d815 100644 --- a/gost/ssh.go +++ b/gost/ssh.go @@ -1,11 +1,16 @@ package gost import ( + "crypto/tls" + "encoding/binary" "errors" "fmt" "net" "net/url" "strconv" + "sync" + "time" + "weed-fs/go/glog" "github.com/go-log/log" "golang.org/x/crypto/ssh" @@ -17,65 +22,315 @@ const ( RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1 ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2 CancelRemoteForwardRequest = "cancel-tcpip-forward" // RFC 4254 7.1 + + GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel ) -type sshListener struct { - net.Listener - config *ssh.ServerConfig - connChan chan net.Conn - errChan chan error +type sshForwardConnector struct { } -func SSHListener(addr string, config *ssh.ServerConfig) (Listener, error) { - ln, err := net.Listen("tcp", addr) +func SSHForwardConnector() Connector { + return &sshForwardConnector{} +} + +func (c *sshForwardConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { + cc, ok := conn.(*sshNopConn) + if !ok { + return nil, errors.New("ssh: wrong connection type") + } + conn, err := cc.session.client.Dial("tcp", addr) if err != nil { + log.Logf("[ssh-tcp] %s -> %s : %s", cc.session.addr, addr, err) return nil, err } - - if config == nil { - config = &ssh.ServerConfig{} - } - - l := &sshListener{ - Listener: ln, - config: config, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - - go l.listenLoop() - - return l, nil + return conn, nil } -func (l *sshListener) listenLoop() { - for { - conn, err := l.Listener.Accept() +type sshForwardTransporter struct { + sessions map[string]*sshSession + sessionMutex sync.Mutex +} + +func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if !ok { + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } if err != nil { - log.Log("[ssh] accept:", err) - l.errChan <- err - close(l.errChan) return } - go l.serveConn(conn) + session = &sshSession{ + addr: addr, + conn: conn, + } + tr.sessions[addr] = session + } + + return session.conn, nil +} + +func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + config := ssh.ClientConfig{ + Timeout: opts.Timeout, + } + if opts.User != nil { + config.User = opts.User.Username() + password, _ := opts.User.Password() + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[opts.Addr] + if !ok || session.client == nil { + sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + session = &sshSession{ + addr: opts.Addr, + conn: conn, + client: ssh.NewClient(sshConn, chans, reqs), + closed: make(chan struct{}), + } + tr.sessions[opts.Addr] = session + } + return &sshNopConn{session: session}, nil +} + +func (tr *sshForwardTransporter) Multiplex() bool { + return true +} + +type sshTunnelTransporter struct { + sessions map[string]*sshSession + sessionMutex sync.Mutex +} + +// SSHTunnelTransporter creates a Transporter that is used by SSH tunnel client. +func SSHTunnelTransporter() Transporter { + return &sshTunnelTransporter{ + sessions: make(map[string]*sshSession), } } -func (l *sshListener) serveConn(conn net.Conn) { - sshConn, chans, reqs, err := ssh.NewServerConn(conn, l.config) +func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if !ok { + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &sshSession{ + addr: addr, + conn: conn, + } + tr.sessions[addr] = session + } + + return session.conn, nil +} + +func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + config := ssh.ClientConfig{ + Timeout: opts.Timeout, + } + if opts.User != nil { + config.User = opts.User.Username() + password, _ := opts.User.Password() + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("ssh: unrecognized connection") + } + if !ok || session.client == nil { + sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + session = &sshSession{ + addr: opts.Addr, + conn: conn, + client: ssh.NewClient(sshConn, chans, reqs), + closed: make(chan struct{}), + deaded: make(chan struct{}), + } + tr.sessions[opts.Addr] = session + go session.Ping(opts.Interval, 1) + } + + if session.Dead() { + delete(tr.sessions, opts.Addr) + return nil, errors.New("ssh: session is dead") + } + + channel, reqs, err := session.client.OpenChannel(GostSSHTunnelRequest, nil) if err != nil { - log.Logf("[ssh] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + session.client.Close() + close(session.closed) + delete(tr.sessions, opts.Addr) + return nil, err + } + go ssh.DiscardRequests(reqs) + return &sshConn{channel: channel, conn: conn}, nil +} + +func (tr *sshTunnelTransporter) Multiplex() bool { + return true +} + +type sshSession struct { + addr string + conn net.Conn + client *ssh.Client + closed chan struct{} + deaded chan struct{} +} + +func (s *sshSession) Ping(interval time.Duration, retries int) { + interval = 30 * time.Second + if interval <= 0 { + return + } + defer close(s.deaded) + defer s.client.Close() + + log.Log("[ssh] ping is enabled, interval:", interval) + // baseCtx := context.Background() + t := time.NewTicker(interval) + defer t.Stop() + + for { + select { + case <-t.C: + if Debug { + log.Log("[ssh] sending ping") + } + _, _, err := s.client.SendRequest("ping", true, nil) + if err != nil { + log.Log("[ssh] ping:", err) + return + } + if Debug { + log.Log("[ssh] ping OK") + } + + case <-s.closed: + return + } + } +} + +func (s *sshSession) Dead() bool { + select { + case <-s.deaded: + return true + default: + } + return false +} + +type sshForwardHandler struct { + options *HandlerOptions + config *ssh.ServerConfig +} + +func SSHForwardHandler(opts ...HandlerOption) Handler { + h := &sshForwardHandler{ + options: new(HandlerOptions), + config: new(ssh.ServerConfig), + } + for _, opt := range opts { + opt(h.options) + } + h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Users...) + if len(h.options.Users) == 0 { + h.config.NoClientAuth = true + } + if h.options.TLSConfig != nil && len(h.options.TLSConfig.Certificates) > 0 { + signer, err := ssh.NewSignerFromKey(h.options.TLSConfig.Certificates[0].PrivateKey) + if err != nil { + log.Log("[sshf]", err) + } + h.config.AddHostKey(signer) + } + + return h +} + +func (h *sshForwardHandler) Handle(conn net.Conn) { + sshConn, chans, reqs, err := ssh.NewServerConn(conn, h.config) + if err != nil { + log.Logf("[sshf] %s -> %s : %s", conn.RemoteAddr(), h.options.Addr, err) conn.Close() return } defer sshConn.Close() - quit := make(chan interface{}) + log.Logf("[sshf] %s <-> %s", conn.RemoteAddr(), h.options.Addr) + h.handleForward(sshConn, chans, reqs) + log.Logf("[sshf] %s >-< %s", conn.RemoteAddr(), h.options.Addr) +} + +func (h *sshForwardHandler) handleForward(conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { + quit := make(chan struct{}) + defer close(quit) // quit signal + go func() { for req := range reqs { switch req.Type { case RemoteForwardRequest: - // go l.tcpipForwardRequest(conn, req, quit) + go h.tcpipForwardRequest(conn, req, quit) default: log.Log("[ssh] unknown channel type:", req.Type) if req.WantReply { @@ -91,7 +346,6 @@ func (l *sshListener) serveConn(conn net.Conn) { t := newChannel.ChannelType() switch t { case DirectForwardRequest: - /* channel, requests, err := newChannel.Accept() if err != nil { log.Log("[ssh] Could not accept channel:", err) @@ -105,8 +359,7 @@ func (l *sshListener) serveConn(conn net.Conn) { } go ssh.DiscardRequests(requests) - go l.directPortForwardChannel(channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1)) - */ + go h.directPortForwardChannel(channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1)) default: log.Log("[ssh] Unknown channel type:", t) newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) @@ -114,12 +367,220 @@ func (l *sshListener) serveConn(conn net.Conn) { } }() - sshConn.Wait() - close(quit) - + conn.Wait() } -func (l *sshListener) Accept() (conn net.Conn, err error) { +func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr string) { + defer channel.Close() + + log.Logf("[ssh-tcp] %s - %s", h.options.Addr, raddr) + + //! if !s.Base.Node.Can("tcp", raddr) { + //! glog.Errorf("Unauthorized to tcp connect to %s", raddr) + //! return + //! } + + conn, err := h.options.Chain.Dial(raddr) + if err != nil { + log.Logf("[ssh-tcp] %s - %s : %s", h.options.Addr, raddr, err) + return + } + defer conn.Close() + + log.Logf("[ssh-tcp] %s <-> %s", h.options.Addr, raddr) + transport(conn, channel) + log.Logf("[ssh-tcp] %s >-< %s", h.options.Addr, raddr) +} + +// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request +type tcpipForward struct { + Host string + Port uint32 +} + +func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan struct{}) { + t := tcpipForward{} + ssh.Unmarshal(req.Payload, &t) + + addr := fmt.Sprintf("%s:%d", t.Host, t.Port) + + //! if !s.Base.Node.Can("rtcp", addr) { + //! glog.Errorf("Unauthorized to tcp bind to %s", addr) + //! req.Reply(false, nil) + //! return + //! } + + log.Log("[ssh-rtcp] listening on tcp", addr) + ln, err := net.Listen("tcp", addr) //tie to the client connection + if err != nil { + log.Log("[ssh-rtcp]", err) + req.Reply(false, nil) + return + } + defer ln.Close() + + replyFunc := func() error { + if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used + _, port, err := getHostPortFromAddr(ln.Addr()) + if err != nil { + return err + } + var b [4]byte + binary.BigEndian.PutUint32(b[:], uint32(port)) + t.Port = uint32(port) + return req.Reply(true, b[:]) + } + return req.Reply(true, nil) + } + if err := replyFunc(); err != nil { + log.Log("[ssh-rtcp]", err) + return + } + + go func() { + for { + conn, err := ln.Accept() + if err != nil { // Unable to accept new connection - listener is likely closed + return + } + + go func(conn net.Conn) { + defer conn.Close() + + p := directForward{} + var err error + + var portnum int + p.Host1 = t.Host + p.Port1 = t.Port + p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr()) + if err != nil { + return + } + + p.Port2 = uint32(portnum) + glog.V(3).Info(p) + ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p)) + if err != nil { + log.Log("[ssh-rtcp] open forwarded channel:", err) + return + } + defer ch.Close() + go ssh.DiscardRequests(reqs) + + log.Logf("[ssh-rtcp] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + transport(ch, conn) + log.Logf("[ssh-rtcp] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) + }(conn) + } + }() + + <-quit +} + +// SSHConfig holds the SSH tunnel server config +type SSHConfig struct { + Users []*url.Userinfo + TLSConfig *tls.Config +} + +type sshTunnelListener struct { + net.Listener + config *ssh.ServerConfig + connChan chan net.Conn + errChan chan error +} + +// SSHTunnelListener creates a Listener for SSH tunnel server. +func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + if config == nil { + config = &SSHConfig{} + } + + sshConfig := &ssh.ServerConfig{} + sshConfig.PasswordCallback = defaultSSHPasswordCallback(config.Users...) + if len(config.Users) == 0 { + sshConfig.NoClientAuth = true + } + if config.TLSConfig != nil && len(config.TLSConfig.Certificates) > 0 { + signer, err := ssh.NewSignerFromKey(config.TLSConfig.Certificates[0].PrivateKey) + if err != nil { + log.Log("[sshf]", err) + } + sshConfig.AddHostKey(signer) + } + + l := &sshTunnelListener{ + Listener: ln, + config: sshConfig, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + + go l.listenLoop() + + return l, nil +} + +func (l *sshTunnelListener) listenLoop() { + for { + conn, err := l.Listener.Accept() + if err != nil { + log.Log("[ssh] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.serveConn(conn) + } +} + +func (l *sshTunnelListener) serveConn(conn net.Conn) { + sc, chans, reqs, err := ssh.NewServerConn(conn, l.config) + if err != nil { + log.Logf("[ssh] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + conn.Close() + return + } + defer sc.Close() + + go ssh.DiscardRequests(reqs) + go func() { + for newChannel := range chans { + // Check the type of channel + t := newChannel.ChannelType() + switch t { + case GostSSHTunnelRequest: + channel, requests, err := newChannel.Accept() + if err != nil { + log.Log("[ssh] Could not accept channel:", err) + continue + } + go ssh.DiscardRequests(requests) + select { + case l.connChan <- &sshConn{conn: conn, channel: channel}: + default: + log.Logf("[ssh] %s - %s: connection queue is full", conn.RemoteAddr(), l.Addr()) + } + + default: + log.Log("[ssh] Unknown channel type:", t) + newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + } + } + }() + + log.Logf("[ssh] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + sc.Wait() + log.Logf("[ssh] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) +} + +func (l *sshTunnelListener) Accept() (conn net.Conn, err error) { var ok bool select { case conn = <-l.connChan: @@ -143,116 +604,6 @@ func (p directForward) String() string { return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) } -/* -func (l *sshListener) directPortForwardChannel(channel ssh.Channel, raddr string) { - defer channel.Close() - - log.Logf("[ssh-tcp] %s - %s", l.Addr, raddr) - - //! if !s.Base.Node.Can("tcp", raddr) { - //! glog.Errorf("Unauthorized to tcp connect to %s", raddr) - //! return - //! } - - conn, err := h.Base.Chain.Dial(raddr) - if err != nil { - glog.V(LINFO).Infof("[ssh-tcp] %s - %s : %s", s.Addr, raddr, err) - return - } - defer conn.Close() - - glog.V(LINFO).Infof("[ssh-tcp] %s <-> %s", s.Addr, raddr) - Transport(conn, channel) - glog.V(LINFO).Infof("[ssh-tcp] %s >-< %s", s.Addr, raddr) -} - -// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request -type tcpipForward struct { - Host string - Port uint32 -} - -func (s *SSHServer) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan interface{}) { - t := tcpipForward{} - ssh.Unmarshal(req.Payload, &t) - - addr := fmt.Sprintf("%s:%d", t.Host, t.Port) - - if !s.Base.Node.Can("rtcp", addr) { - glog.Errorf("Unauthorized to tcp bind to %s", addr) - req.Reply(false, nil) - return - } - - glog.V(LINFO).Infoln("[ssh-rtcp] listening tcp", addr) - ln, err := net.Listen("tcp", addr) //tie to the client connection - if err != nil { - glog.V(LWARNING).Infoln("[ssh-rtcp]", err) - req.Reply(false, nil) - return - } - defer ln.Close() - - replyFunc := func() error { - if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used - _, port, err := getHostPortFromAddr(ln.Addr()) - if err != nil { - return err - } - var b [4]byte - binary.BigEndian.PutUint32(b[:], uint32(port)) - t.Port = uint32(port) - return req.Reply(true, b[:]) - } - return req.Reply(true, nil) - } - if err := replyFunc(); err != nil { - glog.V(LWARNING).Infoln("[ssh-rtcp]", err) - return - } - - go func() { - for { - conn, err := ln.Accept() - if err != nil { // Unable to accept new connection - listener likely closed - return - } - - go func(conn net.Conn) { - defer conn.Close() - - p := directForward{} - var err error - - var portnum int - p.Host1 = t.Host - p.Port1 = t.Port - p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr()) - if err != nil { - return - } - - p.Port2 = uint32(portnum) - glog.V(3).Info(p) - ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p)) - if err != nil { - glog.V(1).Infoln("[ssh-rtcp] open forwarded channel:", err) - return - } - defer ch.Close() - go ssh.DiscardRequests(reqs) - - glog.V(LINFO).Infof("[ssh-rtcp] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) - Transport(ch, conn) - glog.V(LINFO).Infof("[ssh-rtcp] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) - }(conn) - } - }() - - <-quit -} -*/ - func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { host, portString, err := net.SplitHostPort(addr.String()) if err != nil { @@ -264,7 +615,7 @@ func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) -func SSHPasswordCallback(users []*url.Userinfo) PasswordCallbackFunc { +func defaultSSHPasswordCallback(users ...*url.Userinfo) PasswordCallbackFunc { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { for _, user := range users { u := user.Username() @@ -277,3 +628,82 @@ func SSHPasswordCallback(users []*url.Userinfo) PasswordCallbackFunc { return nil, fmt.Errorf("password rejected for %s", conn.User()) } } + +type sshNopConn struct { + session *sshSession +} + +func (c *sshNopConn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *sshNopConn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *sshNopConn) Close() error { + return nil +} + +func (c *sshNopConn) LocalAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } +} + +func (c *sshNopConn) RemoteAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } +} + +func (c *sshNopConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshNopConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshNopConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +type sshConn struct { + channel ssh.Channel + conn net.Conn +} + +func (c *sshConn) Read(b []byte) (n int, err error) { + return c.channel.Read(b) +} + +func (c *sshConn) Write(b []byte) (n int, err error) { + return c.channel.Write(b) +} + +func (c *sshConn) Close() error { + return c.channel.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/gost/tls.go b/gost/tls.go index 3b5952f..0140d5b 100644 --- a/gost/tls.go +++ b/gost/tls.go @@ -26,10 +26,6 @@ func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( return tls.Client(conn, opts.TLSConfig), nil } -func (tr *tlsTransporter) Multiplex() bool { - return false -} - type tlsListener struct { net.Listener } diff --git a/gost/ws.go b/gost/ws.go index 191478c..204be1b 100644 --- a/gost/ws.go +++ b/gost/ws.go @@ -122,10 +122,6 @@ func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (n return websocketClientConn(url.String(), conn, nil, wsOptions) } -func (tr *wsTransporter) Multiplex() bool { - return false -} - type wssTransporter struct { *tcpTransporter options *WSOptions @@ -154,10 +150,6 @@ func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) } -func (tr *wssTransporter) Multiplex() bool { - return false -} - type wsListener struct { addr net.Addr upgrader *websocket.Upgrader