From 56bc433cd622946734b473bb1d9d05b5847341b2 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 19 May 2018 17:52:34 +0800 Subject: [PATCH] fix dns resolver --- chain.go | 10 ++++++++-- cmd/gost/cfg.go | 29 ++++++++++++++++++++--------- cmd/gost/dns.txt | 6 +++--- cmd/gost/main.go | 1 - gost.go | 8 ++++---- resolver.go | 31 ++++++++++++++++++++++++++----- 6 files changed, 61 insertions(+), 24 deletions(-) diff --git a/chain.go b/chain.go index af72047..5ca322e 100644 --- a/chain.go +++ b/chain.go @@ -128,8 +128,14 @@ func (c *Chain) dial(addr string) (net.Conn, error) { if c != nil && c.Resolver != nil { host, port, err := net.SplitHostPort(addr) if err == nil { - addrs, _ := c.Resolver.Resolve(host) - log.Log(addr, addrs) + addrs, er := c.Resolver.Resolve(host) + if er != nil { + log.Logf("[resolver] %s: %v", addr, er) + return nil, er + } + if Debug { + log.Logf("[resolver] %s %v", addr, addrs) + } if len(addrs) > 0 { addr = net.JoinHostPort(addrs[0].IP.String(), port) } diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index 2241b58..7793df0 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -265,6 +265,9 @@ func parseResolver(cfg string) gost.Resolver { if cfg == "" { return nil } + timeout := 30 * time.Second + var nss []gost.NameServer + f, err := os.Open(cfg) if err != nil { for _, s := range strings.Split(cfg, ",") { @@ -272,13 +275,22 @@ func parseResolver(cfg string) gost.Resolver { if s == "" { continue } + ss := strings.Split(s, "/") + if len(ss) == 1 { + nss = append(nss, gost.NameServer{ + Addr: ss[0], + }) + } + if len(ss) == 2 { + nss = append(nss, gost.NameServer{ + Addr: ss[0], + Protocol: ss[1], + }) + } } - // return gost.NewBypass(matchers, reversed) + return gost.NewResolver(nss, timeout) } - timeout := 30 * time.Second - - var nss []gost.NameServer scanner := bufio.NewScanner(f) for scanner.Scan() { line := scanner.Text() @@ -310,14 +322,13 @@ func parseResolver(cfg string) gost.Resolver { } var ns gost.NameServer - if len(ss) == 1 { + switch len(ss) { + case 1: ns.Addr = ss[0] - } - if len(ss) == 2 { + case 2: ns.Addr = ss[0] ns.Protocol = ss[1] - } - if len(ss) == 3 { + default: ns.Addr = ss[0] ns.Protocol = ss[1] ns.Hostname = ss[2] diff --git a/cmd/gost/dns.txt b/cmd/gost/dns.txt index 8d9d55b..d06d899 100644 --- a/cmd/gost/dns.txt +++ b/cmd/gost/dns.txt @@ -1,8 +1,8 @@ -# ip[:port] [protocol] [hostname] - -# resolver timeout +# resolver timeout, default 30s. timeout 10 +# ip[:port] [protocol] [hostname] + 1.1.1.1:853 tls cloudflare-dns.com 8.8.8.8 8.8.8.8 tcp diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 23665f1..fbf7321 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -506,7 +506,6 @@ func (r *route) serve() error { ) chain.Resolver = parseResolver(node.Get("dns")) - log.Log(chain.Resolver) go srv.Serve(handler) } diff --git a/gost.go b/gost.go index 1c672b3..cad43dc 100644 --- a/gost.go +++ b/gost.go @@ -44,19 +44,19 @@ var ( ) var ( - // DefaultTLSConfig is a default TLS config for internal use + // DefaultTLSConfig is a default TLS config for internal use. DefaultTLSConfig *tls.Config - // DefaultUserAgent is the default HTTP User-Agent header used by HTTP and websocket + // DefaultUserAgent is the default HTTP User-Agent header used by HTTP and websocket. DefaultUserAgent = "Chrome/60.0.3112.90" ) -// SetLogger sets a new logger for internal log system +// SetLogger sets a new logger for internal log system. func SetLogger(logger log.Logger) { log.DefaultLogger = logger } -// GenCertificate generates a random TLS certificate +// GenCertificate generates a random TLS certificate. func GenCertificate() (cert tls.Certificate, err error) { rawCert, rawKey, err := generateKeyPair() if err != nil { diff --git a/resolver.go b/resolver.go index db06823..8c4bf09 100644 --- a/resolver.go +++ b/resolver.go @@ -6,7 +6,10 @@ import ( "crypto/tls" "fmt" "net" + "strings" "time" + + "github.com/go-log/log" ) var ( @@ -29,6 +32,19 @@ type NameServer struct { Hostname string // for TLS handshake verification } +func (ns *NameServer) String() string { + addr := ns.Addr + prot := ns.Protocol + host := ns.Hostname + if !strings.Contains(addr, ":") { + addr += ":53" + } + if prot == "" { + prot = "udp" + } + return fmt.Sprintf("%s/%s %s", addr, prot, host) +} + type resolver struct { Resolver *net.Resolver Servers []NameServer @@ -54,6 +70,7 @@ func (r *resolver) init() { if err == nil { break } + log.Logf("[resolver] %s : %s", ns, err) } return }, @@ -63,11 +80,15 @@ func (r *resolver) init() { func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) { var d net.Dialer - switch ns.Protocol { + addr := ns.Addr + if !strings.Contains(addr, ":") { + addr += ":53" + } + switch strings.ToLower(ns.Protocol) { case "tcp": - return d.DialContext(ctx, "tcp", ns.Addr) + return d.DialContext(ctx, "tcp", addr) case "tls": - conn, err := d.DialContext(ctx, "tcp", ns.Addr) + conn, err := d.DialContext(ctx, "tcp", addr) if err != nil { return nil, err } @@ -81,7 +102,7 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) { case "udp": fallthrough default: - return d.DialContext(ctx, "udp", ns.Addr) + return d.DialContext(ctx, "udp", addr) } } @@ -104,7 +125,7 @@ func (r *resolver) String() string { b := &bytes.Buffer{} fmt.Fprintf(b, "timeout %v\n", r.Timeout) for i := range r.Servers { - fmt.Fprintf(b, "%s/%s %s\n", r.Servers[i].Addr, r.Servers[i].Protocol, r.Servers[i].Hostname) + fmt.Fprintln(b, r.Servers[i]) } return b.String() }