diff --git a/bypass.go b/bypass.go index 792c33d..5b98cfd 100644 --- a/bypass.go +++ b/bypass.go @@ -124,7 +124,7 @@ type Bypass struct { matchers []Matcher reversed bool period time.Duration // the period for live reloading - mux sync.Mutex + mux sync.RWMutex } // NewBypass creates and initializes a new Bypass using matchers as its match rules. @@ -160,8 +160,8 @@ func (bp *Bypass) Contains(addr string) bool { } } - bp.mux.Lock() - defer bp.mux.Unlock() + bp.mux.RLock() + defer bp.mux.RUnlock() var matched bool for _, matcher := range bp.matchers { @@ -179,22 +179,33 @@ func (bp *Bypass) Contains(addr string) bool { // AddMatchers appends matchers to the bypass matcher list. func (bp *Bypass) AddMatchers(matchers ...Matcher) { + bp.mux.Lock() + defer bp.mux.Unlock() + bp.matchers = append(bp.matchers, matchers...) } // Matchers return the bypass matcher list. func (bp *Bypass) Matchers() []Matcher { + bp.mux.RLock() + defer bp.mux.RUnlock() + return bp.matchers } // Reversed reports whether the rules of the bypass are reversed. func (bp *Bypass) Reversed() bool { + bp.mux.RLock() + defer bp.mux.RUnlock() + return bp.reversed } // Reload parses config from r, then live reloads the bypass. func (bp *Bypass) Reload(r io.Reader) error { var matchers []Matcher + var period time.Duration + var reversed bool scanner := bufio.NewScanner(r) for scanner.Scan() { @@ -217,7 +228,7 @@ func (bp *Bypass) Reload(r io.Reader) error { } } if len(ss) == 2 { - bp.period, _ = time.ParseDuration(ss[1]) + period, _ = time.ParseDuration(ss[1]) continue } } @@ -231,7 +242,7 @@ func (bp *Bypass) Reload(r io.Reader) error { } } if len(ss) == 2 { - bp.reversed, _ = strconv.ParseBool(ss[1]) + reversed, _ = strconv.ParseBool(ss[1]) continue } } @@ -247,19 +258,28 @@ func (bp *Bypass) Reload(r io.Reader) error { defer bp.mux.Unlock() bp.matchers = matchers + bp.period = period + bp.reversed = reversed return nil } // Period returns the reload period func (bp *Bypass) Period() time.Duration { + bp.mux.RLock() + defer bp.mux.RUnlock() + return bp.period } func (bp *Bypass) String() string { + bp.mux.RLock() + defer bp.mux.RUnlock() + b := &bytes.Buffer{} - fmt.Fprintf(b, "reversed: %v\n", bp.Reversed()) - for _, m := range bp.Matchers() { + fmt.Fprintf(b, "reversed: %v\n", bp.reversed) + fmt.Fprintf(b, "reload: %v\n", bp.period) + for _, m := range bp.matchers { b.WriteString(m.String()) b.WriteByte('\n') } diff --git a/chain.go b/chain.go index cf77cf8..9116508 100644 --- a/chain.go +++ b/chain.go @@ -38,7 +38,7 @@ func newRoute(nodes ...Node) *Chain { } // Nodes returns the proxy nodes that the chain holds. -// If a node is a node group, the first node in the group will be returned. +// The first node in each group will be returned. func (c *Chain) Nodes() (nodes []Node) { for _, group := range c.nodeGroups { if ns := group.Nodes(); len(ns) > 0 { @@ -61,7 +61,7 @@ func (c *Chain) LastNode() Node { return Node{} } group := c.nodeGroups[len(c.nodeGroups)-1] - return group.nodes[0].Clone() + return group.GetNode(0) } // LastNodeGroup returns the last group of the group list. @@ -173,7 +173,6 @@ func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { } // Conn obtains a handshaked connection to the last node of the chain. -// If the chain is empty, it returns an ErrEmptyChain error. func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { options := &ChainOptions{} for _, opt := range opts { @@ -206,6 +205,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { } // getConn obtains a connection to the last node of the chain. +// It does not handshake with the last node. func (c *Chain) getConn() (conn net.Conn, err error) { if c.IsEmpty() { err = ErrEmptyChain @@ -216,16 +216,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) { cn, err := node.Client.Dial(node.Addr, node.DialOptions...) if err != nil { - node.MarkDead() + node.group.MarkDeadNode(node.ID) return } cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) if err != nil { - node.MarkDead() + node.group.MarkDeadNode(node.ID) return } - node.ResetDead() + node.group.ResetDeadNode(node.ID) preNode := node for _, node := range nodes[1:] { @@ -233,16 +233,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) { cc, err = preNode.Client.Connect(cn, node.Addr) if err != nil { cn.Close() - node.MarkDead() + node.group.MarkDeadNode(node.ID) return } cc, err = node.Client.Handshake(cc, node.HandshakeOptions...) if err != nil { cn.Close() - node.MarkDead() + node.group.MarkDeadNode(node.ID) return } - node.ResetDead() + node.group.ResetDeadNode(node.ID) cn = cc preNode = node diff --git a/cmd/gost/peer.go b/cmd/gost/peer.go index 7952ba2..a0e43a7 100644 --- a/cmd/gost/peer.go +++ b/cmd/gost/peer.go @@ -60,7 +60,8 @@ func (cfg *peerConfig) Reload(r io.Reader) error { strategy = s } } - group.Options = append([]gost.SelectOption{}, + group.SetSelector( + nil, gost.WithFilter(&gost.FailFilter{ MaxFails: cfg.MaxFails, FailTimeout: time.Duration(cfg.FailTimeout) * time.Second, diff --git a/forward.go b/forward.go index c901c60..8519c66 100644 --- a/forward.go +++ b/forward.go @@ -107,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { ) if err != nil { log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) - node.MarkDead() + node.group.MarkDeadNode(node.ID) } else { break } @@ -116,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { return } - node.ResetDead() + node.group.ResetDeadNode(node.ID) defer cc.Close() log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr) @@ -191,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { if h.options.Chain.IsEmpty() { raddr, err := net.ResolveUDPAddr("udp", node.Addr) if err != nil { - node.MarkDead() + node.group.MarkDeadNode(node.ID) log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) return } cc, err = net.DialUDP("udp", nil, raddr) if err != nil { - node.MarkDead() + node.group.MarkDeadNode(node.ID) log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) return } @@ -212,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { } defer cc.Close() - node.ResetDead() + node.group.ResetDeadNode(node.ID) log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr) transport(conn, cc) @@ -291,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout) if err != nil { log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err) - node.MarkDead() + node.group.MarkDeadNode(node.ID) } else { break } @@ -301,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { } defer cc.Close() - node.ResetDead() + node.group.ResetDeadNode(node.ID) log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr) transport(cc, conn) @@ -369,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { raddr, err := net.ResolveUDPAddr("udp", node.Addr) if err != nil { - node.MarkDead() + node.group.MarkDeadNode(node.ID) log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) return } cc, err := net.DialUDP("udp", nil, raddr) if err != nil { - node.MarkDead() + node.group.MarkDeadNode(node.ID) log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) return } defer cc.Close() - node.ResetDead() + node.group.ResetDeadNode(node.ID) log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr) transport(conn, cc) diff --git a/hosts.go b/hosts.go index de2866f..b7744d5 100644 --- a/hosts.go +++ b/hosts.go @@ -5,6 +5,7 @@ import ( "io" "net" "strings" + "sync" "time" "github.com/go-log/log" @@ -25,6 +26,7 @@ type Host struct { type Hosts struct { hosts []Host period time.Duration + mux sync.RWMutex } // NewHosts creates a Hosts with optional list of host @@ -36,6 +38,9 @@ func NewHosts(hosts ...Host) *Hosts { // AddHost adds host(s) to the host table. func (h *Hosts) AddHost(host ...Host) { + h.mux.Lock() + defer h.mux.Unlock() + h.hosts = append(h.hosts, host...) } @@ -44,6 +49,10 @@ func (h *Hosts) Lookup(host string) (ip net.IP) { if h == nil { return } + + h.mux.RLock() + defer h.mux.RUnlock() + for _, h := range h.hosts { if h.Hostname == host { ip = h.IP @@ -64,6 +73,7 @@ func (h *Hosts) Lookup(host string) (ip net.IP) { // Reload parses config from r, then live reloads the hosts. func (h *Hosts) Reload(r io.Reader) error { + var period time.Duration var hosts []Host scanner := bufio.NewScanner(r) @@ -89,7 +99,7 @@ func (h *Hosts) Reload(r io.Reader) error { // reload option if strings.ToLower(ss[0]) == "reload" { - h.period, _ = time.ParseDuration(ss[1]) + period, _ = time.ParseDuration(ss[1]) continue } @@ -110,11 +120,18 @@ func (h *Hosts) Reload(r io.Reader) error { return err } + h.mux.Lock() + h.period = period h.hosts = hosts + h.mux.Unlock() + return nil } // Period returns the reload period func (h *Hosts) Period() time.Duration { + h.mux.RLock() + defer h.mux.RUnlock() + return h.period } diff --git a/http2.go b/http2.go index cdd7109..922b02a 100644 --- a/http2.go +++ b/http2.go @@ -321,7 +321,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) if Debug && (u != "" || p != "") { - log.Logf("[http] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p) + log.Logf("[http2] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p) } if !authenticate(u, p, h.options.Users...) { // probing resistance is enabled diff --git a/node.go b/node.go index 247dcff..d6c7e7e 100644 --- a/node.go +++ b/node.go @@ -5,6 +5,7 @@ import ( "net/url" "strconv" "strings" + "sync" "sync/atomic" "time" ) @@ -88,41 +89,6 @@ func ParseNode(s string) (node Node, err error) { return } -// MarkDead marks the node fail status. -func (node *Node) MarkDead() { - atomic.AddUint32(&node.failCount, 1) - atomic.StoreInt64(&node.failTime, time.Now().Unix()) - - if node.group == nil { - return - } - for i := range node.group.nodes { - if node.group.nodes[i].ID == node.ID { - atomic.AddUint32(&node.group.nodes[i].failCount, 1) - atomic.StoreInt64(&node.group.nodes[i].failTime, time.Now().Unix()) - break - } - } -} - -// ResetDead resets the node fail status. -func (node *Node) ResetDead() { - atomic.StoreUint32(&node.failCount, 0) - atomic.StoreInt64(&node.failTime, 0) - - if node.group == nil { - return - } - - for i := range node.group.nodes { - if node.group.nodes[i].ID == node.ID { - atomic.StoreUint32(&node.group.nodes[i].failCount, 0) - atomic.StoreInt64(&node.group.nodes[i].failTime, 0) - break - } - } -} - // Clone clones the node, it will prevent data race. func (node *Node) Clone() Node { return Node{ @@ -167,10 +133,11 @@ func (node *Node) String() string { // NodeGroup is a group of nodes. type NodeGroup struct { - ID int - nodes []Node - Options []SelectOption - Selector NodeSelector + ID int + nodes []Node + selectorOptions []SelectOption + selector NodeSelector + mux sync.RWMutex } // NewNodeGroup creates a node group @@ -185,11 +152,21 @@ func (group *NodeGroup) AddNode(node ...Node) { if group == nil { return } + group.mux.Lock() + defer group.mux.Unlock() + group.nodes = append(group.nodes, node...) } // SetNodes replaces the group nodes to the specified nodes. func (group *NodeGroup) SetNodes(nodes ...Node) { + if group == nil { + return + } + + group.mux.Lock() + defer group.mux.Unlock() + group.nodes = nodes } @@ -198,27 +175,100 @@ func (group *NodeGroup) SetSelector(selector NodeSelector, opts ...SelectOption) if group == nil { return } - group.Selector = selector - group.Options = opts + group.mux.Lock() + defer group.mux.Unlock() + + group.selector = selector + group.selectorOptions = opts } -// Nodes returns node list in the group +// Nodes returns the node list in the group func (group *NodeGroup) Nodes() []Node { if group == nil { return nil } + + group.mux.RLock() + defer group.mux.RUnlock() + return group.nodes } -// Next selects the next node from group. +func (group *NodeGroup) copyNodes() []Node { + group.mux.RLock() + defer group.mux.RUnlock() + + var nodes []Node + for i := range group.nodes { + nodes = append(nodes, group.nodes[i]) + } + return nodes +} + +// GetNode returns a copy of the node specified by index in the group. +func (group *NodeGroup) GetNode(i int) Node { + group.mux.RLock() + defer group.mux.RUnlock() + + if i < 0 || group == nil || len(group.nodes) <= i { + return Node{} + } + return group.nodes[i].Clone() +} + +// MarkDeadNode marks the node with ID nid status to dead. +func (group *NodeGroup) MarkDeadNode(nid int) { + group.mux.RLock() + defer group.mux.RUnlock() + + if group == nil || nid <= 0 { + return + } + + for i := range group.nodes { + if group.nodes[i].ID == nid { + atomic.AddUint32(&group.nodes[i].failCount, 1) + atomic.StoreInt64(&group.nodes[i].failTime, time.Now().Unix()) + break + } + } +} + +// ResetDeadNode resets the node with ID nid status. +func (group *NodeGroup) ResetDeadNode(nid int) { + group.mux.RLock() + defer group.mux.RUnlock() + + if group == nil || nid <= 0 { + return + } + + for i := range group.nodes { + if group.nodes[i].ID == nid { + atomic.StoreUint32(&group.nodes[i].failCount, 0) + atomic.StoreInt64(&group.nodes[i].failTime, 0) + break + } + } +} + +// Next selects a node from group. // It also selects IP if the IP list exists. func (group *NodeGroup) Next() (node Node, err error) { - selector := group.Selector + if group == nil { + return + } + + group.mux.RLock() + defer group.mux.RUnlock() + + selector := group.selector if selector == nil { selector = &defaultSelector{} } + // select node from node group - node, err = selector.Select(group.Nodes(), group.Options...) + node, err = selector.Select(group.nodes, group.selectorOptions...) if err != nil { return } diff --git a/reload.go b/reload.go index 5b96d8c..e5ced63 100644 --- a/reload.go +++ b/reload.go @@ -26,6 +26,7 @@ func PeriodReload(r Reloader, configFile string) error { finfo, err := f.Stat() if err != nil { + f.Close() return err } mt := finfo.ModTime() diff --git a/resolver.go b/resolver.go index 25cb39b..5d25fc7 100644 --- a/resolver.go +++ b/resolver.go @@ -68,6 +68,7 @@ type resolver struct { TTL time.Duration period time.Duration domain string + mux sync.RWMutex } // NewResolver create a new Resolver with the given name servers and resolution timeout. @@ -78,17 +79,23 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv TTL: ttl, mCache: &sync.Map{}, } - r.init() - return r -} -func (r *resolver) init() { if r.Timeout <= 0 { r.Timeout = DefaultResolverTimeout } if r.TTL == 0 { r.TTL = DefaultResolverTTL } + return r +} + +func (r *resolver) copyServers() []NameServer { + var servers []NameServer + for i := range r.Servers { + servers = append(servers, r.Servers[i]) + } + + return servers } func (r *resolver) Resolve(host string) (ips []net.IP, err error) { @@ -96,14 +103,24 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { return } + var domain string + var timeout, ttl time.Duration + var servers []NameServer + + r.mux.RLock() + domain = r.domain + timeout = r.Timeout + servers = r.copyServers() + r.mux.RUnlock() + if ip := net.ParseIP(host); ip != nil { return []net.IP{ip}, nil } - if !strings.Contains(host, ".") && r.domain != "" { - host = host + "." + r.domain + if !strings.Contains(host, ".") && domain != "" { + host = host + "." + domain } - ips = r.loadCache(host) + ips = r.loadCache(host, ttl) if len(ips) > 0 { if Debug { log.Logf("[resolver] cache hit %s: %v", host, ips) @@ -111,8 +128,8 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { return } - for _, ns := range r.Servers { - ips, err = r.resolve(ns, host) + for _, ns := range servers { + ips, err = r.resolve(ns, host, timeout) if err != nil { log.Logf("[resolver] %s via %s : %s", host, ns, err) continue @@ -130,14 +147,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { return } -func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error) { +func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) { addr := ns.Addr if _, port, _ := net.SplitHostPort(addr); port == "" { addr = net.JoinHostPort(addr, "53") } client := dns.Client{ - Timeout: r.Timeout, + Timeout: timeout, } switch strings.ToLower(ns.Protocol) { case "tcp": @@ -171,8 +188,7 @@ func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error) return } -func (r *resolver) loadCache(name string) []net.IP { - ttl := r.TTL +func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP { if ttl < 0 { return nil } @@ -189,8 +205,7 @@ func (r *resolver) loadCache(name string) []net.IP { } func (r *resolver) storeCache(name string, ips []net.IP) { - ttl := r.TTL - if ttl < 0 || name == "" || len(ips) == 0 { + if name == "" || len(ips) == 0 { return } r.mCache.Store(name, &resolverCacheItem{ @@ -200,6 +215,8 @@ func (r *resolver) storeCache(name string, ips []net.IP) { } func (r *resolver) Reload(rd io.Reader) error { + var ttl, timeout, period time.Duration + var domain string var nss []NameServer split := func(line string) []string { @@ -232,19 +249,19 @@ func (r *resolver) Reload(rd io.Reader) error { switch ss[0] { case "timeout": // timeout option if len(ss) > 1 { - r.Timeout, _ = time.ParseDuration(ss[1]) + timeout, _ = time.ParseDuration(ss[1]) } case "ttl": // ttl option if len(ss) > 1 { - r.TTL, _ = time.ParseDuration(ss[1]) + ttl, _ = time.ParseDuration(ss[1]) } case "reload": // reload option if len(ss) > 1 { - r.period, _ = time.ParseDuration(ss[1]) + period, _ = time.ParseDuration(ss[1]) } case "domain": if len(ss) > 1 { - r.domain = ss[1] + domain = ss[1] } case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf case "nameserver": // nameserver option, compatible with /etc/resolv.conf @@ -276,11 +293,21 @@ func (r *resolver) Reload(rd io.Reader) error { return err } + r.mux.Lock() + r.Timeout = timeout + r.TTL = ttl + r.domain = domain + r.period = period r.Servers = nss + r.mux.Unlock() + return nil } func (r *resolver) Period() time.Duration { + r.mux.RLock() + defer r.mux.RUnlock() + return r.period } @@ -289,6 +316,9 @@ func (r *resolver) String() string { return "" } + r.mux.RLock() + defer r.mux.RUnlock() + b := &bytes.Buffer{} fmt.Fprintf(b, "Timeout %v\n", r.Timeout) fmt.Fprintf(b, "TTL %v\n", r.TTL) diff --git a/selector.go b/selector.go index cadf799..f83a601 100644 --- a/selector.go +++ b/selector.go @@ -94,6 +94,7 @@ type RandomStrategy struct { Seed int64 rand *rand.Rand once sync.Once + mux sync.Mutex } // Apply applies the random strategy for the nodes. @@ -109,7 +110,11 @@ func (s *RandomStrategy) Apply(nodes []Node) Node { return Node{} } - return nodes[s.rand.Int()%len(nodes)] + s.mux.Lock() + r := s.rand.Int() + s.mux.Unlock() + + return nodes[r%len(nodes)] } func (s *RandomStrategy) String() string { diff --git a/server.go b/server.go index a9fbd0e..e4cdfce 100644 --- a/server.go +++ b/server.go @@ -72,11 +72,13 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error { } tempDelay = 0 - if s.options.Bypass.Contains(conn.RemoteAddr().String()) { - log.Log("[bypass]", conn.RemoteAddr()) - conn.Close() - continue - } + /* + if s.options.Bypass.Contains(conn.RemoteAddr().String()) { + log.Log("[bypass]", conn.RemoteAddr()) + conn.Close() + continue + } + */ go h.Handle(conn) } @@ -90,12 +92,14 @@ type ServerOptions struct { // ServerOption allows a common way to set server options. type ServerOption func(opts *ServerOptions) +/* // BypassServerOption sets the bypass option of ServerOptions. func BypassServerOption(bypass *Bypass) ServerOption { return func(opts *ServerOptions) { opts.Bypass = bypass } } +*/ // Listener is a proxy server listener, just like a net.Listener. type Listener interface {