diff --git a/chain.go b/chain.go index 3cd0172..4c6ea2a 100644 --- a/chain.go +++ b/chain.go @@ -12,28 +12,51 @@ var ( // Chain is a proxy chain that holds a list of proxy nodes. type Chain struct { - nodes []Node + nodeGroups []*NodeGroup } -// NewChain creates a proxy chain with proxy nodes nodes. +// NewChain creates a proxy chain with a list of proxy nodes. func NewChain(nodes ...Node) *Chain { - return &Chain{ - nodes: nodes, + chain := &Chain{} + for _, node := range nodes { + chain.nodeGroups = append(chain.nodeGroups, NewNodeGroup(node)) } + return chain } // Nodes returns the proxy nodes that the chain holds. -func (c *Chain) Nodes() []Node { - return c.nodes +// If a node is a node group, the first node in the group will be returned. +func (c *Chain) Nodes() (nodes []Node) { + for _, group := range c.nodeGroups { + if ns := group.Nodes(); len(ns) > 0 { + nodes = append(nodes, ns[0]) + } + } + return +} + +// NodeGroups returns the list of node group. +func (c *Chain) NodeGroups() []*NodeGroup { + return c.nodeGroups } // LastNode returns the last node of the node list. // If the chain is empty, an empty node is returns. +// If the last node is a node group, the first node in the group will be returned. func (c *Chain) LastNode() Node { if c.IsEmpty() { return Node{} } - return c.nodes[len(c.nodes)-1] + last := c.nodeGroups[len(c.nodeGroups)-1] + return last.nodes[0] +} + +// LastNodeGroup returns the last group of the group list. +func (c *Chain) LastNodeGroup() *NodeGroup { + if c.IsEmpty() { + return nil + } + return c.nodeGroups[len(c.nodeGroups)-1] } // AddNode appends the node(s) to the chain. @@ -41,13 +64,25 @@ func (c *Chain) AddNode(nodes ...Node) { if c == nil { return } - c.nodes = append(c.nodes, nodes...) + for _, node := range nodes { + c.nodeGroups = append(c.nodeGroups, NewNodeGroup(node)) + } +} + +// AddNodeGroup appends the group(s) to the chain. +func (c *Chain) AddNodeGroup(groups ...*NodeGroup) { + if c == nil { + return + } + for _, group := range groups { + c.nodeGroups = append(c.nodeGroups, group) + } } // IsEmpty checks if the chain is empty. -// An empty chain means that there is no proxy node in the chain. +// An empty chain means that there is no proxy node or node group in the chain. func (c *Chain) IsEmpty() bool { - return c == nil || len(c.nodes) == 0 + return c == nil || len(c.nodeGroups) == 0 } // Dial connects to the target address addr through the chain. @@ -57,12 +92,12 @@ func (c *Chain) Dial(addr string) (net.Conn, error) { return net.Dial("tcp", addr) } - conn, err := c.Conn() + conn, nodes, err := c.getConn() if err != nil { return nil, err } - cc, err := c.LastNode().Client.Connect(conn, addr) + cc, err := nodes[len(nodes)-1].Client.Connect(conn, addr) if err != nil { conn.Close() return nil, err @@ -72,39 +107,68 @@ func (c *Chain) Dial(addr string) (net.Conn, error) { // Conn obtains a handshaked connection to the last node of the chain. // If the chain is empty, it returns an ErrEmptyChain error. -func (c *Chain) Conn() (net.Conn, error) { +func (c *Chain) Conn() (conn net.Conn, err error) { + conn, _, err = c.getConn() + return +} + +func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) { if c.IsEmpty() { - return nil, ErrEmptyChain + err = ErrEmptyChain + return } - - nodes := c.nodes - conn, err := nodes[0].Client.Dial(nodes[0].Addr, nodes[0].DialOptions...) + groups := c.nodeGroups + selector := groups[0].Selector + if selector == nil { + selector = &defaultSelector{} + } + node, err := selector.Select(groups[0].Nodes(), groups[0].Options...) if err != nil { - return nil, err + return } + nodes = append(nodes, node) - conn, err = nodes[0].Client.Handshake(conn, nodes[0].HandshakeOptions...) + cn, err := node.Client.Dial(node.Addr, node.DialOptions...) if err != nil { - return nil, err + return } - for i, node := range nodes { - if i == len(nodes)-1 { + cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) + if err != nil { + return + } + + preNode := node + for i := range groups { + if i == len(groups)-1 { break } + selector = groups[i+1].Selector + if selector == nil { + selector = &defaultSelector{} + } + node, err = selector.Select(groups[i+1].Nodes(), groups[i+1].Options...) + if err != nil { + cn.Close() + return + } + nodes = append(nodes, node) - next := nodes[i+1] - cc, err := node.Client.Connect(conn, next.Addr) + var cc net.Conn + cc, err = preNode.Client.Connect(cn, node.Addr) if err != nil { - conn.Close() - return nil, err + cn.Close() + return } - cc, err = next.Client.Handshake(cc, next.HandshakeOptions...) + cc, err = node.Client.Handshake(cc, node.HandshakeOptions...) if err != nil { - conn.Close() - return nil, err + cn.Close() + return } - conn = cc + cn = cc + preNode = node } - return conn, nil + + conn = cn + return } diff --git a/node.go b/node.go index 629bec7..88b0245 100644 --- a/node.go +++ b/node.go @@ -13,9 +13,9 @@ type Node struct { Remote string // remote address, used by tcp/udp port forwarding User *url.Userinfo Values url.Values - Client *Client DialOptions []DialOption HandshakeOptions []HandshakeOption + Client *Client } // ParseNode parses the node info. @@ -75,3 +75,33 @@ func ParseNode(s string) (node Node, err error) { return } + +// NodeGroup is a group of nodes. +type NodeGroup struct { + nodes []Node + Options []SelectOption + Selector Selector +} + +// NewNodeGroup creates a node group +func NewNodeGroup(nodes ...Node) *NodeGroup { + return &NodeGroup{ + nodes: nodes, + } +} + +// AddNode adds node or node list into group +func (ng *NodeGroup) AddNode(node ...Node) { + if ng == nil { + return + } + ng.nodes = append(ng.nodes, node...) +} + +// Nodes returns node list in the group +func (ng *NodeGroup) Nodes() []Node { + if ng == nil { + return nil + } + return ng.nodes +} diff --git a/selector.go b/selector.go new file mode 100644 index 0000000..5a48982 --- /dev/null +++ b/selector.go @@ -0,0 +1,73 @@ +package gost + +import "errors" + +var ( + // ErrNoneAvailable indicates there is no node available + ErrNoneAvailable = errors.New("none available") +) + +// SelectOption used when making a select call +type SelectOption func(*SelectOptions) + +// Selector as a mechanism to pick nodes and mark their status. +type Selector interface { + Select(nodes []Node, opts ...SelectOption) (Node, error) + // Mark(node Node) + String() string +} + +type defaultSelector struct { +} + +func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, error) { + sopts := SelectOptions{ + Strategy: defaultStrategy, + } + for _, opt := range opts { + opt(&sopts) + } + + for _, filter := range sopts.Filters { + nodes = filter(nodes) + } + if len(nodes) == 0 { + return Node{}, ErrNoneAvailable + } + return sopts.Strategy(nodes), nil +} + +func (s *defaultSelector) String() string { + return "default" +} + +// Filter is used to filter a node during the selection process +type Filter func([]Node) []Node + +// Strategy is a selection strategy e.g random, round robin +type Strategy func([]Node) Node + +func defaultStrategy(nodes []Node) Node { + return nodes[0] +} + +// SelectOptions is the options for node selection +type SelectOptions struct { + Filters []Filter + Strategy Strategy +} + +// WithFilter adds a filter function to the list of filters +// used during the Select call. +func WithFilter(f ...Filter) SelectOption { + return func(o *SelectOptions) { + o.Filters = append(o.Filters, f...) + } +} + +// WithStrategy sets the selector strategy +func WithStrategy(s Strategy) SelectOption { + return func(o *SelectOptions) { + o.Strategy = s + } +}