diff --git a/forward.go b/forward.go index 40c5154..23c4d64 100644 --- a/forward.go +++ b/forward.go @@ -187,14 +187,15 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { return } + raddr, err := net.ResolveUDPAddr("udp", node.Addr) + if err != nil { + node.MarkDead() + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) + return + } + var cc net.Conn if h.options.Chain.IsEmpty() { - raddr, err := net.ResolveUDPAddr("udp", node.Addr) - if err != nil { - node.MarkDead() - log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) - return - } cc, err = net.DialUDP("udp", nil, raddr) if err != nil { node.MarkDead() @@ -208,7 +209,8 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) return } - cc = &udpTunnelConn{Conn: cc, raddr: node.Addr} + + cc = &udpTunnelConn{Conn: cc, raddr: raddr} } defer cc.Close() @@ -763,7 +765,7 @@ func (l *tcpRemoteForwardListener) getSession() (s *muxSession, err error) { conn.SetDeadline(time.Now().Add(HandshakeTimeout)) defer conn.SetDeadline(time.Time{}) - conn, err = socks5Handshake(conn, nil, l.chain.LastNode().User) + conn, err = socks5Handshake(conn, userSocks5HandshakeOption(l.chain.LastNode().User)) if err != nil { return nil, err } @@ -798,7 +800,7 @@ func (l *tcpRemoteForwardListener) getSession() (s *muxSession, err error) { } func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) { - conn, err := socks5Handshake(conn, nil, l.chain.LastNode().User) + conn, err := socks5Handshake(conn, userSocks5HandshakeOption(l.chain.LastNode().User)) if err != nil { return nil, err } diff --git a/gost.go b/gost.go index 9a71fc0..4a0b171 100644 --- a/gost.go +++ b/gost.go @@ -20,7 +20,7 @@ import ( ) // Version is the gost version. -const Version = "2.9.0" +const Version = "2.9.1-dev" // Debug is a flag that enables the debug log. var Debug bool diff --git a/snapcraft.yaml b/snapcraft.yaml index 639fa56..13ab1d1 100644 --- a/snapcraft.yaml +++ b/snapcraft.yaml @@ -1,6 +1,6 @@ name: gost type: app -version: '2.9.0' +version: '2.9.1' title: GO Simple Tunnel summary: A simple security tunnel written in golang description: | diff --git a/socks.go b/socks.go index d9a94ba..cfc4480 100644 --- a/socks.go +++ b/socks.go @@ -213,7 +213,9 @@ func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...Connect if user == nil { user = c.User } - cc, err := socks5Handshake(conn, opts.Selector, user) + cc, err := socks5Handshake(conn, + selectorSocks5HandshakeOption(opts.Selector), + userSocks5HandshakeOption(user)) if err != nil { return nil, err } @@ -281,7 +283,9 @@ func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...Con if user == nil { user = c.User } - cc, err := socks5Handshake(conn, opts.Selector, user) + cc, err := socks5Handshake(conn, + selectorSocks5HandshakeOption(opts.Selector), + userSocks5HandshakeOption(user)) if err != nil { return nil, err } @@ -431,7 +435,7 @@ func (tr *socks5MuxBindTransporter) initSession(conn net.Conn, addr string, opts opts = &HandshakeOptions{} } - cc, err := socks5Handshake(conn, nil, opts.User) + cc, err := socks5Handshake(conn, userSocks5HandshakeOption(opts.User)) if err != nil { return nil, err } @@ -515,7 +519,9 @@ func (c *socks5UDPConnector) Connect(conn net.Conn, addr string, options ...Conn if user == nil { user = c.User } - cc, err := socks5Handshake(conn, opts.Selector, user) + cc, err := socks5Handshake(conn, + selectorSocks5HandshakeOption(opts.Selector), + userSocks5HandshakeOption(user)) if err != nil { return nil, err } @@ -594,7 +600,9 @@ func (c *socks5UDPTunConnector) Connect(conn net.Conn, addr string, options ...C if user == nil { user = c.User } - cc, err := socks5Handshake(conn, opts.Selector, user) + cc, err := socks5Handshake(conn, + selectorSocks5HandshakeOption(opts.Selector), + userSocks5HandshakeOption(user)) if err != nil { return nil, err } @@ -636,7 +644,7 @@ func (c *socks5UDPTunConnector) Connect(conn net.Conn, addr string, options ...C } log.Logf("[socks5] udp-tun associate on %s OK", baddr) - return &udpTunnelConn{Conn: conn, raddr: taddr.String()}, nil + return &udpTunnelConn{Conn: conn, raddr: taddr}, nil } type socks4Connector struct{} @@ -1122,7 +1130,7 @@ func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) { } defer cc.Close() - cc, err = socks5Handshake(cc, nil, h.options.Chain.LastNode().User) + cc, err = socks5Handshake(cc, userSocks5HandshakeOption(h.options.Chain.LastNode().User)) if err != nil { log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, err) return @@ -1386,7 +1394,7 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { } defer cc.Close() - cc, err = socks5Handshake(cc, nil, h.options.Chain.LastNode().User) + cc, err = socks5Handshake(cc, userSocks5HandshakeOption(h.options.Chain.LastNode().User)) if err != nil { log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) return @@ -1780,7 +1788,11 @@ func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { conn.SetDeadline(time.Now().Add(HandshakeTimeout)) defer conn.SetDeadline(time.Time{}) - cc, err := socks5Handshake(conn, nil, chain.LastNode().User) + node := chain.LastNode() + cc, err := socks5Handshake(conn, + userSocks5HandshakeOption(node.User), + noTLSSocks5HandshakeOption(node.GetBool("notls")), + ) if err != nil { conn.Close() return nil, err @@ -1813,17 +1825,51 @@ func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { return conn, nil } -func socks5Handshake(conn net.Conn, selector gosocks5.Selector, user *url.Userinfo) (net.Conn, error) { +type socks5HandshakeOptions struct { + selector gosocks5.Selector + user *url.Userinfo + tlsConfig *tls.Config + noTLS bool +} + +type socks5HandshakeOption func(opts *socks5HandshakeOptions) + +func selectorSocks5HandshakeOption(selector gosocks5.Selector) socks5HandshakeOption { + return func(opts *socks5HandshakeOptions) { + opts.selector = selector + } +} + +func userSocks5HandshakeOption(user *url.Userinfo) socks5HandshakeOption { + return func(opts *socks5HandshakeOptions) { + opts.user = user + } +} + +func noTLSSocks5HandshakeOption(noTLS bool) socks5HandshakeOption { + return func(opts *socks5HandshakeOptions) { + opts.noTLS = noTLS + } +} + +func socks5Handshake(conn net.Conn, opts ...socks5HandshakeOption) (net.Conn, error) { + options := socks5HandshakeOptions{} + for _, opt := range opts { + opt(&options) + } + selector := options.selector if selector == nil { cs := &clientSelector{ TLSConfig: &tls.Config{InsecureSkipVerify: true}, - User: user, + User: options.user, } cs.AddMethod( gosocks5.MethodNoAuth, gosocks5.MethodUserPass, - MethodTLS, ) + if !options.noTLS { + cs.AddMethod(MethodTLS) + } selector = cs } @@ -1835,7 +1881,7 @@ func socks5Handshake(conn net.Conn, selector gosocks5.Selector, user *url.Userin } type udpTunnelConn struct { - raddr string + raddr net.Addr net.Conn } @@ -1859,11 +1905,7 @@ func (c *udpTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { } func (c *udpTunnelConn) Write(b []byte) (n int, err error) { - addr, err := net.ResolveUDPAddr("udp", c.raddr) - if err != nil { - return - } - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(addr)), b) + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(c.raddr)), b) if err = dgram.Write(c.Conn); err != nil { return } diff --git a/tuntap.go b/tuntap.go index c48f650..bae285f 100644 --- a/tuntap.go +++ b/tuntap.go @@ -114,12 +114,14 @@ func (l *tunListener) Close() error { type tunHandler struct { options *HandlerOptions routes sync.Map + chExit chan struct{} } // TunHandler creates a handler for tun tunnel. func TunHandler(opts ...HandlerOption) Handler { h := &tunHandler{ options: &HandlerOptions{}, + chExit: make(chan struct{}, 1), } for _, opt := range opts { opt(h.options) @@ -141,44 +143,76 @@ func (h *tunHandler) Handle(conn net.Conn) { defer os.Exit(0) defer conn.Close() - laddr, raddr := h.options.Node.Addr, h.options.Node.Remote - var pc net.PacketConn var err error - if h.options.TCPMode { - if raddr != "" { - pc, err = tcpraw.Dial("tcp", raddr) - } else { - pc, err = tcpraw.Listen("tcp", laddr) + var raddr net.Addr + if addr := h.options.Node.Remote; addr != "" { + raddr, err = net.ResolveUDPAddr("udp", addr) + if err != nil { + log.Logf("[tun] %s: remote addr: %v", conn.LocalAddr(), err) + return } - } else { - addr, _ := net.ResolveUDPAddr("udp", laddr) - pc, err = net.ListenUDP("udp", addr) - } - if err != nil { - log.Logf("[tun] %s: %v", conn.LocalAddr(), err) - return } + var tempDelay time.Duration + for { + err := func() error { + var err error + var pc net.PacketConn + if raddr != nil && !h.options.Chain.IsEmpty() { + var cc net.Conn + cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) + pc = &udpTunnelConn{Conn: cc, raddr: raddr} + } else { + laddr, _ := net.ResolveUDPAddr("udp", h.options.Node.Addr) + pc, err = net.ListenUDP("udp", laddr) + } + if err != nil { + return err + } + + pc, err = h.initTunnelConn(pc) + if err != nil { + return err + } + + return h.transportTun(conn, pc, raddr) + }() + if err != nil { + log.Logf("[tun] %s: %v", conn.LocalAddr(), err) + } + + select { + case <-h.chExit: + return + default: + } + + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + tempDelay = 0 + } +} + +func (h *tunHandler) initTunnelConn(pc net.PacketConn) (net.PacketConn, error) { if len(h.options.Users) > 0 && h.options.Users[0] != nil { passwd, _ := h.options.Users[0].Password() cipher, err := core.PickCipher(h.options.Users[0].Username(), nil, passwd) if err != nil { - log.Logf("[tun] %s - %s cipher: %v", conn.LocalAddr(), pc.LocalAddr(), err) - return + return nil, err } pc = cipher.PacketConn(pc) } - - var ra net.Addr - if raddr != "" { - ra, err = net.ResolveUDPAddr("udp", raddr) - if err != nil { - log.Logf("[tun] %s - %s: remote addr: %v", conn.LocalAddr(), pc.LocalAddr(), err) - return - } - } - - h.transportTun(conn, pc, ra) + return pc, nil } func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.Addr) error { @@ -192,6 +226,10 @@ func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.A n, err := tun.Read(b) if err != nil { + select { + case h.chExit <- struct{}{}: + default: + } return err } @@ -323,6 +361,10 @@ func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.A } if _, err := tun.Write(b[:n]); err != nil { + select { + case h.chExit <- struct{}{}: + default: + } return err } return nil @@ -339,7 +381,6 @@ func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.A if err != nil && err == io.EOF { err = nil } - log.Logf("[tun] %s - %s: %v", tun.LocalAddr(), conn.LocalAddr(), err) return err }