From d61407c7fb9b6dd3e4f9079640d50df33cd01598 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 4 Jun 2019 10:40:38 +0800 Subject: [PATCH] add ip option support for port forward handler --- cmd/gost/cfg.go | 1 + cmd/gost/route.go | 3 +- forward.go | 214 ++++++++++++++++++---------------------------- forward_test.go | 34 ++++++-- handler.go | 8 ++ selector.go | 45 ++++++++-- 6 files changed, 159 insertions(+), 146 deletions(-) diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index fdc4eab..8ecec21 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -148,6 +148,7 @@ func parseIP(s string, port string) (ips []string) { for _, s := range ss { s = strings.TrimSpace(s) if s != "" { + // TODO: support IPv6 if !strings.Contains(s, ":") { s = s + ":" + port } diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 17271b9..0aad21f 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -417,9 +417,9 @@ func (r *route) GenRouters() ([]router, error) { node.Bypass = parseBypass(node.Get("bypass")) resolver := parseResolver(node.Get("dns")) hosts := parseHosts(node.Get("hosts")) + ips := parseIP(node.Get("ip"), "") handler.Init( - // gost.AddrHandlerOption(node.Addr), gost.AddrHandlerOption(ln.Addr().String()), gost.ChainHandlerOption(chain), gost.UsersHandlerOption(node.User), @@ -435,6 +435,7 @@ func (r *route) GenRouters() ([]router, error) { gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second), gost.ProbeResistHandlerOption(node.Get("probe_resist")), gost.NodeHandlerOption(node), + gost.IPsHandlerOption(ips), ) rt := router{ diff --git a/forward.go b/forward.go index 1034dbb..e9d1f4c 100644 --- a/forward.go +++ b/forward.go @@ -26,42 +26,13 @@ func (c *forwardConnector) Connect(conn net.Conn, addr string, options ...Connec return conn, nil } -type tcpDirectForwardHandler struct { +type baseForwardHandler struct { raddr string group *NodeGroup options *HandlerOptions } -// TCPDirectForwardHandler creates a server Handler for TCP port forwarding server. -// The raddr is the remote address that the server will forward to. -// NOTE: as of 2.6, remote address can be a comma-separated address list. -func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { - h := &tcpDirectForwardHandler{ - raddr: raddr, - group: NewNodeGroup(), - } - - if raddr == "" { - raddr = ":0" // dummy address - } - for i, addr := range strings.Split(raddr, ",") { - if addr == "" { - continue - } - // We treat the remote target server as a node, so we can put them in a group, - // and perform the node selection for load balancing. - h.group.AddNode(Node{ - ID: i + 1, - Addr: addr, - Host: addr, - }) - } - h.Init(opts...) - - return h -} - -func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) { +func (h *baseForwardHandler) Init(options ...HandlerOption) { if h.options == nil { h.options = &HandlerOptions{} } @@ -70,6 +41,8 @@ func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) { opt(h.options) } + h.group = NewNodeGroup() // reset node group + h.group.SetSelector(&defaultSelector{}, WithStrategy(h.options.Strategy), WithFilter(&FailFilter{ @@ -77,6 +50,59 @@ func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) { FailTimeout: 30 * time.Second, }), ) + + n := 1 + addrs := append(strings.Split(h.raddr, ","), h.options.IPs...) + for _, addr := range addrs { + if addr == "" { + continue + } + + // We treat the remote target server as a node, so we can put them in a group, + // and perform the node selection for load balancing. + h.group.AddNode(Node{ + ID: n, + Addr: addr, + Host: addr, + marker: &failMarker{}, + }) + + n++ + } + if len(h.group.Nodes()) == 0 { + h.group.AddNode(Node{ // dummy address + ID: n, + Addr: ":0", + Host: ":0", + }) + } +} + +type tcpDirectForwardHandler struct { + *baseForwardHandler +} + +// TCPDirectForwardHandler creates a server Handler for TCP port forwarding server. +// The raddr is the remote address that the server will forward to. +// NOTE: as of 2.6, remote address can be a comma-separated address list. +func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &tcpDirectForwardHandler{ + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, + } + + for _, opt := range opts { + opt(h.options) + } + + return h +} + +func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) { + h.baseForwardHandler.Init(options...) } func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { @@ -125,9 +151,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { } type udpDirectForwardHandler struct { - raddr string - group *NodeGroup - options *HandlerOptions + *baseForwardHandler } // UDPDirectForwardHandler creates a server Handler for UDP port forwarding server. @@ -135,47 +159,22 @@ type udpDirectForwardHandler struct { // NOTE: as of 2.6, remote address can be a comma-separated address list. func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { h := &udpDirectForwardHandler{ - raddr: raddr, - group: NewNodeGroup(), + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, } - if raddr == "" { - raddr = ":0" // dummy address + for _, opt := range opts { + opt(h.options) } - for i, addr := range strings.Split(raddr, ",") { - if addr == "" { - continue - } - // We treat the remote target server as a node, so we can put them in a group, - // and perform the node selection for load balancing. - h.group.AddNode(Node{ - ID: i + 1, - Addr: addr, - Host: addr, - }) - } - - h.Init(opts...) return h } func (h *udpDirectForwardHandler) Init(options ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - - for _, opt := range options { - opt(h.options) - } - - h.group.SetSelector(&defaultSelector{}, - WithStrategy(h.options.Strategy), - WithFilter(&FailFilter{ - MaxFails: 1, - FailTimeout: 30 * time.Second, - }), - ) + h.baseForwardHandler.Init(options...) } func (h *udpDirectForwardHandler) Handle(conn net.Conn) { @@ -220,9 +219,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { } type tcpRemoteForwardHandler struct { - raddr string - group *NodeGroup - options *HandlerOptions + *baseForwardHandler } // TCPRemoteForwardHandler creates a server Handler for TCP remote port forwarding server. @@ -230,42 +227,22 @@ type tcpRemoteForwardHandler struct { // NOTE: as of 2.6, remote address can be a comma-separated address list. func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { h := &tcpRemoteForwardHandler{ - raddr: raddr, - group: NewNodeGroup(), + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, } - for i, addr := range strings.Split(raddr, ",") { - if addr == "" { - continue - } - // We treat the remote target server as a node, so we can put them in a group, - // and perform the node selection for load balancing. - h.group.AddNode(Node{ - ID: i + 1, - Addr: addr, - Host: addr, - }) + for _, opt := range opts { + opt(h.options) } - h.Init(opts...) return h } func (h *tcpRemoteForwardHandler) Init(options ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - for _, opt := range options { - opt(h.options) - } - - h.group.SetSelector(&defaultSelector{}, - WithStrategy(h.options.Strategy), - WithFilter(&FailFilter{ - MaxFails: 1, - FailTimeout: 30 * time.Second, - }), - ) + h.baseForwardHandler.Init(options...) } func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { @@ -306,9 +283,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { } type udpRemoteForwardHandler struct { - raddr string - group *NodeGroup - options *HandlerOptions + *baseForwardHandler } // UDPRemoteForwardHandler creates a server Handler for UDP remote port forwarding server. @@ -316,43 +291,22 @@ type udpRemoteForwardHandler struct { // NOTE: as of 2.6, remote address can be a comma-separated address list. func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { h := &udpRemoteForwardHandler{ - raddr: raddr, - group: NewNodeGroup(), + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, } - for i, addr := range strings.Split(raddr, ",") { - if addr == "" { - continue - } - // We treat the remote target server as a node, so we can put them in a group, - // and perform the node selection for load balancing. - h.group.AddNode(Node{ - ID: i + 1, - Addr: addr, - Host: addr, - }) + for _, opt := range opts { + opt(h.options) } - h.Init(opts...) - return h } func (h *udpRemoteForwardHandler) Init(options ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - - for _, opt := range options { - opt(h.options) - } - h.group.SetSelector(&defaultSelector{}, - WithStrategy(h.options.Strategy), - WithFilter(&FailFilter{ - MaxFails: 1, - FailTimeout: 30 * time.Second, - }), - ) + h.baseForwardHandler.Init(options...) } func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { diff --git a/forward_test.go b/forward_test.go index fc58a3d..f8653de 100644 --- a/forward_test.go +++ b/forward_test.go @@ -23,9 +23,11 @@ func tcpDirectForwardRoundtrip(targetURL string, data []byte) error { Transporter: TCPTransporter(), } + h := TCPDirectForwardHandler(u.Host) + h.Init() server := &Server{ Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), + Handler: h, } go server.Run() @@ -68,9 +70,12 @@ func BenchmarkTCPDirectForward(b *testing.B) { if err != nil { b.Error(err) } + + h := TCPDirectForwardHandler(u.Host) + h.Init() server := &Server{ Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), + Handler: h, } go server.Run() defer server.Close() @@ -103,9 +108,12 @@ func BenchmarkTCPDirectForwardParallel(b *testing.B) { if err != nil { b.Error(err) } + + h := TCPDirectForwardHandler(u.Host) + h.Init() server := &Server{ Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), + Handler: h, } go server.Run() defer server.Close() @@ -130,9 +138,11 @@ func udpDirectForwardRoundtrip(t *testing.T, host string, data []byte) error { Transporter: UDPTransporter(), } + h := UDPDirectForwardHandler(host) + h.Init() server := &Server{ Listener: ln, - Handler: UDPDirectForwardHandler(host), + Handler: h, } go server.Run() @@ -172,9 +182,11 @@ func BenchmarkUDPDirectForward(b *testing.B) { Transporter: UDPTransporter(), } + h := UDPDirectForwardHandler(udpSrv.Addr()) + h.Init() server := &Server{ Listener: ln, - Handler: UDPDirectForwardHandler(udpSrv.Addr()), + Handler: h, } go server.Run() @@ -205,9 +217,11 @@ func BenchmarkUDPDirectForwardParallel(b *testing.B) { Transporter: UDPTransporter(), } + h := UDPDirectForwardHandler(udpSrv.Addr()) + h.Init() server := &Server{ Listener: ln, - Handler: UDPDirectForwardHandler(udpSrv.Addr()), + Handler: h, } go server.Run() @@ -238,9 +252,11 @@ func tcpRemoteForwardRoundtrip(t *testing.T, targetURL string, data []byte) erro Transporter: TCPTransporter(), } + h := TCPRemoteForwardHandler(u.Host) // forward to u.Host + h.Init() server := &Server{ Listener: ln, - Handler: TCPRemoteForwardHandler(u.Host), // forward to u.Host + Handler: h, } go server.Run() @@ -273,9 +289,11 @@ func udpRemoteForwardRoundtrip(t *testing.T, host string, data []byte) error { Transporter: UDPTransporter(), } + h := UDPRemoteForwardHandler(host) + h.Init() server := &Server{ Listener: ln, - Handler: UDPRemoteForwardHandler(host), + Handler: h, } go server.Run() diff --git a/handler.go b/handler.go index 603d110..568c341 100644 --- a/handler.go +++ b/handler.go @@ -36,6 +36,7 @@ type HandlerOptions struct { ProbeResist string Node Node Host string + IPs []string } // HandlerOption allows a common way to set handler options. @@ -163,6 +164,13 @@ func HostHandlerOption(host string) HandlerOption { } } +// IPsHandlerOption sets the ip list for port forward. +func IPsHandlerOption(ips []string) HandlerOption { + return func(opts *HandlerOptions) { + opts.IPs = ips + } +} + type autoHandler struct { options *HandlerOptions } diff --git a/selector.go b/selector.go index f6d2c14..cb678fd 100644 --- a/selector.go +++ b/selector.go @@ -174,13 +174,10 @@ func (f *FailFilter) Filter(nodes []Node) []Node { } nl := []Node{} for i := range nodes { - marker := &failMarker{} - if nil != nodes[i].marker { - marker = nodes[i].marker.Clone() - } - // log.Logf("%s: %d/%d %v/%v", nodes[i], marker.failCount, f.MaxFails, marker.failTime, f.FailTimeout) - if marker.failCount < uint32(f.MaxFails) || - time.Since(time.Unix(marker.failTime, 0)) >= f.FailTimeout { + marker := nodes[i].marker.Clone() + // log.Logf("%s: %d/%d %v/%v", nodes[i], marker.FailCount(), f.MaxFails, marker.FailTime(), f.FailTimeout) + if marker.FailCount() < uint32(f.MaxFails) || + time.Since(time.Unix(marker.FailTime(), 0)) >= f.FailTimeout { nl = append(nl, nodes[i]) } } @@ -197,7 +194,33 @@ type failMarker struct { mux sync.RWMutex } +func (m *failMarker) FailTime() int64 { + if m == nil { + return 0 + } + + m.mux.Lock() + defer m.mux.Unlock() + + return m.failTime +} + +func (m *failMarker) FailCount() uint32 { + if m == nil { + return 0 + } + + m.mux.Lock() + defer m.mux.Unlock() + + return m.failCount +} + func (m *failMarker) Mark() { + if m == nil { + return + } + m.mux.Lock() defer m.mux.Unlock() @@ -206,6 +229,10 @@ func (m *failMarker) Mark() { } func (m *failMarker) Reset() { + if m == nil { + return + } + m.mux.Lock() defer m.mux.Unlock() @@ -214,6 +241,10 @@ func (m *failMarker) Reset() { } func (m *failMarker) Clone() *failMarker { + if m == nil { + return nil + } + m.mux.RLock() defer m.mux.RUnlock()