diff --git a/chain.go b/chain.go index 9dc672b..b12857a 100644 --- a/chain.go +++ b/chain.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "strings" "github.com/go-log/log" ) @@ -100,11 +101,7 @@ func (c *Chain) IsEmpty() bool { // Dial connects to the target address addr through the chain. // If the chain is empty, it will use the net.Dial directly. func (c *Chain) Dial(addr string) (conn net.Conn, err error) { - if c.IsEmpty() { - return net.DialTimeout("tcp", addr, DialTimeout) - } - - for i := 0; i < c.Retries+1; i++ { + for i := 0; i < c.Retries; i++ { conn, err = c.dial(addr) if err == nil { break @@ -114,10 +111,13 @@ func (c *Chain) Dial(addr string) (conn net.Conn, err error) { } func (c *Chain) dial(addr string) (net.Conn, error) { - route, err := c.selectRoute() + route, err := c.selectRouteFor(addr) if err != nil { return nil, err } + if route.IsEmpty() { + return net.DialTimeout("tcp", addr, DialTimeout) + } conn, err := route.getConn() if err != nil { @@ -135,7 +135,7 @@ func (c *Chain) dial(addr string) (net.Conn, error) { // Conn obtains a handshaked connection to the last node of the chain. // If the chain is empty, it returns an ErrEmptyChain error. func (c *Chain) Conn() (conn net.Conn, err error) { - for i := 0; i < c.Retries+1; i++ { + for i := 0; i < c.Retries; i++ { var route *Chain route, err = c.selectRoute() if err != nil { @@ -228,3 +228,50 @@ func (c *Chain) selectRoute() (route *Chain, err error) { } return } + +// selectRouteFor selects route with bypass testing. +func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { + if c.IsEmpty() || c.isRoute { + return c, nil + } + + buf := bytes.Buffer{} + route = newRoute() + route.Retries = c.Retries + + for _, group := range c.nodeGroups { + var node Node + node, err = group.Next() + if err != nil { + return + } + + // NOTE: IPv6 will not work. + if strings.Contains(addr, ":") { + addr = strings.Split(addr, ":")[0] + } + if node.Bypass.Contains(addr) { + if Debug { + buf.WriteString(fmt.Sprintf("[%d@bypass: %s]", node.ID, addr)) + log.Log("select route:", buf.String()) + } + return + } + + buf.WriteString(fmt.Sprintf("%s -> ", node.String())) + + if node.Client.Transporter.Multiplex() { + node.DialOptions = append(node.DialOptions, + ChainDialOption(route), + ) + route = newRoute() // cutoff the chain for multiplex. + route.Retries = c.Retries + } + + route.AddNode(node) + } + if Debug { + log.Log("select route:", buf.String()) + } + return +} diff --git a/cmd/gost/bypass.txt b/cmd/gost/bypass.txt new file mode 100644 index 0000000..04bf178 --- /dev/null +++ b/cmd/gost/bypass.txt @@ -0,0 +1,5 @@ +10.0.0.1 +192.168.0.0/24 +172.1.0.0/16 +192.168.100.190/32 +*.example.com \ No newline at end of file diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 1e0b49e..8abf2f3 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -92,7 +92,11 @@ type route struct { func (r *route) initChain() (*gost.Chain, error) { chain := gost.NewChain() + chain.Retries = r.Retries + if chain.Retries == 0 { + chain.Retries = 1 + } gid := 1 // group ID @@ -143,6 +147,17 @@ func (r *route) initChain() (*gost.Chain, error) { ngroup.AddNode(nodes...) } + var bypass *gost.Bypass + if peerCfg.Bypass != nil { + bypass = gost.NewBypassPatterns(peerCfg.Bypass.Patterns, peerCfg.Bypass.Reverse) + } + nodes = ngroup.Nodes() + for i := range nodes { + if nodes[i].Bypass == nil { + nodes[i].Bypass = bypass // use global bypass if local bypass does not exist. + } + } + chain.AddNodeGroup(ngroup) } @@ -297,9 +312,12 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { Transporter: tr, } + node.Bypass = parseBypass(node.Get("bypass")) + ips := parseIP(node.Get("ip"), sport) for _, ip := range ips { node.Addr = ip + // override the default node address node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip)) nodes = append(nodes, node) } @@ -484,14 +502,9 @@ func (r *route) serve() error { } } - fBypass := node.Get("bypass") - if fBypass == "" { - fBypass = "bypass" // default bypass file - } - srv := &gost.Server{Listener: ln} srv.Init( - gost.BypassServerOption(parseBypass(fBypass)), + gost.BypassServerOption(parseBypass(node.Get("bypass"))), ) go srv.Serve(handler) } @@ -657,6 +670,12 @@ type peerConfig struct { MaxFails int `json:"max_fails"` FailTimeout int `json:"fail_timeout"` Nodes []string `json:"nodes"` + Bypass *bypass `json:"bypass"` // global bypass +} + +type bypass struct { + Reverse bool `json:"reverse"` + Patterns []string `json:"patterns"` } func loadPeerConfig(peer string) (config peerConfig, err error) { @@ -694,16 +713,29 @@ func parseStrategy(s string) gost.Strategy { } } -func parseBypass(fpath string) (bypass *gost.Bypass) { - if fpath == "" { - return +func parseBypass(s string) *gost.Bypass { + if s == "" { + return nil } - f, err := os.Open(fpath) - if err != nil { - return + var matchers []gost.Matcher + var reversed bool + if strings.HasPrefix(s, "~") { + reversed = true + s = strings.TrimLeft(s, "~") + } + + f, err := os.Open(s) + if err != nil { + for _, s := range strings.Split(s, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + matchers = append(matchers, gost.NewMatcher(s)) + } + return gost.NewBypass(matchers, reversed) } - var matchers []gost.Matcher scanner := bufio.NewScanner(f) for scanner.Scan() { line := scanner.Text() @@ -716,6 +748,5 @@ func parseBypass(fpath string) (bypass *gost.Bypass) { } matchers = append(matchers, gost.NewMatcher(line)) } - bypass = gost.NewBypass(matchers, strings.HasPrefix(fpath, "~")) - return + return gost.NewBypass(matchers, reversed) } diff --git a/cmd/gost/peer.json b/cmd/gost/peer.json index a4eee76..4cbcd4d 100644 --- a/cmd/gost/peer.json +++ b/cmd/gost/peer.json @@ -6,5 +6,13 @@ "socks5://:1081", "socks://:1082", "socks4a://:1083" - ] + ], + "bypass":{ + "reverse": false, + "patterns": [ + "10.0.0.1", + "192.168.0.0/24", + "*.example.com" + ] + } } \ No newline at end of file diff --git a/http.go b/http.go index 4b2c507..ccd8bea 100644 --- a/http.go +++ b/http.go @@ -150,10 +150,15 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { } } + route, err := h.options.Chain.selectRouteFor(req.Host) + if err != nil { + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) + return + } // forward http request - lastNode := h.options.Chain.LastNode() + lastNode := route.LastNode() if req.Method != http.MethodConnect && lastNode.Protocol == "http" { - h.forwardRequest(conn, req) + h.forwardRequest(conn, req, route) return } @@ -162,7 +167,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { host += ":80" } - cc, err := h.options.Chain.Dial(host) + cc, err := route.Dial(host) if err != nil { log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), host, err) @@ -197,13 +202,13 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { log.Logf("[http] %s >-< %s", cc.LocalAddr(), host) } -func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request) { - if h.options.Chain.IsEmpty() { +func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Chain) { + if route.IsEmpty() { return } - lastNode := h.options.Chain.LastNode() + lastNode := route.LastNode() - cc, err := h.options.Chain.Conn() + cc, err := route.Conn() if err != nil { log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), lastNode.Addr, err) diff --git a/node.go b/node.go index b63fac9..c174ba5 100644 --- a/node.go +++ b/node.go @@ -25,6 +25,7 @@ type Node struct { group *NodeGroup failCount uint32 failTime int64 + Bypass *Bypass } // ParseNode parses the node info. diff --git a/socks.go b/socks.go index 65bfc1f..35e7800 100644 --- a/socks.go +++ b/socks.go @@ -267,7 +267,7 @@ func (c *socks4Connector) Connect(conn net.Conn, addr string) (net.Conn, error) return nil, err } if len(taddr.IP) == 0 { - taddr.IP = net.IPv4(0, 0, 0, 0) + taddr.IP = net.IPv4zero } req := gosocks4.NewRequest(gosocks4.CmdConnect,