add cache support for DNS resolver

This commit is contained in:
ginuerzh 2018-05-19 19:17:53 +08:00
parent 56bc433cd6
commit 482d2857f6
5 changed files with 84 additions and 17 deletions

View File

@ -130,11 +130,7 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
if err == nil { if err == nil {
addrs, er := c.Resolver.Resolve(host) addrs, er := c.Resolver.Resolve(host)
if er != nil { if er != nil {
log.Logf("[resolver] %s: %v", addr, er) log.Logf("[resolver] %s: %v", host, er)
return nil, er
}
if Debug {
log.Logf("[resolver] %s %v", addr, addrs)
} }
if len(addrs) > 0 { if len(addrs) > 0 {
addr = net.JoinHostPort(addrs[0].IP.String(), port) addr = net.JoinHostPort(addrs[0].IP.String(), port)
@ -280,7 +276,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
if node.Bypass.Contains(addr) { if node.Bypass.Contains(addr) {
if Debug { if Debug {
buf.WriteString(fmt.Sprintf("[%d@bypass: %s]", node.ID, addr)) buf.WriteString(fmt.Sprintf("[%d@bypass: %s]", node.ID, addr))
log.Log("select route:", buf.String()) log.Log("[route]", buf.String())
} }
return return
} }
@ -301,7 +297,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
route.Resolver = c.Resolver route.Resolver = c.Resolver
if Debug { if Debug {
log.Log("select route:", buf.String()) log.Log("[route]", buf.String())
} }
return return
} }

View File

@ -266,6 +266,7 @@ func parseResolver(cfg string) gost.Resolver {
return nil return nil
} }
timeout := 30 * time.Second timeout := 30 * time.Second
ttl := 60 * time.Second
var nss []gost.NameServer var nss []gost.NameServer
f, err := os.Open(cfg) f, err := os.Open(cfg)
@ -288,7 +289,7 @@ func parseResolver(cfg string) gost.Resolver {
}) })
} }
} }
return gost.NewResolver(nss, timeout) return gost.NewResolver(nss, timeout, ttl)
} }
scanner := bufio.NewScanner(f) scanner := bufio.NewScanner(f)
@ -312,7 +313,7 @@ func parseResolver(cfg string) gost.Resolver {
continue continue
} }
if ss[0] == "timeout" { if strings.ToLower(ss[0]) == "timeout" {
if len(ss) >= 2 { if len(ss) >= 2 {
if n, _ := strconv.Atoi(ss[1]); n > 0 { if n, _ := strconv.Atoi(ss[1]); n > 0 {
timeout = time.Second * time.Duration(n) timeout = time.Second * time.Duration(n)
@ -320,6 +321,13 @@ func parseResolver(cfg string) gost.Resolver {
} }
continue continue
} }
if strings.ToLower(ss[0]) == "ttl" {
if len(ss) >= 2 {
n, _ := strconv.Atoi(ss[1])
ttl = time.Second * time.Duration(n)
}
continue
}
var ns gost.NameServer var ns gost.NameServer
switch len(ss) { switch len(ss) {
@ -335,5 +343,5 @@ func parseResolver(cfg string) gost.Resolver {
} }
nss = append(nss, ns) nss = append(nss, ns)
} }
return gost.NewResolver(nss, timeout) return gost.NewResolver(nss, timeout, ttl)
} }

View File

@ -1,6 +1,9 @@
# resolver timeout, default 30s. # resolver timeout, default 30s.
timeout 10 timeout 10
# resolver cache TTL, default 60s, minus value means that cache is disabled.
ttl 300
# ip[:port] [protocol] [hostname] # ip[:port] [protocol] [hostname]
1.1.1.1:853 tls cloudflare-dns.com 1.1.1.1:853 tls cloudflare-dns.com

View File

