DNS resolver support DoH #335
This commit is contained in:
parent
02f1d099c4
commit
a7d49f0b37
@ -10,6 +10,7 @@ reload 10s
|
||||
# ip[:port] [protocol] [hostname]
|
||||
|
||||
1.1.1.1:853 tls cloudflare-dns.com
|
||||
https://1.0.0.1/dns-query https
|
||||
8.8.8.8
|
||||
8.8.8.8 tcp
|
||||
1.1.1.1 udp
|
||||
|
@ -207,17 +207,34 @@ func parseResolver(cfg string) gost.Resolver {
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(s, "https") {
|
||||
ns := gost.NameServer{
|
||||
Addr: s,
|
||||
Protocol: "https",
|
||||
}
|
||||
if err := ns.Init(); err == nil {
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
ss := strings.Split(s, "/")
|
||||
if len(ss) == 1 {
|
||||
nss = append(nss, gost.NameServer{
|
||||
ns := gost.NameServer{
|
||||
Addr: ss[0],
|
||||
})
|
||||
}
|
||||
if err := ns.Init(); err == nil {
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
}
|
||||
if len(ss) == 2 {
|
||||
nss = append(nss, gost.NameServer{
|
||||
ns := gost.NameServer{
|
||||
Addr: ss[0],
|
||||
Protocol: ss[1],
|
||||
})
|
||||
}
|
||||
if err := ns.Init(); err == nil {
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
return gost.NewResolver(timeout, ttl, nss...)
|
||||
|
242
resolver.go
242
resolver.go
@ -3,23 +3,28 @@ package gost
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-log/log"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultResolverTimeout is the default timeout for name resolution.
|
||||
DefaultResolverTimeout = 30 * time.Second
|
||||
DefaultResolverTimeout = 5 * time.Second
|
||||
// DefaultResolverTTL is the default cache TTL for name resolution.
|
||||
DefaultResolverTTL = 60 * time.Second
|
||||
DefaultResolverTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
// Resolver is a name resolver for domain name.
|
||||
@ -39,9 +44,73 @@ type ReloadResolver interface {
|
||||
// NameServer is a name server.
|
||||
// Currently supported protocol: TCP, UDP and TLS.
|
||||
type NameServer struct {
|
||||
Addr string
|
||||
Protocol string
|
||||
Hostname string // for TLS handshake verification
|
||||
Addr string
|
||||
Protocol string
|
||||
Hostname string // for TLS handshake verification
|
||||
Timeout time.Duration
|
||||
exchanger Exchanger
|
||||
}
|
||||
|
||||
// Init initializes the name server.
|
||||
func (ns *NameServer) Init() error {
|
||||
switch strings.ToLower(ns.Protocol) {
|
||||
case "tcp":
|
||||
ns.exchanger = &dnsExchanger{
|
||||
endpoint: ns.Addr,
|
||||
client: &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: ns.Timeout,
|
||||
},
|
||||
}
|
||||
case "tls":
|
||||
cfg := &tls.Config{
|
||||
ServerName: ns.Hostname,
|
||||
}
|
||||
if cfg.ServerName == "" {
|
||||
cfg.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
ns.exchanger = &dnsExchanger{
|
||||
endpoint: ns.Addr,
|
||||
client: &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
Timeout: ns.Timeout,
|
||||
TLSConfig: cfg,
|
||||
},
|
||||
}
|
||||
case "https":
|
||||
u, err := url.Parse(ns.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg := &tls.Config{ServerName: u.Hostname()}
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: cfg,
|
||||
DisableCompression: true,
|
||||
MaxIdleConns: 1,
|
||||
}
|
||||
http2.ConfigureTransport(transport)
|
||||
|
||||
ns.exchanger = &dohExchanger{
|
||||
endpoint: u,
|
||||
client: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: ns.Timeout,
|
||||
},
|
||||
}
|
||||
case "udp":
|
||||
fallthrough
|
||||
default:
|
||||
ns.exchanger = &dnsExchanger{
|
||||
endpoint: ns.Addr,
|
||||
client: &dns.Client{
|
||||
Net: "udp",
|
||||
Timeout: ns.Timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ns NameServer) String() string {
|
||||
@ -62,26 +131,19 @@ type resolverCacheItem struct {
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
Resolver *net.Resolver
|
||||
Servers []NameServer
|
||||
mCache *sync.Map
|
||||
Timeout time.Duration
|
||||
TTL time.Duration
|
||||
period time.Duration
|
||||
domain string
|
||||
stopped chan struct{}
|
||||
mux sync.RWMutex
|
||||
Servers []NameServer
|
||||
mCache *sync.Map
|
||||
Timeout time.Duration
|
||||
TTL time.Duration
|
||||
period time.Duration
|
||||
domain string
|
||||
stopped chan struct{}
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewResolver create a new Resolver with the given name servers and resolution timeout.
|
||||
func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolver {
|
||||
r := &resolver{
|
||||
Servers: servers,
|
||||
Timeout: timeout,
|
||||
TTL: ttl,
|
||||
mCache: &sync.Map{},
|
||||
stopped: make(chan struct{}),
|
||||
}
|
||||
r := newResolver(timeout, ttl, servers...)
|
||||
|
||||
if r.Timeout <= 0 {
|
||||
r.Timeout = DefaultResolverTimeout
|
||||
@ -92,6 +154,16 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
|
||||
return r
|
||||
}
|
||||
|
||||
func newResolver(timeout, ttl time.Duration, servers ...NameServer) *resolver {
|
||||
return &resolver{
|
||||
Servers: servers,
|
||||
Timeout: timeout,
|
||||
TTL: ttl,
|
||||
mCache: &sync.Map{},
|
||||
stopped: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) copyServers() []NameServer {
|
||||
var servers []NameServer
|
||||
for i := range r.Servers {
|
||||
@ -107,12 +179,11 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
}
|
||||
|
||||
var domain string
|
||||
var timeout, ttl time.Duration
|
||||
var ttl time.Duration
|
||||
var servers []NameServer
|
||||
|
||||
r.mux.RLock()
|
||||
domain = r.domain
|
||||
timeout = r.Timeout
|
||||
ttl = r.TTL
|
||||
servers = r.copyServers()
|
||||
r.mux.RUnlock()
|
||||
@ -133,7 +204,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
}
|
||||
|
||||
for _, ns := range servers {
|
||||
ips, err = r.resolve(ns, host, timeout)
|
||||
ips, err = r.resolve(ns.exchanger, host)
|
||||
if err != nil {
|
||||
log.Logf("[resolver] %s via %s : %s", host, ns, err)
|
||||
continue
|
||||
@ -151,36 +222,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) {
|
||||
addr := ns.Addr
|
||||
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
||||
addr = net.JoinHostPort(addr, "53")
|
||||
func (*resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) {
|
||||
if ex == nil {
|
||||
return
|
||||
}
|
||||
|
||||
client := dns.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
switch strings.ToLower(ns.Protocol) {
|
||||
case "tcp":
|
||||
client.Net = "tcp"
|
||||
case "tls":
|
||||
cfg := &tls.Config{
|
||||
ServerName: ns.Hostname,
|
||||
}
|
||||
if cfg.ServerName == "" {
|
||||
cfg.InsecureSkipVerify = true
|
||||
}
|
||||
client.Net = "tcp-tls"
|
||||
client.TLSConfig = cfg
|
||||
case "udp":
|
||||
fallthrough
|
||||
default:
|
||||
client.Net = "udp"
|
||||
}
|
||||
|
||||
m := dns.Msg{}
|
||||
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||
mr, _, err := client.Exchange(&m, addr)
|
||||
query := dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||
mr, err := ex.Exchange(context.Background(), &query)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -223,7 +272,7 @@ func (r *resolver) Reload(rd io.Reader) error {
|
||||
var domain string
|
||||
var nss []NameServer
|
||||
|
||||
if r.Stopped() {
|
||||
if rd == nil || r.Stopped() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -293,7 +342,15 @@ func (r *resolver) Reload(rd io.Reader) error {
|
||||
ns.Protocol = ss[1]
|
||||
ns.Hostname = ss[2]
|
||||
}
|
||||
nss = append(nss, ns)
|
||||
|
||||
ns.Timeout = timeout
|
||||
if timeout <= 0 {
|
||||
ns.Timeout = DefaultResolverTimeout
|
||||
}
|
||||
|
||||
if err := ns.Init(); err == nil {
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -359,3 +416,80 @@ func (r *resolver) String() string {
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Exchanger is an interface for DNS synchronous query.
|
||||
type Exchanger interface {
|
||||
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
||||
}
|
||||
|
||||
type dnsExchanger struct {
|
||||
endpoint string
|
||||
client *dns.Client
|
||||
}
|
||||
|
||||
func (ex *dnsExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
ep := ex.endpoint
|
||||
if _, port, _ := net.SplitHostPort(ep); port == "" {
|
||||
ep = net.JoinHostPort(ep, "53")
|
||||
}
|
||||
mr, _, err := ex.client.Exchange(query, ep)
|
||||
return mr, err
|
||||
}
|
||||
|
||||
type dohExchanger struct {
|
||||
endpoint *url.URL
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// reference: https://github.com/cloudflare/cloudflared/blob/master/tunneldns/https_upstream.go#L54
|
||||
func (ex *dohExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
queryBuf, err := query.Pack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pack DNS query: %s", err)
|
||||
}
|
||||
|
||||
// No content negotiation for now, use DNS wire format
|
||||
buf, backendErr := ex.exchangeWireformat(queryBuf)
|
||||
if backendErr == nil {
|
||||
response := &dns.Msg{}
|
||||
if err := response.Unpack(buf); err != nil {
|
||||
return nil, fmt.Errorf("failed to unpack DNS response from body: %s", err)
|
||||
}
|
||||
|
||||
response.Id = query.Id
|
||||
return response, nil
|
||||
}
|
||||
|
||||
return nil, backendErr
|
||||
}
|
||||
|
||||
// Perform message exchange with the default UDP wireformat defined in current draft
|
||||
// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https
|
||||
func (ex *dohExchanger) exchangeWireformat(msg []byte) ([]byte, error) {
|
||||
req, err := http.NewRequest("POST", ex.endpoint.String(), bytes.NewBuffer(msg))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create an HTTPS request: %s", err)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/dns-udpwireformat")
|
||||
req.Host = ex.endpoint.Hostname()
|
||||
|
||||
resp, err := ex.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform an HTTPS request: %s", err)
|
||||
}
|
||||
|
||||
// Check response status code
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read wireformat response from the body
|
||||
buf, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read the response body: %s", err)
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
221
resolver_test.go
Normal file
221
resolver_test.go
Normal file
@ -0,0 +1,221 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var dnsTests = []struct {
|
||||
ns NameServer
|
||||
host string
|
||||
pass bool
|
||||
}{
|
||||
{NameServer{Addr: "1.1.1.1"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:53"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:53", Protocol: "tcp"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "example.com"}, "github.com", false},
|
||||
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true},
|
||||
{NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true},
|
||||
{NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true},
|
||||
{NameServer{Addr: "1.1.1.1:12345", Timeout: 1 * time.Second}, "github.com", false},
|
||||
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp", Timeout: 1 * time.Second}, "github.com", false},
|
||||
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tls", Timeout: 1 * time.Second}, "github.com", false},
|
||||
{NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https", Timeout: 1 * time.Second}, "github.com", false},
|
||||
}
|
||||
|
||||
func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error {
|
||||
ips, err := r.Resolve(host)
|
||||
t.Log(host, ips, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDNSResolver(t *testing.T) {
|
||||
for i, tc := range dnsTests {
|
||||
tc := tc
|
||||
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
||||
ns := tc.ns
|
||||
if err := ns.Init(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Log(ns)
|
||||
r := NewResolver(0, 0, ns)
|
||||
err := dnsResolverRoundtrip(t, r, tc.host)
|
||||
if err != nil {
|
||||
if tc.pass {
|
||||
t.Error("got error:", err)
|
||||
}
|
||||
} else {
|
||||
if !tc.pass {
|
||||
t.Error("should failed")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var resolverReloadTests = []struct {
|
||||
r io.Reader
|
||||
|
||||
timeout time.Duration
|
||||
ttl time.Duration
|
||||
domain string
|
||||
period time.Duration
|
||||
ns *NameServer
|
||||
|
||||
stopped bool
|
||||
}{
|
||||
{
|
||||
r: nil,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString(""),
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("reload 10s"),
|
||||
period: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("timeout 10s\nreload 10s\n"),
|
||||
timeout: 10 * time.Second,
|
||||
period: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("ttl 10s\ntimeout 10s\nreload 10s\n"),
|
||||
timeout: 10 * time.Second,
|
||||
period: 10 * time.Second,
|
||||
ttl: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("domain example.com\nttl 10s\ntimeout 10s\nreload 10s\n"),
|
||||
timeout: 10 * time.Second,
|
||||
period: 10 * time.Second,
|
||||
ttl: 10 * time.Second,
|
||||
domain: "example.com",
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.1.1.1"),
|
||||
ns: &NameServer{
|
||||
Addr: "1.1.1.1",
|
||||
Timeout: DefaultResolverTimeout,
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("timeout 10s\nsearch\nnameserver \nnameserver 1.1.1.1 udp"),
|
||||
ns: &NameServer{
|
||||
Protocol: "udp",
|
||||
Addr: "1.1.1.1",
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
timeout: 10 * time.Second,
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.1.1.1 tcp"),
|
||||
ns: &NameServer{
|
||||
Addr: "1.1.1.1",
|
||||
Protocol: "tcp",
|
||||
Timeout: DefaultResolverTimeout,
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.1.1.1:853 tls cloudflare-dns.com"),
|
||||
ns: &NameServer{
|
||||
Addr: "1.1.1.1:853",
|
||||
Protocol: "tls",
|
||||
Hostname: "cloudflare-dns.com",
|
||||
Timeout: DefaultResolverTimeout,
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.1.1.1:853 tls"),
|
||||
ns: &NameServer{
|
||||
Addr: "1.1.1.1:853",
|
||||
Protocol: "tls",
|
||||
Timeout: DefaultResolverTimeout,
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("1.0.0.1:53 https"),
|
||||
stopped: true,
|
||||
},
|
||||
{
|
||||
r: bytes.NewBufferString("https://1.0.0.1/dns-query https"),
|
||||
ns: &NameServer{
|
||||
Addr: "https://1.0.0.1/dns-query",
|
||||
Protocol: "https",
|
||||
Timeout: DefaultResolverTimeout,
|
||||
},
|
||||
stopped: true,
|
||||
},
|
||||
}
|
||||
|
||||
func TestResolverReload(t *testing.T) {
|
||||
for i, tc := range resolverReloadTests {
|
||||
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
||||
r := newResolver(0, 0)
|
||||
if err := r.Reload(tc.r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Log(r.String())
|
||||
if r.Timeout != tc.timeout {
|
||||
t.Errorf("timeout value should be %v, got %v",
|
||||
tc.timeout, r.Timeout)
|
||||
}
|
||||
if r.TTL != tc.ttl {
|
||||
t.Errorf("ttl value should be %v, got %v",
|
||||
tc.ttl, r.TTL)
|
||||
}
|
||||
if r.Period() != tc.period {
|
||||
t.Errorf("period value should be %v, got %v",
|
||||
tc.period, r.period)
|
||||
}
|
||||
if r.domain != tc.domain {
|
||||
t.Errorf("domain value should be %v, got %v",
|
||||
tc.domain, r.domain)
|
||||
}
|
||||
|
||||
var ns *NameServer
|
||||
if len(r.Servers) > 0 {
|
||||
ns = &r.Servers[0]
|
||||
}
|
||||
|
||||
if !compareNameServer(ns, tc.ns) {
|
||||
t.Errorf("nameserver not equal, should be %v, got %v",
|
||||
tc.ns, r.Servers)
|
||||
}
|
||||
|
||||
if tc.stopped {
|
||||
r.Stop()
|
||||
}
|
||||
if r.Stopped() != tc.stopped {
|
||||
t.Errorf("stopped value should be %v, got %v",
|
||||
tc.stopped, r.Stopped())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func compareNameServer(n1, n2 *NameServer) bool {
|
||||
if n1 == n2 {
|
||||
return true
|
||||
}
|
||||
if n1 == nil || n2 == nil {
|
||||
return false
|
||||
}
|
||||
return n1.Addr == n2.Addr &&
|
||||
n1.Hostname == n2.Hostname &&
|
||||
n1.Protocol == n2.Protocol &&
|
||||
n1.Timeout == n2.Timeout
|
||||
}
|
Loading…
Reference in New Issue
Block a user