add max_fails & fail_timeout options support for port forwarding

This commit is contained in:
ginuerzh 2019-06-20 10:47:58 +08:00
parent e8ad44cab3
commit bea815ec9a
6 changed files with 44 additions and 30 deletions

View File

@ -13,11 +13,6 @@ import (
"github.com/ginuerzh/gost"
)
const (
defaultMaxFails = 1
defaultFailTimeout = 30 * time.Second
)
type peerConfig struct {
Strategy string `json:"strategy"`
MaxFails int `json:"max_fails"`
@ -36,12 +31,6 @@ func newPeerConfig() *peerConfig {
}
func (cfg *peerConfig) Validate() {
if cfg.MaxFails <= 0 {
cfg.MaxFails = defaultMaxFails
}
if cfg.FailTimeout <= 0 {
cfg.FailTimeout = defaultFailTimeout // seconds
}
}
func (cfg *peerConfig) Reload(r io.Reader) error {

View File

@ -51,20 +51,11 @@ func (r *route) parseChain() (*gost.Chain, error) {
}
ngroup.AddNode(nodes...)
maxFails := nodes[0].GetInt("max_fails")
if maxFails == 0 {
maxFails = defaultMaxFails
}
failTimeout := nodes[0].GetDuration("fail_timeout")
if failTimeout == 0 {
failTimeout = defaultFailTimeout
}
ngroup.SetSelector(nil,
gost.WithFilter(
&gost.FailFilter{
MaxFails: maxFails,
FailTimeout: failTimeout,
MaxFails: nodes[0].GetInt("max_fails"),
FailTimeout: nodes[0].GetDuration("fail_timeout"),
},
&gost.InvalidFilter{},
),
@ -444,6 +435,8 @@ func (r *route) GenRouters() ([]router, error) {
gost.WhitelistHandlerOption(whitelist),
gost.BlacklistHandlerOption(blacklist),
gost.StrategyHandlerOption(gost.NewStrategy(node.Get("strategy"))),
gost.MaxFailsHandlerOption(node.GetInt("max_fails")),
gost.FailTimeoutHandlerOption(node.GetDuration("fail_timeout")),
gost.BypassHandlerOption(node.Bypass),
gost.ResolverHandlerOption(resolver),
gost.HostsHandlerOption(hosts),

View File

@ -46,8 +46,8 @@ func (h *baseForwardHandler) Init(options ...HandlerOption) {
h.group.SetSelector(&defaultSelector{},
WithStrategy(h.options.Strategy),
WithFilter(&FailFilter{
MaxFails: 1,
FailTimeout: 30 * time.Second,
MaxFails: h.options.MaxFails,
FailTimeout: h.options.FailTimeout,
}),
)

View File

@ -28,6 +28,8 @@ type HandlerOptions struct {
Whitelist *Permissions
Blacklist *Permissions
Strategy Strategy
MaxFails int
FailTimeout time.Duration
Bypass *Bypass
Retries int
Timeout time.Duration
@ -116,6 +118,20 @@ func StrategyHandlerOption(strategy Strategy) HandlerOption {
}
}
// MaxFailsHandlerOption sets the max_fails option of HandlerOptions.
func MaxFailsHandlerOption(n int) HandlerOption {
return func(opts *HandlerOptions) {
opts.MaxFails = n
}
}
// FailTimeoutHandlerOption sets the fail_timeout option of HandlerOptions.
func FailTimeoutHandlerOption(d time.Duration) HandlerOption {
return func(opts *HandlerOptions) {
opts.FailTimeout = d
}
}
// RetryHandlerOption sets the retry option of HandlerOptions.
func RetryHandlerOption(retries int) HandlerOption {
return func(opts *HandlerOptions) {

View File

@ -162,6 +162,12 @@ type Filter interface {
String() string
}
// default options for FailFilter
const (
DefaultMaxFails = 1
DefaultFailTimeout = 30 * time.Second
)
// FailFilter filters the dead node.
// A node is marked as dead if its failed count is greater than MaxFails.
type FailFilter struct {
@ -171,15 +177,24 @@ type FailFilter struct {
// Filter filters dead nodes.
func (f *FailFilter) Filter(nodes []Node) []Node {
if len(nodes) <= 1 || f.MaxFails <= 0 {
maxFails := f.MaxFails
if maxFails == 0 {
maxFails = DefaultMaxFails
}
failTimeout := f.FailTimeout
if failTimeout == 0 {
failTimeout = DefaultFailTimeout
}
if len(nodes) <= 1 || maxFails < 0 {
return nodes
}
nl := []Node{}
for i := range nodes {
marker := nodes[i].marker.Clone()
// log.Logf("%s: %d/%d %v/%v", nodes[i], marker.FailCount(), f.MaxFails, marker.FailTime(), f.FailTimeout)
if marker.FailCount() < uint32(f.MaxFails) ||
time.Since(time.Unix(marker.FailTime(), 0)) >= f.FailTimeout {
if marker.FailCount() < uint32(maxFails) ||
time.Since(time.Unix(marker.FailTime(), 0)) >= failTimeout {
nl = append(nl, nodes[i])
}
}

View File

@ -97,13 +97,14 @@ func TestFailFilter(t *testing.T) {
t.Error("unexpected node", v)
}
filter.MaxFails = 1
filter.MaxFails = -1
nodes[0].MarkDead()
if v := filter.Filter(nodes); !isEqual(v, nodes) {
t.Error("unexpected node", v)
}
nodes[0].MarkDead()
if v := filter.Filter(nodes); !isEqual(v, nodes) {
filter.MaxFails = 0
if v := filter.Filter(nodes); isEqual(v, nodes) {
t.Error("unexpected node", v)
}