From 9345a2dc2b4dc145138c1206da7cce753df07187 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 20 May 2018 22:03:41 +0800 Subject: [PATCH] #249 add load balancing support for port forwarding --- forward.go | 185 +++++++++++++++++++++++++++++++++++++++++++++-------- node.go | 10 +++ 2 files changed, 169 insertions(+), 26 deletions(-) diff --git a/forward.go b/forward.go index 718c0a8..a556552 100644 --- a/forward.go +++ b/forward.go @@ -3,6 +3,7 @@ package gost import ( "errors" "net" + "strings" "sync" "time" @@ -27,17 +28,39 @@ func (c *forwardConnector) Connect(conn net.Conn, addr string) (net.Conn, error) type tcpDirectForwardHandler struct { raddr string + group *NodeGroup options *HandlerOptions } // TCPDirectForwardHandler creates a server Handler for TCP port forwarding server. // The raddr is the remote address that the server will forward to. +// NOTE: as of 2.6, remote address can be a comma-separated address list. func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { - if raddr == "" { - raddr = "0.0.0.0:0" + group := NewNodeGroup() + group.SetSelector(&defaultSelector{}, + WithStrategy(&RoundStrategy{}), + WithFilter(&FailFilter{ + MaxFails: 1, + FailTimeout: 30 * time.Second, + }), + ) + + for i, addr := range strings.Split(raddr, ",") { + if addr == "" { + continue + } + // We treat the remote target server as a node, so we can put them in a group, + // and perform the node selection for load balancing. + group.AddNode(Node{ + ID: i + 1, + Addr: addr, + Host: addr, + }) } + h := &tcpDirectForwardHandler{ raddr: raddr, + group: group, options: &HandlerOptions{}, } for _, opt := range opts { @@ -49,29 +72,63 @@ func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { defer conn.Close() - log.Logf("[tcp] %s - %s", conn.RemoteAddr(), h.raddr) - cc, err := h.options.Chain.Dial(h.raddr) + node, err := h.group.Next() if err != nil { - log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), h.raddr, err) + log.Logf("[tcp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) + return + } + + log.Logf("[tcp] %s - %s", conn.RemoteAddr(), node.Addr) + cc, err := h.options.Chain.Dial(node.Addr) + if err != nil { + log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) + node.MarkDead() return } defer cc.Close() - log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), h.raddr) + node.ResetDead() + + log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr) transport(conn, cc) - log.Logf("[tcp] %s >-< %s", conn.RemoteAddr(), h.raddr) + log.Logf("[tcp] %s >-< %s", conn.RemoteAddr(), node.Addr) } type udpDirectForwardHandler struct { raddr string + group *NodeGroup options *HandlerOptions } // UDPDirectForwardHandler creates a server Handler for UDP port forwarding server. // The raddr is the remote address that the server will forward to. +// NOTE: as of 2.6, remote address can be a comma-separated address list. func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { + group := NewNodeGroup() + group.SetSelector(&defaultSelector{}, + WithStrategy(&RoundStrategy{}), + WithFilter(&FailFilter{ + MaxFails: 1, + FailTimeout: 30 * time.Second, + }), + ) + + for i, addr := range strings.Split(raddr, ",") { + if addr == "" { + continue + } + // We treat the remote target server as a node, so we can put them in a group, + // and perform the node selection for load balancing. + group.AddNode(Node{ + ID: i + 1, + Addr: addr, + Host: addr, + }) + } + h := &udpDirectForwardHandler{ raddr: raddr, + group: group, options: &HandlerOptions{}, } for _, opt := range opts { @@ -83,45 +140,79 @@ func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { func (h *udpDirectForwardHandler) Handle(conn net.Conn) { defer conn.Close() + node, err := h.group.Next() + if err != nil { + log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) + return + } + var cc net.Conn if h.options.Chain.IsEmpty() { - raddr, err := net.ResolveUDPAddr("udp", h.raddr) + raddr, err := net.ResolveUDPAddr("udp", node.Addr) if err != nil { - log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + node.MarkDead() + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) return } cc, err = net.DialUDP("udp", nil, raddr) if err != nil { - log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + node.MarkDead() + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) return } } else { var err error cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) if err != nil { - log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) return } - cc = &udpTunnelConn{Conn: cc, raddr: h.raddr} + cc = &udpTunnelConn{Conn: cc, raddr: node.Addr} } defer cc.Close() + node.ResetDead() - log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), h.raddr) + log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr) transport(conn, cc) - log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), h.raddr) + log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), node.Addr) } type tcpRemoteForwardHandler struct { raddr string + group *NodeGroup options *HandlerOptions } // TCPRemoteForwardHandler creates a server Handler for TCP remote port forwarding server. // The raddr is the remote address that the server will forward to. +// NOTE: as of 2.6, remote address can be a comma-separated address list. func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { + group := NewNodeGroup() + group.SetSelector(&defaultSelector{}, + WithStrategy(&RoundStrategy{}), + WithFilter(&FailFilter{ + MaxFails: 1, + FailTimeout: 30 * time.Second, + }), + ) + + for i, addr := range strings.Split(raddr, ",") { + if addr == "" { + continue + } + // We treat the remote target server as a node, so we can put them in a group, + // and perform the node selection for load balancing. + group.AddNode(Node{ + ID: i + 1, + Addr: addr, + Host: addr, + }) + } + h := &tcpRemoteForwardHandler{ raddr: raddr, + group: group, options: &HandlerOptions{}, } for _, opt := range opts { @@ -133,28 +224,60 @@ func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { defer conn.Close() - cc, err := net.DialTimeout("tcp", h.raddr, DialTimeout) + node, err := h.group.Next() if err != nil { - log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), h.raddr, err) + log.Logf("[rtcp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + return + } + + cc, err := net.DialTimeout("tcp", node.Addr, DialTimeout) + if err != nil { + node.MarkDead() + log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err) return } defer cc.Close() + node.ResetDead() - log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), h.raddr) + log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr) transport(cc, conn) - log.Logf("[rtcp] %s >-< %s", conn.LocalAddr(), h.raddr) + log.Logf("[rtcp] %s >-< %s", conn.LocalAddr(), node.Addr) } type udpRemoteForwardHandler struct { raddr string + group *NodeGroup options *HandlerOptions } // UDPRemoteForwardHandler creates a server Handler for UDP remote port forwarding server. // The raddr is the remote address that the server will forward to. +// NOTE: as of 2.6, remote address can be a comma-separated address list. func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { + group := NewNodeGroup() + group.SetSelector(&defaultSelector{}, + WithStrategy(&RoundStrategy{}), + WithFilter(&FailFilter{ + MaxFails: 1, + FailTimeout: 30 * time.Second, + }), + ) + for i, addr := range strings.Split(raddr, ",") { + if addr == "" { + continue + } + // We treat the remote target server as a node, so we can put them in a group, + // and perform the node selection for load balancing. + group.AddNode(Node{ + ID: i + 1, + Addr: addr, + Host: addr, + }) + } + h := &udpRemoteForwardHandler{ raddr: raddr, + group: group, options: &HandlerOptions{}, } for _, opt := range opts { @@ -166,20 +289,30 @@ func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { defer conn.Close() - raddr, err := net.ResolveUDPAddr("udp", h.raddr) - if err != nil { - log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) - return - } - cc, err := net.DialUDP("udp", nil, raddr) + node, err := h.group.Next() if err != nil { log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) return } - log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), h.raddr) + raddr, err := net.ResolveUDPAddr("udp", node.Addr) + if err != nil { + node.MarkDead() + log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) + return + } + cc, err := net.DialUDP("udp", nil, raddr) + if err != nil { + node.MarkDead() + log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) + return + } + defer cc.Close() + node.ResetDead() + + log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr) transport(conn, cc) - log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), h.raddr) + log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr) } type udpDirectForwardListener struct { diff --git a/node.go b/node.go index c174ba5..c70c134 100644 --- a/node.go +++ b/node.go @@ -140,6 +140,7 @@ func (node *Node) Clone() Node { group: node.group, failCount: atomic.LoadUint32(&node.failCount), failTime: atomic.LoadInt64(&node.failTime), + Bypass: node.Bypass, } } @@ -187,6 +188,15 @@ func (group *NodeGroup) AddNode(node ...Node) { group.nodes = append(group.nodes, node...) } +// SetSelector sets node selector with options for the group. +func (group *NodeGroup) SetSelector(selector NodeSelector, opts ...SelectOption) { + if group == nil { + return + } + group.Selector = selector + group.Options = opts +} + // Nodes returns node list in the group func (group *NodeGroup) Nodes() []Node { if group == nil {