diff --git a/cmd/gost/route.go b/cmd/gost/route.go index d6a4f18..2a8a6a4 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -9,7 +9,6 @@ import ( "net/url" "os" "strings" - "time" "github.com/ginuerzh/gost" "github.com/go-log/log" @@ -136,6 +135,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { wsOpts.UserAgent = node.Get("agent") wsOpts.Path = node.Get("path") + timeout := node.GetDuration("timeout") + var host string var tr gost.Transporter @@ -175,8 +176,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { config := &gost.QUICConfig{ TLSConfig: tlsCfg, KeepAlive: node.GetBool("keepalive"), - Timeout: time.Duration(node.GetInt("timeout")) * time.Second, - IdleTimeout: time.Duration(node.GetInt("idle")) * time.Second, + Timeout: timeout, + IdleTimeout: node.GetDuration("idle"), } if cipher := node.Get("cipher"); cipher != "" { @@ -232,9 +233,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { connector = gost.AutoConnector(node.User) } - timeout := node.GetInt("timeout") node.DialOptions = append(node.DialOptions, - gost.TimeoutDialOption(time.Duration(timeout)*time.Second), + gost.TimeoutDialOption(timeout), ) node.ConnectOptions = []gost.ConnectOption{ @@ -250,8 +250,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { gost.HostHandshakeOption(host), gost.UserHandshakeOption(node.User), gost.TLSConfigHandshakeOption(tlsCfg), - gost.IntervalHandshakeOption(time.Duration(node.GetInt("ping")) * time.Second), - gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second), + gost.IntervalHandshakeOption(node.GetDuration("ping")), + gost.TimeoutHandshakeOption(timeout), gost.RetryHandshakeOption(node.GetInt("retry")), } node.Client = &gost.Client{ @@ -339,10 +339,8 @@ func (r *route) GenRouters() ([]router, error) { wsOpts.WriteBufferSize = node.GetInt("wbuf") wsOpts.Path = node.Get("path") - ttl, err := time.ParseDuration(node.Get("ttl")) - if err != nil { - ttl = time.Duration(node.GetInt("ttl")) * time.Second - } + ttl := node.GetDuration("ttl") + timeout := node.GetDuration("timeout") tunRoutes := parseIPRoutes(node.Get("route")) gw := net.ParseIP(node.Get("gw")) // default gateway @@ -393,8 +391,8 @@ func (r *route) GenRouters() ([]router, error) { config := &gost.QUICConfig{ TLSConfig: tlsCfg, KeepAlive: node.GetBool("keepalive"), - Timeout: time.Duration(node.GetInt("timeout")) * time.Second, - IdleTimeout: time.Duration(node.GetInt("idle")) * time.Second, + Timeout: timeout, + IdleTimeout: node.GetDuration("idle"), } if cipher := node.Get("cipher"); cipher != "" { sum := sha256.Sum256([]byte(cipher)) @@ -555,7 +553,11 @@ func (r *route) GenRouters() ([]router, error) { resolver := parseResolver(node.Get("dns")) if resolver != nil { - resolver.Init(gost.ChainResolverOption(chain)) + resolver.Init( + gost.ChainResolverOption(chain), + gost.TimeoutResolverOption(timeout), + gost.TTLResolverOption(ttl), + ) } handler.Init( @@ -573,7 +575,7 @@ func (r *route) GenRouters() ([]router, error) { gost.ResolverHandlerOption(resolver), gost.HostsHandlerOption(hosts), gost.RetryHandlerOption(node.GetInt("retry")), // override the global retry option. - gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second), + gost.TimeoutHandlerOption(timeout), gost.ProbeResistHandlerOption(node.Get("probe_resist")), gost.KnockingHandlerOption(node.Get("knock")), gost.NodeHandlerOption(node), diff --git a/forward.go b/forward.go index f11ae34..73e85b9 100644 --- a/forward.go +++ b/forward.go @@ -334,7 +334,6 @@ type tcpRemoteForwardListener struct { sessionMux sync.Mutex closed chan struct{} closeMux sync.Mutex - errChan chan error } // TCPRemoteForwardListener creates a Listener for TCP remote port forwarding server. @@ -349,7 +348,6 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) { chain: chain, connChan: make(chan net.Conn, 1024), closed: make(chan struct{}), - errChan: make(chan error), } if !ln.isChainValid() { @@ -365,7 +363,7 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) { func (l *tcpRemoteForwardListener) isChainValid() bool { lastNode := l.chain.LastNode() if (lastNode.Protocol == "forward" && lastNode.Transport == "ssh") || - lastNode.Protocol == "socks5" { + lastNode.Protocol == "socks5" || lastNode.Protocol == "" { return true } return false @@ -431,7 +429,7 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { return l.chain.Dial(l.addr.String()) } - if lastNode.Protocol == "socks5" { + if lastNode.Protocol == "socks5" || lastNode.Protocol == "" { if lastNode.GetBool("mbind") { return l.muxAccept() // multiplexing support for binding. } @@ -588,11 +586,9 @@ type udpRemoteForwardListener struct { connMap *udpConnMap connChan chan net.Conn ln *net.UDPConn - errChan chan error ttl time.Duration closed chan struct{} closeMux sync.Mutex - once sync.Once config *UDPListenConfig } @@ -617,15 +613,12 @@ func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPListenConfig) ( chain: chain, connMap: new(udpConnMap), connChan: make(chan net.Conn, backlog), - errChan: make(chan error, 1), closed: make(chan struct{}), config: cfg, } go ln.listenLoop() - err = <-ln.errChan - return ln, err } @@ -699,7 +692,7 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { } lastNode := l.chain.LastNode() - if lastNode.Protocol == "socks5" { + if lastNode.Protocol == "socks5" || lastNode.Protocol == "" { var cc net.Conn cc, err = getSocks5UDPTunnel(l.chain, l.addr) if err != nil { @@ -716,11 +709,6 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { } } - l.once.Do(func() { - l.errChan <- err - close(l.errChan) - }) - if err != nil { if tempDelay == 0 { tempDelay = 1000 * time.Millisecond diff --git a/node.go b/node.go index 0d45fb2..336ac15 100644 --- a/node.go +++ b/node.go @@ -157,13 +157,16 @@ func (node *Node) GetBool(key string) bool { // GetInt converts node parameter value to int. func (node *Node) GetInt(key string) int { - n, _ := strconv.Atoi(node.Values.Get(key)) + n, _ := strconv.Atoi(node.Get(key)) return n } // GetDuration converts node parameter value to time.Duration. func (node *Node) GetDuration(key string) time.Duration { - d, _ := time.ParseDuration(node.Values.Get(key)) + d, err := time.ParseDuration(node.Get(key)) + if err != nil { + d = time.Duration(node.GetInt(key)) * time.Second + } return d } diff --git a/resolver.go b/resolver.go index 99529a0..e7b3f91 100644 --- a/resolver.go +++ b/resolver.go @@ -119,7 +119,9 @@ func (ns *NameServer) String() string { } type resolverOptions struct { - chain *Chain + chain *Chain + timeout time.Duration + ttl time.Duration } // ResolverOption allows a common way to set Resolver options. @@ -132,6 +134,20 @@ func ChainResolverOption(chain *Chain) ResolverOption { } } +// TimeoutResolverOption sets the timeout for Resolver. +func TimeoutResolverOption(timeout time.Duration) ResolverOption { + return func(opts *resolverOptions) { + opts.timeout = timeout + } +} + +// TTLResolverOption sets the timeout for Resolver. +func TTLResolverOption(ttl time.Duration) ResolverOption { + return func(opts *resolverOptions) { + opts.ttl = ttl + } +} + // Resolver is a name resolver for domain name. // It contains a list of name servers. type Resolver interface { @@ -191,10 +207,17 @@ func (r *resolver) Init(opts ...ResolverOption) error { } timeout := r.timeout + if r.options.timeout != 0 { + timeout = r.options.timeout + } if timeout <= 0 { timeout = DefaultResolverTimeout } + if r.options.ttl != 0 { + r.ttl = r.options.ttl + } + var nss []NameServer for _, ns := range r.servers { if err := ns.Init( // init all name servers @@ -279,7 +302,7 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) } func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { - mr, err := r.exchangeMsg(ctx, ex, mq) + mr, _, err := r.exchangeMsg(ctx, ex, mq) if err != nil { return } @@ -302,12 +325,20 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er return } + var qs string + if len(mq.Question) > 0 { + qs = mq.Question[0].String() + } + var mr *dns.Msg for _, ns := range r.copyServers() { - mr, err = r.exchangeMsg(ctx, ns.exchanger, mq) + var cache bool + mr, cache, err = r.exchangeMsg(ctx, ns.exchanger, mq) + log.Logf("[dns] exchange message %d via %s (cache hit: %v): %s", mq.Id, ns.String(), cache, qs) if err == nil { break } + log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), err) } if err != nil { return @@ -315,12 +346,13 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er return mr.Pack() } -func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { +func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, cache bool, err error) { // Only cache for single question. if len(mq.Question) == 1 { key := newResolverCacheKey(&mq.Question[0]) mr = r.cache.loadCache(key) if mr != nil { + cache = true mr.Id = mq.Id return } @@ -340,9 +372,7 @@ func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) ( } mr = &dns.Msg{} - if err = mr.Unpack(reply); err != nil { - return nil, err - } + err = mr.Unpack(reply) return } diff --git a/socks.go b/socks.go index acd3a39..39a7a9e 100644 --- a/socks.go +++ b/socks.go @@ -1400,14 +1400,14 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { addr := req.Addr.String() if !Can("rudp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks5-udp] Unauthorized to udp bind to %s", addr) + log.Logf("[socks5] udp-tun Unauthorized to udp bind to %s", addr) return } bindAddr, _ := net.ResolveUDPAddr("udp", addr) uc, err := net.ListenUDP("udp", bindAddr) if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + log.Logf("[socks5] udp-tun %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) return } defer uc.Close() @@ -1416,32 +1416,32 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) if err := reply.Write(conn); err != nil { - log.Logf("[socks5-udp] %s <- %s : %s", conn.RemoteAddr(), socksAddr, err) + log.Logf("[socks5] udp-tun %s <- %s : %s", conn.RemoteAddr(), socksAddr, err) return } if Debug { - log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), socksAddr, reply) + log.Logf("[socks5] udp-tun %s <- %s\n%s", conn.RemoteAddr(), socksAddr, reply) } - log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr) + log.Logf("[socks5] udp-tun %s <-> %s", conn.RemoteAddr(), socksAddr) h.tunnelServerUDP(conn, uc) - log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr) + log.Logf("[socks5] udp-tun %s >-< %s", conn.RemoteAddr(), socksAddr) return } cc, err := h.options.Chain.Conn() // connection error if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + log.Logf("[socks5] udp-tun %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) reply := gosocks5.NewReply(gosocks5.Failure, nil) reply.Write(conn) - log.Logf("[socks5-udp] %s -> %s\n%s", conn.RemoteAddr(), req.Addr, reply) + log.Logf("[socks5] udp-tun %s -> %s\n%s", conn.RemoteAddr(), req.Addr, reply) return } defer cc.Close() cc, err = socks5Handshake(cc, userSocks5HandshakeOption(h.options.Chain.LastNode().User)) if err != nil { - log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + log.Logf("[socks5] udp-tun %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) return } // tunnel <-> tunnel, direct forwarding @@ -1449,9 +1449,9 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { // so we don't need to authenticate it, as it's as explicit as whitelisting req.Write(cc) - log.Logf("[socks5-udp] %s <-> %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) + log.Logf("[socks5] udp-tun %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) transport(conn, cc) - log.Logf("[socks5-udp] %s >-< %s [tun]", conn.RemoteAddr(), cc.RemoteAddr()) + log.Logf("[socks5] udp-tun %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) } func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err error) { @@ -1469,7 +1469,7 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err err return } if h.options.Bypass.Contains(addr.String()) { - log.Log("[udp-tun] [bypass] read from", addr) + log.Log("[socks5] udp-tun bypass read from", addr) continue // bypass } @@ -1477,12 +1477,12 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err err dgram := gosocks5.NewUDPDatagram( gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n]) if err := dgram.Write(cc); err != nil { - log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err) + log.Logf("[socks5] udp-tun %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err) errc <- err return } if Debug { - log.Logf("[udp-tun] %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data)) + log.Logf("[socks5] udp-tun %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data)) } } }() @@ -1502,16 +1502,16 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err err continue // drop silently } if h.options.Bypass.Contains(addr.String()) { - log.Log("[udp-tun] [bypass] write to", addr) + log.Log("[socks5] udp-tun bypass write to", addr) continue // bypass } if _, err := pc.WriteTo(dgram.Data, addr); err != nil { - log.Logf("[udp-tun] %s -> %s : %s", cc.RemoteAddr(), addr, err) + log.Logf("[socks5] udp-tun %s -> %s : %s", cc.RemoteAddr(), addr, err) errc <- err return } if Debug { - log.Logf("[udp-tun] %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data)) + log.Logf("[socks5] udp-tun %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data)) } } }() @@ -1895,7 +1895,7 @@ func getSocks5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { if err != nil { c.Close() } - return conn, nil + return conn, err } type socks5UDPTunnelConn struct {