DNS resolver support DoH #335

This commit is contained in:
ginuerzh 2018-12-27 19:58:12 +08:00
parent 02f1d099c4
commit a7d49f0b37
4 changed files with 431 additions and 58 deletions

View File

@ -10,6 +10,7 @@ reload 10s
# ip[:port] [protocol] [hostname]
1.1.1.1:853 tls cloudflare-dns.com
https://1.0.0.1/dns-query https
8.8.8.8
8.8.8.8 tcp
1.1.1.1 udp

View File

@ -207,17 +207,34 @@ func parseResolver(cfg string) gost.Resolver {
if s == "" {
continue
}
if strings.HasPrefix(s, "https") {
ns := gost.NameServer{
Addr: s,
Protocol: "https",
}
if err := ns.Init(); err == nil {
nss = append(nss, ns)
}
continue
}
ss := strings.Split(s, "/")
if len(ss) == 1 {
nss = append(nss, gost.NameServer{
ns := gost.NameServer{
Addr: ss[0],
})
}
if err := ns.Init(); err == nil {
nss = append(nss, ns)
}
}
if len(ss) == 2 {
nss = append(nss, gost.NameServer{
ns := gost.NameServer{
Addr: ss[0],
Protocol: ss[1],
})
}
if err := ns.Init(); err == nil {
nss = append(nss, ns)
}
}
}
return gost.NewResolver(timeout, ttl, nss...)

View File

