add failMarker for fail filter

This commit is contained in:
ginuerzh 2018-11-26 11:28:07 +08:00
parent 83906404ae
commit 194b651dd8
5 changed files with 99 additions and 103 deletions

View File

@ -151,7 +151,7 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e
return cc, nil return cc, nil
} }
func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return addr return addr
@ -216,16 +216,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cn, err := node.Client.Dial(node.Addr, node.DialOptions...) cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil { if err != nil {
node.group.MarkDeadNode(node.ID) node.MarkDead()
return return
} }
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
if err != nil { if err != nil {
node.group.MarkDeadNode(node.ID) node.MarkDead()
return return
} }
node.group.ResetDeadNode(node.ID) node.ResetDead()
preNode := node preNode := node
for _, node := range nodes[1:] { for _, node := range nodes[1:] {
@ -233,16 +233,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cc, err = preNode.Client.Connect(cn, node.Addr) cc, err = preNode.Client.Connect(cn, node.Addr)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.group.MarkDeadNode(node.ID) node.MarkDead()
return return
} }
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...) cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.group.MarkDeadNode(node.ID) node.MarkDead()
return return
} }
node.group.ResetDeadNode(node.ID) node.ResetDead()
cn = cc cn = cc
preNode = node preNode = node
@ -321,10 +321,9 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
} }
route.Retries = c.Retries route.Retries = c.Retries
if Debug {
buf.WriteString(addr) buf.WriteString(addr)
log.Log("[route]", buf.String()) log.Log("[route]", buf.String())
}
return return
} }

View File

@ -107,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
) )
if err != nil { if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
node.group.MarkDeadNode(node.ID) node.MarkDead()
} else { } else {
break break
} }
@ -116,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
return return
} }
node.group.ResetDeadNode(node.ID) node.ResetDead()
defer cc.Close() defer cc.Close()
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr) log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
@ -191,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
if h.options.Chain.IsEmpty() { if h.options.Chain.IsEmpty() {
raddr, err := net.ResolveUDPAddr("udp", node.Addr) raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil { if err != nil {
node.group.MarkDeadNode(node.ID) node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return return
} }
cc, err = net.DialUDP("udp", nil, raddr) cc, err = net.DialUDP("udp", nil, raddr)
if err != nil { if err != nil {
node.group.MarkDeadNode(node.ID) node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return return
} }
@ -212,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
} }
defer cc.Close() defer cc.Close()
node.group.ResetDeadNode(node.ID) node.ResetDead()
log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr) log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc) transport(conn, cc)
@ -291,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout) cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout)
if err != nil { if err != nil {
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err) log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
node.group.MarkDeadNode(node.ID) node.MarkDead()
} else { } else {
break break
} }
@ -301,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
} }
defer cc.Close() defer cc.Close()
node.group.ResetDeadNode(node.ID) node.ResetDead()
log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr) log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
transport(cc, conn) transport(cc, conn)
@ -369,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
raddr, err := net.ResolveUDPAddr("udp", node.Addr) raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil { if err != nil {
node.group.MarkDeadNode(node.ID) node.MarkDead()
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return return
} }
cc, err := net.DialUDP("udp", nil, raddr) cc, err := net.DialUDP("udp", nil, raddr)
if err != nil { if err != nil {
node.group.MarkDeadNode(node.ID) node.MarkDead()
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return return
} }
defer cc.Close() defer cc.Close()
node.group.ResetDeadNode(node.ID) node.ResetDead()
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr) log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc) transport(conn, cc)

View File

@ -15,7 +15,7 @@ import (
) )
// Version is the gost version. // Version is the gost version.
const Version = "2.6.1" const Version = "2.7-dev"
// Debug is a flag that enables the debug log. // Debug is a flag that enables the debug log.
var Debug bool var Debug bool

109
node.go
View File

