From 3ebf423e87fe7fe612b5d27740f7b06312df9680 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 10 Nov 2018 12:14:26 +0800 Subject: [PATCH] fix resolver timeout --- resolver.go | 104 ++++++++++++++++++++++++++-------------------------- 1 file changed, 53 insertions(+), 51 deletions(-) diff --git a/resolver.go b/resolver.go index 87e81eb..ca283c9 100644 --- a/resolver.go +++ b/resolver.go @@ -3,7 +3,6 @@ package gost import ( "bufio" "bytes" - "context" "crypto/tls" "fmt" "io" @@ -13,6 +12,7 @@ import ( "time" "github.com/go-log/log" + "github.com/miekg/dns" ) var ( @@ -46,14 +46,13 @@ type NameServer struct { func (ns NameServer) String() string { addr := ns.Addr prot := ns.Protocol - host := ns.Hostname if _, port, _ := net.SplitHostPort(addr); port == "" { addr = net.JoinHostPort(addr, "53") } if prot == "" { prot = "udp" } - return fmt.Sprintf("%s/%s %s", addr, prot, host) + return fmt.Sprintf("%s/%s", addr, prot) } type resolverCacheItem struct { @@ -89,78 +88,81 @@ func (r *resolver) init() { if r.TTL == 0 { r.TTL = DefaultResolverTTL } - - r.Resolver = &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) { - for _, ns := range r.Servers { - conn, err = r.dial(ctx, ns) - if err == nil { - break - } - log.Logf("[resolver] %s : %s", ns, err) - } - return - }, - } } -func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) { - var d net.Dialer +func (r *resolver) Resolve(host string) (ips []net.IP, err error) { + if r == nil { + return + } + if ip := net.ParseIP(host); ip != nil { + return []net.IP{ip}, nil + } + + ips = r.loadCache(host) + if len(ips) > 0 { + if Debug { + log.Logf("[resolver] cache hit %s: %v", host, ips) + } + return + } + + for _, ns := range r.Servers { + ips, err = r.resolve(ns, host) + if err != nil { + log.Logf("[resolver] %s via %s : %s", host, ns, err) + continue + } + + if Debug { + log.Logf("[resolver] %s via %s %v", host, ns, ips) + } + if len(ips) > 0 { + break + } + } + + r.storeCache(host, ips) + return +} + +func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error) { addr := ns.Addr if _, port, _ := net.SplitHostPort(addr); port == "" { addr = net.JoinHostPort(addr, "53") } + + client := dns.Client{ + Timeout: r.Timeout, + } switch strings.ToLower(ns.Protocol) { case "tcp": - return d.DialContext(ctx, "tcp", addr) + client.Net = "tcp" case "tls": - conn, err := d.DialContext(ctx, "tcp", addr) - if err != nil { - return nil, err - } cfg := &tls.Config{ ServerName: ns.Hostname, } if cfg.ServerName == "" { cfg.InsecureSkipVerify = true } - return tls.Client(conn, cfg), nil + client.Net = "tcp-tls" + client.TLSConfig = cfg case "udp": fallthrough default: - return d.DialContext(ctx, "udp", addr) + client.Net = "udp" } -} -func (r *resolver) Resolve(name string) (ips []net.IP, err error) { - if r == nil { + m := dns.Msg{} + m.SetQuestion(dns.Fqdn(host), dns.TypeA) + mr, _, err := client.Exchange(&m, addr) + if err != nil { return } - timeout := r.Timeout - - if ip := net.ParseIP(name); ip != nil { - return []net.IP{ip}, nil - } - - ips = r.loadCache(name) - if len(ips) > 0 { - if Debug { - log.Logf("[resolver] cache hit: %s %v", name, ips) + for _, ans := range mr.Answer { + if ar, _ := ans.(*dns.A); ar != nil { + ips = append(ips, ar.A) } - return - } - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - addrs, err := r.Resolver.LookupIPAddr(ctx, name) - for _, addr := range addrs { - ips = append(ips, addr.IP) - } - r.storeCache(name, ips) - if len(ips) > 0 && Debug { - log.Logf("[resolver] %s %v", name, ips) } return }