@ -3,23 +3,28 @@ package gost
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/go-log/log"
"github.com/miekg/dns"
"golang.org/x/net/http2"
)
var (
// DefaultResolverTimeout is the default timeout for name resolution.
DefaultResolverTimeout = 30 * time.Second
DefaultResolverTimeout = 5 * time.Second
// DefaultResolverTTL is the default cache TTL for name resolution.
DefaultResolverTTL = 60 * time.Second
DefaultResolverTTL = 1 * time.Hour
)
// Resolver is a name resolver for domain name.
@ -39,9 +44,73 @@ type ReloadResolver interface {
// NameServer is a name server.
// Currently supported protocol: TCP, UDP and TLS.
type NameServer struct {
Addr string
Protocol string
Hostname string // for TLS handshake verification
Addr string
Protocol string
Hostname string // for TLS handshake verification
Timeout time.Duration
exchanger Exchanger
}
// Init initializes the name server.
func (ns *NameServer) Init() error {
switch strings.ToLower(ns.Protocol) {
case "tcp":
ns.exchanger = &dnsExchanger{
endpoint: ns.Addr,
client: &dns.Client{
Net: "tcp",
Timeout: ns.Timeout,
},
}
case "tls":
cfg := &tls.Config{
ServerName: ns.Hostname,
}
if cfg.ServerName == "" {
cfg.InsecureSkipVerify = true
}
ns.exchanger = &dnsExchanger{
endpoint: ns.Addr,
client: &dns.Client{
Net: "tcp-tls",
Timeout: ns.Timeout,
TLSConfig: cfg,
},
}
case "https":
u, err := url.Parse(ns.Addr)
if err != nil {
return err
}
cfg := &tls.Config{ServerName: u.Hostname()}
transport := &http.Transport{
TLSClientConfig: cfg,
DisableCompression: true,
MaxIdleConns: 1,
}
http2.ConfigureTransport(transport)
ns.exchanger = &dohExchanger{
endpoint: u,
client: &http.Client{
Transport: transport,
Timeout: ns.Timeout,
},
}
case "udp":
fallthrough
default:
ns.exchanger = &dnsExchanger{
endpoint: ns.Addr,
client: &dns.Client{
Net: "udp",
Timeout: ns.Timeout,
},
}
}
return nil
}
func (ns NameServer) String() string {
@ -62,26 +131,19 @@ type resolverCacheItem struct {
}
type resolver struct {
Resolver *net.Resolver
Servers []NameServer
mCache *sync.Map
Timeout time.Duration
TTL time.Duration
period time.Duration
domain string
stopped chan struct{}
mux sync.RWMutex
Servers []NameServer
mCache *sync.Map
Timeout time.Duration
TTL time.Duration
period time.Duration
domain string
stopped chan struct{}
mux sync.RWMutex
}
// NewResolver create a new Resolver with the given name servers and resolution timeout.
func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolver {
r := &resolver{
Servers: servers,
Timeout: timeout,
TTL: ttl,
mCache: &sync.Map{},
stopped: make(chan struct{}),
}
r := newResolver(timeout, ttl, servers...)
if r.Timeout <= 0 {
r.Timeout = DefaultResolverTimeout
@ -92,6 +154,16 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
return r
}
func newResolver(timeout, ttl time.Duration, servers ...NameServer) *resolver {
return &resolver{
Servers: servers,
Timeout: timeout,
TTL: ttl,
mCache: &sync.Map{},
stopped: make(chan struct{}),
}
}
func (r *resolver) copyServers() []NameServer {
var servers []NameServer
for i := range r.Servers {
@ -107,12 +179,11 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
}
var domain string
var timeout, ttl time.Duration
var ttl time.Duration
var servers []NameServer
r.mux.RLock()
domain = r.domain
timeout = r.Timeout
ttl = r.TTL
servers = r.copyServers()
r.mux.RUnlock()
@ -133,7 +204,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
}
for _, ns := range servers {
ips, err = r.resolve(ns, host, timeout)
ips, err = r.resolve(ns.exchanger, host)
if err != nil {
log.Logf("[resolver] %s via %s : %s", host, ns, err)
continue
@ -151,36 +222,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return
}
func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) {
addr := ns.Addr
if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "53")
func (*resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) {
if ex == nil {
return
}
client := dns.Client{
Timeout: timeout,
}
switch strings.ToLower(ns.Protocol) {
case "tcp":
client.Net = "tcp"
case "tls":
cfg := &tls.Config{
ServerName: ns.Hostname,
}
if cfg.ServerName == "" {
cfg.InsecureSkipVerify = true
}
client.Net = "tcp-tls"
client.TLSConfig = cfg
case "udp":
fallthrough
default:
client.Net = "udp"
}
m := dns.Msg{}
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
mr, _, err := client.Exchange(&m, addr)
query := dns.Msg{}
query.SetQuestion(dns.Fqdn(host), dns.TypeA)
mr, err := ex.Exchange(context.Background(), &query)
if err != nil {
return
}
@ -223,7 +272,7 @@ func (r *resolver) Reload(rd io.Reader) error {
var domain string
var nss []NameServer
if r.Stopped() {
if rd == nil || r.Stopped() {
return nil
}
@ -293,7 +342,15 @@ func (r *resolver) Reload(rd io.Reader) error {
ns.Protocol = ss[1]
ns.Hostname = ss[2]
}
nss = append(nss, ns)
ns.Timeout = timeout
if timeout <= 0 {
ns.Timeout = DefaultResolverTimeout
}
if err := ns.Init(); err == nil {
nss = append(nss, ns)
}
}
}
@ -359,3 +416,80 @@ func (r *resolver) String() string {
}
return b.String()
}
// Exchanger is an interface for DNS synchronous query.
type Exchanger interface {
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
}
type dnsExchanger struct {
endpoint string
client *dns.Client
}
func (ex *dnsExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
ep := ex.endpoint
if _, port, _ := net.SplitHostPort(ep); port == "" {
ep = net.JoinHostPort(ep, "53")
}
mr, _, err := ex.client.Exchange(query, ep)
return mr, err
}
type dohExchanger struct {
endpoint *url.URL
client *http.Client
}
// reference: https://github.com/cloudflare/cloudflared/blob/master/tunneldns/https_upstream.go#L54
func (ex *dohExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
queryBuf, err := query.Pack()
if err != nil {
return nil, fmt.Errorf("failed to pack DNS query: %s", err)
}
// No content negotiation for now, use DNS wire format
buf, backendErr := ex.exchangeWireformat(queryBuf)
if backendErr == nil {
response := &dns.Msg{}
if err := response.Unpack(buf); err != nil {
return nil, fmt.Errorf("failed to unpack DNS response from body: %s", err)
}
response.Id = query.Id
return response, nil
}
return nil, backendErr
}
// Perform message exchange with the default UDP wireformat defined in current draft
// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https
func (ex *dohExchanger) exchangeWireformat(msg []byte) ([]byte, error) {
req, err := http.NewRequest("POST", ex.endpoint.String(), bytes.NewBuffer(msg))
if err != nil {
return nil, fmt.Errorf("failed to create an HTTPS request: %s", err)
}
req.Header.Add("Content-Type", "application/dns-udpwireformat")
req.Host = ex.endpoint.Hostname()
resp, err := ex.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform an HTTPS request: %s", err)
}
// Check response status code
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
}
// Read wireformat response from the body
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read the response body: %s", err)
}
return buf, nil
}

221
resolver_test.go Normal file
View File

