From c1f4325b195320ee54f12a71f4089ccc59e900eb Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 8 Jul 2018 10:41:56 +0800 Subject: [PATCH] add ChainOptions for Chain.Dial --- chain.go | 66 ++++++++++++++++++++++++++---------------- cmd/gost/main.go | 63 +++++++++++++++++++---------------------- forward.go | 5 +++- handler.go | 35 ++++++++++++++++++++++- http.go | 74 ++++++++++++++++++++++++++++++------------------ http2.go | 7 ++++- redirect.go | 5 +++- resolver.go | 6 ---- sni.go | 8 +++++- socks.go | 12 ++++++-- ss.go | 7 ++++- ssh.go | 12 +++++++- 12 files changed, 200 insertions(+), 100 deletions(-) diff --git a/chain.go b/chain.go index fd7bf45..d9d6682 100644 --- a/chain.go +++ b/chain.go @@ -19,8 +19,6 @@ var ( type Chain struct { isRoute bool Retries int - Hosts *Hosts - Resolver Resolver nodeGroups []*NodeGroup } @@ -102,17 +100,22 @@ func (c *Chain) IsEmpty() bool { // Dial connects to the target address addr through the chain. // If the chain is empty, it will use the net.Dial directly. -func (c *Chain) Dial(addr string) (conn net.Conn, err error) { - var retries int - if c != nil { +func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error) { + options := &ChainOptions{} + for _, opt := range opts { + opt(options) + } + + retries := 1 + if c != nil && c.Retries > 0 { retries = c.Retries } - if retries == 0 { - retries = 1 + if options.Retries > 0 { + retries = options.Retries } for i := 0; i < retries; i++ { - conn, err = c.dial(addr) + conn, err = c.dialWithOptions(addr, options) if err == nil { break } @@ -120,16 +123,19 @@ func (c *Chain) Dial(addr string) (conn net.Conn, err error) { return } -func (c *Chain) dial(addr string) (net.Conn, error) { +func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, error) { + if options == nil { + options = &ChainOptions{} + } route, err := c.selectRouteFor(addr) if err != nil { return nil, err } - addr = c.resolve(addr) + addr = c.resolve(addr, options.Resolver, options.Hosts) if route.IsEmpty() { - return net.DialTimeout("tcp", addr, DialTimeout) + return net.DialTimeout("tcp", addr, options.Timeout) } conn, err := route.getConn() @@ -145,17 +151,17 @@ func (c *Chain) dial(addr string) (net.Conn, error) { return cc, nil } -func (c *Chain) resolve(addr string) string { +func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { host, port, err := net.SplitHostPort(addr) if err != nil { return addr } - if ip := c.Hosts.Lookup(host); ip != nil { + if ip := hosts.Lookup(host); ip != nil { return net.JoinHostPort(ip.String(), port) } - if c.Resolver != nil { - ips, err := c.Resolver.Resolve(host) + if resolver != nil { + ips, err := resolver.Resolve(host) if err != nil { log.Logf("[resolver] %s: %v", host, err) } @@ -168,8 +174,21 @@ func (c *Chain) resolve(addr string) string { // Conn obtains a handshaked connection to the last node of the chain. // If the chain is empty, it returns an ErrEmptyChain error. -func (c *Chain) Conn() (conn net.Conn, err error) { - for i := 0; i < c.Retries; i++ { +func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { + options := &ChainOptions{} + for _, opt := range opts { + opt(options) + } + + retries := 1 + if c != nil && c.Retries > 0 { + retries = c.Retries + } + if options.Retries > 0 { + retries = options.Retries + } + + for i := 0; i < retries; i++ { var route *Chain route, err = c.selectRoute() if err != nil { @@ -177,6 +196,7 @@ func (c *Chain) Conn() (conn net.Conn, err error) { } conn, err = route.getConn() if err != nil { + log.Log(err) continue } @@ -185,6 +205,7 @@ func (c *Chain) Conn() (conn net.Conn, err error) { return } +// getConn obtains a connection to the last node of the chain. func (c *Chain) getConn() (conn net.Conn, err error) { if c.IsEmpty() { err = ErrEmptyChain @@ -232,7 +253,7 @@ func (c *Chain) getConn() (conn net.Conn, err error) { } func (c *Chain) selectRoute() (route *Chain, err error) { - if c.isRoute { + if c.IsEmpty() || c.isRoute { return c, nil } @@ -256,7 +277,6 @@ func (c *Chain) selectRoute() (route *Chain, err error) { route.AddNode(node) } route.Retries = c.Retries - route.Resolver = c.Resolver if Debug { log.Log("select route:", buf.String()) @@ -299,9 +319,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { route.AddNode(node) } - route.Retries = c.Retries - route.Resolver = c.Resolver if Debug { buf.WriteString(addr) @@ -312,7 +330,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { // ChainOptions holds options for Chain. type ChainOptions struct { - Retry int + Retries int Timeout time.Duration Hosts *Hosts Resolver Resolver @@ -322,9 +340,9 @@ type ChainOptions struct { type ChainOption func(opts *ChainOptions) // RetryChainOption specifies the times of retry used by Chain.Dial. -func RetryChainOption(retry int) ChainOption { +func RetryChainOption(retries int) ChainOption { return func(opts *ChainOptions) { - opts.Retry = retry + opts.Retries = retries } } diff --git a/cmd/gost/main.go b/cmd/gost/main.go index fb26896..9625007 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -6,8 +6,7 @@ import ( "flag" "fmt" "net" - "net/http" - _ "net/http/pprof" + // _ "net/http/pprof" "os" "runtime" "time" @@ -59,9 +58,9 @@ func init() { } func main() { - go func() { - log.Log(http.ListenAndServe("localhost:6060", nil)) - }() + // go func() { + // log.Log(http.ListenAndServe("localhost:6060", nil)) + // }() // NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate. config, err := tlsConfig(defaultCertFile, defaultKeyFile) if err != nil { @@ -95,12 +94,7 @@ type route struct { func (r *route) initChain() (*gost.Chain, error) { chain := gost.NewChain() - chain.Retries = r.Retries - if chain.Retries == 0 { - chain.Retries = 1 - } - gid := 1 // group ID for _, ns := range r.ChainNodes { @@ -454,18 +448,6 @@ func (r *route) serve() error { return err } - var whitelist, blacklist *gost.Permissions - if node.Values.Get("whitelist") != "" { - if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil { - return err - } - } - if node.Values.Get("blacklist") != "" { - if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil { - return err - } - } - var handler gost.Handler switch node.Protocol { case "http2": @@ -502,6 +484,27 @@ func (r *route) serve() error { handler = gost.AutoHandler() } } + + var whitelist, blacklist *gost.Permissions + if node.Values.Get("whitelist") != "" { + if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil { + return err + } + } + if node.Values.Get("blacklist") != "" { + if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil { + return err + } + } + + var hosts *gost.Hosts + if f, _ := os.Open(node.Get("hosts")); f != nil { + hosts, err = gost.ParseHosts(f) + if err != nil { + log.Logf("[hosts] %s: %v", f.Name(), err) + } + } + handler.Init( gost.AddrHandlerOption(node.Addr), gost.ChainHandlerOption(chain), @@ -511,20 +514,12 @@ func (r *route) serve() error { gost.BlacklistHandlerOption(blacklist), gost.BypassHandlerOption(parseBypass(node.Get("bypass"))), gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))), + gost.ResolverHandlerOption(parseResolver(node.Get("dns"))), + gost.HostsHandlerOption(hosts), + gost.RetryHandlerOption(node.GetInt("retry")), + gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second), ) - chain.Resolver = parseResolver(node.Get("dns")) - if gost.Debug { - log.Logf("[resolver]\n%v", chain.Resolver) - } - - if f, _ := os.Open(node.Get("hosts")); f != nil { - chain.Hosts, err = gost.ParseHosts(f) - if err != nil { - log.Logf("[hosts] %s: %v", f.Name(), err) - } - } - srv := &gost.Server{Listener: ln} go srv.Serve(handler) } diff --git a/forward.go b/forward.go index 3f014c6..c9dced9 100644 --- a/forward.go +++ b/forward.go @@ -86,7 +86,10 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { } log.Logf("[tcp] %s - %s", conn.RemoteAddr(), node.Addr) - cc, err := h.options.Chain.Dial(node.Addr) + cc, err := h.options.Chain.Dial(node.Addr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) if err != nil { log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) node.MarkDead() diff --git a/handler.go b/handler.go index d5d2fe8..d19855b 100644 --- a/handler.go +++ b/handler.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "net" "net/url" + "time" "github.com/ginuerzh/gosocks4" "github.com/ginuerzh/gosocks5" @@ -25,8 +26,12 @@ type HandlerOptions struct { TLSConfig *tls.Config Whitelist *Permissions Blacklist *Permissions - Bypass *Bypass Strategy Strategy + Bypass *Bypass + Retries int + Timeout time.Duration + Resolver Resolver + Hosts *Hosts } // HandlerOption allows a common way to set handler options. @@ -88,6 +93,34 @@ func StrategyHandlerOption(strategy Strategy) HandlerOption { } } +// RetryHandlerOption sets the retry option of HandlerOptions. +func RetryHandlerOption(retries int) HandlerOption { + return func(opts *HandlerOptions) { + opts.Retries = retries + } +} + +// TimeoutHandlerOption sets the timeout option of HandlerOptions. +func TimeoutHandlerOption(timeout time.Duration) HandlerOption { + return func(opts *HandlerOptions) { + opts.Timeout = timeout + } +} + +// ResolverHandlerOption sets the resolver option of HandlerOptions. +func ResolverHandlerOption(resolver Resolver) HandlerOption { + return func(opts *HandlerOptions) { + opts.Resolver = resolver + } +} + +// HostsHandlerOption sets the Hosts option of HandlerOptions. +func HostsHandlerOption(hosts *Hosts) HandlerOption { + return func(opts *HandlerOptions) { + opts.Hosts = hosts + } +} + type autoHandler struct { options *HandlerOptions } diff --git a/http.go b/http.go index 7c9f86d..a1923db 100644 --- a/http.go +++ b/http.go @@ -166,24 +166,50 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { req.Header.Del("Proxy-Authorization") // req.Header.Del("Proxy-Connection") - route, err := h.options.Chain.selectRouteFor(req.Host) - if err != nil { - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) - return - } - // forward http request - lastNode := route.LastNode() - if req.Method != http.MethodConnect && lastNode.Protocol == "http" { - h.forwardRequest(conn, req, route) - return - } - host := req.Host if _, port, _ := net.SplitHostPort(host); port == "" { host = net.JoinHostPort(req.Host, "80") } - cc, err := route.Dial(host) + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var err error + var cc net.Conn + var route *Chain + for i := 0; i < retries; i++ { + route, err = h.options.Chain.selectRouteFor(req.Host) + if err != nil { + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) + continue + } + // forward http request + lastNode := route.LastNode() + if req.Method != http.MethodConnect && lastNode.Protocol == "http" { + err = h.forwardRequest(conn, req, route) + if err == nil { + return + } + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) + continue + } + + cc, err = route.Dial(host, + RetryChainOption(1), + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) + if err == nil { + break + } + } + if err != nil { log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), host, err) @@ -218,23 +244,17 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { log.Logf("[http] %s >-< %s", cc.LocalAddr(), host) } -func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Chain) { +func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Chain) error { if route.IsEmpty() { - return + return nil } lastNode := route.LastNode() - cc, err := route.Conn() + cc, err := route.Conn( + RetryChainOption(1), // we control the retry manually. + ) if err != nil { - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), lastNode.Addr, err) - - b := []byte("HTTP/1.1 503 Service unavailable\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n") - if Debug { - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), lastNode.Addr, string(b)) - } - conn.Write(b) - return + return err } defer cc.Close() @@ -253,14 +273,14 @@ func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Ch } if err = req.WriteProxy(cc); err != nil { log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) - return + return nil } cc.SetWriteDeadline(time.Time{}) log.Logf("[http] %s <-> %s", conn.RemoteAddr(), req.Host) transport(conn, cc) log.Logf("[http] %s >-< %s", conn.RemoteAddr(), req.Host) - return + return nil } func basicProxyAuth(proxyAuth string) (username, password string, ok bool) { diff --git a/http2.go b/http2.go index 6501d83..c30fd96 100644 --- a/http2.go +++ b/http2.go @@ -321,7 +321,12 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { r.Header.Del("Proxy-Authorization") r.Header.Del("Proxy-Connection") - cc, err := h.options.Chain.Dial(target) + cc, err := h.options.Chain.Dial(target, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) if err != nil { log.Logf("[http2] %s -> %s : %s", r.RemoteAddr, target, err) w.WriteHeader(http.StatusServiceUnavailable) diff --git a/redirect.go b/redirect.go index f1be354..c36db97 100644 --- a/redirect.go +++ b/redirect.go @@ -49,7 +49,10 @@ func (h *tcpRedirectHandler) Handle(c net.Conn) { log.Logf("[red-tcp] %s -> %s", srcAddr, dstAddr) - cc, err := h.options.Chain.Dial(dstAddr.String()) + cc, err := h.options.Chain.Dial(dstAddr.String(), + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) if err != nil { log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) return diff --git a/resolver.go b/resolver.go index 2764ca7..08c38ff 100644 --- a/resolver.go +++ b/resolver.go @@ -48,12 +48,6 @@ func (ns NameServer) String() string { return fmt.Sprintf("%s/%s %s", addr, prot, host) } -type nameServers struct { - Servers []NameServer - Timeout time.Duration - TTL time.Duration -} - type resolverCacheItem struct { IPs []net.IP ts int64 diff --git a/sni.go b/sni.go index 0d99ae2..30c65a2 100644 --- a/sni.go +++ b/sni.go @@ -77,6 +77,7 @@ func (h *sniHandler) Handle(conn net.Conn) { req.URL.Scheme = "http" // make sure that the URL is absolute } handler := &httpHandler{options: h.options} + handler.Init() handler.handleRequest(conn, req) return } @@ -98,7 +99,12 @@ func (h *sniHandler) Handle(conn net.Conn) { return } - cc, err := h.options.Chain.Dial(addr) + cc, err := h.options.Chain.Dial(addr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) if err != nil { log.Logf("[sni] %s -> %s : %s", conn.RemoteAddr(), addr, err) return diff --git a/socks.go b/socks.go index fa4a926..b1bb24a 100644 --- a/socks.go +++ b/socks.go @@ -435,7 +435,12 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { return } - cc, err := h.options.Chain.Dial(addr) + cc, err := h.options.Chain.Dial(addr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) if err != nil { log.Logf("[socks5-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) rep := gosocks5.NewReply(gosocks5.HostUnreachable, nil) @@ -1181,7 +1186,10 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { return } - cc, err := h.options.Chain.Dial(addr) + cc, err := h.options.Chain.Dial(addr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) if err != nil { log.Logf("[socks4-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) rep := gosocks4.NewReply(gosocks4.Failed, nil) diff --git a/ss.go b/ss.go index 6b682c8..42cd419 100644 --- a/ss.go +++ b/ss.go @@ -152,7 +152,12 @@ func (h *shadowHandler) Handle(conn net.Conn) { return } - cc, err := h.options.Chain.Dial(addr) + cc, err := h.options.Chain.Dial(addr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) if err != nil { log.Logf("[ss] %s -> %s : %s", conn.RemoteAddr(), addr, err) return diff --git a/ssh.go b/ssh.go index c66b9a3..67ddfbe 100644 --- a/ssh.go +++ b/ssh.go @@ -513,7 +513,17 @@ func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr return } - conn, err := h.options.Chain.Dial(raddr) + if h.options.Bypass.Contains(raddr) { + log.Logf("[ssh-tcp] [bypass] %s", raddr) + return + } + + conn, err := h.options.Chain.Dial(raddr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) if err != nil { log.Logf("[ssh-tcp] %s - %s : %s", h.options.Addr, raddr, err) return