diff --git a/forward.go b/forward.go index c9dced9..715e4a6 100644 --- a/forward.go +++ b/forward.go @@ -79,25 +79,42 @@ func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) { func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { defer conn.Close() - node, err := h.group.Next() - if err != nil { - log.Logf("[tcp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) - return + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries } - log.Logf("[tcp] %s - %s", conn.RemoteAddr(), node.Addr) - cc, err := h.options.Chain.Dial(node.Addr, - RetryChainOption(h.options.Retries), - TimeoutChainOption(h.options.Timeout), - ) + var cc net.Conn + var node Node + var err error + for i := 0; i < retries; i++ { + node, err = h.group.Next() + if err != nil { + log.Logf("[tcp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) + return + } + + log.Logf("[tcp] %s - %s", conn.RemoteAddr(), node.Addr) + cc, err = h.options.Chain.Dial(node.Addr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) + if err != nil { + log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) + node.MarkDead() + } else { + break + } + } if err != nil { - log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) - node.MarkDead() return } - defer cc.Close() node.ResetDead() + defer cc.Close() log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr) transport(conn, cc) @@ -248,18 +265,35 @@ func (h *tcpRemoteForwardHandler) Init(options ...HandlerOption) { func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { defer conn.Close() - node, err := h.group.Next() + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var cc net.Conn + var node Node + var err error + for i := 0; i < retries; i++ { + node, err = h.group.Next() + if err != nil { + log.Logf("[rtcp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + return + } + cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout) + if err != nil { + log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err) + node.MarkDead() + } else { + break + } + } if err != nil { - log.Logf("[rtcp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) return } - cc, err := net.DialTimeout("tcp", node.Addr, DialTimeout) - if err != nil { - node.MarkDead() - log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err) - return - } defer cc.Close() node.ResetDead()