fix dns resolver

This commit is contained in:
ginuerzh 2018-05-19 17:52:34 +08:00
parent 3271a50bdb
commit 56bc433cd6
6 changed files with 61 additions and 24 deletions

View File

@ -128,8 +128,14 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
if c != nil && c.Resolver != nil { if c != nil && c.Resolver != nil {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err == nil { if err == nil {
addrs, _ := c.Resolver.Resolve(host) addrs, er := c.Resolver.Resolve(host)
log.Log(addr, addrs) if er != nil {
log.Logf("[resolver] %s: %v", addr, er)
return nil, er
}
if Debug {
log.Logf("[resolver] %s %v", addr, addrs)
}
if len(addrs) > 0 { if len(addrs) > 0 {
addr = net.JoinHostPort(addrs[0].IP.String(), port) addr = net.JoinHostPort(addrs[0].IP.String(), port)
} }

View File

@ -265,6 +265,9 @@ func parseResolver(cfg string) gost.Resolver {
if cfg == "" { if cfg == "" {
return nil return nil
} }
timeout := 30 * time.Second
var nss []gost.NameServer
f, err := os.Open(cfg) f, err := os.Open(cfg)
if err != nil { if err != nil {
for _, s := range strings.Split(cfg, ",") { for _, s := range strings.Split(cfg, ",") {
@ -272,13 +275,22 @@ func parseResolver(cfg string) gost.Resolver {
if s == "" { if s == "" {
continue continue
} }
ss := strings.Split(s, "/")
if len(ss) == 1 {
nss = append(nss, gost.NameServer{
Addr: ss[0],
})
} }
// return gost.NewBypass(matchers, reversed) if len(ss) == 2 {
nss = append(nss, gost.NameServer{
Addr: ss[0],
Protocol: ss[1],
})
}
}
return gost.NewResolver(nss, timeout)
} }
timeout := 30 * time.Second
var nss []gost.NameServer
scanner := bufio.NewScanner(f) scanner := bufio.NewScanner(f)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@ -310,14 +322,13 @@ func parseResolver(cfg string) gost.Resolver {
} }
var ns gost.NameServer var ns gost.NameServer
if len(ss) == 1 { switch len(ss) {
case 1:
ns.Addr = ss[0] ns.Addr = ss[0]
} case 2:
if len(ss) == 2 {
ns.Addr = ss[0] ns.Addr = ss[0]
ns.Protocol = ss[1] ns.Protocol = ss[1]
} default:
if len(ss) == 3 {
ns.Addr = ss[0] ns.Addr = ss[0]
ns.Protocol = ss[1] ns.Protocol = ss[1]
ns.Hostname = ss[2] ns.Hostname = ss[2]

View File

@ -1,8 +1,8 @@
# ip[:port] [protocol] [hostname] # resolver timeout, default 30s.
# resolver timeout
timeout 10 timeout 10
# ip[:port] [protocol] [hostname]
1.1.1.1:853 tls cloudflare-dns.com 1.1.1.1:853 tls cloudflare-dns.com
8.8.8.8 8.8.8.8
8.8.8.8 tcp 8.8.8.8 tcp

View File

@ -506,7 +506,6 @@ func (r *route) serve() error {
) )
chain.Resolver = parseResolver(node.Get("dns")) chain.Resolver = parseResolver(node.Get("dns"))
log.Log(chain.Resolver)
go srv.Serve(handler) go srv.Serve(handler)
} }

View File

@ -44,19 +44,19 @@ var (
) )
var ( var (
// DefaultTLSConfig is a default TLS config for internal use // DefaultTLSConfig is a default TLS config for internal use.
DefaultTLSConfig *tls.Config DefaultTLSConfig *tls.Config
// DefaultUserAgent is the default HTTP User-Agent header used by HTTP and websocket // DefaultUserAgent is the default HTTP User-Agent header used by HTTP and websocket.
DefaultUserAgent = "Chrome/60.0.3112.90" DefaultUserAgent = "Chrome/60.0.3112.90"
) )
// SetLogger sets a new logger for internal log system // SetLogger sets a new logger for internal log system.
func SetLogger(logger log.Logger) { func SetLogger(logger log.Logger) {
log.DefaultLogger = logger log.DefaultLogger = logger
} }
// GenCertificate generates a random TLS certificate // GenCertificate generates a random TLS certificate.
func GenCertificate() (cert tls.Certificate, err error) { func GenCertificate() (cert tls.Certificate, err error) {
rawCert, rawKey, err := generateKeyPair() rawCert, rawKey, err := generateKeyPair()
if err != nil { if err != nil {

View File

@ -6,7 +6,10 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
"strings"
"time" "time"
"github.com/go-log/log"
) )
var ( var (
@ -29,6 +32,19 @@ type NameServer struct {
Hostname string // for TLS handshake verification Hostname string // for TLS handshake verification
} }
func (ns *NameServer) String() string {
addr := ns.Addr
prot := ns.Protocol
host := ns.Hostname
if !strings.Contains(addr, ":") {
addr += ":53"
}
if prot == "" {
prot = "udp"
}
return fmt.Sprintf("%s/%s %s", addr, prot, host)
}
type resolver struct { type resolver struct {
Resolver *net.Resolver Resolver *net.Resolver
Servers []NameServer Servers []NameServer
@ -54,6 +70,7 @@ func (r *resolver) init() {
if err == nil { if err == nil {
break break
} }
log.Logf("[resolver] %s : %s", ns, err)
} }
return return
}, },
@ -63,11 +80,15 @@ func (r *resolver) init() {
func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) { func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
var d net.Dialer var d net.Dialer
switch ns.Protocol { addr := ns.Addr
if !strings.Contains(addr, ":") {
addr += ":53"
}
switch strings.ToLower(ns.Protocol) {
case "tcp": case "tcp":
return d.DialContext(ctx, "tcp", ns.Addr) return d.DialContext(ctx, "tcp", addr)
case "tls": case "tls":
conn, err := d.DialContext(ctx, "tcp", ns.Addr) conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -81,7 +102,7 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
case "udp": case "udp":
fallthrough fallthrough
default: default:
return d.DialContext(ctx, "udp", ns.Addr) return d.DialContext(ctx, "udp", addr)
} }
} }
@ -104,7 +125,7 @@ func (r *resolver) String() string {
b := &bytes.Buffer{} b := &bytes.Buffer{}
fmt.Fprintf(b, "timeout %v\n", r.Timeout) fmt.Fprintf(b, "timeout %v\n", r.Timeout)
for i := range r.Servers { for i := range r.Servers {
fmt.Fprintf(b, "%s/%s %s\n", r.Servers[i].Addr, r.Servers[i].Protocol, r.Servers[i].Hostname) fmt.Fprintln(b, r.Servers[i])
} }
return b.String() return b.String()
} }