From 99b141e5bea609219146b14e3b4ea61136573fe7 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 15 Jan 2020 22:30:37 +0800 Subject: [PATCH] add cache for dns --- dns.go | 21 ++++- gost.go | 2 +- resolver.go | 262 ++++++++++++++++++++++++++++++---------------------- 3 files changed, 175 insertions(+), 110 deletions(-) diff --git a/dns.go b/dns.go index fc29dc6..ae87b21 100644 --- a/dns.go +++ b/dns.go @@ -12,6 +12,20 @@ import ( "github.com/miekg/dns" ) +var ( + defaultResolver Resolver +) + +func init() { + defaultResolver = NewResolver( + DefaultResolverTimeout, + NameServer{ + Addr: "127.0.0.1:53", + Protocol: "udp", + }) + defaultResolver.Init() +} + type dnsHandler struct { options *HandlerOptions } @@ -58,7 +72,12 @@ func (h *dnsHandler) Handle(conn net.Conn) { } start := time.Now() - reply, err := h.options.Resolver.Exchange(context.Background(), b[:n]) + + resolver := h.options.Resolver + if resolver == nil { + resolver = defaultResolver + } + reply, err := resolver.Exchange(context.Background(), b[:n]) if err != nil { log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err) return diff --git a/gost.go b/gost.go index bf4f825..1527745 100644 --- a/gost.go +++ b/gost.go @@ -20,7 +20,7 @@ import ( ) // Version is the gost version. -const Version = "2.9.2" +const Version = "2.10.0-dev" // Debug is a flag that enables the debug log. var Debug bool diff --git a/resolver.go b/resolver.go index 73f6951..c8d3e65 100644 --- a/resolver.go +++ b/resolver.go @@ -149,12 +149,12 @@ type ReloadResolver interface { } type resolver struct { - Servers []NameServer - mCache *sync.Map - TTL time.Duration + servers []NameServer + ttl time.Duration timeout time.Duration period time.Duration domain string + cache *resolverCache stopped chan struct{} mux sync.RWMutex prefer string // ipv4 or ipv6 @@ -169,9 +169,8 @@ func NewResolver(ttl time.Duration, servers ...NameServer) ReloadResolver { func newResolver(ttl time.Duration, servers ...NameServer) *resolver { return &resolver{ - Servers: servers, - TTL: ttl, - mCache: &sync.Map{}, + servers: servers, + cache: newResolverCache(ttl), stopped: make(chan struct{}), } } @@ -189,7 +188,7 @@ func (r *resolver) Init(opts ...ResolverOption) error { } var nss []NameServer - for _, ns := range r.Servers { + for _, ns := range r.servers { if err := ns.Init( // init all name servers ChainNameServerOption(r.options.chain), TimeoutNameServerOption(r.timeout), @@ -199,33 +198,26 @@ func (r *resolver) Init(opts ...ResolverOption) error { nss = append(nss, ns) } - r.Servers = nss + r.servers = nss return nil } func (r *resolver) copyServers() []NameServer { - var servers []NameServer - for i := range r.Servers { - servers = append(servers, r.Servers[i]) + r.mux.RLock() + defer r.mux.RUnlock() + + servers := make([]NameServer, len(r.servers)) + for i := range r.servers { + servers[i] = r.servers[i] } return servers } func (r *resolver) Resolve(host string) (ips []net.IP, err error) { - if r == nil { - return - } - - var domain string - var ttl time.Duration - var servers []NameServer - r.mux.RLock() - domain = r.domain - ttl = r.TTL - servers = r.copyServers() + domain := r.domain r.mux.RUnlock() if ip := net.ParseIP(host); ip != nil { @@ -235,140 +227,124 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { if !strings.Contains(host, ".") && domain != "" { host = host + "." + domain } - ips = r.loadCache(host, ttl) - if len(ips) > 0 { - if Debug { - log.Logf("[resolver] cache hit %s: %v", host, ips) - } - return - } - for _, ns := range servers { - ips, ttl, err = r.resolve(ns.exchanger, host) + for _, ns := range r.copyServers() { + ips, err = r.resolve(ns.exchanger, host) if err != nil { log.Logf("[resolver] %s via %s : %s", host, ns.String(), err) continue } if Debug { - log.Logf("[resolver] %s via %s %v(ttl: %v)", host, ns.String(), ips, ttl) + log.Logf("[resolver] %s via %s %v", host, ns.String(), ips) } if len(ips) > 0 { break } } - r.storeCache(host, ips, ttl) return } -func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, ttl time.Duration, err error) { +func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) { if ex == nil { return } + r.mux.RLock() + prefer := r.prefer + r.mux.RUnlock() + + prefer = "ipv6" + ctx := context.Background() - if r.prefer == "ipv6" { // prefer ipv6 - query := dns.Msg{} - query.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) - ips, ttl, err = r.resolveIPs(ctx, ex, &query) + if prefer == "ipv6" { // prefer ipv6 + mq := &dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) + ips, err = r.resolveIPs(ctx, ex, mq) if err != nil || len(ips) > 0 { return } } - query := dns.Msg{} - query.SetQuestion(dns.Fqdn(host), dns.TypeA) - return r.resolveIPs(ctx, ex, &query) + mq := &dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeA) + return r.resolveIPs(ctx, ex, mq) } -func (*resolver) resolveIPs(ctx context.Context, ex Exchanger, query *dns.Msg) (ips []net.IP, ttl time.Duration, err error) { - // buf := mPool.Get().([]byte) - // defer mPool.Put(buf) - - // buf = buf[:0] - // mq, err := query.PackBuffer(buf) - mq, err := query.Pack() +func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { + mr, err := r.exchangeMsg(ctx, ex, mq) if err != nil { return } - reply, err := ex.Exchange(ctx, mq) - if err != nil { - return - } - mr := &dns.Msg{} - if err = mr.Unpack(reply); err != nil { - return - } for _, ans := range mr.Answer { if ar, _ := ans.(*dns.AAAA); ar != nil { ips = append(ips, ar.AAAA) - ttl = time.Duration(ar.Header().Ttl) * time.Second } if ar, _ := ans.(*dns.A); ar != nil { ips = append(ips, ar.A) - ttl = time.Duration(ar.Header().Ttl) * time.Second } } + return } func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) { - if r == nil { + mq := &dns.Msg{} + if err = mq.Unpack(query); err != nil { return } - var servers []NameServer - r.mux.RLock() - servers = r.copyServers() - r.mux.RUnlock() - - for _, ns := range servers { - reply, err = ns.exchanger.Exchange(ctx, query) + var mr *dns.Msg + for _, ns := range r.copyServers() { + mr, err = r.exchangeMsg(ctx, ns.exchanger, mq) if err == nil { - return + break } } + if err != nil { + return + } + return mr.Pack() +} + +func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { + // Only cache for single question. + if len(mq.Question) == 1 { + key := newResolverCacheKey(&mq.Question[0]) + mr = r.cache.loadCache(key) + if mr != nil { + mr.Id = mq.Id + return + } + + defer func() { + r.cache.storeCache(key, mr, r.TTL()) + }() + } + + query, err := mq.Pack() + if err != nil { + return + } + reply, err := ex.Exchange(ctx, query) + if err != nil { + return + } + + mr = &dns.Msg{} + if err = mr.Unpack(reply); err != nil { + return nil, err + } + return } -type resolverCacheItem struct { - IPs []net.IP - ts int64 - ttl time.Duration -} - -func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP { - if name == "" || ttl < 0 { - return nil - } - - if v, ok := r.mCache.Load(name); ok { - item, _ := v.(*resolverCacheItem) - if ttl == 0 { - ttl = item.ttl - } - - if time.Since(time.Unix(item.ts, 0)) > ttl { - r.mCache.Delete(name) - return nil - } - return item.IPs - } - - return nil -} - -func (r *resolver) storeCache(name string, ips []net.IP, ttl time.Duration) { - if name == "" || len(ips) == 0 || ttl < 0 { - return - } - r.mCache.Store(name, &resolverCacheItem{ - IPs: ips, - ts: time.Now().Unix(), - ttl: ttl, - }) +func (r *resolver) TTL() time.Duration { + r.mux.RLock() + defer r.mux.RUnlock() + return r.ttl } func (r *resolver) Reload(rd io.Reader) error { @@ -444,12 +420,12 @@ func (r *resolver) Reload(rd io.Reader) error { } r.mux.Lock() - r.TTL = ttl + r.ttl = ttl r.timeout = timeout r.domain = domain r.period = period r.prefer = prefer - r.Servers = nss + r.servers = nss r.mux.Unlock() r.Init() @@ -496,15 +472,85 @@ func (r *resolver) String() string { defer r.mux.RUnlock() b := &bytes.Buffer{} - fmt.Fprintf(b, "TTL %v\n", r.TTL) + fmt.Fprintf(b, "TTL %v\n", r.ttl) fmt.Fprintf(b, "Reload %v\n", r.period) fmt.Fprintf(b, "Domain %v\n", r.domain) - for i := range r.Servers { - fmt.Fprintln(b, r.Servers[i]) + for i := range r.servers { + fmt.Fprintln(b, r.servers[i]) } return b.String() } +type resolverCacheKey string + +// newResolverCacheKey generates resolver cache key from question of dns query. +func newResolverCacheKey(q *dns.Question) resolverCacheKey { + if q == nil { + return "" + } + key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String()) + return resolverCacheKey(key) +} + +type resolverCacheItem struct { + mr *dns.Msg + ts int64 + ttl time.Duration +} + +type resolverCache struct { + m sync.Map +} + +func newResolverCache(ttl time.Duration) *resolverCache { + return &resolverCache{} +} + +func (rc *resolverCache) loadCache(key resolverCacheKey) *dns.Msg { + v, ok := rc.m.Load(key) + if !ok { + return nil + } + + item, ok := v.(*resolverCacheItem) + if !ok { + return nil + } + + elapsed := time.Since(time.Unix(item.ts, 0)) + if item.ttl > 0 && elapsed > item.ttl { + rc.m.Delete(key) + return nil + } + for _, rr := range item.mr.Answer { + if elapsed > time.Duration(rr.Header().Ttl)*time.Second { + rc.m.Delete(key) + return nil + } + } + + if Debug { + log.Logf("[resolver] cache hit %s", key) + } + + return item.mr.Copy() +} + +func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time.Duration) { + if key == "" || mr == nil || ttl < 0 { + return + } + + rc.m.Store(key, &resolverCacheItem{ + mr: mr.Copy(), + ts: time.Now().Unix(), + ttl: ttl, + }) + if Debug { + log.Logf("[resolver] cache store %s", key) + } +} + // Exchanger is an interface for DNS synchronous query. type Exchanger interface { Exchange(ctx context.Context, query []byte) ([]byte, error)