diff --git a/chain.go b/chain.go index 9116508..49fe439 100644 --- a/chain.go +++ b/chain.go @@ -151,7 +151,7 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e return cc, nil } -func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { +func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { host, port, err := net.SplitHostPort(addr) if err != nil { return addr @@ -216,16 +216,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) { cn, err := node.Client.Dial(node.Addr, node.DialOptions...) if err != nil { - node.group.MarkDeadNode(node.ID) + node.MarkDead() return } cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) if err != nil { - node.group.MarkDeadNode(node.ID) + node.MarkDead() return } - node.group.ResetDeadNode(node.ID) + node.ResetDead() preNode := node for _, node := range nodes[1:] { @@ -233,16 +233,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) { cc, err = preNode.Client.Connect(cn, node.Addr) if err != nil { cn.Close() - node.group.MarkDeadNode(node.ID) + node.MarkDead() return } cc, err = node.Client.Handshake(cc, node.HandshakeOptions...) if err != nil { cn.Close() - node.group.MarkDeadNode(node.ID) + node.MarkDead() return } - node.group.ResetDeadNode(node.ID) + node.ResetDead() cn = cc preNode = node @@ -321,10 +321,9 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { } route.Retries = c.Retries - if Debug { - buf.WriteString(addr) - log.Log("[route]", buf.String()) - } + buf.WriteString(addr) + log.Log("[route]", buf.String()) + return } diff --git a/forward.go b/forward.go index 8519c66..c901c60 100644 --- a/forward.go +++ b/forward.go @@ -107,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { ) if err != nil { log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) - node.group.MarkDeadNode(node.ID) + node.MarkDead() } else { break } @@ -116,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { return } - node.group.ResetDeadNode(node.ID) + node.ResetDead() defer cc.Close() log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr) @@ -191,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { if h.options.Chain.IsEmpty() { raddr, err := net.ResolveUDPAddr("udp", node.Addr) if err != nil { - node.group.MarkDeadNode(node.ID) + node.MarkDead() log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) return } cc, err = net.DialUDP("udp", nil, raddr) if err != nil { - node.group.MarkDeadNode(node.ID) + node.MarkDead() log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) return } @@ -212,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { } defer cc.Close() - node.group.ResetDeadNode(node.ID) + node.ResetDead() log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr) transport(conn, cc) @@ -291,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout) if err != nil { log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err) - node.group.MarkDeadNode(node.ID) + node.MarkDead() } else { break } @@ -301,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { } defer cc.Close() - node.group.ResetDeadNode(node.ID) + node.ResetDead() log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr) transport(cc, conn) @@ -369,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { raddr, err := net.ResolveUDPAddr("udp", node.Addr) if err != nil { - node.group.MarkDeadNode(node.ID) + node.MarkDead() log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) return } cc, err := net.DialUDP("udp", nil, raddr) if err != nil { - node.group.MarkDeadNode(node.ID) + node.MarkDead() log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) return } defer cc.Close() - node.group.ResetDeadNode(node.ID) + node.ResetDead() log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr) transport(conn, cc) diff --git a/gost.go b/gost.go index 432651f..6889460 100644 --- a/gost.go +++ b/gost.go @@ -15,7 +15,7 @@ import ( ) // Version is the gost version. -const Version = "2.6.1" +const Version = "2.7-dev" // Debug is a flag that enables the debug log. var Debug bool diff --git a/node.go b/node.go index d6c7e7e..fbe929e 100644 --- a/node.go +++ b/node.go @@ -1,13 +1,17 @@ package gost import ( + "errors" "fmt" "net/url" "strconv" "strings" "sync" - "sync/atomic" - "time" +) + +var ( + // ErrInvalidNode is an error that implies the node is invalid. + ErrInvalidNode = errors.New("invalid node") ) // Node is a proxy node, mainly used to construct a proxy chain. @@ -23,9 +27,7 @@ type Node struct { DialOptions []DialOption HandshakeOptions []HandshakeOption Client *Client - group *NodeGroup - failCount uint32 - failTime int64 + marker *failMarker Bypass *Bypass } @@ -33,8 +35,9 @@ type Node struct { // The proxy node string pattern is [scheme://][user:pass@host]:port. // Scheme can be divided into two parts by character '+', such as: http+tls. func ParseNode(s string) (node Node, err error) { + s = strings.TrimSpace(s) if s == "" { - return Node{}, nil + return Node{}, ErrInvalidNode } if !strings.Contains(s, "://") { @@ -51,6 +54,7 @@ func ParseNode(s string) (node Node, err error) { Remote: strings.Trim(u.EscapedPath(), "/"), Values: u.Query(), User: u.User, + marker: &failMarker{}, } schemes := strings.Split(u.Scheme, "+") @@ -89,25 +93,29 @@ func ParseNode(s string) (node Node, err error) { return } +// MarkDead marks the node fail status. +func (node *Node) MarkDead() { + if node.marker == nil { + return + } + node.marker.Mark() +} + +// ResetDead resets the node fail status. +func (node *Node) ResetDead() { + if node.marker == nil { + return + } + node.marker.Reset() +} + // Clone clones the node, it will prevent data race. func (node *Node) Clone() Node { - return Node{ - ID: node.ID, - Addr: node.Addr, - Host: node.Host, - Protocol: node.Protocol, - Transport: node.Transport, - Remote: node.Remote, - User: node.User, - Values: node.Values, - DialOptions: node.DialOptions, - HandshakeOptions: node.HandshakeOptions, - Client: node.Client, - group: node.group, - failCount: atomic.LoadUint32(&node.failCount), - failTime: atomic.LoadInt64(&node.failTime), - Bypass: node.Bypass, + nd := *node + if node.marker != nil { + nd.marker = node.marker.Clone() } + return nd } // Get returns node parameter specified by key. @@ -127,8 +135,9 @@ func (node *Node) GetInt(key string) int { return n } -func (node *Node) String() string { - return fmt.Sprintf("%d@%s", node.ID, node.Addr) +func (node Node) String() string { + return fmt.Sprintf("%d@%s+%s://%s", + node.ID, node.Protocol, node.Transport, node.Addr) } // NodeGroup is a group of nodes. @@ -194,18 +203,7 @@ func (group *NodeGroup) Nodes() []Node { return group.nodes } -func (group *NodeGroup) copyNodes() []Node { - group.mux.RLock() - defer group.mux.RUnlock() - - var nodes []Node - for i := range group.nodes { - nodes = append(nodes, group.nodes[i]) - } - return nodes -} - -// GetNode returns a copy of the node specified by index in the group. +// GetNode returns the node specified by index in the group. func (group *NodeGroup) GetNode(i int) Node { group.mux.RLock() defer group.mux.RUnlock() @@ -213,43 +211,7 @@ func (group *NodeGroup) GetNode(i int) Node { if i < 0 || group == nil || len(group.nodes) <= i { return Node{} } - return group.nodes[i].Clone() -} - -// MarkDeadNode marks the node with ID nid status to dead. -func (group *NodeGroup) MarkDeadNode(nid int) { - group.mux.RLock() - defer group.mux.RUnlock() - - if group == nil || nid <= 0 { - return - } - - for i := range group.nodes { - if group.nodes[i].ID == nid { - atomic.AddUint32(&group.nodes[i].failCount, 1) - atomic.StoreInt64(&group.nodes[i].failTime, time.Now().Unix()) - break - } - } -} - -// ResetDeadNode resets the node with ID nid status. -func (group *NodeGroup) ResetDeadNode(nid int) { - group.mux.RLock() - defer group.mux.RUnlock() - - if group == nil || nid <= 0 { - return - } - - for i := range group.nodes { - if group.nodes[i].ID == nid { - atomic.StoreUint32(&group.nodes[i].failCount, 0) - atomic.StoreInt64(&group.nodes[i].failTime, 0) - break - } - } + return group.nodes[i] } // Next selects a node from group. @@ -272,7 +234,6 @@ func (group *NodeGroup) Next() (node Node, err error) { if err != nil { return } - node.group = group return } diff --git a/selector.go b/selector.go index f83a601..f4120c5 100644 --- a/selector.go +++ b/selector.go @@ -71,7 +71,7 @@ type Strategy interface { // RoundStrategy is a strategy for node selector. // The node will be selected by round-robin algorithm. type RoundStrategy struct { - count uint64 + counter uint64 } // Apply applies the round-robin strategy for the nodes. @@ -79,9 +79,9 @@ func (s *RoundStrategy) Apply(nodes []Node) Node { if len(nodes) == 0 { return Node{} } - old := atomic.LoadUint64(&s.count) - atomic.AddUint64(&s.count, 1) - return nodes[int(old%uint64(len(nodes)))] + + n := atomic.AddUint64(&s.counter, 1) - 1 + return nodes[int(n%uint64(len(nodes)))] } func (s *RoundStrategy) String() string { @@ -158,9 +158,11 @@ func (f *FailFilter) Filter(nodes []Node) []Node { } nl := []Node{} for i := range nodes { - if atomic.LoadUint32(&nodes[i].failCount) < uint32(f.MaxFails) || - time.Since(time.Unix(atomic.LoadInt64(&nodes[i].failTime), 0)) >= f.FailTimeout { - nl = append(nl, nodes[i].Clone()) + marker := nodes[i].marker.Clone() + // log.Logf("%s: %d/%d %d/%d", nodes[i], marker.failCount, f.MaxFails, marker.failTime, f.FailTimeout) + if marker.failCount < uint32(f.MaxFails) || + time.Since(time.Unix(marker.failTime, 0)) >= f.FailTimeout { + nl = append(nl, nodes[i]) } } return nl @@ -169,3 +171,37 @@ func (f *FailFilter) Filter(nodes []Node) []Node { func (f *FailFilter) String() string { return "fail" } + +type failMarker struct { + failTime int64 + failCount uint32 + mux sync.RWMutex +} + +func (m *failMarker) Mark() { + m.mux.Lock() + defer m.mux.Unlock() + + m.failTime = time.Now().Unix() + m.failCount++ +} + +func (m *failMarker) Reset() { + m.mux.Lock() + defer m.mux.Unlock() + + m.failTime = 0 + m.failCount = 0 +} + +func (m *failMarker) Clone() *failMarker { + m.mux.RLock() + defer m.mux.RUnlock() + + fc, ft := m.failCount, m.failTime + + return &failMarker{ + failCount: fc, + failTime: ft, + } +}