From a7d49f0b3789bcb568eafb9bdcb2cf927d2f28f9 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 27 Dec 2018 19:58:12 +0800 Subject: [PATCH] DNS resolver support DoH #335 --- cmd/gost/.config/dns.txt | 1 + cmd/gost/cfg.go | 25 +++- resolver.go | 242 ++++++++++++++++++++++++++++++--------- resolver_test.go | 221 +++++++++++++++++++++++++++++++++++ 4 files changed, 431 insertions(+), 58 deletions(-) create mode 100644 resolver_test.go diff --git a/cmd/gost/.config/dns.txt b/cmd/gost/.config/dns.txt index 303ad2c..b2fc503 100644 --- a/cmd/gost/.config/dns.txt +++ b/cmd/gost/.config/dns.txt @@ -10,6 +10,7 @@ reload 10s # ip[:port] [protocol] [hostname] 1.1.1.1:853 tls cloudflare-dns.com +https://1.0.0.1/dns-query https 8.8.8.8 8.8.8.8 tcp 1.1.1.1 udp diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index d78700a..5b669e2 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -207,17 +207,34 @@ func parseResolver(cfg string) gost.Resolver { if s == "" { continue } + if strings.HasPrefix(s, "https") { + ns := gost.NameServer{ + Addr: s, + Protocol: "https", + } + if err := ns.Init(); err == nil { + nss = append(nss, ns) + } + continue + } + ss := strings.Split(s, "/") if len(ss) == 1 { - nss = append(nss, gost.NameServer{ + ns := gost.NameServer{ Addr: ss[0], - }) + } + if err := ns.Init(); err == nil { + nss = append(nss, ns) + } } if len(ss) == 2 { - nss = append(nss, gost.NameServer{ + ns := gost.NameServer{ Addr: ss[0], Protocol: ss[1], - }) + } + if err := ns.Init(); err == nil { + nss = append(nss, ns) + } } } return gost.NewResolver(timeout, ttl, nss...) diff --git a/resolver.go b/resolver.go index 2296817..dca415c 100644 --- a/resolver.go +++ b/resolver.go @@ -3,23 +3,28 @@ package gost import ( "bufio" "bytes" + "context" "crypto/tls" "fmt" "io" + "io/ioutil" "net" + "net/http" + "net/url" "strings" "sync" "time" "github.com/go-log/log" "github.com/miekg/dns" + "golang.org/x/net/http2" ) var ( // DefaultResolverTimeout is the default timeout for name resolution. - DefaultResolverTimeout = 30 * time.Second + DefaultResolverTimeout = 5 * time.Second // DefaultResolverTTL is the default cache TTL for name resolution. - DefaultResolverTTL = 60 * time.Second + DefaultResolverTTL = 1 * time.Hour ) // Resolver is a name resolver for domain name. @@ -39,9 +44,73 @@ type ReloadResolver interface { // NameServer is a name server. // Currently supported protocol: TCP, UDP and TLS. type NameServer struct { - Addr string - Protocol string - Hostname string // for TLS handshake verification + Addr string + Protocol string + Hostname string // for TLS handshake verification + Timeout time.Duration + exchanger Exchanger +} + +// Init initializes the name server. +func (ns *NameServer) Init() error { + switch strings.ToLower(ns.Protocol) { + case "tcp": + ns.exchanger = &dnsExchanger{ + endpoint: ns.Addr, + client: &dns.Client{ + Net: "tcp", + Timeout: ns.Timeout, + }, + } + case "tls": + cfg := &tls.Config{ + ServerName: ns.Hostname, + } + if cfg.ServerName == "" { + cfg.InsecureSkipVerify = true + } + + ns.exchanger = &dnsExchanger{ + endpoint: ns.Addr, + client: &dns.Client{ + Net: "tcp-tls", + Timeout: ns.Timeout, + TLSConfig: cfg, + }, + } + case "https": + u, err := url.Parse(ns.Addr) + if err != nil { + return err + } + cfg := &tls.Config{ServerName: u.Hostname()} + transport := &http.Transport{ + TLSClientConfig: cfg, + DisableCompression: true, + MaxIdleConns: 1, + } + http2.ConfigureTransport(transport) + + ns.exchanger = &dohExchanger{ + endpoint: u, + client: &http.Client{ + Transport: transport, + Timeout: ns.Timeout, + }, + } + case "udp": + fallthrough + default: + ns.exchanger = &dnsExchanger{ + endpoint: ns.Addr, + client: &dns.Client{ + Net: "udp", + Timeout: ns.Timeout, + }, + } + } + + return nil } func (ns NameServer) String() string { @@ -62,26 +131,19 @@ type resolverCacheItem struct { } type resolver struct { - Resolver *net.Resolver - Servers []NameServer - mCache *sync.Map - Timeout time.Duration - TTL time.Duration - period time.Duration - domain string - stopped chan struct{} - mux sync.RWMutex + Servers []NameServer + mCache *sync.Map + Timeout time.Duration + TTL time.Duration + period time.Duration + domain string + stopped chan struct{} + mux sync.RWMutex } // NewResolver create a new Resolver with the given name servers and resolution timeout. func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolver { - r := &resolver{ - Servers: servers, - Timeout: timeout, - TTL: ttl, - mCache: &sync.Map{}, - stopped: make(chan struct{}), - } + r := newResolver(timeout, ttl, servers...) if r.Timeout <= 0 { r.Timeout = DefaultResolverTimeout @@ -92,6 +154,16 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv return r } +func newResolver(timeout, ttl time.Duration, servers ...NameServer) *resolver { + return &resolver{ + Servers: servers, + Timeout: timeout, + TTL: ttl, + mCache: &sync.Map{}, + stopped: make(chan struct{}), + } +} + func (r *resolver) copyServers() []NameServer { var servers []NameServer for i := range r.Servers { @@ -107,12 +179,11 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { } var domain string - var timeout, ttl time.Duration + var ttl time.Duration var servers []NameServer r.mux.RLock() domain = r.domain - timeout = r.Timeout ttl = r.TTL servers = r.copyServers() r.mux.RUnlock() @@ -133,7 +204,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { } for _, ns := range servers { - ips, err = r.resolve(ns, host, timeout) + ips, err = r.resolve(ns.exchanger, host) if err != nil { log.Logf("[resolver] %s via %s : %s", host, ns, err) continue @@ -151,36 +222,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { return } -func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) { - addr := ns.Addr - if _, port, _ := net.SplitHostPort(addr); port == "" { - addr = net.JoinHostPort(addr, "53") +func (*resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) { + if ex == nil { + return } - client := dns.Client{ - Timeout: timeout, - } - switch strings.ToLower(ns.Protocol) { - case "tcp": - client.Net = "tcp" - case "tls": - cfg := &tls.Config{ - ServerName: ns.Hostname, - } - if cfg.ServerName == "" { - cfg.InsecureSkipVerify = true - } - client.Net = "tcp-tls" - client.TLSConfig = cfg - case "udp": - fallthrough - default: - client.Net = "udp" - } - - m := dns.Msg{} - m.SetQuestion(dns.Fqdn(host), dns.TypeA) - mr, _, err := client.Exchange(&m, addr) + query := dns.Msg{} + query.SetQuestion(dns.Fqdn(host), dns.TypeA) + mr, err := ex.Exchange(context.Background(), &query) if err != nil { return } @@ -223,7 +272,7 @@ func (r *resolver) Reload(rd io.Reader) error { var domain string var nss []NameServer - if r.Stopped() { + if rd == nil || r.Stopped() { return nil } @@ -293,7 +342,15 @@ func (r *resolver) Reload(rd io.Reader) error { ns.Protocol = ss[1] ns.Hostname = ss[2] } - nss = append(nss, ns) + + ns.Timeout = timeout + if timeout <= 0 { + ns.Timeout = DefaultResolverTimeout + } + + if err := ns.Init(); err == nil { + nss = append(nss, ns) + } } } @@ -359,3 +416,80 @@ func (r *resolver) String() string { } return b.String() } + +// Exchanger is an interface for DNS synchronous query. +type Exchanger interface { + Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) +} + +type dnsExchanger struct { + endpoint string + client *dns.Client +} + +func (ex *dnsExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + ep := ex.endpoint + if _, port, _ := net.SplitHostPort(ep); port == "" { + ep = net.JoinHostPort(ep, "53") + } + mr, _, err := ex.client.Exchange(query, ep) + return mr, err +} + +type dohExchanger struct { + endpoint *url.URL + client *http.Client +} + +// reference: https://github.com/cloudflare/cloudflared/blob/master/tunneldns/https_upstream.go#L54 +func (ex *dohExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + queryBuf, err := query.Pack() + if err != nil { + return nil, fmt.Errorf("failed to pack DNS query: %s", err) + } + + // No content negotiation for now, use DNS wire format + buf, backendErr := ex.exchangeWireformat(queryBuf) + if backendErr == nil { + response := &dns.Msg{} + if err := response.Unpack(buf); err != nil { + return nil, fmt.Errorf("failed to unpack DNS response from body: %s", err) + } + + response.Id = query.Id + return response, nil + } + + return nil, backendErr +} + +// Perform message exchange with the default UDP wireformat defined in current draft +// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https +func (ex *dohExchanger) exchangeWireformat(msg []byte) ([]byte, error) { + req, err := http.NewRequest("POST", ex.endpoint.String(), bytes.NewBuffer(msg)) + if err != nil { + return nil, fmt.Errorf("failed to create an HTTPS request: %s", err) + } + + req.Header.Add("Content-Type", "application/dns-udpwireformat") + req.Host = ex.endpoint.Hostname() + + resp, err := ex.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform an HTTPS request: %s", err) + } + + // Check response status code + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("returned status code %d", resp.StatusCode) + } + + // Read wireformat response from the body + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read the response body: %s", err) + } + + return buf, nil +} diff --git a/resolver_test.go b/resolver_test.go new file mode 100644 index 0000000..b057fe6 --- /dev/null +++ b/resolver_test.go @@ -0,0 +1,221 @@ +package gost + +import ( + "bytes" + "fmt" + "io" + "testing" + "time" +) + +var dnsTests = []struct { + ns NameServer + host string + pass bool +}{ + {NameServer{Addr: "1.1.1.1"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:53"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:53", Protocol: "tcp"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:853", Protocol: "tls"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "example.com"}, "github.com", false}, + {NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true}, + {NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true}, + {NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:12345", Timeout: 1 * time.Second}, "github.com", false}, + {NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp", Timeout: 1 * time.Second}, "github.com", false}, + {NameServer{Addr: "1.1.1.1:12345", Protocol: "tls", Timeout: 1 * time.Second}, "github.com", false}, + {NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https", Timeout: 1 * time.Second}, "github.com", false}, +} + +func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error { + ips, err := r.Resolve(host) + t.Log(host, ips, err) + if err != nil { + return err + } + + return nil +} + +func TestDNSResolver(t *testing.T) { + for i, tc := range dnsTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + ns := tc.ns + if err := ns.Init(); err != nil { + t.Error(err) + } + t.Log(ns) + r := NewResolver(0, 0, ns) + err := dnsResolverRoundtrip(t, r, tc.host) + if err != nil { + if tc.pass { + t.Error("got error:", err) + } + } else { + if !tc.pass { + t.Error("should failed") + } + } + }) + } +} + +var resolverReloadTests = []struct { + r io.Reader + + timeout time.Duration + ttl time.Duration + domain string + period time.Duration + ns *NameServer + + stopped bool +}{ + { + r: nil, + }, + { + r: bytes.NewBufferString(""), + }, + { + r: bytes.NewBufferString("reload 10s"), + period: 10 * time.Second, + }, + { + r: bytes.NewBufferString("timeout 10s\nreload 10s\n"), + timeout: 10 * time.Second, + period: 10 * time.Second, + }, + { + r: bytes.NewBufferString("ttl 10s\ntimeout 10s\nreload 10s\n"), + timeout: 10 * time.Second, + period: 10 * time.Second, + ttl: 10 * time.Second, + }, + { + r: bytes.NewBufferString("domain example.com\nttl 10s\ntimeout 10s\nreload 10s\n"), + timeout: 10 * time.Second, + period: 10 * time.Second, + ttl: 10 * time.Second, + domain: "example.com", + }, + { + r: bytes.NewBufferString("1.1.1.1"), + ns: &NameServer{ + Addr: "1.1.1.1", + Timeout: DefaultResolverTimeout, + }, + stopped: true, + }, + { + r: bytes.NewBufferString("timeout 10s\nsearch\nnameserver \nnameserver 1.1.1.1 udp"), + ns: &NameServer{ + Protocol: "udp", + Addr: "1.1.1.1", + Timeout: 10 * time.Second, + }, + timeout: 10 * time.Second, + stopped: true, + }, + { + r: bytes.NewBufferString("1.1.1.1 tcp"), + ns: &NameServer{ + Addr: "1.1.1.1", + Protocol: "tcp", + Timeout: DefaultResolverTimeout, + }, + stopped: true, + }, + { + r: bytes.NewBufferString("1.1.1.1:853 tls cloudflare-dns.com"), + ns: &NameServer{ + Addr: "1.1.1.1:853", + Protocol: "tls", + Hostname: "cloudflare-dns.com", + Timeout: DefaultResolverTimeout, + }, + stopped: true, + }, + { + r: bytes.NewBufferString("1.1.1.1:853 tls"), + ns: &NameServer{ + Addr: "1.1.1.1:853", + Protocol: "tls", + Timeout: DefaultResolverTimeout, + }, + stopped: true, + }, + { + r: bytes.NewBufferString("1.0.0.1:53 https"), + stopped: true, + }, + { + r: bytes.NewBufferString("https://1.0.0.1/dns-query https"), + ns: &NameServer{ + Addr: "https://1.0.0.1/dns-query", + Protocol: "https", + Timeout: DefaultResolverTimeout, + }, + stopped: true, + }, +} + +func TestResolverReload(t *testing.T) { + for i, tc := range resolverReloadTests { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + r := newResolver(0, 0) + if err := r.Reload(tc.r); err != nil { + t.Error(err) + } + t.Log(r.String()) + if r.Timeout != tc.timeout { + t.Errorf("timeout value should be %v, got %v", + tc.timeout, r.Timeout) + } + if r.TTL != tc.ttl { + t.Errorf("ttl value should be %v, got %v", + tc.ttl, r.TTL) + } + if r.Period() != tc.period { + t.Errorf("period value should be %v, got %v", + tc.period, r.period) + } + if r.domain != tc.domain { + t.Errorf("domain value should be %v, got %v", + tc.domain, r.domain) + } + + var ns *NameServer + if len(r.Servers) > 0 { + ns = &r.Servers[0] + } + + if !compareNameServer(ns, tc.ns) { + t.Errorf("nameserver not equal, should be %v, got %v", + tc.ns, r.Servers) + } + + if tc.stopped { + r.Stop() + } + if r.Stopped() != tc.stopped { + t.Errorf("stopped value should be %v, got %v", + tc.stopped, r.Stopped()) + } + }) + } +} + +func compareNameServer(n1, n2 *NameServer) bool { + if n1 == n2 { + return true + } + if n1 == nil || n2 == nil { + return false + } + return n1.Addr == n2.Addr && + n1.Hostname == n2.Hostname && + n1.Protocol == n2.Protocol && + n1.Timeout == n2.Timeout +}