diff --git a/chain.go b/chain.go index a09e7e0..9ebd14b 100644 --- a/chain.go +++ b/chain.go @@ -17,6 +17,7 @@ var ( // Chain is a proxy chain that holds a list of proxy nodes. type Chain struct { isRoute bool + Retries int nodeGroups []*NodeGroup } @@ -58,8 +59,8 @@ func (c *Chain) LastNode() Node { if c.IsEmpty() { return Node{} } - last := c.nodeGroups[len(c.nodeGroups)-1] - return last.nodes[0] + group := c.nodeGroups[len(c.nodeGroups)-1] + return group.nodes[0].Clone() } // LastNodeGroup returns the last group of the group list. @@ -185,6 +186,8 @@ func (c *Chain) selectRoute() (route *Chain, err error) { buf := bytes.Buffer{} route = newRoute() + route.Retries = c.Retries + for _, group := range c.nodeGroups { node, err := group.Next() if err != nil { @@ -197,6 +200,7 @@ func (c *Chain) selectRoute() (route *Chain, err error) { ChainDialOption(route), ) route = newRoute() // cutoff the chain for multiplex. + route.Retries = c.Retries } route.AddNode(node) diff --git a/cmd/gost/gost.json b/cmd/gost/gost.json index 8602a51..e7217c9 100644 --- a/cmd/gost/gost.json +++ b/cmd/gost/gost.json @@ -1,4 +1,6 @@ { + "Debug": false, + "Retries": 1, "ServeNodes": [ ":8080", "ss://chacha20:12345678@:8338" @@ -6,5 +8,23 @@ "ChainNodes": [ "http://192.168.1.1:8080", "https://10.0.2.1:443" + ], + + "Routes": [ + { + "Retries": 1, + "ServeNodes": [ + "ws://:1443" + ], + "ChainNodes": [ + "socks://:192.168.1.1:1080" + ] + }, + { + "Retries": 3, + "ServeNodes": [ + "quic://:443" + ] + } ] } \ No newline at end of file diff --git a/cmd/gost/main.go b/cmd/gost/main.go index d15cabd..b075cfa 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -22,10 +22,8 @@ import ( ) var ( - options struct { - ChainNodes, ServeNodes stringList - Debug bool - } + options route + routes []route ) func init() { @@ -43,12 +41,17 @@ func init() { flag.BoolVar(&printVersion, "V", false, "print version") flag.Parse() + if len(options.ServeNodes) > 0 { + routes = append(routes, options) + } + gost.Debug = options.Debug + if err := loadConfigureFile(configureFile); err != nil { log.Log(err) os.Exit(1) } - if flag.NFlag() == 0 { + if flag.NFlag() == 0 || len(routes) == 0 { flag.PrintDefaults() os.Exit(0) } @@ -58,7 +61,7 @@ func init() { os.Exit(0) } - gost.Debug = options.Debug + log.Log("Debug:", gost.Debug) } func main() { @@ -72,24 +75,29 @@ func main() { Certificates: []tls.Certificate{cert}, } - chain, err := initChain() - if err != nil { - log.Log(err) - os.Exit(1) - } - if err := serve(chain); err != nil { - log.Log(err) - os.Exit(1) + for _, route := range routes { + if err := route.serve(); err != nil { + log.Log(err) + os.Exit(1) + } } select {} } -func initChain() (*gost.Chain, error) { +type route struct { + ChainNodes, ServeNodes stringList + Retries int + Debug bool +} + +func (r *route) initChain() (*gost.Chain, error) { chain := gost.NewChain() + chain.Retries = r.Retries + gid := 1 // group ID - for _, ns := range options.ChainNodes { + for _, ns := range r.ChainNodes { ngroup := gost.NewNodeGroup() ngroup.ID = gid gid++ @@ -297,8 +305,13 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { return } -func serve(chain *gost.Chain) error { - for _, ns := range options.ServeNodes { +func (r *route) serve() error { + chain, err := r.initChain() + if err != nil { + return err + } + + for _, ns := range r.ServeNodes { node, err := gost.ParseNode(ns) if err != nil { return err @@ -507,9 +520,24 @@ func loadConfigureFile(configureFile string) error { if err != nil { return err } - if err := json.Unmarshal(content, &options); err != nil { + var cfg struct { + route + Routes []route + } + if err := json.Unmarshal(content, &cfg); err != nil { return err } + + if len(cfg.route.ServeNodes) > 0 { + routes = append(routes, cfg.route) + } + for _, route := range cfg.Routes { + if len(cfg.route.ServeNodes) > 0 { + routes = append(routes, route) + } + } + gost.Debug = cfg.Debug + return nil } diff --git a/handler.go b/handler.go index 2b51e4e..84162b3 100644 --- a/handler.go +++ b/handler.go @@ -95,7 +95,8 @@ func (h *autoHandler) Handle(conn net.Conn) { cc := &bufferdConn{Conn: conn, br: br} switch b[0] { case gosocks4.Ver4: - return // SOCKS4(a) does not suppport authentication method, so we ignore it. + cc.Close() + return // SOCKS4(a) does not suppport authentication method, so we ignore it for security reason. case gosocks5.Ver5: SOCKS5Handler(h.options...).Handle(cc) default: // http diff --git a/node.go b/node.go index 799b6e4..14d5b17 100644 --- a/node.go +++ b/node.go @@ -23,7 +23,7 @@ type Node struct { Client *Client group *NodeGroup failCount uint32 - failTime time.Time + failTime int64 } // ParseNode parses the node info. @@ -89,7 +89,7 @@ func ParseNode(s string) (node Node, err error) { // MarkDead marks the node fail status. func (node *Node) MarkDead() { atomic.AddUint32(&node.failCount, 1) - node.failTime = time.Now() + atomic.StoreInt64(&node.failTime, time.Now().Unix()) if node.group == nil { return @@ -97,7 +97,7 @@ func (node *Node) MarkDead() { for i := range node.group.nodes { if node.group.nodes[i].ID == node.ID { atomic.AddUint32(&node.group.nodes[i].failCount, 1) - node.group.nodes[i].failTime = time.Now() + atomic.StoreInt64(&node.group.nodes[i].failTime, time.Now().Unix()) break } } @@ -106,7 +106,7 @@ func (node *Node) MarkDead() { // ResetDead resets the node fail status. func (node *Node) ResetDead() { atomic.StoreUint32(&node.failCount, 0) - node.failTime = time.Time{} + atomic.StoreInt64(&node.failTime, 0) if node.group == nil { return @@ -115,12 +115,32 @@ func (node *Node) ResetDead() { for i := range node.group.nodes { if node.group.nodes[i].ID == node.ID { atomic.StoreUint32(&node.group.nodes[i].failCount, 0) - node.group.nodes[i].failTime = time.Time{} + atomic.StoreInt64(&node.group.nodes[i].failTime, 0) break } } } +// Clone clones the node, it will prevent data race. +func (node *Node) Clone() Node { + return Node{ + ID: node.ID, + Addr: node.Addr, + Host: node.Host, + Protocol: node.Protocol, + Transport: node.Transport, + Remote: node.Remote, + User: node.User, + Values: node.Values, + DialOptions: node.DialOptions, + HandshakeOptions: node.HandshakeOptions, + Client: node.Client, + group: node.group, + failCount: atomic.LoadUint32(&node.failCount), + failTime: atomic.LoadInt64(&node.failTime), + } +} + func (node *Node) String() string { return fmt.Sprintf("%d@%s", node.ID, node.Addr) } diff --git a/selector.go b/selector.go index 345b265..bea2c65 100644 --- a/selector.go +++ b/selector.go @@ -79,7 +79,7 @@ func (s *RoundStrategy) Apply(nodes []Node) Node { if len(nodes) == 0 { return Node{} } - old := s.count + old := atomic.LoadUint64(&s.count) atomic.AddUint64(&s.count, 1) return nodes[int(old%uint64(len(nodes)))] } @@ -134,10 +134,10 @@ func (f *FailFilter) Filter(nodes []Node) []Node { return nodes } nl := []Node{} - for _, node := range nodes { - if node.failCount < uint32(f.MaxFails) || - time.Since(node.failTime) >= f.FailTimeout { - nl = append(nl, node) + for i := range nodes { + if atomic.LoadUint32(&nodes[i].failCount) < uint32(f.MaxFails) || + time.Since(time.Unix(atomic.LoadInt64(&nodes[i].failTime), 0)) >= f.FailTimeout { + nl = append(nl, nodes[i].Clone()) } } return nl