add max_fails & fail_timeout options support for port forwarding
This commit is contained in:
parent
e8ad44cab3
commit
bea815ec9a
@ -13,11 +13,6 @@ import (
|
||||
"github.com/ginuerzh/gost"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxFails = 1
|
||||
defaultFailTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type peerConfig struct {
|
||||
Strategy string `json:"strategy"`
|
||||
MaxFails int `json:"max_fails"`
|
||||
@ -36,12 +31,6 @@ func newPeerConfig() *peerConfig {
|
||||
}
|
||||
|
||||
func (cfg *peerConfig) Validate() {
|
||||
if cfg.MaxFails <= 0 {
|
||||
cfg.MaxFails = defaultMaxFails
|
||||
}
|
||||
if cfg.FailTimeout <= 0 {
|
||||
cfg.FailTimeout = defaultFailTimeout // seconds
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *peerConfig) Reload(r io.Reader) error {
|
||||
|
@ -51,20 +51,11 @@ func (r *route) parseChain() (*gost.Chain, error) {
|
||||
}
|
||||
ngroup.AddNode(nodes...)
|
||||
|
||||
maxFails := nodes[0].GetInt("max_fails")
|
||||
if maxFails == 0 {
|
||||
maxFails = defaultMaxFails
|
||||
}
|
||||
failTimeout := nodes[0].GetDuration("fail_timeout")
|
||||
if failTimeout == 0 {
|
||||
failTimeout = defaultFailTimeout
|
||||
}
|
||||
|
||||
ngroup.SetSelector(nil,
|
||||
gost.WithFilter(
|
||||
&gost.FailFilter{
|
||||
MaxFails: maxFails,
|
||||
FailTimeout: failTimeout,
|
||||
MaxFails: nodes[0].GetInt("max_fails"),
|
||||
FailTimeout: nodes[0].GetDuration("fail_timeout"),
|
||||
},
|
||||
&gost.InvalidFilter{},
|
||||
),
|
||||
@ -444,6 +435,8 @@ func (r *route) GenRouters() ([]router, error) {
|
||||
gost.WhitelistHandlerOption(whitelist),
|
||||
gost.BlacklistHandlerOption(blacklist),
|
||||
gost.StrategyHandlerOption(gost.NewStrategy(node.Get("strategy"))),
|
||||
gost.MaxFailsHandlerOption(node.GetInt("max_fails")),
|
||||
gost.FailTimeoutHandlerOption(node.GetDuration("fail_timeout")),
|
||||
gost.BypassHandlerOption(node.Bypass),
|
||||
gost.ResolverHandlerOption(resolver),
|
||||
gost.HostsHandlerOption(hosts),
|
||||
|
@ -46,8 +46,8 @@ func (h *baseForwardHandler) Init(options ...HandlerOption) {
|
||||
h.group.SetSelector(&defaultSelector{},
|
||||
WithStrategy(h.options.Strategy),
|
||||
WithFilter(&FailFilter{
|
||||
MaxFails: 1,
|
||||
FailTimeout: 30 * time.Second,
|
||||
MaxFails: h.options.MaxFails,
|
||||
FailTimeout: h.options.FailTimeout,
|
||||
}),
|
||||
)
|
||||
|
||||
|
16
handler.go
16
handler.go
@ -28,6 +28,8 @@ type HandlerOptions struct {
|
||||
Whitelist *Permissions
|
||||
Blacklist *Permissions
|
||||
Strategy Strategy
|
||||
MaxFails int
|
||||
FailTimeout time.Duration
|
||||
Bypass *Bypass
|
||||
Retries int
|
||||
Timeout time.Duration
|
||||
@ -116,6 +118,20 @@ func StrategyHandlerOption(strategy Strategy) HandlerOption {
|
||||
}
|
||||
}
|
||||
|
||||
// MaxFailsHandlerOption sets the max_fails option of HandlerOptions.
|
||||
func MaxFailsHandlerOption(n int) HandlerOption {
|
||||
return func(opts *HandlerOptions) {
|
||||
opts.MaxFails = n
|
||||
}
|
||||
}
|
||||
|
||||
// FailTimeoutHandlerOption sets the fail_timeout option of HandlerOptions.
|
||||
func FailTimeoutHandlerOption(d time.Duration) HandlerOption {
|
||||
return func(opts *HandlerOptions) {
|
||||
opts.FailTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
// RetryHandlerOption sets the retry option of HandlerOptions.
|
||||
func RetryHandlerOption(retries int) HandlerOption {
|
||||
return func(opts *HandlerOptions) {
|
||||
|
21
selector.go
21
selector.go
@ -162,6 +162,12 @@ type Filter interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
// default options for FailFilter
|
||||
const (
|
||||
DefaultMaxFails = 1
|
||||
DefaultFailTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// FailFilter filters the dead node.
|
||||
// A node is marked as dead if its failed count is greater than MaxFails.
|
||||
type FailFilter struct {
|
||||
@ -171,15 +177,24 @@ type FailFilter struct {
|
||||
|
||||
// Filter filters dead nodes.
|
||||
func (f *FailFilter) Filter(nodes []Node) []Node {
|
||||
if len(nodes) <= 1 || f.MaxFails <= 0 {
|
||||
maxFails := f.MaxFails
|
||||
if maxFails == 0 {
|
||||
maxFails = DefaultMaxFails
|
||||
}
|
||||
failTimeout := f.FailTimeout
|
||||
if failTimeout == 0 {
|
||||
failTimeout = DefaultFailTimeout
|
||||
}
|
||||
|
||||
if len(nodes) <= 1 || maxFails < 0 {
|
||||
return nodes
|
||||
}
|
||||
nl := []Node{}
|
||||
for i := range nodes {
|
||||
marker := nodes[i].marker.Clone()
|
||||
// log.Logf("%s: %d/%d %v/%v", nodes[i], marker.FailCount(), f.MaxFails, marker.FailTime(), f.FailTimeout)
|
||||
if marker.FailCount() < uint32(f.MaxFails) ||
|
||||
time.Since(time.Unix(marker.FailTime(), 0)) >= f.FailTimeout {
|
||||
if marker.FailCount() < uint32(maxFails) ||
|
||||
time.Since(time.Unix(marker.FailTime(), 0)) >= failTimeout {
|
||||
nl = append(nl, nodes[i])
|
||||
}
|
||||
}
|
||||
|
@ -97,13 +97,14 @@ func TestFailFilter(t *testing.T) {
|
||||
t.Error("unexpected node", v)
|
||||
}
|
||||
|
||||
filter.MaxFails = 1
|
||||
filter.MaxFails = -1
|
||||
nodes[0].MarkDead()
|
||||
if v := filter.Filter(nodes); !isEqual(v, nodes) {
|
||||
t.Error("unexpected node", v)
|
||||
}
|
||||
|
||||
nodes[0].MarkDead()
|
||||
if v := filter.Filter(nodes); !isEqual(v, nodes) {
|
||||
filter.MaxFails = 0
|
||||
if v := filter.Filter(nodes); isEqual(v, nodes) {
|
||||
t.Error("unexpected node", v)
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user