@ -1,13 +1,17 @@
package gost package gost
import ( import (
"errors"
"fmt" "fmt"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" )
"time"
var (
// ErrInvalidNode is an error that implies the node is invalid.
ErrInvalidNode = errors.New("invalid node")
) )
// Node is a proxy node, mainly used to construct a proxy chain. // Node is a proxy node, mainly used to construct a proxy chain.
@ -23,9 +27,7 @@ type Node struct {
DialOptions []DialOption DialOptions []DialOption
HandshakeOptions []HandshakeOption HandshakeOptions []HandshakeOption
Client *Client Client *Client
group *NodeGroup marker *failMarker
failCount uint32
failTime int64
Bypass *Bypass Bypass *Bypass
} }
@ -33,8 +35,9 @@ type Node struct {
// The proxy node string pattern is [scheme://][user:pass@host]:port. // The proxy node string pattern is [scheme://][user:pass@host]:port.
// Scheme can be divided into two parts by character '+', such as: http+tls. // Scheme can be divided into two parts by character '+', such as: http+tls.
func ParseNode(s string) (node Node, err error) { func ParseNode(s string) (node Node, err error) {
s = strings.TrimSpace(s)
if s == "" { if s == "" {
return Node{}, nil return Node{}, ErrInvalidNode
} }
if !strings.Contains(s, "://") { if !strings.Contains(s, "://") {
@ -51,6 +54,7 @@ func ParseNode(s string) (node Node, err error) {
Remote: strings.Trim(u.EscapedPath(), "/"), Remote: strings.Trim(u.EscapedPath(), "/"),
Values: u.Query(), Values: u.Query(),
User: u.User, User: u.User,
marker: &failMarker{},
} }
schemes := strings.Split(u.Scheme, "+") schemes := strings.Split(u.Scheme, "+")
@ -89,25 +93,29 @@ func ParseNode(s string) (node Node, err error) {
return return
} }
// MarkDead marks the node fail status.
func (node *Node) MarkDead() {
if node.marker == nil {
return
}
node.marker.Mark()
}
// ResetDead resets the node fail status.
func (node *Node) ResetDead() {
if node.marker == nil {
return
}
node.marker.Reset()
}
// Clone clones the node, it will prevent data race. // Clone clones the node, it will prevent data race.
func (node *Node) Clone() Node { func (node *Node) Clone() Node {
return Node{ nd := *node
ID: node.ID, if node.marker != nil {
Addr: node.Addr, nd.marker = node.marker.Clone()
Host: node.Host,
Protocol: node.Protocol,
Transport: node.Transport,
Remote: node.Remote,
User: node.User,
Values: node.Values,
DialOptions: node.DialOptions,
HandshakeOptions: node.HandshakeOptions,
Client: node.Client,
group: node.group,
failCount: atomic.LoadUint32(&node.failCount),
failTime: atomic.LoadInt64(&node.failTime),
Bypass: node.Bypass,
} }
return nd
} }
// Get returns node parameter specified by key. // Get returns node parameter specified by key.
@ -127,8 +135,9 @@ func (node *Node) GetInt(key string) int {
return n return n
} }
func (node *Node) String() string { func (node Node) String() string {
return fmt.Sprintf("%d@%s", node.ID, node.Addr) return fmt.Sprintf("%d@%s+%s://%s",
node.ID, node.Protocol, node.Transport, node.Addr)
} }
// NodeGroup is a group of nodes. // NodeGroup is a group of nodes.
@ -194,18 +203,7 @@ func (group *NodeGroup) Nodes() []Node {
return group.nodes return group.nodes
} }
func (group *NodeGroup) copyNodes() []Node { // GetNode returns the node specified by index in the group.
group.mux.RLock()
defer group.mux.RUnlock()
var nodes []Node
for i := range group.nodes {
nodes = append(nodes, group.nodes[i])
}
return nodes
}
// GetNode returns a copy of the node specified by index in the group.
func (group *NodeGroup) GetNode(i int) Node { func (group *NodeGroup) GetNode(i int) Node {
group.mux.RLock() group.mux.RLock()
defer group.mux.RUnlock() defer group.mux.RUnlock()
@ -213,43 +211,7 @@ func (group *NodeGroup) GetNode(i int) Node {
if i < 0 || group == nil || len(group.nodes) <= i { if i < 0 || group == nil || len(group.nodes) <= i {
return Node{} return Node{}
} }
return group.nodes[i].Clone() return group.nodes[i]
}
// MarkDeadNode marks the node with ID nid status to dead.
func (group *NodeGroup) MarkDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.AddUint32(&group.nodes[i].failCount, 1)
atomic.StoreInt64(&group.nodes[i].failTime, time.Now().Unix())
break
}
}
}
// ResetDeadNode resets the node with ID nid status.
func (group *NodeGroup) ResetDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.StoreUint32(&group.nodes[i].failCount, 0)
atomic.StoreInt64(&group.nodes[i].failTime, 0)
break
}
}
} }
// Next selects a node from group. // Next selects a node from group.
@ -272,7 +234,6 @@ func (group *NodeGroup) Next() (node Node, err error) {
if err != nil { if err != nil {
return return
} }
node.group = group
return return
} }

