add ip option support for port forward handler

This commit is contained in:
ginuerzh 2019-06-04 10:40:38 +08:00
parent 2e5601bfd6
commit d61407c7fb
6 changed files with 159 additions and 146 deletions

View File

@ -148,6 +148,7 @@ func parseIP(s string, port string) (ips []string) {
for _, s := range ss { for _, s := range ss {
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
if s != "" { if s != "" {
// TODO: support IPv6
if !strings.Contains(s, ":") { if !strings.Contains(s, ":") {
s = s + ":" + port s = s + ":" + port
} }

View File

@ -417,9 +417,9 @@ func (r *route) GenRouters() ([]router, error) {
node.Bypass = parseBypass(node.Get("bypass")) node.Bypass = parseBypass(node.Get("bypass"))
resolver := parseResolver(node.Get("dns")) resolver := parseResolver(node.Get("dns"))
hosts := parseHosts(node.Get("hosts")) hosts := parseHosts(node.Get("hosts"))
ips := parseIP(node.Get("ip"), "")
handler.Init( handler.Init(
// gost.AddrHandlerOption(node.Addr),
gost.AddrHandlerOption(ln.Addr().String()), gost.AddrHandlerOption(ln.Addr().String()),
gost.ChainHandlerOption(chain), gost.ChainHandlerOption(chain),
gost.UsersHandlerOption(node.User), gost.UsersHandlerOption(node.User),
@ -435,6 +435,7 @@ func (r *route) GenRouters() ([]router, error) {
gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second), gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second),
gost.ProbeResistHandlerOption(node.Get("probe_resist")), gost.ProbeResistHandlerOption(node.Get("probe_resist")),
gost.NodeHandlerOption(node), gost.NodeHandlerOption(node),
gost.IPsHandlerOption(ips),
) )
rt := router{ rt := router{

View File

@ -26,42 +26,13 @@ func (c *forwardConnector) Connect(conn net.Conn, addr string, options ...Connec
return conn, nil return conn, nil
} }
type tcpDirectForwardHandler struct { type baseForwardHandler struct {
raddr string raddr string
group *NodeGroup group *NodeGroup
options *HandlerOptions options *HandlerOptions
} }
// TCPDirectForwardHandler creates a server Handler for TCP port forwarding server. func (h *baseForwardHandler) Init(options ...HandlerOption) {
// The raddr is the remote address that the server will forward to.
// NOTE: as of 2.6, remote address can be a comma-separated address list.
func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler {
h := &tcpDirectForwardHandler{
raddr: raddr,
group: NewNodeGroup(),
}
if raddr == "" {
raddr = ":0" // dummy address
}
for i, addr := range strings.Split(raddr, ",") {
if addr == "" {
continue
}
// We treat the remote target server as a node, so we can put them in a group,
// and perform the node selection for load balancing.
h.group.AddNode(Node{
ID: i + 1,
Addr: addr,
Host: addr,
})
}
h.Init(opts...)
return h
}
func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) {
if h.options == nil { if h.options == nil {
h.options = &HandlerOptions{} h.options = &HandlerOptions{}
} }
@ -70,6 +41,8 @@ func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) {
opt(h.options) opt(h.options)
} }
h.group = NewNodeGroup() // reset node group
h.group.SetSelector(&defaultSelector{}, h.group.SetSelector(&defaultSelector{},
WithStrategy(h.options.Strategy), WithStrategy(h.options.Strategy),
WithFilter(&FailFilter{ WithFilter(&FailFilter{
@ -77,6 +50,59 @@ func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) {
FailTimeout: 30 * time.Second, FailTimeout: 30 * time.Second,
}), }),
) )
n := 1
addrs := append(strings.Split(h.raddr, ","), h.options.IPs...)
for _, addr := range addrs {
if addr == "" {
continue
}
// We treat the remote target server as a node, so we can put them in a group,
// and perform the node selection for load balancing.
h.group.AddNode(Node{
ID: n,
Addr: addr,
Host: addr,
marker: &failMarker{},
})
n++
}
if len(h.group.Nodes()) == 0 {
h.group.AddNode(Node{ // dummy address
ID: n,
Addr: ":0",
Host: ":0",
})
}
}
type tcpDirectForwardHandler struct {
*baseForwardHandler
}
// TCPDirectForwardHandler creates a server Handler for TCP port forwarding server.
// The raddr is the remote address that the server will forward to.
// NOTE: as of 2.6, remote address can be a comma-separated address list.
func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler {
h := &tcpDirectForwardHandler{
baseForwardHandler: &baseForwardHandler{
raddr: raddr,
group: NewNodeGroup(),
options: &HandlerOptions{},
},
}
for _, opt := range opts {
opt(h.options)
}
return h
}
func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) {
h.baseForwardHandler.Init(options...)
} }
func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
@ -125,9 +151,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
} }
type udpDirectForwardHandler struct { type udpDirectForwardHandler struct {
raddr string *baseForwardHandler
group *NodeGroup
options *HandlerOptions
} }
// UDPDirectForwardHandler creates a server Handler for UDP port forwarding server. // UDPDirectForwardHandler creates a server Handler for UDP port forwarding server.
@ -135,47 +159,22 @@ type udpDirectForwardHandler struct {
// NOTE: as of 2.6, remote address can be a comma-separated address list. // NOTE: as of 2.6, remote address can be a comma-separated address list.
func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler {
h := &udpDirectForwardHandler{ h := &udpDirectForwardHandler{
baseForwardHandler: &baseForwardHandler{
raddr: raddr, raddr: raddr,
group: NewNodeGroup(), group: NewNodeGroup(),
options: &HandlerOptions{},
},
} }
if raddr == "" { for _, opt := range opts {
raddr = ":0" // dummy address opt(h.options)
} }
for i, addr := range strings.Split(raddr, ",") {
if addr == "" {
continue
}
// We treat the remote target server as a node, so we can put them in a group,
// and perform the node selection for load balancing.
h.group.AddNode(Node{
ID: i + 1,
Addr: addr,
Host: addr,
})
}
h.Init(opts...)
return h return h
} }
func (h *udpDirectForwardHandler) Init(options ...HandlerOption) { func (h *udpDirectForwardHandler) Init(options ...HandlerOption) {
if h.options == nil { h.baseForwardHandler.Init(options...)
h.options = &HandlerOptions{}
}
for _, opt := range options {
opt(h.options)
}
h.group.SetSelector(&defaultSelector{},
WithStrategy(h.options.Strategy),
WithFilter(&FailFilter{
MaxFails: 1,
FailTimeout: 30 * time.Second,
}),
)
} }
func (h *udpDirectForwardHandler) Handle(conn net.Conn) { func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
@ -220,9 +219,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
} }
type tcpRemoteForwardHandler struct { type tcpRemoteForwardHandler struct {
raddr string *baseForwardHandler
group *NodeGroup
options *HandlerOptions
} }
// TCPRemoteForwardHandler creates a server Handler for TCP remote port forwarding server. // TCPRemoteForwardHandler creates a server Handler for TCP remote port forwarding server.
@ -230,42 +227,22 @@ type tcpRemoteForwardHandler struct {
// NOTE: as of 2.6, remote address can be a comma-separated address list. // NOTE: as of 2.6, remote address can be a comma-separated address list.
func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler {
h := &tcpRemoteForwardHandler{ h := &tcpRemoteForwardHandler{
baseForwardHandler: &baseForwardHandler{
raddr: raddr, raddr: raddr,
group: NewNodeGroup(), group: NewNodeGroup(),
options: &HandlerOptions{},
},
} }
for i, addr := range strings.Split(raddr, ",") { for _, opt := range opts {
if addr == "" { opt(h.options)
continue
} }
// We treat the remote target server as a node, so we can put them in a group,
// and perform the node selection for load balancing.
h.group.AddNode(Node{
ID: i + 1,
Addr: addr,
Host: addr,
})
}
h.Init(opts...)
return h return h
} }
func (h *tcpRemoteForwardHandler) Init(options ...HandlerOption) { func (h *tcpRemoteForwardHandler) Init(options ...HandlerOption) {
if h.options == nil { h.baseForwardHandler.Init(options...)
h.options = &HandlerOptions{}
}
for _, opt := range options {
opt(h.options)
}
h.group.SetSelector(&defaultSelector{},
WithStrategy(h.options.Strategy),
WithFilter(&FailFilter{
MaxFails: 1,
FailTimeout: 30 * time.Second,
}),
)
} }
func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
@ -306,9 +283,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
} }
type udpRemoteForwardHandler struct { type udpRemoteForwardHandler struct {
raddr string *baseForwardHandler
group *NodeGroup
options *HandlerOptions
} }
// UDPRemoteForwardHandler creates a server Handler for UDP remote port forwarding server. // UDPRemoteForwardHandler creates a server Handler for UDP remote port forwarding server.
@ -316,43 +291,22 @@ type udpRemoteForwardHandler struct {
// NOTE: as of 2.6, remote address can be a comma-separated address list. // NOTE: as of 2.6, remote address can be a comma-separated address list.
func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler {
h := &udpRemoteForwardHandler{ h := &udpRemoteForwardHandler{
baseForwardHandler: &baseForwardHandler{
raddr: raddr, raddr: raddr,
group: NewNodeGroup(), group: NewNodeGroup(),
options: &HandlerOptions{},
},
} }
for i, addr := range strings.Split(raddr, ",") { for _, opt := range opts {
if addr == "" { opt(h.options)
continue
} }
// We treat the remote target server as a node, so we can put them in a group,
// and perform the node selection for load balancing.
h.group.AddNode(Node{
ID: i + 1,
Addr: addr,
Host: addr,
})
}
h.Init(opts...)
return h return h
} }
func (h *udpRemoteForwardHandler) Init(options ...HandlerOption) { func (h *udpRemoteForwardHandler) Init(options ...HandlerOption) {
if h.options == nil { h.baseForwardHandler.Init(options...)
h.options = &HandlerOptions{}
}
for _, opt := range options {
opt(h.options)
}
h.group.SetSelector(&defaultSelector{},
WithStrategy(h.options.Strategy),
WithFilter(&FailFilter{
MaxFails: 1,
FailTimeout: 30 * time.Second,
}),
)
} }
func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {

View File

@ -23,9 +23,11 @@ func tcpDirectForwardRoundtrip(targetURL string, data []byte) error {
Transporter: TCPTransporter(), Transporter: TCPTransporter(),
} }
h := TCPDirectForwardHandler(u.Host)
h.Init()
server := &Server{ server := &Server{
Listener: ln, Listener: ln,
Handler: TCPDirectForwardHandler(u.Host), Handler: h,
} }
go server.Run() go server.Run()
@ -68,9 +70,12 @@ func BenchmarkTCPDirectForward(b *testing.B) {
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
h := TCPDirectForwardHandler(u.Host)
h.Init()
server := &Server{ server := &Server{
Listener: ln, Listener: ln,
Handler: TCPDirectForwardHandler(u.Host), Handler: h,
} }
go server.Run() go server.Run()
defer server.Close() defer server.Close()
@ -103,9 +108,12 @@ func BenchmarkTCPDirectForwardParallel(b *testing.B) {
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
h := TCPDirectForwardHandler(u.Host)
h.Init()
server := &Server{ server := &Server{
Listener: ln, Listener: ln,
Handler: TCPDirectForwardHandler(u.Host), Handler: h,
} }
go server.Run() go server.Run()
defer server.Close() defer server.Close()
@ -130,9 +138,11 @@ func udpDirectForwardRoundtrip(t *testing.T, host string, data []byte) error {
Transporter: UDPTransporter(), Transporter: UDPTransporter(),
} }
h := UDPDirectForwardHandler(host)
h.Init()
server := &Server{ server := &Server{
Listener: ln, Listener: ln,
Handler: UDPDirectForwardHandler(host), Handler: h,
} }
go server.Run() go server.Run()
@ -172,9 +182,11 @@ func BenchmarkUDPDirectForward(b *testing.B) {
Transporter: UDPTransporter(), Transporter: UDPTransporter(),
} }
h := UDPDirectForwardHandler(udpSrv.Addr())
h.Init()
server := &Server{ server := &Server{
Listener: ln, Listener: ln,
Handler: UDPDirectForwardHandler(udpSrv.Addr()), Handler: h,
} }
go server.Run() go server.Run()
@ -205,9 +217,11 @@ func BenchmarkUDPDirectForwardParallel(b *testing.B) {
Transporter: UDPTransporter(), Transporter: UDPTransporter(),
} }
h := UDPDirectForwardHandler(udpSrv.Addr())
h.Init()
server := &Server{ server := &Server{
Listener: ln, Listener: ln,
Handler: UDPDirectForwardHandler(udpSrv.Addr()), Handler: h,
} }
go server.Run() go server.Run()
@ -238,9 +252,11 @@ func tcpRemoteForwardRoundtrip(t *testing.T, targetURL string, data []byte) erro
Transporter: TCPTransporter(), Transporter: TCPTransporter(),
} }
h := TCPRemoteForwardHandler(u.Host) // forward to u.Host
h.Init()
server := &Server{ server := &Server{
Listener: ln, Listener: ln,
Handler: TCPRemoteForwardHandler(u.Host), // forward to u.Host Handler: h,
} }
go server.Run() go server.Run()
@ -273,9 +289,11 @@ func udpRemoteForwardRoundtrip(t *testing.T, host string, data []byte) error {
Transporter: UDPTransporter(), Transporter: UDPTransporter(),
} }
h := UDPRemoteForwardHandler(host)
h.Init()
server := &Server{ server := &Server{
Listener: ln, Listener: ln,
Handler: UDPRemoteForwardHandler(host), Handler: h,
} }
go server.Run() go server.Run()

