diff --git a/gost/examples/ssh/sshc.go b/gost/examples/ssh/sshc.go index cce27d5..a5fc674 100644 --- a/gost/examples/ssh/sshc.go +++ b/gost/examples/ssh/sshc.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "flag" "log" @@ -29,19 +30,80 @@ func init() { func main() { chain := gost.NewChain( gost.Node{ - Addr: faddr, + Protocol: "socks5", + Transport: "ssh", + Addr: faddr, Client: gost.NewClient( - gost.HTTPConnector(nil), + gost.SOCKS5Connector(nil), gost.SSHTunnelTransporter(), ), }, ) s := &gost.Server{} - s.Handle(gost.SOCKS5Handler(gost.ChainHandlerOption(chain))) + s.Handle(gost.SOCKS5Handler( + gost.ChainHandlerOption(chain), + gost.TLSConfigHandlerOption(tlsConfig()), + )) ln, err := gost.TCPListener(laddr) if err != nil { log.Fatal(err) } log.Fatal(s.Serve(ln)) } + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} diff --git a/gost/examples/ssh/sshd.go b/gost/examples/ssh/sshd.go index da6189e..adfd970 100644 --- a/gost/examples/ssh/sshd.go +++ b/gost/examples/ssh/sshd.go @@ -34,7 +34,7 @@ func main() { func sshTunnelServer() { s := &gost.Server{} s.Handle( - gost.HTTPHandler(), + gost.SOCKS5Handler(gost.TLSConfigHandlerOption(tlsConfig())), ) ln, err := gost.SSHTunnelListener(laddr, &gost.SSHConfig{TLSConfig: tlsConfig()}) diff --git a/gost/forward.go b/gost/forward.go index fc8ce0a..31b6c68 100644 --- a/gost/forward.go +++ b/gost/forward.go @@ -1,7 +1,6 @@ package gost import ( - "crypto/tls" "errors" "net" "time" @@ -218,10 +217,9 @@ func (l *udpForwardListener) Close() error { } type rtcpForwardListener struct { - addr net.Addr - chain *Chain - selector *clientSelector - close chan struct{} + addr net.Addr + chain *Chain + close chan struct{} } // RTCPForwardListener creates a Listener for TCP remote port forwarding server. @@ -230,24 +228,11 @@ func RTCPForwardListener(addr string, chain *Chain) (Listener, error) { if err != nil { return nil, err } - if chain.IsEmpty() || chain.LastNode().Protocol != "socks5" { - return nil, errors.New("invalid chain") - } - selector := &clientSelector{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - User: chain.LastNode().User, - } - selector.AddMethod( - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - MethodTLS, - ) return &rtcpForwardListener{ - addr: laddr, - chain: chain, - selector: selector, - close: make(chan struct{}), + addr: laddr, + chain: chain, + close: make(chan struct{}), }, nil } @@ -258,10 +243,6 @@ func (l *rtcpForwardListener) Accept() (net.Conn, error) { default: } - if l.chain.IsEmpty() || l.chain.LastNode().Protocol != "socks5" { - return nil, errors.New("invalid chain") - } - conn, err := l.chain.Conn() if err != nil { return nil, err @@ -276,12 +257,10 @@ func (l *rtcpForwardListener) Accept() (net.Conn, error) { } func (l *rtcpForwardListener) handshake(conn net.Conn) (net.Conn, error) { - cc := gosocks5.ClientConn(conn, l.selector) - if err := cc.Handleshake(); err != nil { + conn, err := socks5Handshake(conn, l.chain.LastNode().User) + if err != nil { return nil, err } - conn = cc - req := gosocks5.NewRequest(gosocks5.CmdBind, toSocksAddr(l.addr)) if err := req.Write(conn); err != nil { log.Log("[rtcp] SOCKS5 BIND request: ", err) @@ -327,10 +306,9 @@ func (l *rtcpForwardListener) Close() error { } type rudpForwardListener struct { - addr net.Addr - chain *Chain - selector *clientSelector - close chan struct{} + addr net.Addr + chain *Chain + close chan struct{} } // RUDPForwardListener creates a Listener for UDP remote port forwarding server. @@ -339,24 +317,11 @@ func RUDPForwardListener(addr string, chain *Chain) (Listener, error) { if err != nil { return nil, err } - if chain.IsEmpty() || chain.LastNode().Protocol != "socks5" { - return nil, errors.New("invalid chain") - } - selector := &clientSelector{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - User: chain.LastNode().User, - } - selector.AddMethod( - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - MethodTLS, - ) return &rudpForwardListener{ - addr: laddr, - chain: chain, - selector: selector, - close: make(chan struct{}), + addr: laddr, + chain: chain, + close: make(chan struct{}), }, nil } @@ -382,12 +347,10 @@ func (l *rudpForwardListener) Accept() (net.Conn, error) { } func (l *rudpForwardListener) handshake(conn net.Conn) (net.Conn, error) { - cc := gosocks5.ClientConn(conn, l.selector) - if err := cc.Handleshake(); err != nil { + conn, err := socks5Handshake(conn, l.chain.LastNode().User) + if err != nil { return nil, err } - conn = cc - req := gosocks5.NewRequest(CmdUDPTun, toSocksAddr(l.addr)) if err := req.Write(conn); err != nil { log.Log("[rudp] SOCKS5 UDP relay request: ", err) diff --git a/gost/gost.go b/gost/gost.go index 751989e..c9c16dc 100644 --- a/gost/gost.go +++ b/gost/gost.go @@ -1,6 +1,7 @@ package gost import ( + "errors" "time" "github.com/go-log/log" @@ -36,6 +37,10 @@ var ( defaultTTL = 60 ) +var ( + ErrSessionDead = errors.New("session is dead") +) + func init() { log.DefaultLogger = &LogLogger{} } diff --git a/gost/socks.go b/gost/socks.go index 42ba94f..c4f340a 100644 --- a/gost/socks.go +++ b/gost/socks.go @@ -822,7 +822,7 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { bindAddr, _ := net.ResolveUDPAddr("udp", addr) uc, err := net.ListenUDP("udp", bindAddr) if err != nil { - log.Logf("[socks5-rudp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) return } defer uc.Close() @@ -831,38 +831,59 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) if err := reply.Write(conn); err != nil { - log.Logf("[socks5-rudp] %s <- %s : %s", conn.RemoteAddr(), socksAddr, err) + log.Logf("[socks5-udp] %s <- %s : %s", conn.RemoteAddr(), socksAddr, err) return } if Debug { - log.Logf("[socks5-rudp] %s <- %s\n%s", conn.RemoteAddr(), socksAddr, reply) + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), socksAddr, reply) } - log.Logf("[socks5-rudp] %s <-> %s", conn.RemoteAddr(), socksAddr) + log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr) h.tunnelServerUDP(conn, uc) - log.Logf("[socks5-rudp] %s >-< %s", conn.RemoteAddr(), socksAddr) + log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr) return } cc, err := h.options.Chain.Conn() // connection error if err != nil { - log.Logf("[socks5-rudp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) reply := gosocks5.NewReply(gosocks5.Failure, nil) reply.Write(conn) - log.Logf("[socks5-rudp] %s -> %s\n%s", conn.RemoteAddr(), req.Addr, reply) + log.Logf("[socks5-udp] %s -> %s\n%s", conn.RemoteAddr(), req.Addr, reply) return } - defer cc.Close() + cc, err = socks5Handshake(cc, h.options.Chain.LastNode().User) + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + return + } // tunnel <-> tunnel, direct forwarding // note: this type of request forwarding is defined when starting server // so we don't need to authenticate it, as it's as explicit as whitelisting req.Write(cc) - log.Logf("[socks5-rudp] %s <-> %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) + log.Logf("[socks5-udp] %s <-> %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) transport(conn, cc) - log.Logf("[socks5-rudp] %s >-< %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) + log.Logf("[socks5-udp] %s >-< %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) +} + +func socks5Handshake(conn net.Conn, user *url.Userinfo) (net.Conn, error) { + selector := &clientSelector{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + User: user, + } + selector.AddMethod( + gosocks5.MethodNoAuth, + gosocks5.MethodUserPass, + MethodTLS, + ) + cc := gosocks5.ClientConn(conn, selector) + if err := cc.Handleshake(); err != nil { + return nil, err + } + return cc, nil } func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error) { @@ -874,7 +895,7 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error for { n, addr, err := uc.ReadFromUDP(b) if err != nil { - log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) + // log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) errc <- err return } @@ -897,7 +918,7 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error for { dgram, err := gosocks5.ReadUDPDatagram(cc) if err != nil { - log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) + // log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) errc <- err return } diff --git a/gost/ssh.go b/gost/ssh.go index 0861960..81cf47c 100644 --- a/gost/ssh.go +++ b/gost/ssh.go @@ -1,6 +1,7 @@ package gost import ( + "context" "crypto/tls" "encoding/binary" "errors" @@ -145,7 +146,10 @@ func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn n defer tr.sessionMutex.Unlock() session, ok := tr.sessions[addr] - if !ok { + if !ok || session.Closed() { + if session != nil { + session.client.Close() + } if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, opts.Timeout) } else { @@ -206,18 +210,17 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt } tr.sessions[opts.Addr] = session go session.Ping(opts.Interval, 1) + go session.Wait() } - if session.Dead() { + if session.Closed() { + session.client.Close() delete(tr.sessions, opts.Addr) - return nil, errors.New("ssh: session is dead") + return nil, ErrSessionDead } channel, reqs, err := session.client.OpenChannel(GostSSHTunnelRequest, nil) if err != nil { - session.client.Close() - close(session.closed) - delete(tr.sessions, opts.Addr) return nil, err } go ssh.DiscardRequests(reqs) @@ -242,10 +245,9 @@ func (s *sshSession) Ping(interval time.Duration, retries int) { return } defer close(s.deaded) - defer s.client.Close() log.Log("[ssh] ping is enabled, interval:", interval) - // baseCtx := context.Background() + baseCtx := context.Background() t := time.NewTicker(interval) defer t.Stop() @@ -256,7 +258,14 @@ func (s *sshSession) Ping(interval time.Duration, retries int) { //if Debug { log.Log("[ssh] sending ping") //} - _, _, err := s.client.SendRequest("ping", true, nil) + ctx, cancel := context.WithTimeout(baseCtx, time.Second*30) + var err error + select { + case err = <-s.sendPing(): + case <-ctx.Done(): + err = errors.New("Timeout") + } + cancel() if err != nil { log.Log("[ssh] ping:", err) return @@ -271,10 +280,26 @@ func (s *sshSession) Ping(interval time.Duration, retries int) { } } -func (s *sshSession) Dead() bool { +func (s *sshSession) sendPing() <-chan error { + ch := make(chan error, 1) + if _, _, err := s.client.SendRequest("ping", true, nil); err != nil { + ch <- err + } + close(ch) + return ch +} + +func (s *sshSession) Wait() error { + defer close(s.closed) + return s.client.Wait() +} + +func (s *sshSession) Closed() bool { select { case <-s.deaded: return true + case <-s.closed: + return true default: } return false