diff --git a/cmd/gost/peer.go b/cmd/gost/peer.go index 4ed2e45..e2cb495 100644 --- a/cmd/gost/peer.go +++ b/cmd/gost/peer.go @@ -57,10 +57,13 @@ func (cfg *peerConfig) Reload(r io.Reader) error { group := cfg.group group.SetSelector( nil, - gost.WithFilter(&gost.FailFilter{ - MaxFails: cfg.MaxFails, - FailTimeout: cfg.FailTimeout, - }), + gost.WithFilter( + &gost.FailFilter{ + MaxFails: cfg.MaxFails, + FailTimeout: cfg.FailTimeout, + }, + &gost.InvalidFilter{}, + ), gost.WithStrategy(gost.NewStrategy(cfg.Strategy)), ) diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 7ececfb..cdbc0ed 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -52,10 +52,13 @@ func (r *route) parseChain() (*gost.Chain, error) { ngroup.AddNode(nodes...) ngroup.SetSelector(nil, - gost.WithFilter(&gost.FailFilter{ - MaxFails: defaultMaxFails, - FailTimeout: defaultFailTimeout, - }), + gost.WithFilter( + &gost.FailFilter{ + MaxFails: defaultMaxFails, + FailTimeout: defaultFailTimeout, + }, + &gost.InvalidFilter{}, + ), gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))), ) diff --git a/selector.go b/selector.go index cb678fd..10efa04 100644 --- a/selector.go +++ b/selector.go @@ -3,6 +3,8 @@ package gost import ( "errors" "math/rand" + "net" + "strconv" "sync" "sync/atomic" "time" @@ -167,7 +169,7 @@ type FailFilter struct { FailTimeout time.Duration } -// Filter filters nodes. +// Filter filters dead nodes. func (f *FailFilter) Filter(nodes []Node) []Node { if len(nodes) <= 1 || f.MaxFails <= 0 { return nodes @@ -188,6 +190,26 @@ func (f *FailFilter) String() string { return "fail" } +// InvalidFilter filters the invalid node. +// A node is invalid if its port is invalid (negative or zero value). +type InvalidFilter struct{} + +// Filter filters invalid nodes. +func (f *InvalidFilter) Filter(nodes []Node) []Node { + nl := []Node{} + for i := range nodes { + _, sport, _ := net.SplitHostPort(nodes[i].Addr) + if port, _ := strconv.Atoi(sport); port > 0 { + nl = append(nl, nodes[i]) + } + } + return nl +} + +func (f *InvalidFilter) String() string { + return "invalid" +} + type failMarker struct { failTime int64 failCount uint32