@ -506,7 +506,9 @@ func (r *route) serve() error {
) )
chain.Resolver = parseResolver(node.Get("dns")) chain.Resolver = parseResolver(node.Get("dns"))
if gost.Debug {
log.Logf("[resolver]\n%v", chain.Resolver)
}
go srv.Serve(handler) go srv.Serve(handler)
} }

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/go-log/log" "github.com/go-log/log"
@ -15,6 +16,8 @@ import (
var ( var (
// DefaultResolverTimeout is the default timeout for name resolution. // DefaultResolverTimeout is the default timeout for name resolution.
DefaultResolverTimeout = 30 * time.Second DefaultResolverTimeout = 30 * time.Second
// DefaultResolverTTL is the default cache TTL for name resolution.
DefaultResolverTTL = 60 * time.Second
) )
// Resolver is a name resolver for domain name. // Resolver is a name resolver for domain name.
@ -32,7 +35,7 @@ type NameServer struct {
Hostname string // for TLS handshake verification Hostname string // for TLS handshake verification
} }
func (ns *NameServer) String() string { func (ns NameServer) String() string {
addr := ns.Addr addr := ns.Addr
prot := ns.Protocol prot := ns.Protocol
host := ns.Hostname host := ns.Hostname
@ -45,23 +48,39 @@ func (ns *NameServer) String() string {
return fmt.Sprintf("%s/%s %s", addr, prot, host) return fmt.Sprintf("%s/%s %s", addr, prot, host)
} }
type resolverCacheItem struct {
IPAddrs []net.IPAddr
ts int64
}
type resolver struct { type resolver struct {
Resolver *net.Resolver Resolver *net.Resolver
Servers []NameServer Servers []NameServer
Timeout time.Duration Timeout time.Duration
TTL time.Duration
mCache *sync.Map
} }
// NewResolver create a new Resolver with the given name servers and resolution timeout. // NewResolver create a new Resolver with the given name servers and resolution timeout.
func NewResolver(servers []NameServer, timeout time.Duration) Resolver { func NewResolver(servers []NameServer, timeout, ttl time.Duration) Resolver {
r := &resolver{ r := &resolver{
Servers: servers, Servers: servers,
Timeout: timeout, Timeout: timeout,
TTL: ttl,
mCache: &sync.Map{},
} }
r.init() r.init()
return r return r
} }
func (r *resolver) init() { func (r *resolver) init() {
if r.Timeout <= 0 {
r.Timeout = DefaultResolverTimeout
}
if r.TTL == 0 {
r.TTL = DefaultResolverTTL
}
r.Resolver = &net.Resolver{ r.Resolver = &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) { Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) {
@ -106,15 +125,53 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
} }
} }
func (r *resolver) Resolve(name string) ([]net.IPAddr, error) { func (r *resolver) Resolve(name string) (addrs []net.IPAddr, err error) {
timeout := r.Timeout timeout := r.Timeout
if timeout <= 0 {
timeout = DefaultResolverTimeout addrs = r.loadCache(name)
if len(addrs) > 0 {
if Debug {
log.Logf("[resolver] cache hit: %s %v", name, addrs)
}
return
} }
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
return r.Resolver.LookupIPAddr(ctx, name) addrs, err = r.Resolver.LookupIPAddr(ctx, name)
r.storeCache(name, addrs)
if len(addrs) > 0 && Debug {
log.Logf("[resolver] %s %v", name, addrs)
}
return
}
func (r *resolver) loadCache(name string) []net.IPAddr {
ttl := r.TTL
if ttl < 0 {
return nil
}
if v, ok := r.mCache.Load(name); ok {
item, _ := v.(*resolverCacheItem)
if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl {
return nil
}
return item.IPAddrs
}
return nil
}
func (r *resolver) storeCache(name string, addrs []net.IPAddr) {
ttl := r.TTL
if ttl < 0 || name == "" || len(addrs) == 0 {
return
}
r.mCache.Store(name, &resolverCacheItem{
IPAddrs: addrs,
ts: time.Now().Unix(),
})
} }
func (r *resolver) String() string { func (r *resolver) String() string {
@ -124,6 +181,7 @@ func (r *resolver) String() string {
b := &bytes.Buffer{} b := &bytes.Buffer{}
fmt.Fprintf(b, "timeout %v\n", r.Timeout) fmt.Fprintf(b, "timeout %v\n", r.Timeout)
fmt.Fprintf(b, "ttl %v\n", r.TTL)
for i := range r.Servers { for i := range r.Servers {
fmt.Fprintln(b, r.Servers[i]) fmt.Fprintln(b, r.Servers[i])
} }