This commit is contained in:
ginuerzh 2019-06-15 16:35:33 +08:00
parent 466d2fef6c
commit f89062a84b
3 changed files with 37 additions and 9 deletions

View File

@ -57,10 +57,13 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
group := cfg.group group := cfg.group
group.SetSelector( group.SetSelector(
nil, nil,
gost.WithFilter(&gost.FailFilter{ gost.WithFilter(
&gost.FailFilter{
MaxFails: cfg.MaxFails, MaxFails: cfg.MaxFails,
FailTimeout: cfg.FailTimeout, FailTimeout: cfg.FailTimeout,
}), },
&gost.InvalidFilter{},
),
gost.WithStrategy(gost.NewStrategy(cfg.Strategy)), gost.WithStrategy(gost.NewStrategy(cfg.Strategy)),
) )

View File

@ -52,10 +52,13 @@ func (r *route) parseChain() (*gost.Chain, error) {
ngroup.AddNode(nodes...) ngroup.AddNode(nodes...)
ngroup.SetSelector(nil, ngroup.SetSelector(nil,
gost.WithFilter(&gost.FailFilter{ gost.WithFilter(
&gost.FailFilter{
MaxFails: defaultMaxFails, MaxFails: defaultMaxFails,
FailTimeout: defaultFailTimeout, FailTimeout: defaultFailTimeout,
}), },
&gost.InvalidFilter{},
),
gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))), gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))),
) )

View File

@ -3,6 +3,8 @@ package gost
import ( import (
"errors" "errors"
"math/rand" "math/rand"
"net"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -167,7 +169,7 @@ type FailFilter struct {
FailTimeout time.Duration FailTimeout time.Duration
} }
// Filter filters nodes. // Filter filters dead nodes.
func (f *FailFilter) Filter(nodes []Node) []Node { func (f *FailFilter) Filter(nodes []Node) []Node {
if len(nodes) <= 1 || f.MaxFails <= 0 { if len(nodes) <= 1 || f.MaxFails <= 0 {
return nodes return nodes
@ -188,6 +190,26 @@ func (f *FailFilter) String() string {
return "fail" return "fail"
} }
// InvalidFilter filters the invalid node.
// A node is invalid if its port is invalid (negative or zero value).
type InvalidFilter struct{}
// Filter filters invalid nodes.
func (f *InvalidFilter) Filter(nodes []Node) []Node {
nl := []Node{}
for i := range nodes {
_, sport, _ := net.SplitHostPort(nodes[i].Addr)
if port, _ := strconv.Atoi(sport); port > 0 {
nl = append(nl, nodes[i])
}
}
return nl
}
func (f *InvalidFilter) String() string {
return "invalid"
}
type failMarker struct { type failMarker struct {
failTime int64 failTime int64
failCount uint32 failCount uint32