From 3befde01282a3b62a39b47f1183bcd8aeeff961f Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Tue, 14 Nov 2017 17:50:02 +0800 Subject: [PATCH] add node fail filter --- chain.go | 40 +++++----------- cmd/gost/main.go | 112 ++++++++++++++++++++++++++++++--------------- node.go | 92 +++++++++++++++++++++++++++++-------- selector.go | 117 +++++++++++++++++++++++------------------------ 4 files changed, 217 insertions(+), 144 deletions(-) diff --git a/chain.go b/chain.go index 1d2a3be..a09e7e0 100644 --- a/chain.go +++ b/chain.go @@ -142,13 +142,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) { cn, err := node.Client.Dial(node.Addr, node.DialOptions...) if err != nil { + node.MarkDead() return } cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) if err != nil { + node.MarkDead() return } + node.ResetDead() preNode := node for _, node := range nodes[1:] { @@ -156,13 +159,17 @@ func (c *Chain) getConn() (conn net.Conn, err error) { cc, err = preNode.Client.Connect(cn, node.Addr) if err != nil { cn.Close() + node.MarkDead() return } cc, err = node.Client.Handshake(cc, node.HandshakeOptions...) if err != nil { cn.Close() + node.MarkDead() return } + node.ResetDead() + cn = cc preNode = node } @@ -179,46 +186,21 @@ func (c *Chain) selectRoute() (route *Chain, err error) { buf := bytes.Buffer{} route = newRoute() for _, group := range c.nodeGroups { - selector := group.Selector - if selector == nil { - selector = &defaultSelector{} - } - // select node from node group - node, err := selector.Select(group.Nodes(), group.Options...) + node, err := group.Next() if err != nil { return nil, err } - if _, err := selectIP(&node); err != nil { - return nil, err - } - buf.WriteString(fmt.Sprintf("%d@%s -> ", node.ID, node.Addr)) + buf.WriteString(fmt.Sprintf("%s -> ", node.String())) if node.Client.Transporter.Multiplex() { node.DialOptions = append(node.DialOptions, ChainDialOption(route), ) - route = newRoute() // cutoff the chain for multiplex + route = newRoute() // cutoff the chain for multiplex. } + route.AddNode(node) } log.Log("select route:", buf.String()) return } - -func selectIP(node *Node) (string, error) { - s := node.Selector - if s == nil { - s = &RandomIPSelector{} - } - // select IP from IP list - ip, err := s.Select(node.IPs) - if err != nil { - return "", err - } - if ip != "" { - // override the original address - node.Addr = ip - node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(node.Addr)) - } - return node.Addr, nil -} diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 8e69fdd..d15cabd 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -87,35 +87,53 @@ func main() { func initChain() (*gost.Chain, error) { chain := gost.NewChain() + gid := 1 // group ID + for _, ns := range options.ChainNodes { + ngroup := gost.NewNodeGroup() + ngroup.ID = gid + gid++ + // parse the base node - node, err := parseChainNode(ns) + nodes, err := parseChainNode(ns) if err != nil { return nil, err } - id := 1 // start from 1 + nid := 1 // node ID - node.ID = id - ngroup := gost.NewNodeGroup(node) + for i := range nodes { + nodes[i].ID = nid + nid++ + } + ngroup.AddNode(nodes...) - // parse node peers if exists - peerCfg, err := loadPeerConfig(node.Values.Get("peer")) + // parse peer nodes if exists + peerCfg, err := loadPeerConfig(nodes[0].Values.Get("peer")) if err != nil { log.Log(err) } + peerCfg.Validate() ngroup.Options = append(ngroup.Options, - // gost.WithFilter(), + gost.WithFilter(&gost.FailFilter{ + MaxFails: peerCfg.MaxFails, + FailTimeout: time.Duration(peerCfg.FailTimeout) * time.Second, + }), gost.WithStrategy(parseStrategy(peerCfg.Strategy)), ) + for _, s := range peerCfg.Nodes { - node, err = parseChainNode(s) + nodes, err = parseChainNode(s) if err != nil { return nil, err } - id++ - node.ID = id - ngroup.AddNode(node) + + for i := range nodes { + nodes[i].ID = nid + nid++ + } + + ngroup.AddNode(nodes...) } chain.AddNodeGroup(ngroup) @@ -124,24 +142,12 @@ func initChain() (*gost.Chain, error) { return chain, nil } -func parseChainNode(ns string) (node gost.Node, err error) { - node, err = gost.ParseNode(ns) +func parseChainNode(ns string) (nodes []gost.Node, err error) { + node, err := gost.ParseNode(ns) if err != nil { return } - node.IPs = parseIP(node.Values.Get("ip")) - for i, ip := range node.IPs { - if !strings.Contains(ip, ":") { - _, sport, _ := net.SplitHostPort(node.Addr) - if sport == "" { - sport = "8080" // default port - } - node.IPs[i] = ip + ":" + sport - } - } - node.Selector = &gost.RoundRobinIPSelector{} - users, err := parseUsers(node.Values.Get("secrets")) if err != nil { return @@ -149,7 +155,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { if node.User == nil && len(users) > 0 { node.User = users[0] } - serverName, _, _ := net.SplitHostPort(node.Addr) + serverName, sport, _ := net.SplitHostPort(node.Addr) if serverName == "" { serverName = "localhost" // default server name } @@ -191,7 +197,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { */ config, err := parseKCPConfig(node.Values.Get("c")) if err != nil { - return node, err + return nil, err } tr = gost.KCPTransporter(config) case "ssh": @@ -220,7 +226,7 @@ func parseChainNode(ns string) (node gost.Node, err error) { case "obfs4": if err := gost.Obfs4Init(node, false); err != nil { - return node, err + return nil, err } tr = gost.Obfs4Transporter() case "ohttp": @@ -263,20 +269,31 @@ func parseChainNode(ns string) (node gost.Node, err error) { interval, _ := strconv.Atoi(node.Values.Get("ping")) retry, _ := strconv.Atoi(node.Values.Get("retry")) - node.HandshakeOptions = append(node.HandshakeOptions, + handshakeOptions := []gost.HandshakeOption{ gost.AddrHandshakeOption(node.Addr), gost.HostHandshakeOption(node.Host), gost.UserHandshakeOption(node.User), gost.TLSConfigHandshakeOption(tlsCfg), - gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), - gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second), + gost.IntervalHandshakeOption(time.Duration(interval) * time.Second), + gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second), gost.RetryHandshakeOption(retry), - ) + } node.Client = &gost.Client{ Connector: connector, Transporter: tr, } + ips := parseIP(node.Values.Get("ip"), sport) + for _, ip := range ips { + node.Addr = ip + node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip)) + nodes = append(nodes, node) + } + if len(ips) == 0 { + node.HandshakeOptions = handshakeOptions + nodes = []gost.Node{node} + } + return } @@ -559,16 +576,23 @@ func parseUsers(authFile string) (users []*url.Userinfo, err error) { return } -func parseIP(s string) (ips []string) { +func parseIP(s string, port string) (ips []string) { if s == "" { - return nil + return } + if port == "" { + port = "8080" // default port + } + file, err := os.Open(s) if err != nil { ss := strings.Split(s, ",") for _, s := range ss { s = strings.TrimSpace(s) if s != "" { + if !strings.Contains(s, ":") { + s = s + ":" + port + } ips = append(ips, s) } @@ -582,15 +606,20 @@ func parseIP(s string) (ips []string) { if line == "" || strings.HasPrefix(line, "#") { continue } + if !strings.Contains(line, ":") { + line = line + ":" + port + } ips = append(ips, line) } return } type peerConfig struct { - Strategy string `json:"strategy"` - Filters []string `json:"filters"` - Nodes []string `json:"nodes"` + Strategy string `json:"strategy"` + Filters []string `json:"filters"` + MaxFails int `json:"max_fails"` + FailTimeout int `json:"fail_timeout"` + Nodes []string `json:"nodes"` } func loadPeerConfig(peer string) (config peerConfig, err error) { @@ -605,6 +634,15 @@ func loadPeerConfig(peer string) (config peerConfig, err error) { return } +func (cfg *peerConfig) Validate() { + if cfg.MaxFails <= 0 { + cfg.MaxFails = 3 + } + if cfg.FailTimeout <= 0 { + cfg.FailTimeout = 30 // seconds + } +} + func parseStrategy(s string) gost.Strategy { switch s { case "random": diff --git a/node.go b/node.go index 1266108..799b6e4 100644 --- a/node.go +++ b/node.go @@ -1,16 +1,17 @@ package gost import ( + "fmt" "net/url" "strings" - "sync" + "sync/atomic" + "time" ) // Node is a proxy node, mainly used to construct a proxy chain. type Node struct { ID int Addr string - IPs []string Host string Protocol string Transport string @@ -20,7 +21,9 @@ type Node struct { DialOptions []DialOption HandshakeOptions []HandshakeOption Client *Client - Selector IPSelector + group *NodeGroup + failCount uint32 + failTime time.Time } // ParseNode parses the node info. @@ -83,38 +86,89 @@ func ParseNode(s string) (node Node, err error) { return } +// MarkDead marks the node fail status. +func (node *Node) MarkDead() { + atomic.AddUint32(&node.failCount, 1) + node.failTime = time.Now() + + if node.group == nil { + return + } + for i := range node.group.nodes { + if node.group.nodes[i].ID == node.ID { + atomic.AddUint32(&node.group.nodes[i].failCount, 1) + node.group.nodes[i].failTime = time.Now() + break + } + } +} + +// ResetDead resets the node fail status. +func (node *Node) ResetDead() { + atomic.StoreUint32(&node.failCount, 0) + node.failTime = time.Time{} + + if node.group == nil { + return + } + + for i := range node.group.nodes { + if node.group.nodes[i].ID == node.ID { + atomic.StoreUint32(&node.group.nodes[i].failCount, 0) + node.group.nodes[i].failTime = time.Time{} + break + } + } +} + +func (node *Node) String() string { + return fmt.Sprintf("%d@%s", node.ID, node.Addr) +} + // NodeGroup is a group of nodes. type NodeGroup struct { - nodes []Node - Options []SelectOption - Selector NodeSelector - mutex sync.Mutex - mFails map[string]int // node -> fail count - MaxFails int - FailTimeout int - Retries int + ID int + nodes []Node + Options []SelectOption + Selector NodeSelector } // NewNodeGroup creates a node group func NewNodeGroup(nodes ...Node) *NodeGroup { return &NodeGroup{ - nodes: nodes, - mFails: make(map[string]int), + nodes: nodes, } } // AddNode adds node or node list into group -func (ng *NodeGroup) AddNode(node ...Node) { - if ng == nil { +func (group *NodeGroup) AddNode(node ...Node) { + if group == nil { return } - ng.nodes = append(ng.nodes, node...) + group.nodes = append(group.nodes, node...) } // Nodes returns node list in the group -func (ng *NodeGroup) Nodes() []Node { - if ng == nil { +func (group *NodeGroup) Nodes() []Node { + if group == nil { return nil } - return ng.nodes + return group.nodes +} + +// Next selects the next node from group. +// It also selects IP if the IP list exists. +func (group *NodeGroup) Next() (node Node, err error) { + selector := group.Selector + if selector == nil { + selector = &defaultSelector{} + } + // select node from node group + node, err = selector.Select(group.Nodes(), group.Options...) + if err != nil { + return + } + node.group = group + + return } diff --git a/selector.go b/selector.go index 675a254..345b265 100644 --- a/selector.go +++ b/selector.go @@ -2,6 +2,8 @@ package gost import ( "errors" + "math/rand" + "sync" "sync/atomic" "time" ) @@ -37,9 +39,28 @@ func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, erro return sopts.Strategy.Apply(nodes), nil } -// Filter is used to filter a node during the selection process -type Filter interface { - Filter([]Node) []Node +// SelectOption used when making a select call +type SelectOption func(*SelectOptions) + +// SelectOptions is the options for node selection +type SelectOptions struct { + Filters []Filter + Strategy Strategy +} + +// WithFilter adds a filter function to the list of filters +// used during the Select call. +func WithFilter(f ...Filter) SelectOption { + return func(o *SelectOptions) { + o.Filters = append(o.Filters, f...) + } +} + +// WithStrategy sets the selector strategy +func WithStrategy(s Strategy) SelectOption { + return func(o *SelectOptions) { + o.Strategy = s + } } // Strategy is a selection strategy e.g random, round robin @@ -68,82 +89,60 @@ func (s *RoundStrategy) String() string { } // RandomStrategy is a strategy for node selector -type RandomStrategy struct{} +type RandomStrategy struct { + Seed int64 + rand *rand.Rand + once sync.Once +} // Apply applies the random strategy for the nodes func (s *RandomStrategy) Apply(nodes []Node) Node { + s.once.Do(func() { + seed := s.Seed + if seed == 0 { + seed = time.Now().UnixNano() + } + s.rand = rand.New(rand.NewSource(seed)) + }) if len(nodes) == 0 { return Node{} } - return nodes[time.Now().Nanosecond()%len(nodes)] + return nodes[s.rand.Int()%len(nodes)] } func (s *RandomStrategy) String() string { return "random" } -// SelectOption used when making a select call -type SelectOption func(*SelectOptions) - -// SelectOptions is the options for node selection -type SelectOptions struct { - Filters []Filter - Strategy Strategy -} - -// WithFilter adds a filter function to the list of filters -// used during the Select call. -func WithFilter(f ...Filter) SelectOption { - return func(o *SelectOptions) { - o.Filters = append(o.Filters, f...) - } -} - -// WithStrategy sets the selector strategy -func WithStrategy(s Strategy) SelectOption { - return func(o *SelectOptions) { - o.Strategy = s - } -} - -// IPSelector as a mechanism to pick IPs and mark their status. -type IPSelector interface { - Select(ips []string) (string, error) +// Filter is used to filter a node during the selection process +type Filter interface { + Filter([]Node) []Node String() string } -// RandomIPSelector is an IP Selector that selects an IP with random strategy. -type RandomIPSelector struct { +// FailFilter filters the dead node. +// A node is marked as dead if its failed count is greater than MaxFails. +type FailFilter struct { + MaxFails int + FailTimeout time.Duration } -// Select selects an IP from ips list. -func (s *RandomIPSelector) Select(ips []string) (string, error) { - if len(ips) == 0 { - return "", nil +// Filter filters nodes. +func (f *FailFilter) Filter(nodes []Node) []Node { + if f.MaxFails <= 0 { + return nodes } - return ips[time.Now().Nanosecond()%len(ips)], nil -} - -func (s *RandomIPSelector) String() string { - return "random" -} - -// RoundRobinIPSelector is an IP Selector that selects an IP with round-robin strategy. -type RoundRobinIPSelector struct { - count uint64 -} - -// Select selects an IP from ips list. -func (s *RoundRobinIPSelector) Select(ips []string) (string, error) { - if len(ips) == 0 { - return "", nil + nl := []Node{} + for _, node := range nodes { + if node.failCount < uint32(f.MaxFails) || + time.Since(node.failTime) >= f.FailTimeout { + nl = append(nl, node) + } } - old := s.count - atomic.AddUint64(&s.count, 1) - return ips[int(old%uint64(len(ips)))], nil + return nl } -func (s *RoundRobinIPSelector) String() string { - return "round" +func (f *FailFilter) String() string { + return "fail" }