diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index ad44de0..2897f06 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -223,9 +223,7 @@ func parseResolver(cfg string) gost.Resolver { Addr: s, Protocol: "https", } - if err := ns.Init(); err == nil { - nss = append(nss, ns) - } + nss = append(nss, ns) continue } @@ -234,18 +232,14 @@ func parseResolver(cfg string) gost.Resolver { ns := gost.NameServer{ Addr: ss[0], } - if err := ns.Init(); err == nil { - nss = append(nss, ns) - } + nss = append(nss, ns) } if len(ss) == 2 { ns := gost.NameServer{ Addr: ss[0], Protocol: ss[1], } - if err := ns.Init(); err == nil { - nss = append(nss, ns) - } + nss = append(nss, ns) } } return gost.NewResolver(0, nss...) diff --git a/cmd/gost/route.go b/cmd/gost/route.go index ff115a8..f3809ca 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -477,6 +477,13 @@ func (r *route) GenRouters() ([]router, error) { QueueSize: node.GetInt("queue"), }, ) + case "dns": + ln, err = gost.DNSListener( + node.Addr, + &gost.DNSOptions{ + TCPMode: node.GetBool("tcp"), + }, + ) default: ln, err = gost.TCPListener(node.Addr) } @@ -518,6 +525,8 @@ func (r *route) GenRouters() ([]router, error) { handler = gost.TunHandler() case "tap": handler = gost.TapHandler() + case "dns", "dot", "doh": + handler = gost.DNSHandler(node.Remote) default: // start from 2.5, if remote is not empty, then we assume that it is a forward tunnel. if node.Remote != "" { @@ -540,10 +549,14 @@ func (r *route) GenRouters() ([]router, error) { } node.Bypass = parseBypass(node.Get("bypass")) - resolver := parseResolver(node.Get("dns")) hosts := parseHosts(node.Get("hosts")) ips := parseIP(node.Get("ip"), "") + resolver := parseResolver(node.Get("dns")) + if resolver != nil { + resolver.Init(gost.ChainResolverOption(chain)) + } + handler.Init( gost.AddrHandlerOption(ln.Addr().String()), gost.ChainHandlerOption(chain), diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..fc29dc6 --- /dev/null +++ b/dns.go @@ -0,0 +1,224 @@ +package gost + +import ( + "bytes" + "context" + "errors" + "net" + "strconv" + "time" + + "github.com/go-log/log" + "github.com/miekg/dns" +) + +type dnsHandler struct { + options *HandlerOptions +} + +// DNSHandler creates a Handler for DNS server. +func DNSHandler(raddr string, opts ...HandlerOption) Handler { + h := &dnsHandler{} + + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *dnsHandler) Init(opts ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range opts { + opt(h.options) + } +} + +func (h *dnsHandler) Handle(conn net.Conn) { + defer conn.Close() + + b := mPool.Get().([]byte) + defer mPool.Put(b) + + n, err := conn.Read(b) + if err != nil { + log.Logf("[dns] %s - %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + } + + mq := &dns.Msg{} + if err = mq.Unpack(b[:n]); err != nil { + log.Logf("[dns] %s - %s request unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + log.Logf("[dns] %s -> %s: %s", conn.RemoteAddr(), conn.LocalAddr(), h.dumpMsgHeader(mq)) + if Debug { + log.Logf("[dns] %s >>> %s: %s", conn.RemoteAddr(), conn.LocalAddr(), mq.String()) + } + + start := time.Now() + reply, err := h.options.Resolver.Exchange(context.Background(), b[:n]) + if err != nil { + log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + + rtt := time.Since(start) + + mr := &dns.Msg{} + if err = mr.Unpack(reply); err != nil { + log.Logf("[dns] %s - %s reply unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + log.Logf("[dns] %s <- %s: %s [%s]", + conn.RemoteAddr(), conn.LocalAddr(), h.dumpMsgHeader(mr), rtt) + if Debug { + log.Logf("[dns] %s <<< %s: %s", conn.RemoteAddr(), conn.LocalAddr(), mr.String()) + } + + if _, err = conn.Write(reply); err != nil { + log.Logf("[dns] %s - %s reply unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + } +} + +func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string { + buf := new(bytes.Buffer) + buf.WriteString(m.MsgHdr.String() + " ") + buf.WriteString("QUERY: " + strconv.Itoa(len(m.Question)) + ", ") + buf.WriteString("ANSWER: " + strconv.Itoa(len(m.Answer)) + ", ") + buf.WriteString("AUTHORITY: " + strconv.Itoa(len(m.Ns)) + ", ") + buf.WriteString("ADDITIONAL: " + strconv.Itoa(len(m.Extra))) + return buf.String() +} + +type DNSOptions struct { + TCPMode bool + UDPSize int + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +type dnsListener struct { + addr net.Addr + server *dns.Server + connChan chan net.Conn + errc chan error +} + +func DNSListener(addr string, options *DNSOptions) (Listener, error) { + if options == nil { + options = &DNSOptions{} + } + + ln := &dnsListener{ + connChan: make(chan net.Conn, 128), + errc: make(chan error, 1), + } + + var nets string + var err error + + if options.TCPMode { + nets = "tcp" + ln.addr, err = net.ResolveTCPAddr("tcp", addr) + } else { + nets = "udp" + ln.addr, err = net.ResolveUDPAddr("udp", addr) + } + if err != nil { + return nil, err + } + + ln.server = &dns.Server{ + Addr: addr, + Net: nets, + } + + dns.HandleFunc(".", ln.handleRequest) + + go func() { + if err := ln.server.ListenAndServe(); err != nil { + ln.errc <- err + return + } + }() + + select { + case err := <-ln.errc: + return nil, err + default: + } + + return ln, nil +} + +func (l *dnsListener) handleRequest(w dns.ResponseWriter, m *dns.Msg) { + if w == nil || m == nil { + return + } + + conn := &dnsServerConn{ + mq: make(chan []byte, 1), + ResponseWriter: w, + } + + buf := mPool.Get().([]byte) + defer mPool.Put(buf) + buf = buf[:0] + b, err := m.PackBuffer(buf) + if err != nil { + log.Logf("[dns] %s: %v", l.addr, err) + return + } + conn.mq <- b + + select { + case l.connChan <- conn: + default: + log.Logf("[dns] %s: connection queue is full", l.addr) + } +} + +func (l *dnsListener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errc: + } + return +} + +func (l *dnsListener) Close() error { + return l.server.Shutdown() +} + +func (l *dnsListener) Addr() net.Addr { + return l.addr +} + +type dnsServerConn struct { + mq chan []byte + dns.ResponseWriter +} + +func (c *dnsServerConn) Read(b []byte) (n int, err error) { + var mb []byte + select { + case mb = <-c.mq: + default: + } + n = copy(b, mb) + return +} + +func (c *dnsServerConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *dnsServerConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *dnsServerConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/go.mod b/go.mod index 27a189d..199366f 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/lucas-clemente/aes12 v0.0.0-20171027163421-cd47fb39b79f // indirect github.com/lucas-clemente/quic-go v0.10.0 github.com/lucas-clemente/quic-go-certificates v0.0.0-20160823095156-d2f86524cced // indirect - github.com/miekg/dns v1.1.3 + github.com/miekg/dns v1.1.27 github.com/milosgajdos83/tenus v0.0.0-20190415114537-1f3ed00ae7d8 github.com/onsi/ginkgo v1.7.0 // indirect github.com/onsi/gomega v1.4.3 // indirect @@ -40,9 +40,8 @@ require ( github.com/templexxx/xor v0.0.0-20181023030647-4e92f724b73b // indirect github.com/tjfoc/gmsm v1.0.1 // indirect github.com/xtaci/tcpraw v1.2.25 - golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 - golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 - golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect + golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 + golang.org/x/net v0.0.0-20190923162816-aa69164e4478 gopkg.in/gorilla/websocket.v1 v1.4.0 gopkg.in/xtaci/kcp-go.v4 v4.3.2 gopkg.in/xtaci/smux.v1 v1.0.7 diff --git a/go.sum b/go.sum index a548b91..a62f2a4 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/lucas-clemente/quic-go-certificates v0.0.0-20160823095156-d2f86524cce github.com/lucas-clemente/quic-go-certificates v0.0.0-20160823095156-d2f86524cced/go.mod h1:NCcRLrOTZbzhZvixZLlERbJtDtYsmMw8Jc4vS8Z0g58= github.com/miekg/dns v1.1.3 h1:1g0r1IvskvgL8rR+AcHzUA+oFmGcQlaIm4IqakufeMM= github.com/miekg/dns v1.1.3/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM= +github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= github.com/milosgajdos83/tenus v0.0.0-20190415114537-1f3ed00ae7d8 h1:4WFQEfEJ7zaHYViIVM2Cd6tnQOOhiEHbmQtlcV7aOpc= github.com/milosgajdos83/tenus v0.0.0-20190415114537-1f3ed00ae7d8/go.mod h1:G95Wwn625/q6JCCytI4VR/a5VtPwrtI0B+Q1Gi38QLA= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -88,23 +90,35 @@ golang.org/x/crypto v0.0.0-20190130090550-b01c7a725664 h1:YbZJ76lQ1BqNhVe7dKTSB6 golang.org/x/crypto v0.0.0-20190130090550-b01c7a725664/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3 h1:ulvT7fqt0yHWzpJwI57MezWnYDVpCAYBVuYst/L+fAY= golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e h1:ZytStCyV048ZqDsWHiYDdoI2Vd4msMcrDECFxS+tL9c= golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67 h1:1Fzlr8kkDLQwqMP8GxrhptBLqZG/EDpiATneiZHY998= golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe h1:6fAMxZRR6sl1Uq8U61gxU+kPTs2tR8uOySCbBP7BN/M= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= diff --git a/node.go b/node.go index 3305667..80a9170 100644 --- a/node.go +++ b/node.go @@ -84,6 +84,7 @@ func ParseNode(s string) (node Node, err error) { case "ohttp": // obfs-http case "tun", "tap": // tun/tap device case "ftcp": // fake TCP + case "dns", "dot", "doh": default: node.Transport = "tcp" } @@ -97,6 +98,7 @@ func ParseNode(s string) (node Node, err error) { case "redirect": // TCP transparent proxy case "tun", "tap": // tun/tap device case "ftcp": // fake TCP + case "dns", "dot", "doh": default: node.Protocol = "" } diff --git a/resolver.go b/resolver.go index 2041689..73f6951 100644 --- a/resolver.go +++ b/resolver.go @@ -17,7 +17,6 @@ import ( "github.com/go-log/log" "github.com/miekg/dns" - "golang.org/x/net/http2" ) var ( @@ -25,18 +24,23 @@ var ( DefaultResolverTimeout = 5 * time.Second ) -// Resolver is a name resolver for domain name. -// It contains a list of name servers. -type Resolver interface { - // Resolve returns a slice of that host's IPv4 and IPv6 addresses. - Resolve(host string) ([]net.IP, error) +type nameServerOptions struct { + timeout time.Duration + chain *Chain } -// ReloadResolver is resolover that support live reloading. -type ReloadResolver interface { - Resolver - Reloader - Stoppable +type NameServerOption func(*nameServerOptions) + +func TimeoutNameServerOption(timeout time.Duration) NameServerOption { + return func(opts *nameServerOptions) { + opts.timeout = timeout + } +} + +func ChainNameServerOption(chain *Chain) NameServerOption { + return func(opts *nameServerOptions) { + opts.chain = chain + } } // NameServer is a name server. @@ -45,26 +49,23 @@ type NameServer struct { Addr string Protocol string Hostname string // for TLS handshake verification - Timeout time.Duration exchanger Exchanger + options nameServerOptions } // Init initializes the name server. -func (ns *NameServer) Init() error { - timeout := ns.Timeout - if timeout <= 0 { - timeout = DefaultResolverTimeout +func (ns *NameServer) Init(opts ...NameServerOption) error { + for _, opt := range opts { + opt(&ns.options) } switch strings.ToLower(ns.Protocol) { case "tcp": - ns.exchanger = &dnsExchanger{ - endpoint: ns.Addr, - client: &dns.Client{ - Net: "tcp", - Timeout: timeout, - }, - } + ns.exchanger = NewDNSTCPExchanger( + ns.Addr, + TimeoutExchangerOption(ns.options.timeout), + ChainExchangerOption(ns.options.chain), + ) case "tls": cfg := &tls.Config{ ServerName: ns.Hostname, @@ -72,51 +73,39 @@ func (ns *NameServer) Init() error { if cfg.ServerName == "" { cfg.InsecureSkipVerify = true } - - ns.exchanger = &dnsExchanger{ - endpoint: ns.Addr, - client: &dns.Client{ - Net: "tcp-tls", - Timeout: timeout, - TLSConfig: cfg, - }, - } + ns.exchanger = NewDoTExchanger( + ns.Addr, cfg, + TimeoutExchangerOption(ns.options.timeout), + ChainExchangerOption(ns.options.chain), + ) case "https": u, err := url.Parse(ns.Addr) if err != nil { return err } cfg := &tls.Config{ServerName: u.Hostname()} - transport := &http.Transport{ - TLSClientConfig: cfg, - DisableCompression: true, - MaxIdleConns: 1, - } - http2.ConfigureTransport(transport) - - ns.exchanger = &dohExchanger{ - endpoint: u, - client: &http.Client{ - Transport: transport, - Timeout: timeout, - }, + if cfg.ServerName == "" { + cfg.InsecureSkipVerify = true } + ns.exchanger = NewDoHExchanger( + u, cfg, + TimeoutExchangerOption(ns.options.timeout), + ChainExchangerOption(ns.options.chain), + ) case "udp": fallthrough default: - ns.exchanger = &dnsExchanger{ - endpoint: ns.Addr, - client: &dns.Client{ - Net: "udp", - Timeout: timeout, - }, - } + ns.exchanger = NewDNSExchanger( + ns.Addr, + TimeoutExchangerOption(ns.options.timeout), + ChainExchangerOption(ns.options.chain), + ) } return nil } -func (ns NameServer) String() string { +func (ns *NameServer) String() string { addr := ns.Addr prot := ns.Protocol if _, port, _ := net.SplitHostPort(addr); port == "" { @@ -128,15 +117,48 @@ func (ns NameServer) String() string { return fmt.Sprintf("%s/%s", addr, prot) } +type resolverOptions struct { + chain *Chain +} + +type ResolverOption func(*resolverOptions) + +func ChainResolverOption(chain *Chain) ResolverOption { + return func(opts *resolverOptions) { + opts.chain = chain + } +} + +// Resolver is a name resolver for domain name. +// It contains a list of name servers. +type Resolver interface { + // Init initializes the Resolver instance. + Init(opts ...ResolverOption) error + // Resolve returns a slice of that host's IPv4 and IPv6 addresses. + Resolve(host string) ([]net.IP, error) + // Exchange performs a synchronous query, + // It sends the message query and waits for a reply. + Exchange(ctx context.Context, query []byte) (reply []byte, err error) +} + +// ReloadResolver is resolover that support live reloading. +type ReloadResolver interface { + Resolver + Reloader + Stoppable +} + type resolver struct { Servers []NameServer mCache *sync.Map TTL time.Duration + timeout time.Duration period time.Duration domain string stopped chan struct{} mux sync.RWMutex prefer string // ipv4 or ipv6 + options resolverOptions } // NewResolver create a new Resolver with the given name servers and resolution timeout. @@ -154,6 +176,34 @@ func newResolver(ttl time.Duration, servers ...NameServer) *resolver { } } +func (r *resolver) Init(opts ...ResolverOption) error { + if r == nil { + return nil + } + + r.mux.Lock() + defer r.mux.Unlock() + + for _, opt := range opts { + opt(&r.options) + } + + var nss []NameServer + for _, ns := range r.Servers { + if err := ns.Init( // init all name servers + ChainNameServerOption(r.options.chain), + TimeoutNameServerOption(r.timeout), + ); err != nil { + continue // ignore invalid name servers + } + nss = append(nss, ns) + } + + r.Servers = nss + + return nil +} + func (r *resolver) copyServers() []NameServer { var servers []NameServer for i := range r.Servers { @@ -196,12 +246,12 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { for _, ns := range servers { ips, ttl, err = r.resolve(ns.exchanger, host) if err != nil { - log.Logf("[resolver] %s via %s : %s", host, ns, err) + log.Logf("[resolver] %s via %s : %s", host, ns.String(), err) continue } if Debug { - log.Logf("[resolver] %s via %s %v(ttl: %v)", host, ns, ips, ttl) + log.Logf("[resolver] %s via %s %v(ttl: %v)", host, ns.String(), ips, ttl) } if len(ips) > 0 { break @@ -233,10 +283,24 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, ttl time.Du } func (*resolver) resolveIPs(ctx context.Context, ex Exchanger, query *dns.Msg) (ips []net.IP, ttl time.Duration, err error) { - mr, err := ex.Exchange(ctx, query) + // buf := mPool.Get().([]byte) + // defer mPool.Put(buf) + + // buf = buf[:0] + // mq, err := query.PackBuffer(buf) + mq, err := query.Pack() if err != nil { return } + reply, err := ex.Exchange(ctx, mq) + if err != nil { + return + } + mr := &dns.Msg{} + if err = mr.Unpack(reply); err != nil { + return + } + for _, ans := range mr.Answer { if ar, _ := ans.(*dns.AAAA); ar != nil { ips = append(ips, ar.AAAA) @@ -250,6 +314,25 @@ func (*resolver) resolveIPs(ctx context.Context, ex Exchanger, query *dns.Msg) ( return } +func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) { + if r == nil { + return + } + + var servers []NameServer + r.mux.RLock() + servers = r.copyServers() + r.mux.RUnlock() + + for _, ns := range servers { + reply, err = ns.exchanger.Exchange(ctx, query) + if err == nil { + return + } + } + return +} + type resolverCacheItem struct { IPs []net.IP ts int64 @@ -352,11 +435,7 @@ func (r *resolver) Reload(rd io.Reader) error { if strings.HasPrefix(ns.Addr, "https") { ns.Protocol = "https" } - ns.Timeout = timeout - - if err := ns.Init(); err == nil { - nss = append(nss, ns) - } + nss = append(nss, ns) } } @@ -366,12 +445,15 @@ func (r *resolver) Reload(rd io.Reader) error { r.mux.Lock() r.TTL = ttl + r.timeout = timeout r.domain = domain r.period = period r.prefer = prefer r.Servers = nss r.mux.Unlock() + r.Init() + return nil } @@ -425,62 +507,270 @@ func (r *resolver) String() string { // Exchanger is an interface for DNS synchronous query. type Exchanger interface { - Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) + Exchange(ctx context.Context, query []byte) ([]byte, error) +} + +type exchangerOptions struct { + chain *Chain + timeout time.Duration +} + +type ExchangerOption func(opts *exchangerOptions) + +func ChainExchangerOption(chain *Chain) ExchangerOption { + return func(opts *exchangerOptions) { + opts.chain = chain + } +} + +func TimeoutExchangerOption(timeout time.Duration) ExchangerOption { + return func(opts *exchangerOptions) { + opts.timeout = timeout + } } type dnsExchanger struct { - endpoint string - client *dns.Client + addr string + options exchangerOptions } -func (ex *dnsExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - ep := ex.endpoint - if _, port, _ := net.SplitHostPort(ep); port == "" { - ep = net.JoinHostPort(ep, "53") +func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger { + var options exchangerOptions + for _, opt := range opts { + opt(&options) } - mr, _, err := ex.client.Exchange(query, ep) - return mr, err + + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "53") + } + + return &dnsExchanger{ + addr: addr, + options: options, + } +} + +func (ex *dnsExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { + if ex.options.chain.IsEmpty() { + d := &net.Dialer{ + Timeout: ex.options.timeout, + } + return d.DialContext(ctx, network, address) + } + + raddr, err := net.ResolveUDPAddr(network, address) + if err != nil { + return + } + cc, err := getSOCKS5UDPTunnel(ex.options.chain, nil) + conn = &udpTunnelConn{Conn: cc, raddr: raddr} + return +} + +func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + c, err := ex.dial(ctx, "udp", ex.addr) + if err != nil { + return nil, err + } + + mq := &dns.Msg{} + if err = mq.Unpack(query); err != nil { + return nil, err + } + + conn := &dns.Conn{ + Conn: c, + } + + if err = conn.WriteMsg(mq); err != nil { + return nil, err + } + + mr, err := conn.ReadMsg() + if err != nil { + return nil, err + } + + return mr.Pack() +} + +type dnsTCPExchanger struct { + addr string + options exchangerOptions +} + +func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger { + var options exchangerOptions + for _, opt := range opts { + opt(&options) + } + + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "53") + } + + return &dnsTCPExchanger{ + addr: addr, + options: options, + } +} + +func (ex *dnsTCPExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { + if ex.options.chain.IsEmpty() { + d := &net.Dialer{ + Timeout: ex.options.timeout, + } + return d.DialContext(ctx, network, address) + } + return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout)) +} + +func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + c, err := ex.dial(ctx, "tcp", ex.addr) + if err != nil { + return nil, err + } + + conn := &dns.Conn{ + Conn: c, + } + + if _, err = conn.Write(query); err != nil { + return nil, err + } + + mr, err := conn.ReadMsg() + if err != nil { + log.Log("[dns] exchange", err) + return nil, err + } + + return mr.Pack() +} + +type dotExchanger struct { + addr string + tlsConfig *tls.Config + options exchangerOptions +} + +func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger { + var options exchangerOptions + for _, opt := range opts { + opt(&options) + } + + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "53") + } + + if tlsConfig == nil { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + return &dotExchanger{ + addr: addr, + tlsConfig: tlsConfig, + options: options, + } +} + +func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { + if ex.options.chain.IsEmpty() { + d := &net.Dialer{ + Timeout: ex.options.timeout, + } + conn, err = d.DialContext(ctx, network, address) + } else { + conn, err = ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout)) + } + if err == nil { + conn = tls.Client(conn, ex.tlsConfig) + } + return +} + +func (ex *dotExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + c, err := ex.dial(ctx, "tcp", ex.addr) + if err != nil { + return nil, err + } + + conn := &dns.Conn{ + Conn: c, + } + + if _, err = conn.Write(query); err != nil { + return nil, err + } + + mr, err := conn.ReadMsg() + if err != nil { + return nil, err + } + + return mr.Pack() } type dohExchanger struct { endpoint *url.URL client *http.Client + options exchangerOptions } -// reference: https://github.com/cloudflare/cloudflared/blob/master/tunneldns/https_upstream.go#L54 -func (ex *dohExchanger) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - queryBuf, err := query.Pack() - if err != nil { - return nil, fmt.Errorf("failed to pack DNS query: %s", err) +func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger { + var options exchangerOptions + for _, opt := range opts { + opt(&options) + } + ex := &dohExchanger{ + endpoint: urlStr, + options: options, } - // No content negotiation for now, use DNS wire format - buf, backendErr := ex.exchangeWireformat(queryBuf) - if backendErr == nil { - response := &dns.Msg{} - if err := response.Unpack(buf); err != nil { - return nil, fmt.Errorf("failed to unpack DNS response from body: %s", err) + ex.client = &http.Client{ + Timeout: options.timeout, + Transport: &http.Transport{ + // Proxy: ProxyFromEnvironment, + TLSClientConfig: tlsConfig, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: options.timeout, + ExpectContinueTimeout: 1 * time.Second, + DialContext: ex.dialContext, + }, + } + + return ex +} + +func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (conn net.Conn, err error) { + if ex.options.chain.IsEmpty() { + d := &net.Dialer{ + Timeout: ex.options.timeout, } - - response.Id = query.Id - return response, nil + return d.DialContext(ctx, network, address) } - - return nil, backendErr + return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout)) } -// Perform message exchange with the default UDP wireformat defined in current draft -// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https -func (ex *dohExchanger) exchangeWireformat(msg []byte) ([]byte, error) { - req, err := http.NewRequest("POST", ex.endpoint.String(), bytes.NewBuffer(msg)) +func (ex *dohExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, "POST", ex.endpoint.String(), bytes.NewBuffer(query)) if err != nil { return nil, fmt.Errorf("failed to create an HTTPS request: %s", err) } - req.Header.Add("Content-Type", "application/dns-udpwireformat") + // req.Header.Add("Content-Type", "application/dns-udpwireformat") + req.Header.Add("Content-Type", "application/dns-message") req.Host = ex.endpoint.Hostname() - resp, err := ex.client.Do(req) + client := ex.client + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to perform an HTTPS request: %s", err) }