rudp: fix panic

This commit is contained in:
ginuerzh 2020-02-08 17:27:30 +08:00
parent abe4043413
commit 94dcfcab8c
5 changed files with 80 additions and 57 deletions

View File

@ -9,7 +9,6 @@ import (
"net/url"
"os"
"strings"
"time"
"github.com/ginuerzh/gost"
"github.com/go-log/log"
@ -136,6 +135,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
wsOpts.UserAgent = node.Get("agent")
wsOpts.Path = node.Get("path")
timeout := node.GetDuration("timeout")
var host string
var tr gost.Transporter
@ -175,8 +176,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
config := &gost.QUICConfig{
TLSConfig: tlsCfg,
KeepAlive: node.GetBool("keepalive"),
Timeout: time.Duration(node.GetInt("timeout")) * time.Second,
IdleTimeout: time.Duration(node.GetInt("idle")) * time.Second,
Timeout: timeout,
IdleTimeout: node.GetDuration("idle"),
}
if cipher := node.Get("cipher"); cipher != "" {
@ -232,9 +233,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
connector = gost.AutoConnector(node.User)
}
timeout := node.GetInt("timeout")
node.DialOptions = append(node.DialOptions,
gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
gost.TimeoutDialOption(timeout),
)
node.ConnectOptions = []gost.ConnectOption{
@ -250,8 +250,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
gost.HostHandshakeOption(host),
gost.UserHandshakeOption(node.User),
gost.TLSConfigHandshakeOption(tlsCfg),
gost.IntervalHandshakeOption(time.Duration(node.GetInt("ping")) * time.Second),
gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second),
gost.IntervalHandshakeOption(node.GetDuration("ping")),
gost.TimeoutHandshakeOption(timeout),
gost.RetryHandshakeOption(node.GetInt("retry")),
}
node.Client = &gost.Client{
@ -339,10 +339,8 @@ func (r *route) GenRouters() ([]router, error) {
wsOpts.WriteBufferSize = node.GetInt("wbuf")
wsOpts.Path = node.Get("path")
ttl, err := time.ParseDuration(node.Get("ttl"))
if err != nil {
ttl = time.Duration(node.GetInt("ttl")) * time.Second
}
ttl := node.GetDuration("ttl")
timeout := node.GetDuration("timeout")
tunRoutes := parseIPRoutes(node.Get("route"))
gw := net.ParseIP(node.Get("gw")) // default gateway
@ -393,8 +391,8 @@ func (r *route) GenRouters() ([]router, error) {
config := &gost.QUICConfig{
TLSConfig: tlsCfg,
KeepAlive: node.GetBool("keepalive"),
Timeout: time.Duration(node.GetInt("timeout")) * time.Second,
IdleTimeout: time.Duration(node.GetInt("idle")) * time.Second,
Timeout: timeout,
IdleTimeout: node.GetDuration("idle"),
}
if cipher := node.Get("cipher"); cipher != "" {
sum := sha256.Sum256([]byte(cipher))
@ -555,7 +553,11 @@ func (r *route) GenRouters() ([]router, error) {
resolver := parseResolver(node.Get("dns"))
if resolver != nil {
resolver.Init(gost.ChainResolverOption(chain))
resolver.Init(
gost.ChainResolverOption(chain),
gost.TimeoutResolverOption(timeout),
gost.TTLResolverOption(ttl),
)
}
handler.Init(
@ -573,7 +575,7 @@ func (r *route) GenRouters() ([]router, error) {
gost.ResolverHandlerOption(resolver),
gost.HostsHandlerOption(hosts),
gost.RetryHandlerOption(node.GetInt("retry")), // override the global retry option.
gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second),
gost.TimeoutHandlerOption(timeout),
gost.ProbeResistHandlerOption(node.Get("probe_resist")),
gost.KnockingHandlerOption(node.Get("knock")),
gost.NodeHandlerOption(node),

View File

@ -334,7 +334,6 @@ type tcpRemoteForwardListener struct {
sessionMux sync.Mutex
closed chan struct{}
closeMux sync.Mutex
errChan chan error
}
// TCPRemoteForwardListener creates a Listener for TCP remote port forwarding server.
@ -349,7 +348,6 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) {
chain: chain,
connChan: make(chan net.Conn, 1024),
closed: make(chan struct{}),
errChan: make(chan error),
}
if !ln.isChainValid() {
@ -365,7 +363,7 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) {
func (l *tcpRemoteForwardListener) isChainValid() bool {
lastNode := l.chain.LastNode()
if (lastNode.Protocol == "forward" && lastNode.Transport == "ssh") ||
lastNode.Protocol == "socks5" {
lastNode.Protocol == "socks5" || lastNode.Protocol == "" {
return true
}
return false
@ -431,7 +429,7 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) {
return l.chain.Dial(l.addr.String())
}
if lastNode.Protocol == "socks5" {
if lastNode.Protocol == "socks5" || lastNode.Protocol == "" {
if lastNode.GetBool("mbind") {
return l.muxAccept() // multiplexing support for binding.
}
@ -588,11 +586,9 @@ type udpRemoteForwardListener struct {
connMap *udpConnMap
connChan chan net.Conn
ln *net.UDPConn
errChan chan error
ttl time.Duration
closed chan struct{}
closeMux sync.Mutex
once sync.Once
config *UDPListenConfig
}
@ -617,15 +613,12 @@ func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPListenConfig) (
chain: chain,
connMap: new(udpConnMap),
connChan: make(chan net.Conn, backlog),
errChan: make(chan error, 1),
closed: make(chan struct{}),
config: cfg,
}
go ln.listenLoop()
err = <-ln.errChan
return ln, err
}
@ -699,7 +692,7 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) {
}
lastNode := l.chain.LastNode()
if lastNode.Protocol == "socks5" {
if lastNode.Protocol == "socks5" || lastNode.Protocol == "" {
var cc net.Conn
cc, err = getSocks5UDPTunnel(l.chain, l.addr)
if err != nil {
@ -716,11 +709,6 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) {
}
}
l.once.Do(func() {
l.errChan <- err
close(l.errChan)
})
if err != nil {
if tempDelay == 0 {
tempDelay = 1000 * time.Millisecond

View File

@ -157,13 +157,16 @@ func (node *Node) GetBool(key string) bool {
// GetInt converts node parameter value to int.
func (node *Node) GetInt(key string) int {
n, _ := strconv.Atoi(node.Values.Get(key))
n, _ := strconv.Atoi(node.Get(key))
return n
}
// GetDuration converts node parameter value to time.Duration.
func (node *Node) GetDuration(key string) time.Duration {
d, _ := time.ParseDuration(node.Values.Get(key))
d, err := time.ParseDuration(node.Get(key))
if err != nil {
d = time.Duration(node.GetInt(key)) * time.Second
}
return d
}

View File

@ -120,6 +120,8 @@ func (ns *NameServer) String() string {
type resolverOptions struct {
chain *Chain
timeout time.Duration
ttl time.Duration
}
// ResolverOption allows a common way to set Resolver options.
@ -132,6 +134,20 @@ func ChainResolverOption(chain *Chain) ResolverOption {
}
}
// TimeoutResolverOption sets the timeout for Resolver.
func TimeoutResolverOption(timeout time.Duration) ResolverOption {
return func(opts *resolverOptions) {
opts.timeout = timeout
}
}
// TTLResolverOption sets the timeout for Resolver.
func TTLResolverOption(ttl time.Duration) ResolverOption {
return func(opts *resolverOptions) {
opts.ttl = ttl
}
}
// Resolver is a name resolver for domain name.
// It contains a list of name servers.
type Resolver interface {
@ -191,10 +207,17 @@ func (r *resolver) Init(opts ...ResolverOption) error {
}
timeout := r.timeout
if r.options.timeout != 0 {
timeout = r.options.timeout
}
if timeout <= 0 {
timeout = DefaultResolverTimeout
}
if r.options.ttl != 0 {
r.ttl = r.options.ttl
}
var nss []NameServer
for _, ns := range r.servers {
if err := ns.Init( // init all name servers
@ -279,7 +302,7 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error)
}
func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
mr, err := r.exchangeMsg(ctx, ex, mq)
mr, _, err := r.exchangeMsg(ctx, ex, mq)
if err != nil {
return
}
@ -302,12 +325,20 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er
return
}
var qs string
if len(mq.Question) > 0 {
qs = mq.Question[0].String()
}
var mr *dns.Msg
for _, ns := range r.copyServers() {
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
var cache bool
mr, cache, err = r.exchangeMsg(ctx, ns.exchanger, mq)
log.Logf("[dns] exchange message %d via %s (cache hit: %v): %s", mq.Id, ns.String(), cache, qs)
if err == nil {
break
}
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), err)
}
if err != nil {
return
@ -315,12 +346,13 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er
return mr.Pack()
}
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, cache bool, err error) {
// Only cache for single question.
if len(mq.Question) == 1 {
key := newResolverCacheKey(&mq.Question[0])
mr = r.cache.loadCache(key)
if mr != nil {
cache = true
mr.Id = mq.Id
return
}
@ -340,9 +372,7 @@ func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (
}
mr = &dns.Msg{}
if err = mr.Unpack(reply); err != nil {
return nil, err
}
err = mr.Unpack(reply)
return
}

View File

@ -1400,14 +1400,14 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) {
addr := req.Addr.String()
if !Can("rudp", addr, h.options.Whitelist, h.options.Blacklist) {
log.Logf("[socks5-udp] Unauthorized to udp bind to %s", addr)
log.Logf("[socks5] udp-tun Unauthorized to udp bind to %s", addr)
return
}
bindAddr, _ := net.ResolveUDPAddr("udp", addr)
uc, err := net.ListenUDP("udp", bindAddr)
if err != nil {
log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err)
log.Logf("[socks5] udp-tun %s -> %s : %s", conn.RemoteAddr(), req.Addr, err)
return
}
defer uc.Close()
@ -1416,32 +1416,32 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) {
socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr)
if err := reply.Write(conn); err != nil {
log.Logf("[socks5-udp] %s <- %s : %s", conn.RemoteAddr(), socksAddr, err)
log.Logf("[socks5] udp-tun %s <- %s : %s", conn.RemoteAddr(), socksAddr, err)
return
}
if Debug {
log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), socksAddr, reply)
log.Logf("[socks5] udp-tun %s <- %s\n%s", conn.RemoteAddr(), socksAddr, reply)
}
log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr)
log.Logf("[socks5] udp-tun %s <-> %s", conn.RemoteAddr(), socksAddr)
h.tunnelServerUDP(conn, uc)
log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr)
log.Logf("[socks5] udp-tun %s >-< %s", conn.RemoteAddr(), socksAddr)
return
}
cc, err := h.options.Chain.Conn()
// connection error
if err != nil {
log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err)
log.Logf("[socks5] udp-tun %s -> %s : %s", conn.RemoteAddr(), req.Addr, err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn)
log.Logf("[socks5-udp] %s -> %s\n%s", conn.RemoteAddr(), req.Addr, reply)
log.Logf("[socks5] udp-tun %s -> %s\n%s", conn.RemoteAddr(), req.Addr, reply)
return
}
defer cc.Close()
cc, err = socks5Handshake(cc, userSocks5HandshakeOption(h.options.Chain.LastNode().User))
if err != nil {
log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err)
log.Logf("[socks5] udp-tun %s -> %s : %s", conn.RemoteAddr(), req.Addr, err)
return
}
// tunnel <-> tunnel, direct forwarding
@ -1449,9 +1449,9 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) {
// so we don't need to authenticate it, as it's as explicit as whitelisting
req.Write(cc)
log.Logf("[socks5-udp] %s <-> %s [tun]", conn.RemoteAddr(), cc.RemoteAddr())
log.Logf("[socks5] udp-tun %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr())
transport(conn, cc)
log.Logf("[socks5-udp] %s >-< %s [tun]", conn.RemoteAddr(), cc.RemoteAddr())
log.Logf("[socks5] udp-tun %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr())
}
func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err error) {
@ -1469,7 +1469,7 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err err
return
}
if h.options.Bypass.Contains(addr.String()) {
log.Log("[udp-tun] [bypass] read from", addr)
log.Log("[socks5] udp-tun bypass read from", addr)
continue // bypass
}
@ -1477,12 +1477,12 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err err
dgram := gosocks5.NewUDPDatagram(
gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n])
if err := dgram.Write(cc); err != nil {
log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err)
log.Logf("[socks5] udp-tun %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err)
errc <- err
return
}
if Debug {
log.Logf("[udp-tun] %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data))
log.Logf("[socks5] udp-tun %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data))
}
}
}()
@ -1502,16 +1502,16 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err err
continue // drop silently
}
if h.options.Bypass.Contains(addr.String()) {
log.Log("[udp-tun] [bypass] write to", addr)
log.Log("[socks5] udp-tun bypass write to", addr)
continue // bypass
}
if _, err := pc.WriteTo(dgram.Data, addr); err != nil {
log.Logf("[udp-tun] %s -> %s : %s", cc.RemoteAddr(), addr, err)
log.Logf("[socks5] udp-tun %s -> %s : %s", cc.RemoteAddr(), addr, err)
errc <- err
return
}
if Debug {
log.Logf("[udp-tun] %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data))
log.Logf("[socks5] udp-tun %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data))
}
}
}()
@ -1895,7 +1895,7 @@ func getSocks5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) {
if err != nil {
c.Close()
}
return conn, nil
return conn, err
}
type socks5UDPTunnelConn struct {