diff --git a/forward.go b/forward.go index 73e85b9..9e25fa3 100644 --- a/forward.go +++ b/forward.go @@ -361,6 +361,10 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) { } func (l *tcpRemoteForwardListener) isChainValid() bool { + if l.chain.IsEmpty() { + return false + } + lastNode := l.chain.LastNode() if (lastNode.Protocol == "forward" && lastNode.Transport == "ssh") || lastNode.Protocol == "socks5" || lastNode.Protocol == "" { @@ -429,7 +433,7 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { return l.chain.Dial(l.addr.String()) } - if lastNode.Protocol == "socks5" || lastNode.Protocol == "" { + if l.isChainValid() { if lastNode.GetBool("mbind") { return l.muxAccept() // multiplexing support for binding. } @@ -588,6 +592,8 @@ type udpRemoteForwardListener struct { ln *net.UDPConn ttl time.Duration closed chan struct{} + ready chan struct{} + once sync.Once closeMux sync.Mutex config *UDPListenConfig } @@ -613,18 +619,25 @@ func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPListenConfig) ( chain: chain, connMap: new(udpConnMap), connChan: make(chan net.Conn, backlog), + ready: make(chan struct{}), closed: make(chan struct{}), config: cfg, } go ln.listenLoop() + <-ln.ready + return ln, err } func (l *udpRemoteForwardListener) isChainValid() bool { + if l.chain.IsEmpty() { + return false + } + lastNode := l.chain.LastNode() - return lastNode.Protocol == "socks5" + return lastNode.Protocol == "socks5" || lastNode.Protocol == "" } func (l *udpRemoteForwardListener) listenLoop() { @@ -635,6 +648,10 @@ func (l *udpRemoteForwardListener) listenLoop() { return } + l.once.Do(func() { + close(l.ready) + }) + func() { defer conn.Close() @@ -691,8 +708,7 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { default: } - lastNode := l.chain.LastNode() - if lastNode.Protocol == "socks5" || lastNode.Protocol == "" { + if l.isChainValid() { var cc net.Conn cc, err = getSocks5UDPTunnel(l.chain, l.addr) if err != nil {