@ -0,0 +1,221 @@
package gost
import (
"bytes"
"fmt"
"io"
"testing"
"time"
)
var dnsTests = []struct {
ns NameServer
host string
pass bool
}{
{NameServer{Addr: "1.1.1.1"}, "github.com", true},
{NameServer{Addr: "1.1.1.1:53"}, "github.com", true},
{NameServer{Addr: "1.1.1.1:53", Protocol: "tcp"}, "github.com", true},
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls"}, "github.com", true},
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "example.com"}, "github.com", false},
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true},
{NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true},
{NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true},
{NameServer{Addr: "1.1.1.1:12345", Timeout: 1 * time.Second}, "github.com", false},
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp", Timeout: 1 * time.Second}, "github.com", false},
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tls", Timeout: 1 * time.Second}, "github.com", false},
{NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https", Timeout: 1 * time.Second}, "github.com", false},
}
func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error {
ips, err := r.Resolve(host)
t.Log(host, ips, err)
if err != nil {
return err
}
return nil
}
func TestDNSResolver(t *testing.T) {
for i, tc := range dnsTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
ns := tc.ns
if err := ns.Init(); err != nil {
t.Error(err)
}
t.Log(ns)
r := NewResolver(0, 0, ns)
err := dnsResolverRoundtrip(t, r, tc.host)
if err != nil {
if tc.pass {
t.Error("got error:", err)
}
} else {
if !tc.pass {
t.Error("should failed")
}
}
})
}
}
var resolverReloadTests = []struct {
r io.Reader
timeout time.Duration
ttl time.Duration
domain string
period time.Duration
ns *NameServer
stopped bool
}{
{
r: nil,
},
{
r: bytes.NewBufferString(""),
},
{
r: bytes.NewBufferString("reload 10s"),
period: 10 * time.Second,
},
{
r: bytes.NewBufferString("timeout 10s\nreload 10s\n"),
timeout: 10 * time.Second,
period: 10 * time.Second,
},
{
r: bytes.NewBufferString("ttl 10s\ntimeout 10s\nreload 10s\n"),
timeout: 10 * time.Second,
period: 10 * time.Second,
ttl: 10 * time.Second,
},
{
r: bytes.NewBufferString("domain example.com\nttl 10s\ntimeout 10s\nreload 10s\n"),
timeout: 10 * time.Second,
period: 10 * time.Second,
ttl: 10 * time.Second,
domain: "example.com",
},
{
r: bytes.NewBufferString("1.1.1.1"),
ns: &NameServer{
Addr: "1.1.1.1",
Timeout: DefaultResolverTimeout,
},
stopped: true,
},
{
r: bytes.NewBufferString("timeout 10s\nsearch\nnameserver \nnameserver 1.1.1.1 udp"),
ns: &NameServer{
Protocol: "udp",
Addr: "1.1.1.1",
Timeout: 10 * time.Second,
},
timeout: 10 * time.Second,
stopped: true,
},
{
r: bytes.NewBufferString("1.1.1.1 tcp"),
ns: &NameServer{
Addr: "1.1.1.1",
Protocol: "tcp",
Timeout: DefaultResolverTimeout,
},
stopped: true,
},
{
r: bytes.NewBufferString("1.1.1.1:853 tls cloudflare-dns.com"),
ns: &NameServer{
Addr: "1.1.1.1:853",
Protocol: "tls",
Hostname: "cloudflare-dns.com",
Timeout: DefaultResolverTimeout,
},
stopped: true,
},
{
r: bytes.NewBufferString("1.1.1.1:853 tls"),
ns: &NameServer{
Addr: "1.1.1.1:853",
Protocol: "tls",
Timeout: DefaultResolverTimeout,
},
stopped: true,
},
{
r: bytes.NewBufferString("1.0.0.1:53 https"),
stopped: true,
},
{
r: bytes.NewBufferString("https://1.0.0.1/dns-query https"),
ns: &NameServer{
Addr: "https://1.0.0.1/dns-query",
Protocol: "https",
Timeout: DefaultResolverTimeout,
},
stopped: true,
},
}
func TestResolverReload(t *testing.T) {
for i, tc := range resolverReloadTests {
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
r := newResolver(0, 0)
if err := r.Reload(tc.r); err != nil {
t.Error(err)
}
t.Log(r.String())
if r.Timeout != tc.timeout {
t.Errorf("timeout value should be %v, got %v",
tc.timeout, r.Timeout)
}
if r.TTL != tc.ttl {
t.Errorf("ttl value should be %v, got %v",
tc.ttl, r.TTL)
}
if r.Period() != tc.period {
t.Errorf("period value should be %v, got %v",
tc.period, r.period)
}
if r.domain != tc.domain {
t.Errorf("domain value should be %v, got %v",
tc.domain, r.domain)
}
var ns *NameServer
if len(r.Servers) > 0 {
ns = &r.Servers[0]
}
if !compareNameServer(ns, tc.ns) {
t.Errorf("nameserver not equal, should be %v, got %v",
tc.ns, r.Servers)
}
if tc.stopped {
r.Stop()
}
if r.Stopped() != tc.stopped {
t.Errorf("stopped value should be %v, got %v",
tc.stopped, r.Stopped())
}
})
}
}
func compareNameServer(n1, n2 *NameServer) bool {
if n1 == n2 {
return true
}
if n1 == nil || n2 == nil {
return false
}
return n1.Addr == n2.Addr &&
n1.Hostname == n2.Hostname &&
n1.Protocol == n2.Protocol &&
n1.Timeout == n2.Timeout
}