fix tests

This commit is contained in:
ginuerzh 2020-02-08 19:08:33 +08:00
parent 94dcfcab8c
commit ece79946b3

View File

@ -361,6 +361,10 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) {
} }
func (l *tcpRemoteForwardListener) isChainValid() bool { func (l *tcpRemoteForwardListener) isChainValid() bool {
if l.chain.IsEmpty() {
return false
}
lastNode := l.chain.LastNode() lastNode := l.chain.LastNode()
if (lastNode.Protocol == "forward" && lastNode.Transport == "ssh") || if (lastNode.Protocol == "forward" && lastNode.Transport == "ssh") ||
lastNode.Protocol == "socks5" || lastNode.Protocol == "" { lastNode.Protocol == "socks5" || lastNode.Protocol == "" {
@ -429,7 +433,7 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) {
return l.chain.Dial(l.addr.String()) return l.chain.Dial(l.addr.String())
} }
if lastNode.Protocol == "socks5" || lastNode.Protocol == "" { if l.isChainValid() {
if lastNode.GetBool("mbind") { if lastNode.GetBool("mbind") {
return l.muxAccept() // multiplexing support for binding. return l.muxAccept() // multiplexing support for binding.
} }
@ -588,6 +592,8 @@ type udpRemoteForwardListener struct {
ln *net.UDPConn ln *net.UDPConn
ttl time.Duration ttl time.Duration
closed chan struct{} closed chan struct{}
ready chan struct{}
once sync.Once
closeMux sync.Mutex closeMux sync.Mutex
config *UDPListenConfig config *UDPListenConfig
} }
@ -613,18 +619,25 @@ func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPListenConfig) (
chain: chain, chain: chain,
connMap: new(udpConnMap), connMap: new(udpConnMap),
connChan: make(chan net.Conn, backlog), connChan: make(chan net.Conn, backlog),
ready: make(chan struct{}),
closed: make(chan struct{}), closed: make(chan struct{}),
config: cfg, config: cfg,
} }
go ln.listenLoop() go ln.listenLoop()
<-ln.ready
return ln, err return ln, err
} }
func (l *udpRemoteForwardListener) isChainValid() bool { func (l *udpRemoteForwardListener) isChainValid() bool {
if l.chain.IsEmpty() {
return false
}
lastNode := l.chain.LastNode() lastNode := l.chain.LastNode()
return lastNode.Protocol == "socks5" return lastNode.Protocol == "socks5" || lastNode.Protocol == ""
} }
func (l *udpRemoteForwardListener) listenLoop() { func (l *udpRemoteForwardListener) listenLoop() {
@ -635,6 +648,10 @@ func (l *udpRemoteForwardListener) listenLoop() {
return return
} }
l.once.Do(func() {
close(l.ready)
})
func() { func() {
defer conn.Close() defer conn.Close()
@ -691,8 +708,7 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) {
default: default:
} }
lastNode := l.chain.LastNode() if l.isChainValid() {
if lastNode.Protocol == "socks5" || lastNode.Protocol == "" {
var cc net.Conn var cc net.Conn
cc, err = getSocks5UDPTunnel(l.chain, l.addr) cc, err = getSocks5UDPTunnel(l.chain, l.addr)
if err != nil { if err != nil {