add cache for dns
This commit is contained in:
parent
8ec3d8cbcf
commit
99b141e5be
21
dns.go
21
dns.go
@ -12,6 +12,20 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultResolver Resolver
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
defaultResolver = NewResolver(
|
||||||
|
DefaultResolverTimeout,
|
||||||
|
NameServer{
|
||||||
|
Addr: "127.0.0.1:53",
|
||||||
|
Protocol: "udp",
|
||||||
|
})
|
||||||
|
defaultResolver.Init()
|
||||||
|
}
|
||||||
|
|
||||||
type dnsHandler struct {
|
type dnsHandler struct {
|
||||||
options *HandlerOptions
|
options *HandlerOptions
|
||||||
}
|
}
|
||||||
@ -58,7 +72,12 @@ func (h *dnsHandler) Handle(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
reply, err := h.options.Resolver.Exchange(context.Background(), b[:n])
|
|
||||||
|
resolver := h.options.Resolver
|
||||||
|
if resolver == nil {
|
||||||
|
resolver = defaultResolver
|
||||||
|
}
|
||||||
|
reply, err := resolver.Exchange(context.Background(), b[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
|
log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||||
return
|
return
|
||||||
|
2
gost.go
2
gost.go
@ -20,7 +20,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Version is the gost version.
|
// Version is the gost version.
|
||||||
const Version = "2.9.2"
|
const Version = "2.10.0-dev"
|
||||||
|
|
||||||
// Debug is a flag that enables the debug log.
|
// Debug is a flag that enables the debug log.
|
||||||
var Debug bool
|
var Debug bool
|
||||||
|
262
resolver.go
262
resolver.go
@ -149,12 +149,12 @@ type ReloadResolver interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type resolver struct {
|
type resolver struct {
|
||||||
Servers []NameServer
|
servers []NameServer
|
||||||
mCache *sync.Map
|
ttl time.Duration
|
||||||
TTL time.Duration
|
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
period time.Duration
|
period time.Duration
|
||||||
domain string
|
domain string
|
||||||
|
cache *resolverCache
|
||||||
stopped chan struct{}
|
stopped chan struct{}
|
||||||
mux sync.RWMutex
|
mux sync.RWMutex
|
||||||
prefer string // ipv4 or ipv6
|
prefer string // ipv4 or ipv6
|
||||||
@ -169,9 +169,8 @@ func NewResolver(ttl time.Duration, servers ...NameServer) ReloadResolver {
|
|||||||
|
|
||||||
func newResolver(ttl time.Duration, servers ...NameServer) *resolver {
|
func newResolver(ttl time.Duration, servers ...NameServer) *resolver {
|
||||||
return &resolver{
|
return &resolver{
|
||||||
Servers: servers,
|
servers: servers,
|
||||||
TTL: ttl,
|
cache: newResolverCache(ttl),
|
||||||
mCache: &sync.Map{},
|
|
||||||
stopped: make(chan struct{}),
|
stopped: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -189,7 +188,7 @@ func (r *resolver) Init(opts ...ResolverOption) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var nss []NameServer
|
var nss []NameServer
|
||||||
for _, ns := range r.Servers {
|
for _, ns := range r.servers {
|
||||||
if err := ns.Init( // init all name servers
|
if err := ns.Init( // init all name servers
|
||||||
ChainNameServerOption(r.options.chain),
|
ChainNameServerOption(r.options.chain),
|
||||||
TimeoutNameServerOption(r.timeout),
|
TimeoutNameServerOption(r.timeout),
|
||||||
@ -199,33 +198,26 @@ func (r *resolver) Init(opts ...ResolverOption) error {
|
|||||||
nss = append(nss, ns)
|
nss = append(nss, ns)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Servers = nss
|
r.servers = nss
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) copyServers() []NameServer {
|
func (r *resolver) copyServers() []NameServer {
|
||||||
var servers []NameServer
|
r.mux.RLock()
|
||||||
for i := range r.Servers {
|
defer r.mux.RUnlock()
|
||||||
servers = append(servers, r.Servers[i])
|
|
||||||
|
servers := make([]NameServer, len(r.servers))
|
||||||
|
for i := range r.servers {
|
||||||
|
servers[i] = r.servers[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
return servers
|
return servers
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||||
if r == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var domain string
|
|
||||||
var ttl time.Duration
|
|
||||||
var servers []NameServer
|
|
||||||
|
|
||||||
r.mux.RLock()
|
r.mux.RLock()
|
||||||
domain = r.domain
|
domain := r.domain
|
||||||
ttl = r.TTL
|
|
||||||
servers = r.copyServers()
|
|
||||||
r.mux.RUnlock()
|
r.mux.RUnlock()
|
||||||
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
@ -235,140 +227,124 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
|||||||
if !strings.Contains(host, ".") && domain != "" {
|
if !strings.Contains(host, ".") && domain != "" {
|
||||||
host = host + "." + domain
|
host = host + "." + domain
|
||||||
}
|
}
|
||||||
ips = r.loadCache(host, ttl)
|
|
||||||
if len(ips) > 0 {
|
|
||||||
if Debug {
|
|
||||||
log.Logf("[resolver] cache hit %s: %v", host, ips)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ns := range servers {
|
for _, ns := range r.copyServers() {
|
||||||
ips, ttl, err = r.resolve(ns.exchanger, host)
|
ips, err = r.resolve(ns.exchanger, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Logf("[resolver] %s via %s : %s", host, ns.String(), err)
|
log.Logf("[resolver] %s via %s : %s", host, ns.String(), err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if Debug {
|
if Debug {
|
||||||
log.Logf("[resolver] %s via %s %v(ttl: %v)", host, ns.String(), ips, ttl)
|
log.Logf("[resolver] %s via %s %v", host, ns.String(), ips)
|
||||||
}
|
}
|
||||||
if len(ips) > 0 {
|
if len(ips) > 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.storeCache(host, ips, ttl)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, ttl time.Duration, err error) {
|
func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) {
|
||||||
if ex == nil {
|
if ex == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.mux.RLock()
|
||||||
|
prefer := r.prefer
|
||||||
|
r.mux.RUnlock()
|
||||||
|
|
||||||
|
prefer = "ipv6"
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
if r.prefer == "ipv6" { // prefer ipv6
|
if prefer == "ipv6" { // prefer ipv6
|
||||||
query := dns.Msg{}
|
mq := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
|
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
|
||||||
ips, ttl, err = r.resolveIPs(ctx, ex, &query)
|
ips, err = r.resolveIPs(ctx, ex, mq)
|
||||||
if err != nil || len(ips) > 0 {
|
if err != nil || len(ips) > 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := dns.Msg{}
|
mq := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
mq.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||||
return r.resolveIPs(ctx, ex, &query)
|
return r.resolveIPs(ctx, ex, mq)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*resolver) resolveIPs(ctx context.Context, ex Exchanger, query *dns.Msg) (ips []net.IP, ttl time.Duration, err error) {
|
func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
|
||||||
// buf := mPool.Get().([]byte)
|
mr, err := r.exchangeMsg(ctx, ex, mq)
|
||||||
// defer mPool.Put(buf)
|
|
||||||
|
|
||||||
// buf = buf[:0]
|
|
||||||
// mq, err := query.PackBuffer(buf)
|
|
||||||
mq, err := query.Pack()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reply, err := ex.Exchange(ctx, mq)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
mr := &dns.Msg{}
|
|
||||||
if err = mr.Unpack(reply); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ans := range mr.Answer {
|
for _, ans := range mr.Answer {
|
||||||
if ar, _ := ans.(*dns.AAAA); ar != nil {
|
if ar, _ := ans.(*dns.AAAA); ar != nil {
|
||||||
ips = append(ips, ar.AAAA)
|
ips = append(ips, ar.AAAA)
|
||||||
ttl = time.Duration(ar.Header().Ttl) * time.Second
|
|
||||||
}
|
}
|
||||||
if ar, _ := ans.(*dns.A); ar != nil {
|
if ar, _ := ans.(*dns.A); ar != nil {
|
||||||
ips = append(ips, ar.A)
|
ips = append(ips, ar.A)
|
||||||
ttl = time.Duration(ar.Header().Ttl) * time.Second
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
|
func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||||
if r == nil {
|
mq := &dns.Msg{}
|
||||||
|
if err = mq.Unpack(query); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var servers []NameServer
|
var mr *dns.Msg
|
||||||
r.mux.RLock()
|
for _, ns := range r.copyServers() {
|
||||||
servers = r.copyServers()
|
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
|
||||||
r.mux.RUnlock()
|
|
||||||
|
|
||||||
for _, ns := range servers {
|
|
||||||
reply, err = ns.exchanger.Exchange(ctx, query)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return mr.Pack()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
|
||||||
|
// Only cache for single question.
|
||||||
|
if len(mq.Question) == 1 {
|
||||||
|
key := newResolverCacheKey(&mq.Question[0])
|
||||||
|
mr = r.cache.loadCache(key)
|
||||||
|
if mr != nil {
|
||||||
|
mr.Id = mq.Id
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
r.cache.storeCache(key, mr, r.TTL())
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
query, err := mq.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reply, err := ex.Exchange(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mr = &dns.Msg{}
|
||||||
|
if err = mr.Unpack(reply); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type resolverCacheItem struct {
|
func (r *resolver) TTL() time.Duration {
|
||||||
IPs []net.IP
|
r.mux.RLock()
|
||||||
ts int64
|
defer r.mux.RUnlock()
|
||||||
ttl time.Duration
|
return r.ttl
|
||||||
}
|
|
||||||
|
|
||||||
func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
|
|
||||||
if name == "" || ttl < 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := r.mCache.Load(name); ok {
|
|
||||||
item, _ := v.(*resolverCacheItem)
|
|
||||||
if ttl == 0 {
|
|
||||||
ttl = item.ttl
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Since(time.Unix(item.ts, 0)) > ttl {
|
|
||||||
r.mCache.Delete(name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return item.IPs
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *resolver) storeCache(name string, ips []net.IP, ttl time.Duration) {
|
|
||||||
if name == "" || len(ips) == 0 || ttl < 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.mCache.Store(name, &resolverCacheItem{
|
|
||||||
IPs: ips,
|
|
||||||
ts: time.Now().Unix(),
|
|
||||||
ttl: ttl,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) Reload(rd io.Reader) error {
|
func (r *resolver) Reload(rd io.Reader) error {
|
||||||
@ -444,12 +420,12 @@ func (r *resolver) Reload(rd io.Reader) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.mux.Lock()
|
r.mux.Lock()
|
||||||
r.TTL = ttl
|
r.ttl = ttl
|
||||||
r.timeout = timeout
|
r.timeout = timeout
|
||||||
r.domain = domain
|
r.domain = domain
|
||||||
r.period = period
|
r.period = period
|
||||||
r.prefer = prefer
|
r.prefer = prefer
|
||||||
r.Servers = nss
|
r.servers = nss
|
||||||
r.mux.Unlock()
|
r.mux.Unlock()
|
||||||
|
|
||||||
r.Init()
|
r.Init()
|
||||||
@ -496,15 +472,85 @@ func (r *resolver) String() string {
|
|||||||
defer r.mux.RUnlock()
|
defer r.mux.RUnlock()
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
fmt.Fprintf(b, "TTL %v\n", r.TTL)
|
fmt.Fprintf(b, "TTL %v\n", r.ttl)
|
||||||
fmt.Fprintf(b, "Reload %v\n", r.period)
|
fmt.Fprintf(b, "Reload %v\n", r.period)
|
||||||
fmt.Fprintf(b, "Domain %v\n", r.domain)
|
fmt.Fprintf(b, "Domain %v\n", r.domain)
|
||||||
for i := range r.Servers {
|
for i := range r.servers {
|
||||||
fmt.Fprintln(b, r.Servers[i])
|
fmt.Fprintln(b, r.servers[i])
|
||||||
}
|
}
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type resolverCacheKey string
|
||||||
|
|
||||||
|
// newResolverCacheKey generates resolver cache key from question of dns query.
|
||||||
|
func newResolverCacheKey(q *dns.Question) resolverCacheKey {
|
||||||
|
if q == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
|
||||||
|
return resolverCacheKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
type resolverCacheItem struct {
|
||||||
|
mr *dns.Msg
|
||||||
|
ts int64
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type resolverCache struct {
|
||||||
|
m sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResolverCache(ttl time.Duration) *resolverCache {
|
||||||
|
return &resolverCache{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *resolverCache) loadCache(key resolverCacheKey) *dns.Msg {
|
||||||
|
v, ok := rc.m.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
item, ok := v.(*resolverCacheItem)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := time.Since(time.Unix(item.ts, 0))
|
||||||
|
if item.ttl > 0 && elapsed > item.ttl {
|
||||||
|
rc.m.Delete(key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, rr := range item.mr.Answer {
|
||||||
|
if elapsed > time.Duration(rr.Header().Ttl)*time.Second {
|
||||||
|
rc.m.Delete(key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if Debug {
|
||||||
|
log.Logf("[resolver] cache hit %s", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return item.mr.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time.Duration) {
|
||||||
|
if key == "" || mr == nil || ttl < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rc.m.Store(key, &resolverCacheItem{
|
||||||
|
mr: mr.Copy(),
|
||||||
|
ts: time.Now().Unix(),
|
||||||
|
ttl: ttl,
|
||||||
|
})
|
||||||
|
if Debug {
|
||||||
|
log.Logf("[resolver] cache store %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Exchanger is an interface for DNS synchronous query.
|
// Exchanger is an interface for DNS synchronous query.
|
||||||
type Exchanger interface {
|
type Exchanger interface {
|
||||||
Exchange(ctx context.Context, query []byte) ([]byte, error)
|
Exchange(ctx context.Context, query []byte) ([]byte, error)
|
||||||
|
Loading…
Reference in New Issue
Block a user