diff --git a/bypass.go b/bypass.go index fdeaa13..416d045 100644 --- a/bypass.go +++ b/bypass.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "net" + "strconv" "strings" glob "github.com/gobwas/glob" @@ -147,8 +148,10 @@ func (bp *Bypass) Contains(addr string) bool { return false } // try to strip the port - if host, _, _ := net.SplitHostPort(addr); host != "" { - addr = host + if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { + if p, _ := strconv.Atoi(port); p > 0 { // port is valid + addr = host + } } var matched bool for _, matcher := range bp.matchers { diff --git a/bypass_test.go b/bypass_test.go index dbe4793..975b769 100644 --- a/bypass_test.go +++ b/bypass_test.go @@ -77,15 +77,15 @@ var bypassTests = []struct { } func TestBypass(t *testing.T) { - for i, test := range bypassTests { + for _, test := range bypassTests { bp := NewBypassPatterns(test.patterns, test.reversed) if bp.Contains(test.addr) != test.bypassed { - t.Errorf("test %d failed", i) + t.Errorf("test failed: %v, %s", test.patterns, test.addr) } rbp := NewBypassPatterns(test.patterns, !test.reversed) if rbp.Contains(test.addr) == test.bypassed { - t.Errorf("reverse test %d failed", i) + t.Errorf("reverse test failed: %v, %s", test.patterns, test.addr) } } } diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 4cdaf1b..067f620 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -123,12 +123,18 @@ func (r *route) initChain() (*gost.Chain, error) { log.Log(err) } peerCfg.Validate() + + strategy := peerCfg.Strategy + // overwrite the strategry in the peer config if `strategy` param exists. + if s := nodes[0].Get("strategy"); s != "" { + strategy = s + } ngroup.Options = append(ngroup.Options, gost.WithFilter(&gost.FailFilter{ MaxFails: peerCfg.MaxFails, FailTimeout: time.Duration(peerCfg.FailTimeout) * time.Second, }), - gost.WithStrategy(parseStrategy(peerCfg.Strategy)), + gost.WithStrategy(parseStrategy(strategy)), ) for _, s := range peerCfg.Nodes { @@ -146,6 +152,7 @@ func (r *route) initChain() (*gost.Chain, error) { } var bypass *gost.Bypass + // global bypass if peerCfg.Bypass != nil { bypass = gost.NewBypassPatterns(peerCfg.Bypass.Patterns, peerCfg.Bypass.Reverse) } @@ -464,6 +471,7 @@ func (r *route) serve() error { gost.WhitelistHandlerOption(whitelist), gost.BlacklistHandlerOption(blacklist), gost.BypassHandlerOption(parseBypass(node.Get("bypass"))), + gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))), ) var handler gost.Handler switch node.Protocol { diff --git a/forward.go b/forward.go index a556552..b84ca15 100644 --- a/forward.go +++ b/forward.go @@ -36,9 +36,17 @@ type tcpDirectForwardHandler struct { // The raddr is the remote address that the server will forward to. // NOTE: as of 2.6, remote address can be a comma-separated address list. func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &tcpDirectForwardHandler{ + raddr: raddr, + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + group := NewNodeGroup() group.SetSelector(&defaultSelector{}, - WithStrategy(&RoundStrategy{}), + WithStrategy(h.options.Strategy), WithFilter(&FailFilter{ MaxFails: 1, FailTimeout: 30 * time.Second, @@ -57,15 +65,8 @@ func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { Host: addr, }) } + h.group = group - h := &tcpDirectForwardHandler{ - raddr: raddr, - group: group, - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } return h } @@ -104,9 +105,17 @@ type udpDirectForwardHandler struct { // The raddr is the remote address that the server will forward to. // NOTE: as of 2.6, remote address can be a comma-separated address list. func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &udpDirectForwardHandler{ + raddr: raddr, + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + group := NewNodeGroup() group.SetSelector(&defaultSelector{}, - WithStrategy(&RoundStrategy{}), + WithStrategy(h.options.Strategy), WithFilter(&FailFilter{ MaxFails: 1, FailTimeout: 30 * time.Second, @@ -125,15 +134,8 @@ func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { Host: addr, }) } + h.group = group - h := &udpDirectForwardHandler{ - raddr: raddr, - group: group, - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } return h } @@ -188,9 +190,17 @@ type tcpRemoteForwardHandler struct { // The raddr is the remote address that the server will forward to. // NOTE: as of 2.6, remote address can be a comma-separated address list. func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &tcpRemoteForwardHandler{ + raddr: raddr, + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + group := NewNodeGroup() group.SetSelector(&defaultSelector{}, - WithStrategy(&RoundStrategy{}), + WithStrategy(h.options.Strategy), WithFilter(&FailFilter{ MaxFails: 1, FailTimeout: 30 * time.Second, @@ -209,15 +219,8 @@ func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { Host: addr, }) } + h.group = group - h := &tcpRemoteForwardHandler{ - raddr: raddr, - group: group, - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } return h } @@ -254,9 +257,17 @@ type udpRemoteForwardHandler struct { // The raddr is the remote address that the server will forward to. // NOTE: as of 2.6, remote address can be a comma-separated address list. func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &udpRemoteForwardHandler{ + raddr: raddr, + options: &HandlerOptions{}, + } + for _, opt := range opts { + opt(h.options) + } + group := NewNodeGroup() group.SetSelector(&defaultSelector{}, - WithStrategy(&RoundStrategy{}), + WithStrategy(h.options.Strategy), WithFilter(&FailFilter{ MaxFails: 1, FailTimeout: 30 * time.Second, @@ -274,15 +285,8 @@ func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { Host: addr, }) } + h.group = group - h := &udpRemoteForwardHandler{ - raddr: raddr, - group: group, - options: &HandlerOptions{}, - } - for _, opt := range opts { - opt(h.options) - } return h } diff --git a/handler.go b/handler.go index f0802e7..559c6c4 100644 --- a/handler.go +++ b/handler.go @@ -25,6 +25,7 @@ type HandlerOptions struct { Whitelist *Permissions Blacklist *Permissions Bypass *Bypass + Strategy Strategy } // HandlerOption allows a common way to set handler options. @@ -79,6 +80,13 @@ func BypassHandlerOption(bypass *Bypass) HandlerOption { } } +// StrategyHandlerOption sets the strategy option of HandlerOptions. +func StrategyHandlerOption(strategy Strategy) HandlerOption { + return func(opts *HandlerOptions) { + opts.Strategy = strategy + } +} + type autoHandler struct { options []HandlerOption }