#249 add load balancing support for port forwarding

This commit is contained in:
ginuerzh 2018-05-20 22:03:41 +08:00
parent b15c5bb38c
commit 9345a2dc2b
2 changed files with 169 additions and 26 deletions

View File

@ -3,6 +3,7 @@ package gost
import ( import (
"errors" "errors"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@ -27,17 +28,39 @@ func (c *forwardConnector) Connect(conn net.Conn, addr string) (net.Conn, error)
type tcpDirectForwardHandler struct { type tcpDirectForwardHandler struct {
raddr string raddr string
group *NodeGroup
options *HandlerOptions options *HandlerOptions
} }
// TCPDirectForwardHandler creates a server Handler for TCP port forwarding server. // TCPDirectForwardHandler creates a server Handler for TCP port forwarding server.
// The raddr is the remote address that the server will forward to. // The raddr is the remote address that the server will forward to.
// NOTE: as of 2.6, remote address can be a comma-separated address list.
func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler {
if raddr == "" { group := NewNodeGroup()
raddr = "0.0.0.0:0" group.SetSelector(&defaultSelector{},
WithStrategy(&RoundStrategy{}),
WithFilter(&FailFilter{
MaxFails: 1,
FailTimeout: 30 * time.Second,
}),
)
for i, addr := range strings.Split(raddr, ",") {
if addr == "" {
continue
} }
// We treat the remote target server as a node, so we can put them in a group,
// and perform the node selection for load balancing.
group.AddNode(Node{
ID: i + 1,
Addr: addr,
Host: addr,
})
}
h := &tcpDirectForwardHandler{ h := &tcpDirectForwardHandler{
raddr: raddr, raddr: raddr,
group: group,
options: &HandlerOptions{}, options: &HandlerOptions{},
} }
for _, opt := range opts { for _, opt := range opts {
@ -49,29 +72,63 @@ func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler {
func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
defer conn.Close() defer conn.Close()
log.Logf("[tcp] %s - %s", conn.RemoteAddr(), h.raddr) node, err := h.group.Next()
cc, err := h.options.Chain.Dial(h.raddr)
if err != nil { if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), h.raddr, err) log.Logf("[tcp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err)
return
}
log.Logf("[tcp] %s - %s", conn.RemoteAddr(), node.Addr)
cc, err := h.options.Chain.Dial(node.Addr)
if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
node.MarkDead()
return return
} }
defer cc.Close() defer cc.Close()
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), h.raddr) node.ResetDead()
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc) transport(conn, cc)
log.Logf("[tcp] %s >-< %s", conn.RemoteAddr(), h.raddr) log.Logf("[tcp] %s >-< %s", conn.RemoteAddr(), node.Addr)
} }
type udpDirectForwardHandler struct { type udpDirectForwardHandler struct {
raddr string raddr string
group *NodeGroup
options *HandlerOptions options *HandlerOptions
} }
// UDPDirectForwardHandler creates a server Handler for UDP port forwarding server. // UDPDirectForwardHandler creates a server Handler for UDP port forwarding server.
// The raddr is the remote address that the server will forward to. // The raddr is the remote address that the server will forward to.
// NOTE: as of 2.6, remote address can be a comma-separated address list.
func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler {
group := NewNodeGroup()
group.SetSelector(&defaultSelector{},
WithStrategy(&RoundStrategy{}),
WithFilter(&FailFilter{
MaxFails: 1,
FailTimeout: 30 * time.Second,
}),
)
for i, addr := range strings.Split(raddr, ",") {
if addr == "" {
continue
}
// We treat the remote target server as a node, so we can put them in a group,
// and perform the node selection for load balancing.
group.AddNode(Node{
ID: i + 1,
Addr: addr,
Host: addr,
})
}
h := &udpDirectForwardHandler{ h := &udpDirectForwardHandler{
raddr: raddr, raddr: raddr,
group: group,
options: &HandlerOptions{}, options: &HandlerOptions{},
} }
for _, opt := range opts { for _, opt := range opts {
@ -83,45 +140,79 @@ func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler {
func (h *udpDirectForwardHandler) Handle(conn net.Conn) { func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
defer conn.Close() defer conn.Close()
node, err := h.group.Next()
if err != nil {
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err)
return
}
var cc net.Conn var cc net.Conn
if h.options.Chain.IsEmpty() { if h.options.Chain.IsEmpty() {
raddr, err := net.ResolveUDPAddr("udp", h.raddr) raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil { if err != nil {
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return return
} }
cc, err = net.DialUDP("udp", nil, raddr) cc, err = net.DialUDP("udp", nil, raddr)
if err != nil { if err != nil {
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return return
} }
} else { } else {
var err error var err error
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
if err != nil { if err != nil {
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return return
} }
cc = &udpTunnelConn{Conn: cc, raddr: h.raddr} cc = &udpTunnelConn{Conn: cc, raddr: node.Addr}
} }
defer cc.Close() defer cc.Close()
node.ResetDead()
log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), h.raddr) log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc) transport(conn, cc)
log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), h.raddr) log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), node.Addr)
} }
type tcpRemoteForwardHandler struct { type tcpRemoteForwardHandler struct {
raddr string raddr string
group *NodeGroup
options *HandlerOptions options *HandlerOptions
} }
// TCPRemoteForwardHandler creates a server Handler for TCP remote port forwarding server. // TCPRemoteForwardHandler creates a server Handler for TCP remote port forwarding server.
// The raddr is the remote address that the server will forward to. // The raddr is the remote address that the server will forward to.
// NOTE: as of 2.6, remote address can be a comma-separated address list.
func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler {
group := NewNodeGroup()
group.SetSelector(&defaultSelector{},
WithStrategy(&RoundStrategy{}),
WithFilter(&FailFilter{
MaxFails: 1,
FailTimeout: 30 * time.Second,
}),
)
for i, addr := range strings.Split(raddr, ",") {
if addr == "" {
continue
}
// We treat the remote target server as a node, so we can put them in a group,
// and perform the node selection for load balancing.
group.AddNode(Node{
ID: i + 1,
Addr: addr,
Host: addr,
})
}
h := &tcpRemoteForwardHandler{ h := &tcpRemoteForwardHandler{
raddr: raddr, raddr: raddr,
group: group,
options: &HandlerOptions{}, options: &HandlerOptions{},
} }
for _, opt := range opts { for _, opt := range opts {
@ -133,28 +224,60 @@ func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler {
func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
defer conn.Close() defer conn.Close()
cc, err := net.DialTimeout("tcp", h.raddr, DialTimeout) node, err := h.group.Next()
if err != nil { if err != nil {
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), h.raddr, err) log.Logf("[rtcp] %s - %s : %s", conn.LocalAddr(), h.raddr, err)
return
}
cc, err := net.DialTimeout("tcp", node.Addr, DialTimeout)
if err != nil {
node.MarkDead()
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
return return
} }
defer cc.Close() defer cc.Close()
node.ResetDead()
log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), h.raddr) log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
transport(cc, conn) transport(cc, conn)
log.Logf("[rtcp] %s >-< %s", conn.LocalAddr(), h.raddr) log.Logf("[rtcp] %s >-< %s", conn.LocalAddr(), node.Addr)
} }
type udpRemoteForwardHandler struct { type udpRemoteForwardHandler struct {
raddr string raddr string
group *NodeGroup
options *HandlerOptions options *HandlerOptions
} }
// UDPRemoteForwardHandler creates a server Handler for UDP remote port forwarding server. // UDPRemoteForwardHandler creates a server Handler for UDP remote port forwarding server.
// The raddr is the remote address that the server will forward to. // The raddr is the remote address that the server will forward to.
// NOTE: as of 2.6, remote address can be a comma-separated address list.
func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler {
group := NewNodeGroup()
group.SetSelector(&defaultSelector{},
WithStrategy(&RoundStrategy{}),
WithFilter(&FailFilter{
MaxFails: 1,
FailTimeout: 30 * time.Second,
}),
)
for i, addr := range strings.Split(raddr, ",") {
if addr == "" {
continue
}
// We treat the remote target server as a node, so we can put them in a group,
// and perform the node selection for load balancing.
group.AddNode(Node{
ID: i + 1,
Addr: addr,
Host: addr,
})
}
h := &udpRemoteForwardHandler{ h := &udpRemoteForwardHandler{
raddr: raddr, raddr: raddr,
group: group,
options: &HandlerOptions{}, options: &HandlerOptions{},
} }
for _, opt := range opts { for _, opt := range opts {
@ -166,20 +289,30 @@ func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler {
func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
defer conn.Close() defer conn.Close()
raddr, err := net.ResolveUDPAddr("udp", h.raddr) node, err := h.group.Next()
if err != nil {
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err)
return
}
cc, err := net.DialUDP("udp", nil, raddr)
if err != nil { if err != nil {
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err)
return return
} }
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), h.raddr) raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil {
node.MarkDead()
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
cc, err := net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
defer cc.Close()
node.ResetDead()
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc) transport(conn, cc)
log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), h.raddr) log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr)
} }
type udpDirectForwardListener struct { type udpDirectForwardListener struct {

10
node.go
View File

@ -140,6 +140,7 @@ func (node *Node) Clone() Node {
group: node.group, group: node.group,
failCount: atomic.LoadUint32(&node.failCount), failCount: atomic.LoadUint32(&node.failCount),
failTime: atomic.LoadInt64(&node.failTime), failTime: atomic.LoadInt64(&node.failTime),
Bypass: node.Bypass,
} }
} }
@ -187,6 +188,15 @@ func (group *NodeGroup) AddNode(node ...Node) {
group.nodes = append(group.nodes, node...) group.nodes = append(group.nodes, node...)
} }
// SetSelector sets node selector with options for the group.
func (group *NodeGroup) SetSelector(selector NodeSelector, opts ...SelectOption) {
if group == nil {
return
}
group.Selector = selector
group.Options = opts
}
// Nodes returns node list in the group // Nodes returns node list in the group
func (group *NodeGroup) Nodes() []Node { func (group *NodeGroup) Nodes() []Node {
if group == nil { if group == nil {