diff --git a/cmd/gost/peer.go b/cmd/gost/peer.go index e2cb495..9e1d00f 100644 --- a/cmd/gost/peer.go +++ b/cmd/gost/peer.go @@ -13,11 +13,6 @@ import ( "github.com/ginuerzh/gost" ) -const ( - defaultMaxFails = 1 - defaultFailTimeout = 30 * time.Second -) - type peerConfig struct { Strategy string `json:"strategy"` MaxFails int `json:"max_fails"` @@ -36,12 +31,6 @@ func newPeerConfig() *peerConfig { } func (cfg *peerConfig) Validate() { - if cfg.MaxFails <= 0 { - cfg.MaxFails = defaultMaxFails - } - if cfg.FailTimeout <= 0 { - cfg.FailTimeout = defaultFailTimeout // seconds - } } func (cfg *peerConfig) Reload(r io.Reader) error { diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 38721da..b4e358a 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -51,20 +51,11 @@ func (r *route) parseChain() (*gost.Chain, error) { } ngroup.AddNode(nodes...) - maxFails := nodes[0].GetInt("max_fails") - if maxFails == 0 { - maxFails = defaultMaxFails - } - failTimeout := nodes[0].GetDuration("fail_timeout") - if failTimeout == 0 { - failTimeout = defaultFailTimeout - } - ngroup.SetSelector(nil, gost.WithFilter( &gost.FailFilter{ - MaxFails: maxFails, - FailTimeout: failTimeout, + MaxFails: nodes[0].GetInt("max_fails"), + FailTimeout: nodes[0].GetDuration("fail_timeout"), }, &gost.InvalidFilter{}, ), @@ -444,6 +435,8 @@ func (r *route) GenRouters() ([]router, error) { gost.WhitelistHandlerOption(whitelist), gost.BlacklistHandlerOption(blacklist), gost.StrategyHandlerOption(gost.NewStrategy(node.Get("strategy"))), + gost.MaxFailsHandlerOption(node.GetInt("max_fails")), + gost.FailTimeoutHandlerOption(node.GetDuration("fail_timeout")), gost.BypassHandlerOption(node.Bypass), gost.ResolverHandlerOption(resolver), gost.HostsHandlerOption(hosts), diff --git a/forward.go b/forward.go index e9d1f4c..d403102 100644 --- a/forward.go +++ b/forward.go @@ -46,8 +46,8 @@ func (h *baseForwardHandler) Init(options ...HandlerOption) { h.group.SetSelector(&defaultSelector{}, WithStrategy(h.options.Strategy), WithFilter(&FailFilter{ - MaxFails: 1, - FailTimeout: 30 * time.Second, + MaxFails: h.options.MaxFails, + FailTimeout: h.options.FailTimeout, }), ) diff --git a/handler.go b/handler.go index b0521fe..66f894d 100644 --- a/handler.go +++ b/handler.go @@ -28,6 +28,8 @@ type HandlerOptions struct { Whitelist *Permissions Blacklist *Permissions Strategy Strategy + MaxFails int + FailTimeout time.Duration Bypass *Bypass Retries int Timeout time.Duration @@ -116,6 +118,20 @@ func StrategyHandlerOption(strategy Strategy) HandlerOption { } } +// MaxFailsHandlerOption sets the max_fails option of HandlerOptions. +func MaxFailsHandlerOption(n int) HandlerOption { + return func(opts *HandlerOptions) { + opts.MaxFails = n + } +} + +// FailTimeoutHandlerOption sets the fail_timeout option of HandlerOptions. +func FailTimeoutHandlerOption(d time.Duration) HandlerOption { + return func(opts *HandlerOptions) { + opts.FailTimeout = d + } +} + // RetryHandlerOption sets the retry option of HandlerOptions. func RetryHandlerOption(retries int) HandlerOption { return func(opts *HandlerOptions) { diff --git a/selector.go b/selector.go index 10efa04..12545ac 100644 --- a/selector.go +++ b/selector.go @@ -162,6 +162,12 @@ type Filter interface { String() string } +// default options for FailFilter +const ( + DefaultMaxFails = 1 + DefaultFailTimeout = 30 * time.Second +) + // FailFilter filters the dead node. // A node is marked as dead if its failed count is greater than MaxFails. type FailFilter struct { @@ -171,15 +177,24 @@ type FailFilter struct { // Filter filters dead nodes. func (f *FailFilter) Filter(nodes []Node) []Node { - if len(nodes) <= 1 || f.MaxFails <= 0 { + maxFails := f.MaxFails + if maxFails == 0 { + maxFails = DefaultMaxFails + } + failTimeout := f.FailTimeout + if failTimeout == 0 { + failTimeout = DefaultFailTimeout + } + + if len(nodes) <= 1 || maxFails < 0 { return nodes } nl := []Node{} for i := range nodes { marker := nodes[i].marker.Clone() // log.Logf("%s: %d/%d %v/%v", nodes[i], marker.FailCount(), f.MaxFails, marker.FailTime(), f.FailTimeout) - if marker.FailCount() < uint32(f.MaxFails) || - time.Since(time.Unix(marker.FailTime(), 0)) >= f.FailTimeout { + if marker.FailCount() < uint32(maxFails) || + time.Since(time.Unix(marker.FailTime(), 0)) >= failTimeout { nl = append(nl, nodes[i]) } } diff --git a/selector_test.go b/selector_test.go index 41fe95f..5da667c 100644 --- a/selector_test.go +++ b/selector_test.go @@ -97,13 +97,14 @@ func TestFailFilter(t *testing.T) { t.Error("unexpected node", v) } - filter.MaxFails = 1 + filter.MaxFails = -1 + nodes[0].MarkDead() if v := filter.Filter(nodes); !isEqual(v, nodes) { t.Error("unexpected node", v) } - nodes[0].MarkDead() - if v := filter.Filter(nodes); !isEqual(v, nodes) { + filter.MaxFails = 0 + if v := filter.Filter(nodes); isEqual(v, nodes) { t.Error("unexpected node", v) }