add failMarker for fail filter

This commit is contained in:
ginuerzh 2018-11-26 11:28:07 +08:00
parent 83906404ae
commit 194b651dd8
5 changed files with 99 additions and 103 deletions

View File

@ -151,7 +151,7 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e
return cc, nil
}
func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return addr
@ -216,16 +216,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil {
node.group.MarkDeadNode(node.ID)
node.MarkDead()
return
}
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
if err != nil {
node.group.MarkDeadNode(node.ID)
node.MarkDead()
return
}
node.group.ResetDeadNode(node.ID)
node.ResetDead()
preNode := node
for _, node := range nodes[1:] {
@ -233,16 +233,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cc, err = preNode.Client.Connect(cn, node.Addr)
if err != nil {
cn.Close()
node.group.MarkDeadNode(node.ID)
node.MarkDead()
return
}
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
if err != nil {
cn.Close()
node.group.MarkDeadNode(node.ID)
node.MarkDead()
return
}
node.group.ResetDeadNode(node.ID)
node.ResetDead()
cn = cc
preNode = node
@ -321,10 +321,9 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
}
route.Retries = c.Retries
if Debug {
buf.WriteString(addr)
log.Log("[route]", buf.String())
}
return
}

View File

@ -107,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
)
if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
node.group.MarkDeadNode(node.ID)
node.MarkDead()
} else {
break
}
@ -116,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
return
}
node.group.ResetDeadNode(node.ID)
node.ResetDead()
defer cc.Close()
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
@ -191,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
if h.options.Chain.IsEmpty() {
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil {
node.group.MarkDeadNode(node.ID)
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
cc, err = net.DialUDP("udp", nil, raddr)
if err != nil {
node.group.MarkDeadNode(node.ID)
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
@ -212,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
}
defer cc.Close()
node.group.ResetDeadNode(node.ID)
node.ResetDead()
log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc)
@ -291,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout)
if err != nil {
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
node.group.MarkDeadNode(node.ID)
node.MarkDead()
} else {
break
}
@ -301,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
}
defer cc.Close()
node.group.ResetDeadNode(node.ID)
node.ResetDead()
log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
transport(cc, conn)
@ -369,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil {
node.group.MarkDeadNode(node.ID)
node.MarkDead()
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
cc, err := net.DialUDP("udp", nil, raddr)
if err != nil {
node.group.MarkDeadNode(node.ID)
node.MarkDead()
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
defer cc.Close()
node.group.ResetDeadNode(node.ID)
node.ResetDead()
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc)

View File

@ -15,7 +15,7 @@ import (
)
// Version is the gost version.
const Version = "2.6.1"
const Version = "2.7-dev"
// Debug is a flag that enables the debug log.
var Debug bool

109
node.go
View File

@ -1,13 +1,17 @@
package gost
import (
"errors"
"fmt"
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
var (
// ErrInvalidNode is an error that implies the node is invalid.
ErrInvalidNode = errors.New("invalid node")
)
// Node is a proxy node, mainly used to construct a proxy chain.
@ -23,9 +27,7 @@ type Node struct {
DialOptions []DialOption
HandshakeOptions []HandshakeOption
Client *Client
group *NodeGroup
failCount uint32
failTime int64
marker *failMarker
Bypass *Bypass
}
@ -33,8 +35,9 @@ type Node struct {
// The proxy node string pattern is [scheme://][user:pass@host]:port.
// Scheme can be divided into two parts by character '+', such as: http+tls.
func ParseNode(s string) (node Node, err error) {
s = strings.TrimSpace(s)
if s == "" {
return Node{}, nil
return Node{}, ErrInvalidNode
}
if !strings.Contains(s, "://") {
@ -51,6 +54,7 @@ func ParseNode(s string) (node Node, err error) {
Remote: strings.Trim(u.EscapedPath(), "/"),
Values: u.Query(),
User: u.User,
marker: &failMarker{},
}
schemes := strings.Split(u.Scheme, "+")
@ -89,25 +93,29 @@ func ParseNode(s string) (node Node, err error) {
return
}
// MarkDead marks the node fail status.
func (node *Node) MarkDead() {
if node.marker == nil {
return
}
node.marker.Mark()
}
// ResetDead resets the node fail status.
func (node *Node) ResetDead() {
if node.marker == nil {
return
}
node.marker.Reset()
}
// Clone clones the node, it will prevent data race.
func (node *Node) Clone() Node {
return Node{
ID: node.ID,
Addr: node.Addr,
Host: node.Host,
Protocol: node.Protocol,
Transport: node.Transport,
Remote: node.Remote,
User: node.User,
Values: node.Values,
DialOptions: node.DialOptions,
HandshakeOptions: node.HandshakeOptions,
Client: node.Client,
group: node.group,
failCount: atomic.LoadUint32(&node.failCount),
failTime: atomic.LoadInt64(&node.failTime),
Bypass: node.Bypass,
nd := *node
if node.marker != nil {
nd.marker = node.marker.Clone()
}
return nd
}
// Get returns node parameter specified by key.
@ -127,8 +135,9 @@ func (node *Node) GetInt(key string) int {
return n
}
func (node *Node) String() string {
return fmt.Sprintf("%d@%s", node.ID, node.Addr)
func (node Node) String() string {
return fmt.Sprintf("%d@%s+%s://%s",
node.ID, node.Protocol, node.Transport, node.Addr)
}
// NodeGroup is a group of nodes.
@ -194,18 +203,7 @@ func (group *NodeGroup) Nodes() []Node {
return group.nodes
}
func (group *NodeGroup) copyNodes() []Node {
group.mux.RLock()
defer group.mux.RUnlock()
var nodes []Node
for i := range group.nodes {
nodes = append(nodes, group.nodes[i])
}
return nodes
}
// GetNode returns a copy of the node specified by index in the group.
// GetNode returns the node specified by index in the group.
func (group *NodeGroup) GetNode(i int) Node {
group.mux.RLock()
defer group.mux.RUnlock()
@ -213,43 +211,7 @@ func (group *NodeGroup) GetNode(i int) Node {
if i < 0 || group == nil || len(group.nodes) <= i {
return Node{}
}
return group.nodes[i].Clone()
}
// MarkDeadNode marks the node with ID nid status to dead.
func (group *NodeGroup) MarkDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.AddUint32(&group.nodes[i].failCount, 1)
atomic.StoreInt64(&group.nodes[i].failTime, time.Now().Unix())
break
}
}
}
// ResetDeadNode resets the node with ID nid status.
func (group *NodeGroup) ResetDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.StoreUint32(&group.nodes[i].failCount, 0)
atomic.StoreInt64(&group.nodes[i].failTime, 0)
break
}
}
return group.nodes[i]
}
// Next selects a node from group.
@ -272,7 +234,6 @@ func (group *NodeGroup) Next() (node Node, err error) {
if err != nil {
return
}
node.group = group
return
}

View File

@ -71,7 +71,7 @@ type Strategy interface {
// RoundStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm.
type RoundStrategy struct {
count uint64
counter uint64
}
// Apply applies the round-robin strategy for the nodes.
@ -79,9 +79,9 @@ func (s *RoundStrategy) Apply(nodes []Node) Node {
if len(nodes) == 0 {
return Node{}
}
old := atomic.LoadUint64(&s.count)
atomic.AddUint64(&s.count, 1)
return nodes[int(old%uint64(len(nodes)))]
n := atomic.AddUint64(&s.counter, 1) - 1
return nodes[int(n%uint64(len(nodes)))]
}
func (s *RoundStrategy) String() string {
@ -158,9 +158,11 @@ func (f *FailFilter) Filter(nodes []Node) []Node {
}
nl := []Node{}
for i := range nodes {
if atomic.LoadUint32(&nodes[i].failCount) < uint32(f.MaxFails) ||
time.Since(time.Unix(atomic.LoadInt64(&nodes[i].failTime), 0)) >= f.FailTimeout {
nl = append(nl, nodes[i].Clone())
marker := nodes[i].marker.Clone()
// log.Logf("%s: %d/%d %d/%d", nodes[i], marker.failCount, f.MaxFails, marker.failTime, f.FailTimeout)
if marker.failCount < uint32(f.MaxFails) ||
time.Since(time.Unix(marker.failTime, 0)) >= f.FailTimeout {
nl = append(nl, nodes[i])
}
}
return nl
@ -169,3 +171,37 @@ func (f *FailFilter) Filter(nodes []Node) []Node {
func (f *FailFilter) String() string {
return "fail"
}
type failMarker struct {
failTime int64
failCount uint32
mux sync.RWMutex
}
func (m *failMarker) Mark() {
m.mux.Lock()
defer m.mux.Unlock()
m.failTime = time.Now().Unix()
m.failCount++
}
func (m *failMarker) Reset() {
m.mux.Lock()
defer m.mux.Unlock()
m.failTime = 0
m.failCount = 0
}
func (m *failMarker) Clone() *failMarker {
m.mux.RLock()
defer m.mux.RUnlock()
fc, ft := m.failCount, m.failTime
return &failMarker{
failCount: fc,
failTime: ft,
}
}