add node fail filter

This commit is contained in:
rui.zheng 2017-11-14 17:50:02 +08:00
parent da30584df1
commit 3befde0128
4 changed files with 217 additions and 144 deletions

View File

@ -142,13 +142,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cn, err := node.Client.Dial(node.Addr, node.DialOptions...) cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil { if err != nil {
node.MarkDead()
return return
} }
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
if err != nil { if err != nil {
node.MarkDead()
return return
} }
node.ResetDead()
preNode := node preNode := node
for _, node := range nodes[1:] { for _, node := range nodes[1:] {
@ -156,13 +159,17 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cc, err = preNode.Client.Connect(cn, node.Addr) cc, err = preNode.Client.Connect(cn, node.Addr)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.MarkDead()
return return
} }
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...) cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.MarkDead()
return return
} }
node.ResetDead()
cn = cc cn = cc
preNode = node preNode = node
} }
@ -179,46 +186,21 @@ func (c *Chain) selectRoute() (route *Chain, err error) {
buf := bytes.Buffer{} buf := bytes.Buffer{}
route = newRoute() route = newRoute()
for _, group := range c.nodeGroups { for _, group := range c.nodeGroups {
selector := group.Selector node, err := group.Next()
if selector == nil {
selector = &defaultSelector{}
}
// select node from node group
node, err := selector.Select(group.Nodes(), group.Options...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, err := selectIP(&node); err != nil { buf.WriteString(fmt.Sprintf("%s -> ", node.String()))
return nil, err
}
buf.WriteString(fmt.Sprintf("%d@%s -> ", node.ID, node.Addr))
if node.Client.Transporter.Multiplex() { if node.Client.Transporter.Multiplex() {
node.DialOptions = append(node.DialOptions, node.DialOptions = append(node.DialOptions,
ChainDialOption(route), ChainDialOption(route),
) )
route = newRoute() // cutoff the chain for multiplex route = newRoute() // cutoff the chain for multiplex.
} }
route.AddNode(node) route.AddNode(node)
} }
log.Log("select route:", buf.String()) log.Log("select route:", buf.String())
return return
} }
func selectIP(node *Node) (string, error) {
s := node.Selector
if s == nil {
s = &RandomIPSelector{}
}
// select IP from IP list
ip, err := s.Select(node.IPs)
if err != nil {
return "", err
}
if ip != "" {
// override the original address
node.Addr = ip
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(node.Addr))
}
return node.Addr, nil
}

View File

