diff --git a/gost/forward.go b/gost/forward.go index 2f3a19c..bd7f1a0 100644 --- a/gost/forward.go +++ b/gost/forward.go @@ -82,7 +82,7 @@ func (h *udpForwardHandler) Handle(conn net.Conn) { } } else { var err error - cc, err = h.getUDPTunnel() + cc, err = getSOCKS5UDPTunnel(h.options.Chain) if err != nil { log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) return @@ -97,40 +97,6 @@ func (h *udpForwardHandler) Handle(conn net.Conn) { log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), h.raddr) } -func (h *udpForwardHandler) getUDPTunnel() (net.Conn, error) { - conn, err := h.options.Chain.Conn() - if err != nil { - return nil, err - } - cc, err := socks5Handshake(conn, h.options.Chain.LastNode().User) - if err != nil { - conn.Close() - return nil, err - } - conn = cc - - conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err = gosocks5.NewRequest(CmdUDPTun, nil).Write(conn); err != nil { - conn.Close() - return nil, err - } - conn.SetWriteDeadline(time.Time{}) - - conn.SetReadDeadline(time.Now().Add(ReadTimeout)) - reply, err := gosocks5.ReadReply(conn) - if err != nil { - conn.Close() - return nil, err - } - conn.SetReadDeadline(time.Time{}) - - if reply.Rep != gosocks5.Succeeded { - conn.Close() - return nil, errors.New("UDP tunnel failure") - } - return conn, nil -} - type rtcpForwardHandler struct { raddr string options *HandlerOptions @@ -542,7 +508,7 @@ func (l *rtcpForwardListener) Accept() (net.Conn, error) { if max := 6 * time.Second; tempDelay > max { tempDelay = max } - log.Logf("[ssh-rtcp] Accept error: %v; retrying in %v", err, tempDelay) + log.Logf("[rtcp] Accept error: %v; retrying in %v", err, tempDelay) time.Sleep(tempDelay) continue } @@ -619,9 +585,9 @@ func (l *rtcpForwardListener) Close() error { } type rudpForwardListener struct { - addr net.Addr - chain *Chain - close chan struct{} + addr *net.UDPAddr + chain *Chain + closed chan struct{} } // RUDPForwardListener creates a Listener for UDP remote port forwarding server. @@ -632,57 +598,47 @@ func RUDPForwardListener(addr string, chain *Chain) (Listener, error) { } return &rudpForwardListener{ - addr: laddr, - chain: chain, - close: make(chan struct{}), + addr: laddr, + chain: chain, + closed: make(chan struct{}), }, nil } func (l *rudpForwardListener) Accept() (net.Conn, error) { select { - case <-l.close: + case <-l.closed: return nil, errors.New("closed") default: } - conn, err := l.chain.Conn() - if err != nil { - return nil, err + var tempDelay time.Duration + for { + conn, err := l.accept() + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + log.Logf("[rudp] Accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) + continue + } + return conn, nil } - cc, err := l.handshake(conn) - if err != nil { - conn.Close() - return nil, err - } - conn = cc - - return cc, nil } -func (l *rudpForwardListener) handshake(conn net.Conn) (net.Conn, error) { - conn, err := socks5Handshake(conn, l.chain.LastNode().User) - if err != nil { - return nil, err +func (l *rudpForwardListener) accept() (conn net.Conn, err error) { + lastNode := l.chain.LastNode() + if lastNode.Protocol == "socks5" { + conn, err = getSOCKS5UDPTunnel(l.chain) + } else { + conn, err = net.ListenUDP("udp", l.addr) } - req := gosocks5.NewRequest(CmdUDPTun, toSocksAddr(l.addr)) - if err := req.Write(conn); err != nil { - log.Log("[rudp] SOCKS5 UDP relay request: ", err) - return nil, err - } - - conn.SetReadDeadline(time.Now().Add(ReadTimeout)) - rep, err := gosocks5.ReadReply(conn) - if err != nil { - log.Log("[rudp] SOCKS5 UDP relay reply: ", err) - return nil, err - } - conn.SetReadDeadline(time.Time{}) - if rep.Rep != gosocks5.Succeeded { - log.Logf("[rudp] bind on %s failure: %d", l.addr, rep.Rep) - return nil, fmt.Errorf("Bind on %s failure", l.addr.String()) - } - log.Logf("[rudp] BIND ON %s OK", rep.Addr) - return conn, nil + return } func (l *rudpForwardListener) Addr() net.Addr { @@ -690,6 +646,6 @@ func (l *rudpForwardListener) Addr() net.Addr { } func (l *rudpForwardListener) Close() error { - close(l.close) + close(l.closed) return nil } diff --git a/gost/socks.go b/gost/socks.go index 9d8799d..cdc343f 100644 --- a/gost/socks.go +++ b/gost/socks.go @@ -872,23 +872,6 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { log.Logf("[socks5-udp] %s >-< %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) } -func socks5Handshake(conn net.Conn, user *url.Userinfo) (net.Conn, error) { - selector := &clientSelector{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - User: user, - } - selector.AddMethod( - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - MethodTLS, - ) - cc := gosocks5.ClientConn(conn, selector) - if err := cc.Handleshake(); err != nil { - return nil, err - } - return cc, nil -} - func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error) { errc := make(chan error, 2) @@ -1073,3 +1056,54 @@ func (h *socks4Handler) handleBind(conn net.Conn, req *gosocks4.Request) { transport(conn, cc) log.Logf("[socks4-bind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) } + +func getSOCKS5UDPTunnel(chain *Chain) (net.Conn, error) { + conn, err := chain.Conn() + if err != nil { + return nil, err + } + cc, err := socks5Handshake(conn, chain.LastNode().User) + if err != nil { + conn.Close() + return nil, err + } + conn = cc + + conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if err = gosocks5.NewRequest(CmdUDPTun, nil).Write(conn); err != nil { + conn.Close() + return nil, err + } + conn.SetWriteDeadline(time.Time{}) + + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) + reply, err := gosocks5.ReadReply(conn) + if err != nil { + conn.Close() + return nil, err + } + conn.SetReadDeadline(time.Time{}) + + if reply.Rep != gosocks5.Succeeded { + conn.Close() + return nil, errors.New("UDP tunnel failure") + } + return conn, nil +} + +func socks5Handshake(conn net.Conn, user *url.Userinfo) (net.Conn, error) { + selector := &clientSelector{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + User: user, + } + selector.AddMethod( + gosocks5.MethodNoAuth, + gosocks5.MethodUserPass, + MethodTLS, + ) + cc := gosocks5.ClientConn(conn, selector) + if err := cc.Handleshake(); err != nil { + return nil, err + } + return cc, nil +}