dns: add edns0 subnet option support
This commit is contained in:
parent
8121e20cbd
commit
cbc9c1f77e
@ -557,6 +557,8 @@ func (r *route) GenRouters() ([]router, error) {
|
||||
gost.ChainResolverOption(chain),
|
||||
gost.TimeoutResolverOption(timeout),
|
||||
gost.TTLResolverOption(ttl),
|
||||
gost.PreferResolverOption(node.Get("prefer")),
|
||||
gost.SrcIPResolverOption(net.ParseIP(node.Get("ip"))),
|
||||
)
|
||||
}
|
||||
|
||||
|
112
resolver.go
112
resolver.go
@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -122,6 +123,8 @@ type resolverOptions struct {
|
||||
chain *Chain
|
||||
timeout time.Duration
|
||||
ttl time.Duration
|
||||
prefer string
|
||||
srcIP net.IP
|
||||
}
|
||||
|
||||
// ResolverOption allows a common way to set Resolver options.
|
||||
@ -148,6 +151,20 @@ func TTLResolverOption(ttl time.Duration) ResolverOption {
|
||||
}
|
||||
}
|
||||
|
||||
// PreferResolverOption sets the prefer for Resolver.
|
||||
func PreferResolverOption(prefer string) ResolverOption {
|
||||
return func(opts *resolverOptions) {
|
||||
opts.prefer = prefer
|
||||
}
|
||||
}
|
||||
|
||||
// SrcIPResolverOption sets the source IP for Resolver.
|
||||
func SrcIPResolverOption(ip net.IP) ResolverOption {
|
||||
return func(opts *resolverOptions) {
|
||||
opts.srcIP = ip
|
||||
}
|
||||
}
|
||||
|
||||
// Resolver is a name resolver for domain name.
|
||||
// It contains a list of name servers.
|
||||
type Resolver interface {
|
||||
@ -177,6 +194,7 @@ type resolver struct {
|
||||
stopped chan struct{}
|
||||
mux sync.RWMutex
|
||||
prefer string // ipv4 or ipv6
|
||||
srcIP net.IP // for edns0 subnet option
|
||||
options resolverOptions
|
||||
}
|
||||
|
||||
@ -217,6 +235,12 @@ func (r *resolver) Init(opts ...ResolverOption) error {
|
||||
if r.options.ttl != 0 {
|
||||
r.ttl = r.options.ttl
|
||||
}
|
||||
if r.options.prefer != "" {
|
||||
r.prefer = r.options.prefer
|
||||
}
|
||||
if r.options.srcIP != nil {
|
||||
r.srcIP = r.options.srcIP
|
||||
}
|
||||
|
||||
var nss []NameServer
|
||||
for _, ns := range r.servers {
|
||||
@ -259,8 +283,9 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
host = host + "." + domain
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, ns := range r.copyServers() {
|
||||
ips, err = r.resolve(ns.exchanger, host)
|
||||
ips, err = r.resolve(ctx, ns.exchanger, host)
|
||||
if err != nil {
|
||||
log.Logf("[resolver] %s via %s : %s", host, ns.String(), err)
|
||||
continue
|
||||
@ -277,7 +302,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) {
|
||||
func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
|
||||
if ex == nil {
|
||||
return
|
||||
}
|
||||
@ -286,7 +311,6 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error)
|
||||
prefer := r.prefer
|
||||
r.mux.RUnlock()
|
||||
|
||||
ctx := context.Background()
|
||||
if prefer == "ipv6" { // prefer ipv6
|
||||
mq := &dns.Msg{}
|
||||
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
|
||||
@ -302,10 +326,16 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error)
|
||||
}
|
||||
|
||||
func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
|
||||
mr, _, err := r.exchangeMsg(ctx, ex, mq)
|
||||
key := newResolverCacheKey(&mq.Question[0])
|
||||
mr := r.cache.loadCache(key)
|
||||
if mr == nil {
|
||||
r.addSubnetOpt(mq)
|
||||
mr, err = r.exchangeMsg(ctx, ex, mq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r.cache.storeCache(key, mr, r.TTL())
|
||||
}
|
||||
|
||||
for _, ans := range mr.Answer {
|
||||
if ar, _ := ans.(*dns.AAAA); ar != nil {
|
||||
@ -319,22 +349,61 @@ func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (i
|
||||
return
|
||||
}
|
||||
|
||||
func (r *resolver) addSubnetOpt(m *dns.Msg) {
|
||||
if m == nil || r.srcIP == nil {
|
||||
return
|
||||
}
|
||||
opt := new(dns.OPT)
|
||||
opt.Hdr.Name = "."
|
||||
opt.Hdr.Rrtype = dns.TypeOPT
|
||||
e := new(dns.EDNS0_SUBNET)
|
||||
e.Code = dns.EDNS0SUBNET
|
||||
if ip := r.srcIP.To4(); ip != nil {
|
||||
e.Family = 1
|
||||
e.SourceNetmask = 32
|
||||
e.Address = ip.To4()
|
||||
} else {
|
||||
e.Family = 2
|
||||
e.SourceNetmask = 128
|
||||
e.Address = r.srcIP
|
||||
}
|
||||
opt.Option = append(opt.Option, e)
|
||||
m.Extra = append(m.Extra, opt)
|
||||
}
|
||||
|
||||
func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
mq := &dns.Msg{}
|
||||
if err = mq.Unpack(query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var qs string
|
||||
if len(mq.Question) > 0 {
|
||||
qs = mq.Question[0].String()
|
||||
if len(mq.Question) == 0 {
|
||||
return nil, errors.New("empty question")
|
||||
}
|
||||
|
||||
var mr *dns.Msg
|
||||
// Only cache for single question.
|
||||
if len(mq.Question) == 1 {
|
||||
key := newResolverCacheKey(&mq.Question[0])
|
||||
mr = r.cache.loadCache(key)
|
||||
if mr != nil {
|
||||
log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||
mr.Id = mq.Id
|
||||
return mr.Pack()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if mr != nil {
|
||||
r.cache.storeCache(key, mr, r.TTL())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
r.addSubnetOpt(mq)
|
||||
|
||||
for _, ns := range r.copyServers() {
|
||||
var cache bool
|
||||
mr, cache, err = r.exchangeMsg(ctx, ns.exchanger, mq)
|
||||
log.Logf("[dns] exchange message %d via %s (cache hit: %v): %s", mq.Id, ns.String(), cache, qs)
|
||||
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String())
|
||||
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
@ -346,22 +415,7 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er
|
||||
return mr.Pack()
|
||||
}
|
||||
|
||||
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, cache bool, err error) {
|
||||
// Only cache for single question.
|
||||
if len(mq.Question) == 1 {
|
||||
key := newResolverCacheKey(&mq.Question[0])
|
||||
mr = r.cache.loadCache(key)
|
||||
if mr != nil {
|
||||
cache = true
|
||||
mr.Id = mq.Id
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
r.cache.storeCache(key, mr, r.TTL())
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
|
||||
query, err := mq.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
@ -386,6 +440,7 @@ func (r *resolver) TTL() time.Duration {
|
||||
func (r *resolver) Reload(rd io.Reader) error {
|
||||
var ttl, timeout, period time.Duration
|
||||
var domain, prefer string
|
||||
var srcIP net.IP
|
||||
var nss []NameServer
|
||||
|
||||
if rd == nil || r.Stopped() {
|
||||
@ -422,6 +477,10 @@ func (r *resolver) Reload(rd io.Reader) error {
|
||||
if len(ss) > 1 {
|
||||
prefer = strings.ToLower(ss[1])
|
||||
}
|
||||
case "ip":
|
||||
if len(ss) > 1 {
|
||||
srcIP = net.ParseIP(ss[1])
|
||||
}
|
||||
case "nameserver": // nameserver option, compatible with /etc/resolv.conf
|
||||
if len(ss) <= 1 {
|
||||
break
|
||||
@ -461,6 +520,7 @@ func (r *resolver) Reload(rd io.Reader) error {
|
||||
r.domain = domain
|
||||
r.period = period
|
||||
r.prefer = prefer
|
||||
r.srcIP = srcIP
|
||||
r.servers = nss
|
||||
r.mux.Unlock()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user