View File

@ -36,6 +36,7 @@ type HandlerOptions struct {
ProbeResist string ProbeResist string
Node Node Node Node
Host string Host string
IPs []string
} }
// HandlerOption allows a common way to set handler options. // HandlerOption allows a common way to set handler options.
@ -163,6 +164,13 @@ func HostHandlerOption(host string) HandlerOption {
} }
} }
// IPsHandlerOption sets the ip list for port forward.
func IPsHandlerOption(ips []string) HandlerOption {
return func(opts *HandlerOptions) {
opts.IPs = ips
}
}
type autoHandler struct { type autoHandler struct {
options *HandlerOptions options *HandlerOptions
} }

View File

@ -174,13 +174,10 @@ func (f *FailFilter) Filter(nodes []Node) []Node {
} }
nl := []Node{} nl := []Node{}
for i := range nodes { for i := range nodes {
marker := &failMarker{} marker := nodes[i].marker.Clone()
if nil != nodes[i].marker { // log.Logf("%s: %d/%d %v/%v", nodes[i], marker.FailCount(), f.MaxFails, marker.FailTime(), f.FailTimeout)
marker = nodes[i].marker.Clone() if marker.FailCount() < uint32(f.MaxFails) ||
} time.Since(time.Unix(marker.FailTime(), 0)) >= f.FailTimeout {
// log.Logf("%s: %d/%d %v/%v", nodes[i], marker.failCount, f.MaxFails, marker.failTime, f.FailTimeout)
if marker.failCount < uint32(f.MaxFails) ||
time.Since(time.Unix(marker.failTime, 0)) >= f.FailTimeout {
nl = append(nl, nodes[i]) nl = append(nl, nodes[i])
} }
} }
@ -197,7 +194,33 @@ type failMarker struct {
mux sync.RWMutex mux sync.RWMutex
} }
func (m *failMarker) FailTime() int64 {
if m == nil {
return 0
}
m.mux.Lock()
defer m.mux.Unlock()
return m.failTime
}
func (m *failMarker) FailCount() uint32 {
if m == nil {
return 0
}
m.mux.Lock()
defer m.mux.Unlock()
return m.failCount
}
func (m *failMarker) Mark() { func (m *failMarker) Mark() {
if m == nil {
return
}
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
@ -206,6 +229,10 @@ func (m *failMarker) Mark() {
} }
func (m *failMarker) Reset() { func (m *failMarker) Reset() {
if m == nil {
return
}
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
@ -214,6 +241,10 @@ func (m *failMarker) Reset() {
} }
func (m *failMarker) Clone() *failMarker { func (m *failMarker) Clone() *failMarker {
if m == nil {
return nil
}
m.mux.RLock() m.mux.RLock()
defer m.mux.RUnlock() defer m.mux.RUnlock()