From 694c05b50a22ffeb85919a8854f912ad5f24a4e9 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 16 Jan 2020 22:38:19 +0800 Subject: [PATCH] add dot & doh --- cmd/gost/route.go | 5 +- dns.go | 253 +++++++++++++++++++++++++++++++++++++++------- node.go | 2 +- resolver.go | 21 +++- 4 files changed, 235 insertions(+), 46 deletions(-) diff --git a/cmd/gost/route.go b/cmd/gost/route.go index f3809ca..323f807 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -481,7 +481,8 @@ func (r *route) GenRouters() ([]router, error) { ln, err = gost.DNSListener( node.Addr, &gost.DNSOptions{ - TCPMode: node.GetBool("tcp"), + Mode: node.Get("mode"), + TLSConfig: tlsCfg, }, ) default: @@ -525,7 +526,7 @@ func (r *route) GenRouters() ([]router, error) { handler = gost.TunHandler() case "tap": handler = gost.TapHandler() - case "dns", "dot", "doh": + case "dns": handler = gost.DNSHandler(node.Remote) default: // start from 2.5, if remote is not empty, then we assume that it is a forward tunnel. diff --git a/dns.go b/dns.go index ae87b21..36445f5 100644 --- a/dns.go +++ b/dns.go @@ -3,9 +3,15 @@ package gost import ( "bytes" "context" + "crypto/tls" + "encoding/base64" "errors" + "io" + "io/ioutil" "net" + "net/http" "strconv" + "strings" "time" "github.com/go-log/log" @@ -112,15 +118,16 @@ func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string { } type DNSOptions struct { - TCPMode bool + Mode string UDPSize int ReadTimeout time.Duration WriteTimeout time.Duration + TLSConfig *tls.Config } type dnsListener struct { addr net.Addr - server *dns.Server + server dnsServer connChan chan net.Conn errc chan error } @@ -130,31 +137,70 @@ func DNSListener(addr string, options *DNSOptions) (Listener, error) { options = &DNSOptions{} } + tlsConfig := options.TLSConfig + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + ln := &dnsListener{ connChan: make(chan net.Conn, 128), errc: make(chan error, 1), } - var nets string + var srv dnsServer var err error + switch strings.ToLower(options.Mode) { + case "tcp": + srv = &dns.Server{ + Net: "tcp", + Addr: addr, + Handler: ln, + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + } + case "tls": + srv = &dns.Server{ + Net: "tcp-tls", + Addr: addr, + Handler: ln, + TLSConfig: tlsConfig, + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + } + case "https": + srv = &dohServer{ + addr: addr, + tlsConfig: tlsConfig, + server: &http.Server{ + Handler: ln, + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + }, + } - if options.TCPMode { - nets = "tcp" + default: ln.addr, err = net.ResolveTCPAddr("tcp", addr) - } else { - nets = "udp" - ln.addr, err = net.ResolveUDPAddr("udp", addr) + srv = &dns.Server{ + Net: "udp", + Addr: addr, + Handler: ln, + UDPSize: options.UDPSize, + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + } } if err != nil { return nil, err } - ln.server = &dns.Server{ - Addr: addr, - Net: nets, + if ln.addr == nil { + ln.addr, err = net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } } - dns.HandleFunc(".", ln.handleRequest) + ln.server = srv go func() { if err := ln.server.ListenAndServe(); err != nil { @@ -172,30 +218,76 @@ func DNSListener(addr string, options *DNSOptions) (Listener, error) { return ln, nil } -func (l *dnsListener) handleRequest(w dns.ResponseWriter, m *dns.Msg) { - if w == nil || m == nil { - return - } - - conn := &dnsServerConn{ - mq: make(chan []byte, 1), - ResponseWriter: w, - } - - buf := mPool.Get().([]byte) - defer mPool.Put(buf) - buf = buf[:0] - b, err := m.PackBuffer(buf) - if err != nil { - log.Logf("[dns] %s: %v", l.addr, err) - return - } - conn.mq <- b +func (l *dnsListener) serve(w dnsResponseWriter, mq []byte) (err error) { + conn := newDNSServerConn(l.addr, w.RemoteAddr()) + conn.mq <- mq select { case l.connChan <- conn: default: - log.Logf("[dns] %s: connection queue is full", l.addr) + return errors.New("connection queue is full") + } + + select { + case mr := <-conn.mr: + _, err = w.Write(mr) + case <-conn.cclose: + err = io.EOF + } + return +} + +func (l *dnsListener) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { + b, err := m.Pack() + if err != nil { + log.Logf("[dns] %s: %v", l.addr, err) + return + } + if err := l.serve(w, b); err != nil { + log.Logf("[dns] %s: %v", l.addr, err) + } +} + +// Based on https://github.com/semihalev/sdns +func (l *dnsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var buf []byte + var err error + switch r.Method { + case http.MethodGet: + buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns")) + if len(buf) == 0 || err != nil { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + case http.MethodPost: + if r.Header.Get("Content-Type") != "application/dns-message" { + http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType) + return + } + + buf, err = ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + default: + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + + mq := &dns.Msg{} + if err := mq.Unpack(buf); err != nil { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + + w.Header().Set("Server", "SDNS") + w.Header().Set("Content-Type", "application/dns-message") + + raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) + if err := l.serve(newDoHResponseWriter(raddr, w), buf); err != nil { + log.Logf("[dns] %s: %v", l.addr, err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } } @@ -215,21 +307,85 @@ func (l *dnsListener) Addr() net.Addr { return l.addr } +type dnsServer interface { + ListenAndServe() error + Shutdown() error +} + +type dohServer struct { + addr string + tlsConfig *tls.Config + server *http.Server +} + +func (s *dohServer) ListenAndServe() error { + ln, err := net.Listen("tcp", s.addr) + if err != nil { + return err + } + ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, s.tlsConfig) + return s.server.Serve(ln) +} + +func (s *dohServer) Shutdown() error { + return s.server.Shutdown(context.Background()) +} + type dnsServerConn struct { - mq chan []byte - dns.ResponseWriter + mq chan []byte + mr chan []byte + cclose chan struct{} + laddr, raddr net.Addr +} + +func newDNSServerConn(laddr, raddr net.Addr) *dnsServerConn { + return &dnsServerConn{ + mq: make(chan []byte, 1), + mr: make(chan []byte, 1), + laddr: laddr, + raddr: raddr, + cclose: make(chan struct{}), + } } func (c *dnsServerConn) Read(b []byte) (n int, err error) { - var mb []byte select { - case mb = <-c.mq: - default: + case mb := <-c.mq: + n = copy(b, mb) + case <-c.cclose: + err = errors.New("connection is closed") } - n = copy(b, mb) return } +func (c *dnsServerConn) Write(b []byte) (n int, err error) { + select { + case c.mr <- b: + n = len(b) + case <-c.cclose: + err = errors.New("broken pipe") + } + + return +} + +func (c *dnsServerConn) Close() error { + select { + case <-c.cclose: + default: + close(c.cclose) + } + return nil +} + +func (c *dnsServerConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *dnsServerConn) RemoteAddr() net.Addr { + return c.raddr +} + func (c *dnsServerConn) SetDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } @@ -241,3 +397,24 @@ func (c *dnsServerConn) SetReadDeadline(t time.Time) error { func (c *dnsServerConn) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } + +type dnsResponseWriter interface { + io.Writer + RemoteAddr() net.Addr +} + +type dohResponseWriter struct { + raddr net.Addr + http.ResponseWriter +} + +func newDoHResponseWriter(raddr net.Addr, w http.ResponseWriter) dnsResponseWriter { + return &dohResponseWriter{ + raddr: raddr, + ResponseWriter: w, + } +} + +func (w *dohResponseWriter) RemoteAddr() net.Addr { + return w.raddr +} diff --git a/node.go b/node.go index 80a9170..079d661 100644 --- a/node.go +++ b/node.go @@ -84,7 +84,7 @@ func ParseNode(s string) (node Node, err error) { case "ohttp": // obfs-http case "tun", "tap": // tun/tap device case "ftcp": // fake TCP - case "dns", "dot", "doh": + case "dns": default: node.Transport = "tcp" } diff --git a/resolver.go b/resolver.go index c8d3e65..68873b1 100644 --- a/resolver.go +++ b/resolver.go @@ -187,11 +187,16 @@ func (r *resolver) Init(opts ...ResolverOption) error { opt(&r.options) } + timeout := r.timeout + if timeout <= 0 { + timeout = DefaultResolverTimeout + } + var nss []NameServer for _, ns := range r.servers { if err := ns.Init( // init all name servers ChainNameServerOption(r.options.chain), - TimeoutNameServerOption(r.timeout), + TimeoutNameServerOption(timeout), ); err != nil { continue // ignore invalid name servers } @@ -255,8 +260,6 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) prefer := r.prefer r.mux.RUnlock() - prefer = "ipv6" - ctx := context.Background() if prefer == "ipv6" { // prefer ipv6 mq := &dns.Msg{} @@ -614,10 +617,12 @@ func (ex *dnsExchanger) dial(ctx context.Context, network, address string) (conn } func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + t := time.Now() c, err := ex.dial(ctx, "udp", ex.addr) if err != nil { return nil, err } + c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t))) mq := &dns.Msg{} if err = mq.Unpack(query); err != nil { @@ -672,10 +677,12 @@ func (ex *dnsTCPExchanger) dial(ctx context.Context, network, address string) (c } func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + t := time.Now() c, err := ex.dial(ctx, "tcp", ex.addr) if err != nil { return nil, err } + c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t))) conn := &dns.Conn{ Conn: c, @@ -731,17 +738,21 @@ func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn } else { conn, err = ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout)) } - if err == nil { - conn = tls.Client(conn, ex.tlsConfig) + if err != nil { + return } + conn = tls.Client(conn, ex.tlsConfig) + return } func (ex *dotExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + t := time.Now() c, err := ex.dial(ctx, "tcp", ex.addr) if err != nil { return nil, err } + c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t))) conn := &dns.Conn{ Conn: c,