add cache for dns
This commit is contained in:
parent
8ec3d8cbcf
commit
99b141e5be
21
dns.go
21
dns.go
@ -12,6 +12,20 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultResolver Resolver
|
||||
)
|
||||
|
||||
func init() {
|
||||
defaultResolver = NewResolver(
|
||||
DefaultResolverTimeout,
|
||||
NameServer{
|
||||
Addr: "127.0.0.1:53",
|
||||
Protocol: "udp",
|
||||
})
|
||||
defaultResolver.Init()
|
||||
}
|
||||
|
||||
type dnsHandler struct {
|
||||
options *HandlerOptions
|
||||
}
|
||||
@ -58,7 +72,12 @@ func (h *dnsHandler) Handle(conn net.Conn) {
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
reply, err := h.options.Resolver.Exchange(context.Background(), b[:n])
|
||||
|
||||
resolver := h.options.Resolver
|
||||
if resolver == nil {
|
||||
resolver = defaultResolver
|
||||
}
|
||||
reply, err := resolver.Exchange(context.Background(), b[:n])
|
||||
if err != nil {
|
||||
log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
return
|
||||
|
2
gost.go
2
gost.go
@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
// Version is the gost version.
|
||||
const Version = "2.9.2"
|
||||
const Version = "2.10.0-dev"
|
||||
|
||||
// Debug is a flag that enables the debug log.
|
||||
var Debug bool
|
||||
|
260
resolver.go
260
resolver.go
@ -149,12 +149,12 @@ type ReloadResolver interface {
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
Servers []NameServer
|
||||
mCache *sync.Map
|
||||
TTL time.Duration
|
||||
servers []NameServer
|
||||
ttl time.Duration
|
||||
timeout time.Duration
|
||||
period time.Duration
|
||||
domain string
|
||||
cache *resolverCache
|
||||
stopped chan struct{}
|
||||
mux sync.RWMutex
|
||||
prefer string // ipv4 or ipv6
|
||||
@ -169,9 +169,8 @@ func NewResolver(ttl time.Duration, servers ...NameServer) ReloadResolver {
|
||||
|
||||
func newResolver(ttl time.Duration, servers ...NameServer) *resolver {
|
||||
return &resolver{
|
||||
Servers: servers,
|
||||
TTL: ttl,
|
||||
mCache: &sync.Map{},
|
||||
servers: servers,
|
||||
cache: newResolverCache(ttl),
|
||||
stopped: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
@ -189,7 +188,7 @@ func (r *resolver) Init(opts ...ResolverOption) error {
|
||||
}
|
||||
|
||||
var nss []NameServer
|
||||
for _, ns := range r.Servers {
|
||||
for _, ns := range r.servers {
|
||||
if err := ns.Init( // init all name servers
|
||||
ChainNameServerOption(r.options.chain),
|
||||
TimeoutNameServerOption(r.timeout),
|
||||
@ -199,33 +198,26 @@ func (r *resolver) Init(opts ...ResolverOption) error {
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
|
||||
r.Servers = nss
|
||||
r.servers = nss
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolver) copyServers() []NameServer {
|
||||
var servers []NameServer
|
||||
for i := range r.Servers {
|
||||
servers = append(servers, r.Servers[i])
|
||||
r.mux.RLock()
|
||||
defer r.mux.RUnlock()
|
||||
|
||||
servers := make([]NameServer, len(r.servers))
|
||||
for i := range r.servers {
|
||||
servers[i] = r.servers[i]
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var domain string
|
||||
var ttl time.Duration
|
||||
var servers []NameServer
|
||||
|
||||
r.mux.RLock()
|
||||
domain = r.domain
|
||||
ttl = r.TTL
|
||||
servers = r.copyServers()
|
||||
domain := r.domain
|
||||
r.mux.RUnlock()
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
@ -235,140 +227,124 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
if !strings.Contains(host, ".") && domain != "" {
|
||||
host = host + "." + domain
|
||||
}
|
||||
ips = r.loadCache(host, ttl)
|
||||
if len(ips) > 0 {
|
||||
if Debug {
|
||||
log.Logf("[resolver] cache hit %s: %v", host, ips)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, ns := range servers {
|
||||
ips, ttl, err = r.resolve(ns.exchanger, host)
|
||||
for _, ns := range r.copyServers() {
|
||||
ips, err = r.resolve(ns.exchanger, host)
|
||||
if err != nil {
|
||||
log.Logf("[resolver] %s via %s : %s", host, ns.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
if Debug {
|
||||
log.Logf("[resolver] %s via %s %v(ttl: %v)", host, ns.String(), ips, ttl)
|
||||
log.Logf("[resolver] %s via %s %v", host, ns.String(), ips)
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
r.storeCache(host, ips, ttl)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, ttl time.Duration, err error) {
|
||||
func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) {
|
||||
if ex == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.mux.RLock()
|
||||
prefer := r.prefer
|
||||
r.mux.RUnlock()
|
||||
|
||||
prefer = "ipv6"
|
||||
|
||||
ctx := context.Background()
|
||||
if r.prefer == "ipv6" { // prefer ipv6
|
||||
query := dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
|
||||
ips, ttl, err = r.resolveIPs(ctx, ex, &query)
|
||||
if prefer == "ipv6" { // prefer ipv6
|
||||
mq := &dns.Msg{}
|
||||
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
|
||||
ips, err = r.resolveIPs(ctx, ex, mq)
|
||||
if err != nil || len(ips) > 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
query := dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||
return r.resolveIPs(ctx, ex, &query)
|
||||
mq := &dns.Msg{}
|
||||
mq.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||
return r.resolveIPs(ctx, ex, mq)
|
||||
}
|
||||
|
||||
func (*resolver) resolveIPs(ctx context.Context, ex Exchanger, query *dns.Msg) (ips []net.IP, ttl time.Duration, err error) {
|
||||
// buf := mPool.Get().([]byte)
|
||||
// defer mPool.Put(buf)
|
||||
|
||||
// buf = buf[:0]
|
||||
// mq, err := query.PackBuffer(buf)
|
||||
mq, err := query.Pack()
|
||||
func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
|
||||
mr, err := r.exchangeMsg(ctx, ex, mq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
reply, err := ex.Exchange(ctx, mq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mr := &dns.Msg{}
|
||||
if err = mr.Unpack(reply); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ans := range mr.Answer {
|
||||
if ar, _ := ans.(*dns.AAAA); ar != nil {
|
||||
ips = append(ips, ar.AAAA)
|
||||
ttl = time.Duration(ar.Header().Ttl) * time.Second
|
||||
}
|
||||
if ar, _ := ans.(*dns.A); ar != nil {
|
||||
ips = append(ips, ar.A)
|
||||
ttl = time.Duration(ar.Header().Ttl) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
if r == nil {
|
||||
mq := &dns.Msg{}
|
||||
if err = mq.Unpack(query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var servers []NameServer
|
||||
r.mux.RLock()
|
||||
servers = r.copyServers()
|
||||
r.mux.RUnlock()
|
||||
|
||||
for _, ns := range servers {
|
||||
reply, err = ns.exchanger.Exchange(ctx, query)
|
||||
var mr *dns.Msg
|
||||
for _, ns := range r.copyServers() {
|
||||
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return mr.Pack()
|
||||
}
|
||||
|
||||
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
|
||||
// Only cache for single question.
|
||||
if len(mq.Question) == 1 {
|
||||
key := newResolverCacheKey(&mq.Question[0])
|
||||
mr = r.cache.loadCache(key)
|
||||
if mr != nil {
|
||||
mr.Id = mq.Id
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
r.cache.storeCache(key, mr, r.TTL())
|
||||
}()
|
||||
}
|
||||
|
||||
query, err := mq.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
reply, err := ex.Exchange(ctx, query)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mr = &dns.Msg{}
|
||||
if err = mr.Unpack(reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type resolverCacheItem struct {
|
||||
IPs []net.IP
|
||||
ts int64
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
|
||||
if name == "" || ttl < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if v, ok := r.mCache.Load(name); ok {
|
||||
item, _ := v.(*resolverCacheItem)
|
||||
if ttl == 0 {
|
||||
ttl = item.ttl
|
||||
}
|
||||
|
||||
if time.Since(time.Unix(item.ts, 0)) > ttl {
|
||||
r.mCache.Delete(name)
|
||||
return nil
|
||||
}
|
||||
return item.IPs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolver) storeCache(name string, ips []net.IP, ttl time.Duration) {
|
||||
if name == "" || len(ips) == 0 || ttl < 0 {
|
||||
return
|
||||
}
|
||||
r.mCache.Store(name, &resolverCacheItem{
|
||||
IPs: ips,
|
||||
ts: time.Now().Unix(),
|
||||
ttl: ttl,
|
||||
})
|
||||
func (r *resolver) TTL() time.Duration {
|
||||
r.mux.RLock()
|
||||
defer r.mux.RUnlock()
|
||||
return r.ttl
|
||||
}
|
||||
|
||||
func (r *resolver) Reload(rd io.Reader) error {
|
||||
@ -444,12 +420,12 @@ func (r *resolver) Reload(rd io.Reader) error {
|
||||
}
|
||||
|
||||
r.mux.Lock()
|
||||
r.TTL = ttl
|
||||
r.ttl = ttl
|
||||
r.timeout = timeout
|
||||
r.domain = domain
|
||||
r.period = period
|
||||
r.prefer = prefer
|
||||
r.Servers = nss
|
||||
r.servers = nss
|
||||
r.mux.Unlock()
|
||||
|
||||
r.Init()
|
||||
@ -496,15 +472,85 @@ func (r *resolver) String() string {
|
||||
defer r.mux.RUnlock()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
fmt.Fprintf(b, "TTL %v\n", r.TTL)
|
||||
fmt.Fprintf(b, "TTL %v\n", r.ttl)
|
||||
fmt.Fprintf(b, "Reload %v\n", r.period)
|
||||
fmt.Fprintf(b, "Domain %v\n", r.domain)
|
||||
for i := range r.Servers {
|
||||
fmt.Fprintln(b, r.Servers[i])
|
||||
for i := range r.servers {
|
||||
fmt.Fprintln(b, r.servers[i])
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type resolverCacheKey string
|
||||
|
||||
// newResolverCacheKey generates resolver cache key from question of dns query.
|
||||
func newResolverCacheKey(q *dns.Question) resolverCacheKey {
|
||||
if q == nil {
|
||||
return ""
|
||||
}
|
||||
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
|
||||
return resolverCacheKey(key)
|
||||
}
|
||||
|
||||
type resolverCacheItem struct {
|
||||
mr *dns.Msg
|
||||
ts int64
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
type resolverCache struct {
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func newResolverCache(ttl time.Duration) *resolverCache {
|
||||
return &resolverCache{}
|
||||
}
|
||||
|
||||
func (rc *resolverCache) loadCache(key resolverCacheKey) *dns.Msg {
|
||||
v, ok := rc.m.Load(key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
item, ok := v.(*resolverCacheItem)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
elapsed := time.Since(time.Unix(item.ts, 0))
|
||||
if item.ttl > 0 && elapsed > item.ttl {
|
||||
rc.m.Delete(key)
|
||||
return nil
|
||||
}
|
||||
for _, rr := range item.mr.Answer {
|
||||
if elapsed > time.Duration(rr.Header().Ttl)*time.Second {
|
||||
rc.m.Delete(key)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if Debug {
|
||||
log.Logf("[resolver] cache hit %s", key)
|
||||
}
|
||||
|
||||
return item.mr.Copy()
|
||||
}
|
||||
|
||||
func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time.Duration) {
|
||||
if key == "" || mr == nil || ttl < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
rc.m.Store(key, &resolverCacheItem{
|
||||
mr: mr.Copy(),
|
||||
ts: time.Now().Unix(),
|
||||
ttl: ttl,
|
||||
})
|
||||
if Debug {
|
||||
log.Logf("[resolver] cache store %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// Exchanger is an interface for DNS synchronous query.
|
||||
type Exchanger interface {
|
||||
Exchange(ctx context.Context, query []byte) ([]byte, error)
|
||||
|
Loading…
Reference in New Issue
Block a user