diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 2a8a6a4..36f76c2 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -557,6 +557,8 @@ func (r *route) GenRouters() ([]router, error) { gost.ChainResolverOption(chain), gost.TimeoutResolverOption(timeout), gost.TTLResolverOption(ttl), + gost.PreferResolverOption(node.Get("prefer")), + gost.SrcIPResolverOption(net.ParseIP(node.Get("ip"))), ) } diff --git a/resolver.go b/resolver.go index e7b3f91..618d724 100644 --- a/resolver.go +++ b/resolver.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -122,6 +123,8 @@ type resolverOptions struct { chain *Chain timeout time.Duration ttl time.Duration + prefer string + srcIP net.IP } // ResolverOption allows a common way to set Resolver options. @@ -148,6 +151,20 @@ func TTLResolverOption(ttl time.Duration) ResolverOption { } } +// PreferResolverOption sets the prefer for Resolver. +func PreferResolverOption(prefer string) ResolverOption { + return func(opts *resolverOptions) { + opts.prefer = prefer + } +} + +// SrcIPResolverOption sets the source IP for Resolver. +func SrcIPResolverOption(ip net.IP) ResolverOption { + return func(opts *resolverOptions) { + opts.srcIP = ip + } +} + // Resolver is a name resolver for domain name. // It contains a list of name servers. type Resolver interface { @@ -177,6 +194,7 @@ type resolver struct { stopped chan struct{} mux sync.RWMutex prefer string // ipv4 or ipv6 + srcIP net.IP // for edns0 subnet option options resolverOptions } @@ -217,6 +235,12 @@ func (r *resolver) Init(opts ...ResolverOption) error { if r.options.ttl != 0 { r.ttl = r.options.ttl } + if r.options.prefer != "" { + r.prefer = r.options.prefer + } + if r.options.srcIP != nil { + r.srcIP = r.options.srcIP + } var nss []NameServer for _, ns := range r.servers { @@ -259,8 +283,9 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { host = host + "." + domain } + ctx := context.Background() for _, ns := range r.copyServers() { - ips, err = r.resolve(ns.exchanger, host) + ips, err = r.resolve(ctx, ns.exchanger, host) if err != nil { log.Logf("[resolver] %s via %s : %s", host, ns.String(), err) continue @@ -277,7 +302,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { return } -func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) { +func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { if ex == nil { return } @@ -286,7 +311,6 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) prefer := r.prefer r.mux.RUnlock() - ctx := context.Background() if prefer == "ipv6" { // prefer ipv6 mq := &dns.Msg{} mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) @@ -302,9 +326,15 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) } func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { - mr, _, err := r.exchangeMsg(ctx, ex, mq) - if err != nil { - return + key := newResolverCacheKey(&mq.Question[0]) + mr := r.cache.loadCache(key) + if mr == nil { + r.addSubnetOpt(mq) + mr, err = r.exchangeMsg(ctx, ex, mq) + if err != nil { + return + } + r.cache.storeCache(key, mr, r.TTL()) } for _, ans := range mr.Answer { @@ -319,22 +349,61 @@ func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (i return } +func (r *resolver) addSubnetOpt(m *dns.Msg) { + if m == nil || r.srcIP == nil { + return + } + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + e := new(dns.EDNS0_SUBNET) + e.Code = dns.EDNS0SUBNET + if ip := r.srcIP.To4(); ip != nil { + e.Family = 1 + e.SourceNetmask = 32 + e.Address = ip.To4() + } else { + e.Family = 2 + e.SourceNetmask = 128 + e.Address = r.srcIP + } + opt.Option = append(opt.Option, e) + m.Extra = append(m.Extra, opt) +} + func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) { mq := &dns.Msg{} if err = mq.Unpack(query); err != nil { return } - var qs string - if len(mq.Question) > 0 { - qs = mq.Question[0].String() + if len(mq.Question) == 0 { + return nil, errors.New("empty question") } var mr *dns.Msg + // Only cache for single question. + if len(mq.Question) == 1 { + key := newResolverCacheKey(&mq.Question[0]) + mr = r.cache.loadCache(key) + if mr != nil { + log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) + mr.Id = mq.Id + return mr.Pack() + } + + defer func() { + if mr != nil { + r.cache.storeCache(key, mr, r.TTL()) + } + }() + } + + r.addSubnetOpt(mq) + for _, ns := range r.copyServers() { - var cache bool - mr, cache, err = r.exchangeMsg(ctx, ns.exchanger, mq) - log.Logf("[dns] exchange message %d via %s (cache hit: %v): %s", mq.Id, ns.String(), cache, qs) + log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String()) + mr, err = r.exchangeMsg(ctx, ns.exchanger, mq) if err == nil { break } @@ -346,22 +415,7 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er return mr.Pack() } -func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, cache bool, err error) { - // Only cache for single question. - if len(mq.Question) == 1 { - key := newResolverCacheKey(&mq.Question[0]) - mr = r.cache.loadCache(key) - if mr != nil { - cache = true - mr.Id = mq.Id - return - } - - defer func() { - r.cache.storeCache(key, mr, r.TTL()) - }() - } - +func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { query, err := mq.Pack() if err != nil { return @@ -386,6 +440,7 @@ func (r *resolver) TTL() time.Duration { func (r *resolver) Reload(rd io.Reader) error { var ttl, timeout, period time.Duration var domain, prefer string + var srcIP net.IP var nss []NameServer if rd == nil || r.Stopped() { @@ -422,6 +477,10 @@ func (r *resolver) Reload(rd io.Reader) error { if len(ss) > 1 { prefer = strings.ToLower(ss[1]) } + case "ip": + if len(ss) > 1 { + srcIP = net.ParseIP(ss[1]) + } case "nameserver": // nameserver option, compatible with /etc/resolv.conf if len(ss) <= 1 { break @@ -461,6 +520,7 @@ func (r *resolver) Reload(rd io.Reader) error { r.domain = domain r.period = period r.prefer = prefer + r.srcIP = srcIP r.servers = nss r.mux.Unlock()