add node fail filter
This commit is contained in:
parent
da30584df1
commit
3befde0128
40
chain.go
40
chain.go
@ -142,13 +142,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
|
|||||||
|
|
||||||
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
node.MarkDead()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
|
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
node.MarkDead()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
node.ResetDead()
|
||||||
|
|
||||||
preNode := node
|
preNode := node
|
||||||
for _, node := range nodes[1:] {
|
for _, node := range nodes[1:] {
|
||||||
@ -156,13 +159,17 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
|
|||||||
cc, err = preNode.Client.Connect(cn, node.Addr)
|
cc, err = preNode.Client.Connect(cn, node.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cn.Close()
|
cn.Close()
|
||||||
|
node.MarkDead()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
|
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cn.Close()
|
cn.Close()
|
||||||
|
node.MarkDead()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
node.ResetDead()
|
||||||
|
|
||||||
cn = cc
|
cn = cc
|
||||||
preNode = node
|
preNode = node
|
||||||
}
|
}
|
||||||
@ -179,46 +186,21 @@ func (c *Chain) selectRoute() (route *Chain, err error) {
|
|||||||
buf := bytes.Buffer{}
|
buf := bytes.Buffer{}
|
||||||
route = newRoute()
|
route = newRoute()
|
||||||
for _, group := range c.nodeGroups {
|
for _, group := range c.nodeGroups {
|
||||||
selector := group.Selector
|
node, err := group.Next()
|
||||||
if selector == nil {
|
|
||||||
selector = &defaultSelector{}
|
|
||||||
}
|
|
||||||
// select node from node group
|
|
||||||
node, err := selector.Select(group.Nodes(), group.Options...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, err := selectIP(&node); err != nil {
|
buf.WriteString(fmt.Sprintf("%s -> ", node.String()))
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
buf.WriteString(fmt.Sprintf("%d@%s -> ", node.ID, node.Addr))
|
|
||||||
|
|
||||||
if node.Client.Transporter.Multiplex() {
|
if node.Client.Transporter.Multiplex() {
|
||||||
node.DialOptions = append(node.DialOptions,
|
node.DialOptions = append(node.DialOptions,
|
||||||
ChainDialOption(route),
|
ChainDialOption(route),
|
||||||
)
|
)
|
||||||
route = newRoute() // cutoff the chain for multiplex
|
route = newRoute() // cutoff the chain for multiplex.
|
||||||
}
|
}
|
||||||
|
|
||||||
route.AddNode(node)
|
route.AddNode(node)
|
||||||
}
|
}
|
||||||
log.Log("select route:", buf.String())
|
log.Log("select route:", buf.String())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func selectIP(node *Node) (string, error) {
|
|
||||||
s := node.Selector
|
|
||||||
if s == nil {
|
|
||||||
s = &RandomIPSelector{}
|
|
||||||
}
|
|
||||||
// select IP from IP list
|
|
||||||
ip, err := s.Select(node.IPs)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if ip != "" {
|
|
||||||
// override the original address
|
|
||||||
node.Addr = ip
|
|
||||||
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(node.Addr))
|
|
||||||
}
|
|
||||||
return node.Addr, nil
|
|
||||||
}
|
|
||||||
|
106
cmd/gost/main.go
106
cmd/gost/main.go
@ -87,35 +87,53 @@ func main() {
|
|||||||
|
|
||||||
func initChain() (*gost.Chain, error) {
|
func initChain() (*gost.Chain, error) {
|
||||||
chain := gost.NewChain()
|
chain := gost.NewChain()
|
||||||
|
gid := 1 // group ID
|
||||||
|
|
||||||
for _, ns := range options.ChainNodes {
|
for _, ns := range options.ChainNodes {
|
||||||
|
ngroup := gost.NewNodeGroup()
|
||||||
|
ngroup.ID = gid
|
||||||
|
gid++
|
||||||
|
|
||||||
// parse the base node
|
// parse the base node
|
||||||
node, err := parseChainNode(ns)
|
nodes, err := parseChainNode(ns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
id := 1 // start from 1
|
nid := 1 // node ID
|
||||||
|
|
||||||
node.ID = id
|
for i := range nodes {
|
||||||
ngroup := gost.NewNodeGroup(node)
|
nodes[i].ID = nid
|
||||||
|
nid++
|
||||||
|
}
|
||||||
|
ngroup.AddNode(nodes...)
|
||||||
|
|
||||||
// parse node peers if exists
|
// parse peer nodes if exists
|
||||||
peerCfg, err := loadPeerConfig(node.Values.Get("peer"))
|
peerCfg, err := loadPeerConfig(nodes[0].Values.Get("peer"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Log(err)
|
log.Log(err)
|
||||||
}
|
}
|
||||||
|
peerCfg.Validate()
|
||||||
ngroup.Options = append(ngroup.Options,
|
ngroup.Options = append(ngroup.Options,
|
||||||
// gost.WithFilter(),
|
gost.WithFilter(&gost.FailFilter{
|
||||||
|
MaxFails: peerCfg.MaxFails,
|
||||||
|
FailTimeout: time.Duration(peerCfg.FailTimeout) * time.Second,
|
||||||
|
}),
|
||||||
gost.WithStrategy(parseStrategy(peerCfg.Strategy)),
|
gost.WithStrategy(parseStrategy(peerCfg.Strategy)),
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, s := range peerCfg.Nodes {
|
for _, s := range peerCfg.Nodes {
|
||||||
node, err = parseChainNode(s)
|
nodes, err = parseChainNode(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
id++
|
|
||||||
node.ID = id
|
for i := range nodes {
|
||||||
ngroup.AddNode(node)
|
nodes[i].ID = nid
|
||||||
|
nid++
|
||||||
|
}
|
||||||
|
|
||||||
|
ngroup.AddNode(nodes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddNodeGroup(ngroup)
|
chain.AddNodeGroup(ngroup)
|
||||||
@ -124,24 +142,12 @@ func initChain() (*gost.Chain, error) {
|
|||||||
return chain, nil
|
return chain, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseChainNode(ns string) (node gost.Node, err error) {
|
func parseChainNode(ns string) (nodes []gost.Node, err error) {
|
||||||
node, err = gost.ParseNode(ns)
|
node, err := gost.ParseNode(ns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
node.IPs = parseIP(node.Values.Get("ip"))
|
|
||||||
for i, ip := range node.IPs {
|
|
||||||
if !strings.Contains(ip, ":") {
|
|
||||||
_, sport, _ := net.SplitHostPort(node.Addr)
|
|
||||||
if sport == "" {
|
|
||||||
sport = "8080" // default port
|
|
||||||
}
|
|
||||||
node.IPs[i] = ip + ":" + sport
|
|
||||||
}
|
|
||||||
}
|
|
||||||
node.Selector = &gost.RoundRobinIPSelector{}
|
|
||||||
|
|
||||||
users, err := parseUsers(node.Values.Get("secrets"))
|
users, err := parseUsers(node.Values.Get("secrets"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -149,7 +155,7 @@ func parseChainNode(ns string) (node gost.Node, err error) {
|
|||||||
if node.User == nil && len(users) > 0 {
|
if node.User == nil && len(users) > 0 {
|
||||||
node.User = users[0]
|
node.User = users[0]
|
||||||
}
|
}
|
||||||
serverName, _, _ := net.SplitHostPort(node.Addr)
|
serverName, sport, _ := net.SplitHostPort(node.Addr)
|
||||||
if serverName == "" {
|
if serverName == "" {
|
||||||
serverName = "localhost" // default server name
|
serverName = "localhost" // default server name
|
||||||
}
|
}
|
||||||
@ -191,7 +197,7 @@ func parseChainNode(ns string) (node gost.Node, err error) {
|
|||||||
*/
|
*/
|
||||||
config, err := parseKCPConfig(node.Values.Get("c"))
|
config, err := parseKCPConfig(node.Values.Get("c"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return node, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tr = gost.KCPTransporter(config)
|
tr = gost.KCPTransporter(config)
|
||||||
case "ssh":
|
case "ssh":
|
||||||
@ -220,7 +226,7 @@ func parseChainNode(ns string) (node gost.Node, err error) {
|
|||||||
|
|
||||||
case "obfs4":
|
case "obfs4":
|
||||||
if err := gost.Obfs4Init(node, false); err != nil {
|
if err := gost.Obfs4Init(node, false); err != nil {
|
||||||
return node, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tr = gost.Obfs4Transporter()
|
tr = gost.Obfs4Transporter()
|
||||||
case "ohttp":
|
case "ohttp":
|
||||||
@ -263,20 +269,31 @@ func parseChainNode(ns string) (node gost.Node, err error) {
|
|||||||
|
|
||||||
interval, _ := strconv.Atoi(node.Values.Get("ping"))
|
interval, _ := strconv.Atoi(node.Values.Get("ping"))
|
||||||
retry, _ := strconv.Atoi(node.Values.Get("retry"))
|
retry, _ := strconv.Atoi(node.Values.Get("retry"))
|
||||||
node.HandshakeOptions = append(node.HandshakeOptions,
|
handshakeOptions := []gost.HandshakeOption{
|
||||||
gost.AddrHandshakeOption(node.Addr),
|
gost.AddrHandshakeOption(node.Addr),
|
||||||
gost.HostHandshakeOption(node.Host),
|
gost.HostHandshakeOption(node.Host),
|
||||||
gost.UserHandshakeOption(node.User),
|
gost.UserHandshakeOption(node.User),
|
||||||
gost.TLSConfigHandshakeOption(tlsCfg),
|
gost.TLSConfigHandshakeOption(tlsCfg),
|
||||||
gost.IntervalHandshakeOption(time.Duration(interval)*time.Second),
|
gost.IntervalHandshakeOption(time.Duration(interval) * time.Second),
|
||||||
gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second),
|
gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second),
|
||||||
gost.RetryHandshakeOption(retry),
|
gost.RetryHandshakeOption(retry),
|
||||||
)
|
}
|
||||||
node.Client = &gost.Client{
|
node.Client = &gost.Client{
|
||||||
Connector: connector,
|
Connector: connector,
|
||||||
Transporter: tr,
|
Transporter: tr,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ips := parseIP(node.Values.Get("ip"), sport)
|
||||||
|
for _, ip := range ips {
|
||||||
|
node.Addr = ip
|
||||||
|
node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip))
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
}
|
||||||
|
if len(ips) == 0 {
|
||||||
|
node.HandshakeOptions = handshakeOptions
|
||||||
|
nodes = []gost.Node{node}
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -559,16 +576,23 @@ func parseUsers(authFile string) (users []*url.Userinfo, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseIP(s string) (ips []string) {
|
func parseIP(s string, port string) (ips []string) {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
if port == "" {
|
||||||
|
port = "8080" // default port
|
||||||
|
}
|
||||||
|
|
||||||
file, err := os.Open(s)
|
file, err := os.Open(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ss := strings.Split(s, ",")
|
ss := strings.Split(s, ",")
|
||||||
for _, s := range ss {
|
for _, s := range ss {
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
if s != "" {
|
if s != "" {
|
||||||
|
if !strings.Contains(s, ":") {
|
||||||
|
s = s + ":" + port
|
||||||
|
}
|
||||||
ips = append(ips, s)
|
ips = append(ips, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -582,6 +606,9 @@ func parseIP(s string) (ips []string) {
|
|||||||
if line == "" || strings.HasPrefix(line, "#") {
|
if line == "" || strings.HasPrefix(line, "#") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(line, ":") {
|
||||||
|
line = line + ":" + port
|
||||||
|
}
|
||||||
ips = append(ips, line)
|
ips = append(ips, line)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -590,6 +617,8 @@ func parseIP(s string) (ips []string) {
|
|||||||
type peerConfig struct {
|
type peerConfig struct {
|
||||||
Strategy string `json:"strategy"`
|
Strategy string `json:"strategy"`
|
||||||
Filters []string `json:"filters"`
|
Filters []string `json:"filters"`
|
||||||
|
MaxFails int `json:"max_fails"`
|
||||||
|
FailTimeout int `json:"fail_timeout"`
|
||||||
Nodes []string `json:"nodes"`
|
Nodes []string `json:"nodes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -605,6 +634,15 @@ func loadPeerConfig(peer string) (config peerConfig, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cfg *peerConfig) Validate() {
|
||||||
|
if cfg.MaxFails <= 0 {
|
||||||
|
cfg.MaxFails = 3
|
||||||
|
}
|
||||||
|
if cfg.FailTimeout <= 0 {
|
||||||
|
cfg.FailTimeout = 30 // seconds
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func parseStrategy(s string) gost.Strategy {
|
func parseStrategy(s string) gost.Strategy {
|
||||||
switch s {
|
switch s {
|
||||||
case "random":
|
case "random":
|
||||||
|
84
node.go
84
node.go
@ -1,16 +1,17 @@
|
|||||||
package gost
|
package gost
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Node is a proxy node, mainly used to construct a proxy chain.
|
// Node is a proxy node, mainly used to construct a proxy chain.
|
||||||
type Node struct {
|
type Node struct {
|
||||||
ID int
|
ID int
|
||||||
Addr string
|
Addr string
|
||||||
IPs []string
|
|
||||||
Host string
|
Host string
|
||||||
Protocol string
|
Protocol string
|
||||||
Transport string
|
Transport string
|
||||||
@ -20,7 +21,9 @@ type Node struct {
|
|||||||
DialOptions []DialOption
|
DialOptions []DialOption
|
||||||
HandshakeOptions []HandshakeOption
|
HandshakeOptions []HandshakeOption
|
||||||
Client *Client
|
Client *Client
|
||||||
Selector IPSelector
|
group *NodeGroup
|
||||||
|
failCount uint32
|
||||||
|
failTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseNode parses the node info.
|
// ParseNode parses the node info.
|
||||||
@ -83,38 +86,89 @@ func ParseNode(s string) (node Node, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarkDead marks the node fail status.
|
||||||
|
func (node *Node) MarkDead() {
|
||||||
|
atomic.AddUint32(&node.failCount, 1)
|
||||||
|
node.failTime = time.Now()
|
||||||
|
|
||||||
|
if node.group == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range node.group.nodes {
|
||||||
|
if node.group.nodes[i].ID == node.ID {
|
||||||
|
atomic.AddUint32(&node.group.nodes[i].failCount, 1)
|
||||||
|
node.group.nodes[i].failTime = time.Now()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDead resets the node fail status.
|
||||||
|
func (node *Node) ResetDead() {
|
||||||
|
atomic.StoreUint32(&node.failCount, 0)
|
||||||
|
node.failTime = time.Time{}
|
||||||
|
|
||||||
|
if node.group == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range node.group.nodes {
|
||||||
|
if node.group.nodes[i].ID == node.ID {
|
||||||
|
atomic.StoreUint32(&node.group.nodes[i].failCount, 0)
|
||||||
|
node.group.nodes[i].failTime = time.Time{}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *Node) String() string {
|
||||||
|
return fmt.Sprintf("%d@%s", node.ID, node.Addr)
|
||||||
|
}
|
||||||
|
|
||||||
// NodeGroup is a group of nodes.
|
// NodeGroup is a group of nodes.
|
||||||
type NodeGroup struct {
|
type NodeGroup struct {
|
||||||
|
ID int
|
||||||
nodes []Node
|
nodes []Node
|
||||||
Options []SelectOption
|
Options []SelectOption
|
||||||
Selector NodeSelector
|
Selector NodeSelector
|
||||||
mutex sync.Mutex
|
|
||||||
mFails map[string]int // node -> fail count
|
|
||||||
MaxFails int
|
|
||||||
FailTimeout int
|
|
||||||
Retries int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNodeGroup creates a node group
|
// NewNodeGroup creates a node group
|
||||||
func NewNodeGroup(nodes ...Node) *NodeGroup {
|
func NewNodeGroup(nodes ...Node) *NodeGroup {
|
||||||
return &NodeGroup{
|
return &NodeGroup{
|
||||||
nodes: nodes,
|
nodes: nodes,
|
||||||
mFails: make(map[string]int),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddNode adds node or node list into group
|
// AddNode adds node or node list into group
|
||||||
func (ng *NodeGroup) AddNode(node ...Node) {
|
func (group *NodeGroup) AddNode(node ...Node) {
|
||||||
if ng == nil {
|
if group == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ng.nodes = append(ng.nodes, node...)
|
group.nodes = append(group.nodes, node...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nodes returns node list in the group
|
// Nodes returns node list in the group
|
||||||
func (ng *NodeGroup) Nodes() []Node {
|
func (group *NodeGroup) Nodes() []Node {
|
||||||
if ng == nil {
|
if group == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return ng.nodes
|
return group.nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next selects the next node from group.
|
||||||
|
// It also selects IP if the IP list exists.
|
||||||
|
func (group *NodeGroup) Next() (node Node, err error) {
|
||||||
|
selector := group.Selector
|
||||||
|
if selector == nil {
|
||||||
|
selector = &defaultSelector{}
|
||||||
|
}
|
||||||
|
// select node from node group
|
||||||
|
node, err = selector.Select(group.Nodes(), group.Options...)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
node.group = group
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
117
selector.go
117
selector.go
@ -2,6 +2,8 @@ package gost
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -37,9 +39,28 @@ func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, erro
|
|||||||
return sopts.Strategy.Apply(nodes), nil
|
return sopts.Strategy.Apply(nodes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter is used to filter a node during the selection process
|
// SelectOption used when making a select call
|
||||||
type Filter interface {
|
type SelectOption func(*SelectOptions)
|
||||||
Filter([]Node) []Node
|
|
||||||
|
// SelectOptions is the options for node selection
|
||||||
|
type SelectOptions struct {
|
||||||
|
Filters []Filter
|
||||||
|
Strategy Strategy
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFilter adds a filter function to the list of filters
|
||||||
|
// used during the Select call.
|
||||||
|
func WithFilter(f ...Filter) SelectOption {
|
||||||
|
return func(o *SelectOptions) {
|
||||||
|
o.Filters = append(o.Filters, f...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStrategy sets the selector strategy
|
||||||
|
func WithStrategy(s Strategy) SelectOption {
|
||||||
|
return func(o *SelectOptions) {
|
||||||
|
o.Strategy = s
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strategy is a selection strategy e.g random, round robin
|
// Strategy is a selection strategy e.g random, round robin
|
||||||
@ -68,82 +89,60 @@ func (s *RoundStrategy) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RandomStrategy is a strategy for node selector
|
// RandomStrategy is a strategy for node selector
|
||||||
type RandomStrategy struct{}
|
type RandomStrategy struct {
|
||||||
|
Seed int64
|
||||||
|
rand *rand.Rand
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
// Apply applies the random strategy for the nodes
|
// Apply applies the random strategy for the nodes
|
||||||
func (s *RandomStrategy) Apply(nodes []Node) Node {
|
func (s *RandomStrategy) Apply(nodes []Node) Node {
|
||||||
|
s.once.Do(func() {
|
||||||
|
seed := s.Seed
|
||||||
|
if seed == 0 {
|
||||||
|
seed = time.Now().UnixNano()
|
||||||
|
}
|
||||||
|
s.rand = rand.New(rand.NewSource(seed))
|
||||||
|
})
|
||||||
if len(nodes) == 0 {
|
if len(nodes) == 0 {
|
||||||
return Node{}
|
return Node{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes[time.Now().Nanosecond()%len(nodes)]
|
return nodes[s.rand.Int()%len(nodes)]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RandomStrategy) String() string {
|
func (s *RandomStrategy) String() string {
|
||||||
return "random"
|
return "random"
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectOption used when making a select call
|
// Filter is used to filter a node during the selection process
|
||||||
type SelectOption func(*SelectOptions)
|
type Filter interface {
|
||||||
|
Filter([]Node) []Node
|
||||||
// SelectOptions is the options for node selection
|
|
||||||
type SelectOptions struct {
|
|
||||||
Filters []Filter
|
|
||||||
Strategy Strategy
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithFilter adds a filter function to the list of filters
|
|
||||||
// used during the Select call.
|
|
||||||
func WithFilter(f ...Filter) SelectOption {
|
|
||||||
return func(o *SelectOptions) {
|
|
||||||
o.Filters = append(o.Filters, f...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithStrategy sets the selector strategy
|
|
||||||
func WithStrategy(s Strategy) SelectOption {
|
|
||||||
return func(o *SelectOptions) {
|
|
||||||
o.Strategy = s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPSelector as a mechanism to pick IPs and mark their status.
|
|
||||||
type IPSelector interface {
|
|
||||||
Select(ips []string) (string, error)
|
|
||||||
String() string
|
String() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RandomIPSelector is an IP Selector that selects an IP with random strategy.
|
// FailFilter filters the dead node.
|
||||||
type RandomIPSelector struct {
|
// A node is marked as dead if its failed count is greater than MaxFails.
|
||||||
|
type FailFilter struct {
|
||||||
|
MaxFails int
|
||||||
|
FailTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select selects an IP from ips list.
|
// Filter filters nodes.
|
||||||
func (s *RandomIPSelector) Select(ips []string) (string, error) {
|
func (f *FailFilter) Filter(nodes []Node) []Node {
|
||||||
if len(ips) == 0 {
|
if f.MaxFails <= 0 {
|
||||||
return "", nil
|
return nodes
|
||||||
}
|
}
|
||||||
return ips[time.Now().Nanosecond()%len(ips)], nil
|
nl := []Node{}
|
||||||
}
|
for _, node := range nodes {
|
||||||
|
if node.failCount < uint32(f.MaxFails) ||
|
||||||
func (s *RandomIPSelector) String() string {
|
time.Since(node.failTime) >= f.FailTimeout {
|
||||||
return "random"
|
nl = append(nl, node)
|
||||||
}
|
|
||||||
|
|
||||||
// RoundRobinIPSelector is an IP Selector that selects an IP with round-robin strategy.
|
|
||||||
type RoundRobinIPSelector struct {
|
|
||||||
count uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select selects an IP from ips list.
|
|
||||||
func (s *RoundRobinIPSelector) Select(ips []string) (string, error) {
|
|
||||||
if len(ips) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
}
|
||||||
old := s.count
|
}
|
||||||
atomic.AddUint64(&s.count, 1)
|
return nl
|
||||||
return ips[int(old%uint64(len(ips)))], nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RoundRobinIPSelector) String() string {
|
func (f *FailFilter) String() string {
|
||||||
return "round"
|
return "fail"
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user