add cache support for DNS resolver
This commit is contained in:
parent
56bc433cd6
commit
482d2857f6
10
chain.go
10
chain.go
@ -130,11 +130,7 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
|
||||
if err == nil {
|
||||
addrs, er := c.Resolver.Resolve(host)
|
||||
if er != nil {
|
||||
log.Logf("[resolver] %s: %v", addr, er)
|
||||
return nil, er
|
||||
}
|
||||
if Debug {
|
||||
log.Logf("[resolver] %s %v", addr, addrs)
|
||||
log.Logf("[resolver] %s: %v", host, er)
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
addr = net.JoinHostPort(addrs[0].IP.String(), port)
|
||||
@ -280,7 +276,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
|
||||
if node.Bypass.Contains(addr) {
|
||||
if Debug {
|
||||
buf.WriteString(fmt.Sprintf("[%d@bypass: %s]", node.ID, addr))
|
||||
log.Log("select route:", buf.String())
|
||||
log.Log("[route]", buf.String())
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -301,7 +297,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
|
||||
route.Resolver = c.Resolver
|
||||
|
||||
if Debug {
|
||||
log.Log("select route:", buf.String())
|
||||
log.Log("[route]", buf.String())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -266,6 +266,7 @@ func parseResolver(cfg string) gost.Resolver {
|
||||
return nil
|
||||
}
|
||||
timeout := 30 * time.Second
|
||||
ttl := 60 * time.Second
|
||||
var nss []gost.NameServer
|
||||
|
||||
f, err := os.Open(cfg)
|
||||
@ -288,7 +289,7 @@ func parseResolver(cfg string) gost.Resolver {
|
||||
})
|
||||
}
|
||||
}
|
||||
return gost.NewResolver(nss, timeout)
|
||||
return gost.NewResolver(nss, timeout, ttl)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
@ -312,7 +313,7 @@ func parseResolver(cfg string) gost.Resolver {
|
||||
continue
|
||||
}
|
||||
|
||||
if ss[0] == "timeout" {
|
||||
if strings.ToLower(ss[0]) == "timeout" {
|
||||
if len(ss) >= 2 {
|
||||
if n, _ := strconv.Atoi(ss[1]); n > 0 {
|
||||
timeout = time.Second * time.Duration(n)
|
||||
@ -320,6 +321,13 @@ func parseResolver(cfg string) gost.Resolver {
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(ss[0]) == "ttl" {
|
||||
if len(ss) >= 2 {
|
||||
n, _ := strconv.Atoi(ss[1])
|
||||
ttl = time.Second * time.Duration(n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
var ns gost.NameServer
|
||||
switch len(ss) {
|
||||
@ -335,5 +343,5 @@ func parseResolver(cfg string) gost.Resolver {
|
||||
}
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
return gost.NewResolver(nss, timeout)
|
||||
return gost.NewResolver(nss, timeout, ttl)
|
||||
}
|
||||
|
@ -1,6 +1,9 @@
|
||||
# resolver timeout, default 30s.
|
||||
timeout 10
|
||||
|
||||
# resolver cache TTL, default 60s, minus value means that cache is disabled.
|
||||
ttl 300
|
||||
|
||||
# ip[:port] [protocol] [hostname]
|
||||
|
||||
1.1.1.1:853 tls cloudflare-dns.com
|
||||
|
@ -506,7 +506,9 @@ func (r *route) serve() error {
|
||||
)
|
||||
|
||||
chain.Resolver = parseResolver(node.Get("dns"))
|
||||
|
||||
if gost.Debug {
|
||||
log.Logf("[resolver]\n%v", chain.Resolver)
|
||||
}
|
||||
go srv.Serve(handler)
|
||||
}
|
||||
|
||||
|
70
resolver.go
70
resolver.go
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-log/log"
|
||||
@ -15,6 +16,8 @@ import (
|
||||
var (
|
||||
// DefaultResolverTimeout is the default timeout for name resolution.
|
||||
DefaultResolverTimeout = 30 * time.Second
|
||||
// DefaultResolverTTL is the default cache TTL for name resolution.
|
||||
DefaultResolverTTL = 60 * time.Second
|
||||
)
|
||||
|
||||
// Resolver is a name resolver for domain name.
|
||||
@ -32,7 +35,7 @@ type NameServer struct {
|
||||
Hostname string // for TLS handshake verification
|
||||
}
|
||||
|
||||
func (ns *NameServer) String() string {
|
||||
func (ns NameServer) String() string {
|
||||
addr := ns.Addr
|
||||
prot := ns.Protocol
|
||||
host := ns.Hostname
|
||||
@ -45,23 +48,39 @@ func (ns *NameServer) String() string {
|
||||
return fmt.Sprintf("%s/%s %s", addr, prot, host)
|
||||
}
|
||||
|
||||
type resolverCacheItem struct {
|
||||
IPAddrs []net.IPAddr
|
||||
ts int64
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
Resolver *net.Resolver
|
||||
Servers []NameServer
|
||||
Timeout time.Duration
|
||||
TTL time.Duration
|
||||
mCache *sync.Map
|
||||
}
|
||||
|
||||
// NewResolver create a new Resolver with the given name servers and resolution timeout.
|
||||
func NewResolver(servers []NameServer, timeout time.Duration) Resolver {
|
||||
func NewResolver(servers []NameServer, timeout, ttl time.Duration) Resolver {
|
||||
r := &resolver{
|
||||
Servers: servers,
|
||||
Timeout: timeout,
|
||||
TTL: ttl,
|
||||
mCache: &sync.Map{},
|
||||
}
|
||||
r.init()
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *resolver) init() {
|
||||
if r.Timeout <= 0 {
|
||||
r.Timeout = DefaultResolverTimeout
|
||||
}
|
||||
if r.TTL == 0 {
|
||||
r.TTL = DefaultResolverTTL
|
||||
}
|
||||
|
||||
r.Resolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||
@ -106,15 +125,53 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) Resolve(name string) ([]net.IPAddr, error) {
|
||||
func (r *resolver) Resolve(name string) (addrs []net.IPAddr, err error) {
|
||||
timeout := r.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = DefaultResolverTimeout
|
||||
|
||||
addrs = r.loadCache(name)
|
||||
if len(addrs) > 0 {
|
||||
if Debug {
|
||||
log.Logf("[resolver] cache hit: %s %v", name, addrs)
|
||||
}
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
return r.Resolver.LookupIPAddr(ctx, name)
|
||||
addrs, err = r.Resolver.LookupIPAddr(ctx, name)
|
||||
r.storeCache(name, addrs)
|
||||
if len(addrs) > 0 && Debug {
|
||||
log.Logf("[resolver] %s %v", name, addrs)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *resolver) loadCache(name string) []net.IPAddr {
|
||||
ttl := r.TTL
|
||||
if ttl < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if v, ok := r.mCache.Load(name); ok {
|
||||
item, _ := v.(*resolverCacheItem)
|
||||
if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl {
|
||||
return nil
|
||||
}
|
||||
return item.IPAddrs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolver) storeCache(name string, addrs []net.IPAddr) {
|
||||
ttl := r.TTL
|
||||
if ttl < 0 || name == "" || len(addrs) == 0 {
|
||||
return
|
||||
}
|
||||
r.mCache.Store(name, &resolverCacheItem{
|
||||
IPAddrs: addrs,
|
||||
ts: time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *resolver) String() string {
|
||||
@ -124,6 +181,7 @@ func (r *resolver) String() string {
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
fmt.Fprintf(b, "timeout %v\n", r.Timeout)
|
||||
fmt.Fprintf(b, "ttl %v\n", r.TTL)
|
||||
for i := range r.Servers {
|
||||
fmt.Fprintln(b, r.Servers[i])
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user