From 482d2857f62ceb73e9b49f76256d3c6d7b6ee14c Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 19 May 2018 19:17:53 +0800 Subject: [PATCH] add cache support for DNS resolver --- chain.go | 10 +++---- cmd/gost/cfg.go | 14 +++++++--- cmd/gost/dns.txt | 3 +++ cmd/gost/main.go | 4 ++- resolver.go | 70 +++++++++++++++++++++++++++++++++++++++++++----- 5 files changed, 84 insertions(+), 17 deletions(-) diff --git a/chain.go b/chain.go index 5ca322e..1e553db 100644 --- a/chain.go +++ b/chain.go @@ -130,11 +130,7 @@ func (c *Chain) dial(addr string) (net.Conn, error) { if err == nil { addrs, er := c.Resolver.Resolve(host) if er != nil { - log.Logf("[resolver] %s: %v", addr, er) - return nil, er - } - if Debug { - log.Logf("[resolver] %s %v", addr, addrs) + log.Logf("[resolver] %s: %v", host, er) } if len(addrs) > 0 { addr = net.JoinHostPort(addrs[0].IP.String(), port) @@ -280,7 +276,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { if node.Bypass.Contains(addr) { if Debug { buf.WriteString(fmt.Sprintf("[%d@bypass: %s]", node.ID, addr)) - log.Log("select route:", buf.String()) + log.Log("[route]", buf.String()) } return } @@ -301,7 +297,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { route.Resolver = c.Resolver if Debug { - log.Log("select route:", buf.String()) + log.Log("[route]", buf.String()) } return } diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index 7793df0..1b758dc 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -266,6 +266,7 @@ func parseResolver(cfg string) gost.Resolver { return nil } timeout := 30 * time.Second + ttl := 60 * time.Second var nss []gost.NameServer f, err := os.Open(cfg) @@ -288,7 +289,7 @@ func parseResolver(cfg string) gost.Resolver { }) } } - return gost.NewResolver(nss, timeout) + return gost.NewResolver(nss, timeout, ttl) } scanner := bufio.NewScanner(f) @@ -312,7 +313,7 @@ func parseResolver(cfg string) gost.Resolver { continue } - if ss[0] == "timeout" { + if strings.ToLower(ss[0]) == "timeout" { if len(ss) >= 2 { if n, _ := strconv.Atoi(ss[1]); n > 0 { timeout = time.Second * time.Duration(n) @@ -320,6 +321,13 @@ func parseResolver(cfg string) gost.Resolver { } continue } + if strings.ToLower(ss[0]) == "ttl" { + if len(ss) >= 2 { + n, _ := strconv.Atoi(ss[1]) + ttl = time.Second * time.Duration(n) + } + continue + } var ns gost.NameServer switch len(ss) { @@ -335,5 +343,5 @@ func parseResolver(cfg string) gost.Resolver { } nss = append(nss, ns) } - return gost.NewResolver(nss, timeout) + return gost.NewResolver(nss, timeout, ttl) } diff --git a/cmd/gost/dns.txt b/cmd/gost/dns.txt index d06d899..d2ddc45 100644 --- a/cmd/gost/dns.txt +++ b/cmd/gost/dns.txt @@ -1,6 +1,9 @@ # resolver timeout, default 30s. timeout 10 +# resolver cache TTL, default 60s, minus value means that cache is disabled. +ttl 300 + # ip[:port] [protocol] [hostname] 1.1.1.1:853 tls cloudflare-dns.com diff --git a/cmd/gost/main.go b/cmd/gost/main.go index fbf7321..46eb902 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -506,7 +506,9 @@ func (r *route) serve() error { ) chain.Resolver = parseResolver(node.Get("dns")) - + if gost.Debug { + log.Logf("[resolver]\n%v", chain.Resolver) + } go srv.Serve(handler) } diff --git a/resolver.go b/resolver.go index 8c4bf09..b1f26f8 100644 --- a/resolver.go +++ b/resolver.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "strings" + "sync" "time" "github.com/go-log/log" @@ -15,6 +16,8 @@ import ( var ( // DefaultResolverTimeout is the default timeout for name resolution. DefaultResolverTimeout = 30 * time.Second + // DefaultResolverTTL is the default cache TTL for name resolution. + DefaultResolverTTL = 60 * time.Second ) // Resolver is a name resolver for domain name. @@ -32,7 +35,7 @@ type NameServer struct { Hostname string // for TLS handshake verification } -func (ns *NameServer) String() string { +func (ns NameServer) String() string { addr := ns.Addr prot := ns.Protocol host := ns.Hostname @@ -45,23 +48,39 @@ func (ns *NameServer) String() string { return fmt.Sprintf("%s/%s %s", addr, prot, host) } +type resolverCacheItem struct { + IPAddrs []net.IPAddr + ts int64 +} + type resolver struct { Resolver *net.Resolver Servers []NameServer Timeout time.Duration + TTL time.Duration + mCache *sync.Map } // NewResolver create a new Resolver with the given name servers and resolution timeout. -func NewResolver(servers []NameServer, timeout time.Duration) Resolver { +func NewResolver(servers []NameServer, timeout, ttl time.Duration) Resolver { r := &resolver{ Servers: servers, Timeout: timeout, + TTL: ttl, + mCache: &sync.Map{}, } r.init() return r } func (r *resolver) init() { + if r.Timeout <= 0 { + r.Timeout = DefaultResolverTimeout + } + if r.TTL == 0 { + r.TTL = DefaultResolverTTL + } + r.Resolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) { @@ -106,15 +125,53 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) { } } -func (r *resolver) Resolve(name string) ([]net.IPAddr, error) { +func (r *resolver) Resolve(name string) (addrs []net.IPAddr, err error) { timeout := r.Timeout - if timeout <= 0 { - timeout = DefaultResolverTimeout + + addrs = r.loadCache(name) + if len(addrs) > 0 { + if Debug { + log.Logf("[resolver] cache hit: %s %v", name, addrs) + } + return } ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - return r.Resolver.LookupIPAddr(ctx, name) + addrs, err = r.Resolver.LookupIPAddr(ctx, name) + r.storeCache(name, addrs) + if len(addrs) > 0 && Debug { + log.Logf("[resolver] %s %v", name, addrs) + } + return +} + +func (r *resolver) loadCache(name string) []net.IPAddr { + ttl := r.TTL + if ttl < 0 { + return nil + } + + if v, ok := r.mCache.Load(name); ok { + item, _ := v.(*resolverCacheItem) + if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl { + return nil + } + return item.IPAddrs + } + + return nil +} + +func (r *resolver) storeCache(name string, addrs []net.IPAddr) { + ttl := r.TTL + if ttl < 0 || name == "" || len(addrs) == 0 { + return + } + r.mCache.Store(name, &resolverCacheItem{ + IPAddrs: addrs, + ts: time.Now().Unix(), + }) } func (r *resolver) String() string { @@ -124,6 +181,7 @@ func (r *resolver) String() string { b := &bytes.Buffer{} fmt.Fprintf(b, "timeout %v\n", r.Timeout) + fmt.Fprintf(b, "ttl %v\n", r.TTL) for i := range r.Servers { fmt.Fprintln(b, r.Servers[i]) }