View File

@ -71,7 +71,7 @@ type Strategy interface {
// RoundStrategy is a strategy for node selector. // RoundStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm. // The node will be selected by round-robin algorithm.
type RoundStrategy struct { type RoundStrategy struct {
count uint64 counter uint64
} }
// Apply applies the round-robin strategy for the nodes. // Apply applies the round-robin strategy for the nodes.
@ -79,9 +79,9 @@ func (s *RoundStrategy) Apply(nodes []Node) Node {
if len(nodes) == 0 { if len(nodes) == 0 {
return Node{} return Node{}
} }
old := atomic.LoadUint64(&s.count)
atomic.AddUint64(&s.count, 1) n := atomic.AddUint64(&s.counter, 1) - 1
return nodes[int(old%uint64(len(nodes)))] return nodes[int(n%uint64(len(nodes)))]
} }
func (s *RoundStrategy) String() string { func (s *RoundStrategy) String() string {
@ -158,9 +158,11 @@ func (f *FailFilter) Filter(nodes []Node) []Node {
} }
nl := []Node{} nl := []Node{}
for i := range nodes { for i := range nodes {
if atomic.LoadUint32(&nodes[i].failCount) < uint32(f.MaxFails) || marker := nodes[i].marker.Clone()
time.Since(time.Unix(atomic.LoadInt64(&nodes[i].failTime), 0)) >= f.FailTimeout { // log.Logf("%s: %d/%d %d/%d", nodes[i], marker.failCount, f.MaxFails, marker.failTime, f.FailTimeout)
nl = append(nl, nodes[i].Clone()) if marker.failCount < uint32(f.MaxFails) ||
time.Since(time.Unix(marker.failTime, 0)) >= f.FailTimeout {
nl = append(nl, nodes[i])
} }
} }
return nl return nl
@ -169,3 +171,37 @@ func (f *FailFilter) Filter(nodes []Node) []Node {
func (f *FailFilter) String() string { func (f *FailFilter) String() string {
return "fail" return "fail"
} }
type failMarker struct {
failTime int64
failCount uint32
mux sync.RWMutex
}
func (m *failMarker) Mark() {
m.mux.Lock()
defer m.mux.Unlock()
m.failTime = time.Now().Unix()
m.failCount++
}
func (m *failMarker) Reset() {
m.mux.Lock()
defer m.mux.Unlock()
m.failTime = 0
m.failCount = 0
}
func (m *failMarker) Clone() *failMarker {
m.mux.RLock()
defer m.mux.RUnlock()
fc, ft := m.failCount, m.failTime
return &failMarker{
failCount: fc,
failTime: ft,
}
}