add cache for dns

This commit is contained in:
ginuerzh 2020-01-15 22:30:37 +08:00
parent 8ec3d8cbcf
commit 99b141e5be
3 changed files with 175 additions and 110 deletions

21
dns.go
View File

@ -12,6 +12,20 @@ import (
"github.com/miekg/dns"
)
var (
defaultResolver Resolver
)
func init() {
defaultResolver = NewResolver(
DefaultResolverTimeout,
NameServer{
Addr: "127.0.0.1:53",
Protocol: "udp",
})
defaultResolver.Init()
}
type dnsHandler struct {
options *HandlerOptions
}
@ -58,7 +72,12 @@ func (h *dnsHandler) Handle(conn net.Conn) {
}
start := time.Now()
reply, err := h.options.Resolver.Exchange(context.Background(), b[:n])
resolver := h.options.Resolver
if resolver == nil {
resolver = defaultResolver
}
reply, err := resolver.Exchange(context.Background(), b[:n])
if err != nil {
log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
return

View File

@ -20,7 +20,7 @@ import (
)
// Version is the gost version.
const Version = "2.9.2"
const Version = "2.10.0-dev"
// Debug is a flag that enables the debug log.
var Debug bool

View File

@ -149,12 +149,12 @@ type ReloadResolver interface {
}
type resolver struct {
Servers []NameServer
mCache *sync.Map
TTL time.Duration
servers []NameServer
ttl time.Duration
timeout time.Duration
period time.Duration
domain string
cache *resolverCache
stopped chan struct{}
mux sync.RWMutex
prefer string // ipv4 or ipv6
@ -169,9 +169,8 @@ func NewResolver(ttl time.Duration, servers ...NameServer) ReloadResolver {
func newResolver(ttl time.Duration, servers ...NameServer) *resolver {
return &resolver{
Servers: servers,
TTL: ttl,
mCache: &sync.Map{},
servers: servers,
cache: newResolverCache(ttl),
stopped: make(chan struct{}),
}
}
@ -189,7 +188,7 @@ func (r *resolver) Init(opts ...ResolverOption) error {
}
var nss []NameServer
for _, ns := range r.Servers {
for _, ns := range r.servers {
if err := ns.Init( // init all name servers
ChainNameServerOption(r.options.chain),
TimeoutNameServerOption(r.timeout),
@ -199,33 +198,26 @@ func (r *resolver) Init(opts ...ResolverOption) error {
nss = append(nss, ns)
}
r.Servers = nss
r.servers = nss
return nil
}
func (r *resolver) copyServers() []NameServer {
var servers []NameServer
for i := range r.Servers {
servers = append(servers, r.Servers[i])
r.mux.RLock()
defer r.mux.RUnlock()
servers := make([]NameServer, len(r.servers))
for i := range r.servers {
servers[i] = r.servers[i]
}
return servers
}
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
if r == nil {
return
}
var domain string
var ttl time.Duration
var servers []NameServer
r.mux.RLock()
domain = r.domain
ttl = r.TTL
servers = r.copyServers()
domain := r.domain
r.mux.RUnlock()
if ip := net.ParseIP(host); ip != nil {
@ -235,140 +227,124 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
if !strings.Contains(host, ".") && domain != "" {
host = host + "." + domain
}
ips = r.loadCache(host, ttl)
if len(ips) > 0 {
if Debug {
log.Logf("[resolver] cache hit %s: %v", host, ips)
}
return
}
for _, ns := range servers {
ips, ttl, err = r.resolve(ns.exchanger, host)
for _, ns := range r.copyServers() {
ips, err = r.resolve(ns.exchanger, host)
if err != nil {
log.Logf("[resolver] %s via %s : %s", host, ns.String(), err)
continue
}
if Debug {
log.Logf("[resolver] %s via %s %v(ttl: %v)", host, ns.String(), ips, ttl)
log.Logf("[resolver] %s via %s %v", host, ns.String(), ips)
}
if len(ips) > 0 {
break
}
}
r.storeCache(host, ips, ttl)
return
}
func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, ttl time.Duration, err error) {
func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) {
if ex == nil {
return
}
r.mux.RLock()
prefer := r.prefer
r.mux.RUnlock()
prefer = "ipv6"
ctx := context.Background()
if r.prefer == "ipv6" { // prefer ipv6
query := dns.Msg{}
query.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
ips, ttl, err = r.resolveIPs(ctx, ex, &query)
if prefer == "ipv6" { // prefer ipv6
mq := &dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
ips, err = r.resolveIPs(ctx, ex, mq)
if err != nil || len(ips) > 0 {
return
}
}
query := dns.Msg{}
query.SetQuestion(dns.Fqdn(host), dns.TypeA)
return r.resolveIPs(ctx, ex, &query)
mq := &dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), dns.TypeA)
return r.resolveIPs(ctx, ex, mq)
}
func (*resolver) resolveIPs(ctx context.Context, ex Exchanger, query *dns.Msg) (ips []net.IP, ttl time.Duration, err error) {
// buf := mPool.Get().([]byte)
// defer mPool.Put(buf)
// buf = buf[:0]
// mq, err := query.PackBuffer(buf)
mq, err := query.Pack()
func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
mr, err := r.exchangeMsg(ctx, ex, mq)
if err != nil {
return
}
reply, err := ex.Exchange(ctx, mq)
if err != nil {
return
}
mr := &dns.Msg{}
if err = mr.Unpack(reply); err != nil {
return
}
for _, ans := range mr.Answer {
if ar, _ := ans.(*dns.AAAA); ar != nil {
ips = append(ips, ar.AAAA)
ttl = time.Duration(ar.Header().Ttl) * time.Second
}
if ar, _ := ans.(*dns.A); ar != nil {
ips = append(ips, ar.A)
ttl = time.Duration(ar.Header().Ttl) * time.Second
}
}
return
}
func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
if r == nil {
mq := &dns.Msg{}
if err = mq.Unpack(query); err != nil {
return
}
var servers []NameServer
r.mux.RLock()
servers = r.copyServers()
r.mux.RUnlock()
for _, ns := range servers {
reply, err = ns.exchanger.Exchange(ctx, query)
var mr *dns.Msg
for _, ns := range r.copyServers() {
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
if err == nil {
break
}
}
if err != nil {
return
}
return mr.Pack()
}
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
// Only cache for single question.
if len(mq.Question) == 1 {
key := newResolverCacheKey(&mq.Question[0])
mr = r.cache.loadCache(key)
if mr != nil {
mr.Id = mq.Id
return
}
defer func() {
r.cache.storeCache(key, mr, r.TTL())
}()
}
query, err := mq.Pack()
if err != nil {
return
}
reply, err := ex.Exchange(ctx, query)
if err != nil {
return
}
mr = &dns.Msg{}
if err = mr.Unpack(reply); err != nil {
return nil, err
}
return
}
type resolverCacheItem struct {
IPs []net.IP
ts int64
ttl time.Duration
}
func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
if name == "" || ttl < 0 {
return nil
}
if v, ok := r.mCache.Load(name); ok {
item, _ := v.(*resolverCacheItem)
if ttl == 0 {
ttl = item.ttl
}
if time.Since(time.Unix(item.ts, 0)) > ttl {
r.mCache.Delete(name)
return nil
}
return item.IPs
}
return nil
}
func (r *resolver) storeCache(name string, ips []net.IP, ttl time.Duration) {
if name == "" || len(ips) == 0 || ttl < 0 {
return
}
r.mCache.Store(name, &resolverCacheItem{
IPs: ips,
ts: time.Now().Unix(),
ttl: ttl,
})
func (r *resolver) TTL() time.Duration {
r.mux.RLock()
defer r.mux.RUnlock()
return r.ttl
}
func (r *resolver) Reload(rd io.Reader) error {
@ -444,12 +420,12 @@ func (r *resolver) Reload(rd io.Reader) error {
}
r.mux.Lock()
r.TTL = ttl
r.ttl = ttl
r.timeout = timeout
r.domain = domain
r.period = period
r.prefer = prefer
r.Servers = nss
r.servers = nss
r.mux.Unlock()
r.Init()
@ -496,15 +472,85 @@ func (r *resolver) String() string {
defer r.mux.RUnlock()
b := &bytes.Buffer{}
fmt.Fprintf(b, "TTL %v\n", r.TTL)
fmt.Fprintf(b, "TTL %v\n", r.ttl)
fmt.Fprintf(b, "Reload %v\n", r.period)
fmt.Fprintf(b, "Domain %v\n", r.domain)
for i := range r.Servers {
fmt.Fprintln(b, r.Servers[i])
for i := range r.servers {
fmt.Fprintln(b, r.servers[i])
}
return b.String()
}
type resolverCacheKey string
// newResolverCacheKey generates resolver cache key from question of dns query.
func newResolverCacheKey(q *dns.Question) resolverCacheKey {
if q == nil {
return ""
}
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
return resolverCacheKey(key)
}
type resolverCacheItem struct {
mr *dns.Msg
ts int64
ttl time.Duration
}
type resolverCache struct {
m sync.Map
}
func newResolverCache(ttl time.Duration) *resolverCache {
return &resolverCache{}
}
func (rc *resolverCache) loadCache(key resolverCacheKey) *dns.Msg {
v, ok := rc.m.Load(key)
if !ok {
return nil
}
item, ok := v.(*resolverCacheItem)
if !ok {
return nil
}
elapsed := time.Since(time.Unix(item.ts, 0))
if item.ttl > 0 && elapsed > item.ttl {
rc.m.Delete(key)
return nil
}
for _, rr := range item.mr.Answer {
if elapsed > time.Duration(rr.Header().Ttl)*time.Second {
rc.m.Delete(key)
return nil
}
}
if Debug {
log.Logf("[resolver] cache hit %s", key)
}
return item.mr.Copy()
}
func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time.Duration) {
if key == "" || mr == nil || ttl < 0 {
return
}
rc.m.Store(key, &resolverCacheItem{
mr: mr.Copy(),
ts: time.Now().Unix(),
ttl: ttl,
})
if Debug {
log.Logf("[resolver] cache store %s", key)
}
}
// Exchanger is an interface for DNS synchronous query.
type Exchanger interface {
Exchange(ctx context.Context, query []byte) ([]byte, error)