@ -87,35 +87,53 @@ func main() {
func initChain() (*gost.Chain, error) { func initChain() (*gost.Chain, error) {
chain := gost.NewChain() chain := gost.NewChain()
gid := 1 // group ID
for _, ns := range options.ChainNodes { for _, ns := range options.ChainNodes {
ngroup := gost.NewNodeGroup()
ngroup.ID = gid
gid++
// parse the base node // parse the base node
node, err := parseChainNode(ns) nodes, err := parseChainNode(ns)
if err != nil { if err != nil {
return nil, err return nil, err
} }
id := 1 // start from 1 nid := 1 // node ID
node.ID = id for i := range nodes {
ngroup := gost.NewNodeGroup(node) nodes[i].ID = nid
nid++
}
ngroup.AddNode(nodes...)
// parse node peers if exists // parse peer nodes if exists
peerCfg, err := loadPeerConfig(node.Values.Get("peer")) peerCfg, err := loadPeerConfig(nodes[0].Values.Get("peer"))
if err != nil { if err != nil {
log.Log(err) log.Log(err)
} }
peerCfg.Validate()
ngroup.Options = append(ngroup.Options, ngroup.Options = append(ngroup.Options,
// gost.WithFilter(), gost.WithFilter(&gost.FailFilter{
MaxFails: peerCfg.MaxFails,
FailTimeout: time.Duration(peerCfg.FailTimeout) * time.Second,
}),
gost.WithStrategy(parseStrategy(peerCfg.Strategy)), gost.WithStrategy(parseStrategy(peerCfg.Strategy)),
) )
for _, s := range peerCfg.Nodes { for _, s := range peerCfg.Nodes {
node, err = parseChainNode(s) nodes, err = parseChainNode(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
id++
node.ID = id for i := range nodes {
ngroup.AddNode(node) nodes[i].ID = nid
nid++
}
ngroup.AddNode(nodes...)
} }
chain.AddNodeGroup(ngroup) chain.AddNodeGroup(ngroup)
@ -124,24 +142,12 @@ func initChain() (*gost.Chain, error) {
return chain, nil return chain, nil
} }
func parseChainNode(ns string) (node gost.Node, err error) { func parseChainNode(ns string) (nodes []gost.Node, err error) {
node, err = gost.ParseNode(ns) node, err := gost.ParseNode(ns)
if err != nil { if err != nil {
return return
} }
node.IPs = parseIP(node.Values.Get("ip"))
for i, ip := range node.IPs {
if !strings.Contains(ip, ":") {
_, sport, _ := net.SplitHostPort(node.Addr)
if sport == "" {
sport = "8080" // default port
}
node.IPs[i] = ip + ":" + sport
}
}
node.Selector = &gost.RoundRobinIPSelector{}
users, err := parseUsers(node.Values.Get("secrets")) users, err := parseUsers(node.Values.Get("secrets"))
if err != nil { if err != nil {
return return
@ -149,7 +155,7 @@ func parseChainNode(ns string) (node gost.Node, err error) {
if node.User == nil && len(users) > 0 { if node.User == nil && len(users) > 0 {
node.User = users[0] node.User = users[0]
} }
serverName, _, _ := net.SplitHostPort(node.Addr) serverName, sport, _ := net.SplitHostPort(node.Addr)
if serverName == "" { if serverName == "" {
serverName = "localhost" // default server name serverName = "localhost" // default server name
} }
@ -191,7 +197,7 @@ func parseChainNode(ns string) (node gost.Node, err error) {
*/ */
config, err := parseKCPConfig(node.Values.Get("c")) config, err := parseKCPConfig(node.Values.Get("c"))
if err != nil { if err != nil {
return node, err return nil, err
} }
tr = gost.KCPTransporter(config) tr = gost.KCPTransporter(config)
case "ssh": case "ssh":
@ -220,7 +226,7 @@ func parseChainNode(ns string) (node gost.Node, err error) {
case "obfs4": case "obfs4":
if err := gost.Obfs4Init(node, false); err != nil { if err := gost.Obfs4Init(node, false); err != nil {
return node, err return nil, err
} }
tr = gost.Obfs4Transporter() tr = gost.Obfs4Transporter()
case "ohttp": case "ohttp":
@ -263,20 +269,31 @@ func parseChainNode(ns string) (node gost.Node, err error) {
interval, _ := strconv.Atoi(node.Values.Get("ping")) interval, _ := strconv.Atoi(node.Values.Get("ping"))
retry, _ := strconv.Atoi(node.Values.Get("retry")) retry, _ := strconv.Atoi(node.Values.Get("retry"))
node.HandshakeOptions = append(node.HandshakeOptions, handshakeOptions := []gost.HandshakeOption{
gost.AddrHandshakeOption(node.Addr), gost.AddrHandshakeOption(node.Addr),
gost.HostHandshakeOption(node.Host), gost.HostHandshakeOption(node.Host),
gost.UserHandshakeOption(node.User), gost.UserHandshakeOption(node.User),
gost.TLSConfigHandshakeOption(tlsCfg), gost.TLSConfigHandshakeOption(tlsCfg),
gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), gost.IntervalHandshakeOption(time.Duration(interval) * time.Second),
gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second), gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second),
gost.RetryHandshakeOption(retry), gost.RetryHandshakeOption(retry),
) }
node.Client = &gost.Client{ node.Client = &gost.Client{
Connector: connector, Connector: connector,
Transporter: tr, Transporter: tr,
} }
ips := parseIP(node.Values.Get("ip"), sport)
for _, ip := range ips {
node.Addr = ip
node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip))
nodes = append(nodes, node)
}
if len(ips) == 0 {
node.HandshakeOptions = handshakeOptions
nodes = []gost.Node{node}
}
return return
} }
@ -559,16 +576,23 @@ func parseUsers(authFile string) (users []*url.Userinfo, err error) {
return return
} }
func parseIP(s string) (ips []string) { func parseIP(s string, port string) (ips []string) {
if s == "" { if s == "" {
return nil return
} }
if port == "" {
port = "8080" // default port
}
file, err := os.Open(s) file, err := os.Open(s)
if err != nil { if err != nil {
ss := strings.Split(s, ",") ss := strings.Split(s, ",")
for _, s := range ss { for _, s := range ss {
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
if s != "" { if s != "" {
if !strings.Contains(s, ":") {
s = s + ":" + port
}
ips = append(ips, s) ips = append(ips, s)
} }
@ -582,6 +606,9 @@ func parseIP(s string) (ips []string) {
if line == "" || strings.HasPrefix(line, "#") { if line == "" || strings.HasPrefix(line, "#") {
continue continue
} }
if !strings.Contains(line, ":") {
line = line + ":" + port
}
ips = append(ips, line) ips = append(ips, line)
} }
return return
@ -590,6 +617,8 @@ func parseIP(s string) (ips []string) {
type peerConfig struct { type peerConfig struct {
Strategy string `json:"strategy"` Strategy string `json:"strategy"`
Filters []string `json:"filters"` Filters []string `json:"filters"`
MaxFails int `json:"max_fails"`
FailTimeout int `json:"fail_timeout"`
Nodes []string `json:"nodes"` Nodes []string `json:"nodes"`
} }
@ -605,6 +634,15 @@ func loadPeerConfig(peer string) (config peerConfig, err error) {
return return
} }
func (cfg *peerConfig) Validate() {
if cfg.MaxFails <= 0 {
cfg.MaxFails = 3
}
if cfg.FailTimeout <= 0 {
cfg.FailTimeout = 30 // seconds
}
}
func parseStrategy(s string) gost.Strategy { func parseStrategy(s string) gost.Strategy {
switch s { switch s {
case "random": case "random":

84
node.go
View File

@ -1,16 +1,17 @@
package gost package gost
import ( import (
"fmt"
"net/url" "net/url"
"strings" "strings"
"sync" "sync/atomic"
"time"
) )
// Node is a proxy node, mainly used to construct a proxy chain. // Node is a proxy node, mainly used to construct a proxy chain.
type Node struct { type Node struct {
ID int ID int
Addr string Addr string
IPs []string
Host string Host string
Protocol string Protocol string
Transport string Transport string
@ -20,7 +21,9 @@ type Node struct {
DialOptions []DialOption DialOptions []DialOption
HandshakeOptions []HandshakeOption HandshakeOptions []HandshakeOption
Client *Client Client *Client
Selector IPSelector group *NodeGroup
failCount uint32
failTime time.Time
} }
// ParseNode parses the node info. // ParseNode parses the node info.
@ -83,38 +86,89 @@ func ParseNode(s string) (node Node, err error) {
return return
} }
// MarkDead marks the node fail status.
func (node *Node) MarkDead() {
atomic.AddUint32(&node.failCount, 1)
node.failTime = time.Now()
if node.group == nil {
return
}
for i := range node.group.nodes {
if node.group.nodes[i].ID == node.ID {
atomic.AddUint32(&node.group.nodes[i].failCount, 1)
node.group.nodes[i].failTime = time.Now()
break
}
}
}
// ResetDead resets the node fail status.
func (node *Node) ResetDead() {
atomic.StoreUint32(&node.failCount, 0)
node.failTime = time.Time{}
if node.group == nil {
return
}
for i := range node.group.nodes {
if node.group.nodes[i].ID == node.ID {
atomic.StoreUint32(&node.group.nodes[i].failCount, 0)
node.group.nodes[i].failTime = time.Time{}
break
}
}
}
func (node *Node) String() string {
return fmt.Sprintf("%d@%s", node.ID, node.Addr)
}
// NodeGroup is a group of nodes. // NodeGroup is a group of nodes.
type NodeGroup struct { type NodeGroup struct {
ID int
nodes []Node nodes []Node
Options []SelectOption Options []SelectOption
Selector NodeSelector Selector NodeSelector
mutex sync.Mutex
mFails map[string]int // node -> fail count
MaxFails int
FailTimeout int
Retries int
} }
// NewNodeGroup creates a node group // NewNodeGroup creates a node group
func NewNodeGroup(nodes ...Node) *NodeGroup { func NewNodeGroup(nodes ...Node) *NodeGroup {
return &NodeGroup{ return &NodeGroup{
nodes: nodes, nodes: nodes,
mFails: make(map[string]int),
} }
} }
// AddNode adds node or node list into group // AddNode adds node or node list into group
func (ng *NodeGroup) AddNode(node ...Node) { func (group *NodeGroup) AddNode(node ...Node) {
if ng == nil { if group == nil {
return return
} }
ng.nodes = append(ng.nodes, node...) group.nodes = append(group.nodes, node...)
} }
// Nodes returns node list in the group // Nodes returns node list in the group
func (ng *NodeGroup) Nodes() []Node { func (group *NodeGroup) Nodes() []Node {
if ng == nil { if group == nil {
return nil return nil
} }
return ng.nodes return group.nodes
}
// Next selects the next node from group.
// It also selects IP if the IP list exists.
func (group *NodeGroup) Next() (node Node, err error) {
selector := group.Selector
if selector == nil {
selector = &defaultSelector{}
}
// select node from node group
node, err = selector.Select(group.Nodes(), group.Options...)
if err != nil {
return
}
node.group = group
return
} }

View File

@ -2,6 +2,8 @@ package gost
import ( import (
"errors" "errors"
"math/rand"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
@ -37,9 +39,28 @@ func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, erro
return sopts.Strategy.Apply(nodes), nil return sopts.Strategy.Apply(nodes), nil
} }
// Filter is used to filter a node during the selection process // SelectOption used when making a select call
type Filter interface { type SelectOption func(*SelectOptions)
Filter([]Node) []Node
// SelectOptions is the options for node selection
type SelectOptions struct {
Filters []Filter
Strategy Strategy
}
// WithFilter adds a filter function to the list of filters
// used during the Select call.
func WithFilter(f ...Filter) SelectOption {
return func(o *SelectOptions) {
o.Filters = append(o.Filters, f...)
}
}
// WithStrategy sets the selector strategy
func WithStrategy(s Strategy) SelectOption {
return func(o *SelectOptions) {
o.Strategy = s
}
} }
// Strategy is a selection strategy e.g random, round robin // Strategy is a selection strategy e.g random, round robin
@ -68,82 +89,60 @@ func (s *RoundStrategy) String() string {
} }
// RandomStrategy is a strategy for node selector // RandomStrategy is a strategy for node selector
type RandomStrategy struct{} type RandomStrategy struct {
Seed int64
rand *rand.Rand
once sync.Once
}
// Apply applies the random strategy for the nodes // Apply applies the random strategy for the nodes
func (s *RandomStrategy) Apply(nodes []Node) Node { func (s *RandomStrategy) Apply(nodes []Node) Node {
s.once.Do(func() {
seed := s.Seed
if seed == 0 {
seed = time.Now().UnixNano()
}
s.rand = rand.New(rand.NewSource(seed))
})
if len(nodes) == 0 { if len(nodes) == 0 {
return Node{} return Node{}
} }
return nodes[time.Now().Nanosecond()%len(nodes)] return nodes[s.rand.Int()%len(nodes)]
} }
func (s *RandomStrategy) String() string { func (s *RandomStrategy) String() string {
return "random" return "random"
} }
// SelectOption used when making a select call // Filter is used to filter a node during the selection process
type SelectOption func(*SelectOptions) type Filter interface {
Filter([]Node) []Node
// SelectOptions is the options for node selection
type SelectOptions struct {
Filters []Filter
Strategy Strategy
}
// WithFilter adds a filter function to the list of filters
// used during the Select call.
func WithFilter(f ...Filter) SelectOption {
return func(o *SelectOptions) {
o.Filters = append(o.Filters, f...)
}
}
// WithStrategy sets the selector strategy
func WithStrategy(s Strategy) SelectOption {
return func(o *SelectOptions) {
o.Strategy = s
}
}
// IPSelector as a mechanism to pick IPs and mark their status.
type IPSelector interface {
Select(ips []string) (string, error)
String() string String() string
} }
// RandomIPSelector is an IP Selector that selects an IP with random strategy. // FailFilter filters the dead node.
type RandomIPSelector struct { // A node is marked as dead if its failed count is greater than MaxFails.
type FailFilter struct {
MaxFails int
FailTimeout time.Duration
} }
// Select selects an IP from ips list. // Filter filters nodes.
func (s *RandomIPSelector) Select(ips []string) (string, error) { func (f *FailFilter) Filter(nodes []Node) []Node {
if len(ips) == 0 { if f.MaxFails <= 0 {
return "", nil return nodes
} }
return ips[time.Now().Nanosecond()%len(ips)], nil nl := []Node{}
} for _, node := range nodes {
if node.failCount < uint32(f.MaxFails) ||
func (s *RandomIPSelector) String() string { time.Since(node.failTime) >= f.FailTimeout {
return "random" nl = append(nl, node)
}
// RoundRobinIPSelector is an IP Selector that selects an IP with round-robin strategy.
type RoundRobinIPSelector struct {
count uint64
}
// Select selects an IP from ips list.
func (s *RoundRobinIPSelector) Select(ips []string) (string, error) {
if len(ips) == 0 {
return "", nil
} }
old := s.count }
atomic.AddUint64(&s.count, 1) return nl
return ips[int(old%uint64(len(ips)))], nil
} }
func (s *RoundRobinIPSelector) String() string { func (f *FailFilter) String() string {
return "round" return "fail"
} }