fix resolver timeout

This commit is contained in:
ginuerzh 2018-11-10 12:14:26 +08:00
parent c66751b017
commit 3ebf423e87

View File

@ -3,7 +3,6 @@ package gost
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
@ -13,6 +12,7 @@ import (
"time" "time"
"github.com/go-log/log" "github.com/go-log/log"
"github.com/miekg/dns"
) )
var ( var (
@ -46,14 +46,13 @@ type NameServer struct {
func (ns NameServer) String() string { func (ns NameServer) String() string {
addr := ns.Addr addr := ns.Addr
prot := ns.Protocol prot := ns.Protocol
host := ns.Hostname
if _, port, _ := net.SplitHostPort(addr); port == "" { if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "53") addr = net.JoinHostPort(addr, "53")
} }
if prot == "" { if prot == "" {
prot = "udp" prot = "udp"
} }
return fmt.Sprintf("%s/%s %s", addr, prot, host) return fmt.Sprintf("%s/%s", addr, prot)
} }
type resolverCacheItem struct { type resolverCacheItem struct {
@ -89,78 +88,81 @@ func (r *resolver) init() {
if r.TTL == 0 { if r.TTL == 0 {
r.TTL = DefaultResolverTTL r.TTL = DefaultResolverTTL
} }
r.Resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) {
for _, ns := range r.Servers {
conn, err = r.dial(ctx, ns)
if err == nil {
break
}
log.Logf("[resolver] %s : %s", ns, err)
}
return
},
}
} }
func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) { func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
var d net.Dialer if r == nil {
return
}
if ip := net.ParseIP(host); ip != nil {
return []net.IP{ip}, nil
}
ips = r.loadCache(host)
if len(ips) > 0 {
if Debug {
log.Logf("[resolver] cache hit %s: %v", host, ips)
}
return
}
for _, ns := range r.Servers {
ips, err = r.resolve(ns, host)
if err != nil {
log.Logf("[resolver] %s via %s : %s", host, ns, err)
continue
}
if Debug {
log.Logf("[resolver] %s via %s %v", host, ns, ips)
}
if len(ips) > 0 {
break
}
}
r.storeCache(host, ips)
return
}
func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error) {
addr := ns.Addr addr := ns.Addr
if _, port, _ := net.SplitHostPort(addr); port == "" { if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "53") addr = net.JoinHostPort(addr, "53")
} }
client := dns.Client{
Timeout: r.Timeout,
}
switch strings.ToLower(ns.Protocol) { switch strings.ToLower(ns.Protocol) {
case "tcp": case "tcp":
return d.DialContext(ctx, "tcp", addr) client.Net = "tcp"
case "tls": case "tls":
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
cfg := &tls.Config{ cfg := &tls.Config{
ServerName: ns.Hostname, ServerName: ns.Hostname,
} }
if cfg.ServerName == "" { if cfg.ServerName == "" {
cfg.InsecureSkipVerify = true cfg.InsecureSkipVerify = true
} }
return tls.Client(conn, cfg), nil client.Net = "tcp-tls"
client.TLSConfig = cfg
case "udp": case "udp":
fallthrough fallthrough
default: default:
return d.DialContext(ctx, "udp", addr) client.Net = "udp"
} }
}
func (r *resolver) Resolve(name string) (ips []net.IP, err error) { m := dns.Msg{}
if r == nil { m.SetQuestion(dns.Fqdn(host), dns.TypeA)
mr, _, err := client.Exchange(&m, addr)
if err != nil {
return return
} }
timeout := r.Timeout for _, ans := range mr.Answer {
if ar, _ := ans.(*dns.A); ar != nil {
if ip := net.ParseIP(name); ip != nil { ips = append(ips, ar.A)
return []net.IP{ip}, nil
}
ips = r.loadCache(name)
if len(ips) > 0 {
if Debug {
log.Logf("[resolver] cache hit: %s %v", name, ips)
} }
return
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
addrs, err := r.Resolver.LookupIPAddr(ctx, name)
for _, addr := range addrs {
ips = append(ips, addr.IP)
}
r.storeCache(name, ips)
if len(ips) > 0 && Debug {
log.Logf("[resolver] %s %v", name, ips)
} }
return return
} }