fix data race

This commit is contained in:
ginuerzh 2018-11-21 23:42:45 +08:00
parent e9b872c4cf
commit a020c7bc33
11 changed files with 227 additions and 99 deletions

View File

@ -124,7 +124,7 @@ type Bypass struct {
matchers []Matcher
reversed bool
period time.Duration // the period for live reloading
mux sync.Mutex
mux sync.RWMutex
}
// NewBypass creates and initializes a new Bypass using matchers as its match rules.
@ -160,8 +160,8 @@ func (bp *Bypass) Contains(addr string) bool {
}
}
bp.mux.Lock()
defer bp.mux.Unlock()
bp.mux.RLock()
defer bp.mux.RUnlock()
var matched bool
for _, matcher := range bp.matchers {
@ -179,22 +179,33 @@ func (bp *Bypass) Contains(addr string) bool {
// AddMatchers appends matchers to the bypass matcher list.
func (bp *Bypass) AddMatchers(matchers ...Matcher) {
bp.mux.Lock()
defer bp.mux.Unlock()
bp.matchers = append(bp.matchers, matchers...)
}
// Matchers return the bypass matcher list.
func (bp *Bypass) Matchers() []Matcher {
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.matchers
}
// Reversed reports whether the rules of the bypass are reversed.
func (bp *Bypass) Reversed() bool {
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.reversed
}
// Reload parses config from r, then live reloads the bypass.
func (bp *Bypass) Reload(r io.Reader) error {
var matchers []Matcher
var period time.Duration
var reversed bool
scanner := bufio.NewScanner(r)
for scanner.Scan() {
@ -217,7 +228,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
}
}
if len(ss) == 2 {
bp.period, _ = time.ParseDuration(ss[1])
period, _ = time.ParseDuration(ss[1])
continue
}
}
@ -231,7 +242,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
}
}
if len(ss) == 2 {
bp.reversed, _ = strconv.ParseBool(ss[1])
reversed, _ = strconv.ParseBool(ss[1])
continue
}
}
@ -247,19 +258,28 @@ func (bp *Bypass) Reload(r io.Reader) error {
defer bp.mux.Unlock()
bp.matchers = matchers
bp.period = period
bp.reversed = reversed
return nil
}
// Period returns the reload period
func (bp *Bypass) Period() time.Duration {
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.period
}
func (bp *Bypass) String() string {
bp.mux.RLock()
defer bp.mux.RUnlock()
b := &bytes.Buffer{}
fmt.Fprintf(b, "reversed: %v\n", bp.Reversed())
for _, m := range bp.Matchers() {
fmt.Fprintf(b, "reversed: %v\n", bp.reversed)
fmt.Fprintf(b, "reload: %v\n", bp.period)
for _, m := range bp.matchers {
b.WriteString(m.String())
b.WriteByte('\n')
}

View File

@ -38,7 +38,7 @@ func newRoute(nodes ...Node) *Chain {
}
// Nodes returns the proxy nodes that the chain holds.
// If a node is a node group, the first node in the group will be returned.
// The first node in each group will be returned.
func (c *Chain) Nodes() (nodes []Node) {
for _, group := range c.nodeGroups {
if ns := group.Nodes(); len(ns) > 0 {
@ -61,7 +61,7 @@ func (c *Chain) LastNode() Node {
return Node{}
}
group := c.nodeGroups[len(c.nodeGroups)-1]
return group.nodes[0].Clone()
return group.GetNode(0)
}
// LastNodeGroup returns the last group of the group list.
@ -173,7 +173,6 @@ func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
}
// Conn obtains a handshaked connection to the last node of the chain.
// If the chain is empty, it returns an ErrEmptyChain error.
func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
options := &ChainOptions{}
for _, opt := range opts {
@ -206,6 +205,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
}
// getConn obtains a connection to the last node of the chain.
// It does not handshake with the last node.
func (c *Chain) getConn() (conn net.Conn, err error) {
if c.IsEmpty() {
err = ErrEmptyChain
@ -216,16 +216,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
node.ResetDead()
node.group.ResetDeadNode(node.ID)
preNode := node
for _, node := range nodes[1:] {
@ -233,16 +233,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cc, err = preNode.Client.Connect(cn, node.Addr)
if err != nil {
cn.Close()
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
if err != nil {
cn.Close()
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
node.ResetDead()
node.group.ResetDeadNode(node.ID)
cn = cc
preNode = node

View File

@ -60,7 +60,8 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
strategy = s
}
}
group.Options = append([]gost.SelectOption{},
group.SetSelector(
nil,
gost.WithFilter(&gost.FailFilter{
MaxFails: cfg.MaxFails,
FailTimeout: time.Duration(cfg.FailTimeout) * time.Second,

View File

@ -107,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
)
if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
node.MarkDead()
node.group.MarkDeadNode(node.ID)
} else {
break
}
@ -116,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
return
}
node.ResetDead()
node.group.ResetDeadNode(node.ID)
defer cc.Close()
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
@ -191,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
if h.options.Chain.IsEmpty() {
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
cc, err = net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
@ -212,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
}
defer cc.Close()
node.ResetDead()
node.group.ResetDeadNode(node.ID)
log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc)
@ -291,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout)
if err != nil {
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
node.MarkDead()
node.group.MarkDeadNode(node.ID)
} else {
break
}
@ -301,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
}
defer cc.Close()
node.ResetDead()
node.group.ResetDeadNode(node.ID)
log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
transport(cc, conn)
@ -369,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
cc, err := net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
defer cc.Close()
node.ResetDead()
node.group.ResetDeadNode(node.ID)
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc)

View File

@ -5,6 +5,7 @@ import (
"io"
"net"
"strings"
"sync"
"time"
"github.com/go-log/log"
@ -25,6 +26,7 @@ type Host struct {
type Hosts struct {
hosts []Host
period time.Duration
mux sync.RWMutex
}
// NewHosts creates a Hosts with optional list of host
@ -36,6 +38,9 @@ func NewHosts(hosts ...Host) *Hosts {
// AddHost adds host(s) to the host table.
func (h *Hosts) AddHost(host ...Host) {
h.mux.Lock()
defer h.mux.Unlock()
h.hosts = append(h.hosts, host...)
}
@ -44,6 +49,10 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
if h == nil {
return
}
h.mux.RLock()
defer h.mux.RUnlock()
for _, h := range h.hosts {
if h.Hostname == host {
ip = h.IP
@ -64,6 +73,7 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
// Reload parses config from r, then live reloads the hosts.
func (h *Hosts) Reload(r io.Reader) error {
var period time.Duration
var hosts []Host
scanner := bufio.NewScanner(r)
@ -89,7 +99,7 @@ func (h *Hosts) Reload(r io.Reader) error {
// reload option
if strings.ToLower(ss[0]) == "reload" {
h.period, _ = time.ParseDuration(ss[1])
period, _ = time.ParseDuration(ss[1])
continue
}
@ -110,11 +120,18 @@ func (h *Hosts) Reload(r io.Reader) error {
return err
}
h.mux.Lock()
h.period = period
h.hosts = hosts
h.mux.Unlock()
return nil
}
// Period returns the reload period
func (h *Hosts) Period() time.Duration {
h.mux.RLock()
defer h.mux.RUnlock()
return h.period
}

View File

@ -321,7 +321,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
if Debug && (u != "" || p != "") {
log.Logf("[http] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
log.Logf("[http2] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
}
if !authenticate(u, p, h.options.Users...) {
// probing resistance is enabled

140
node.go
View File

@ -5,6 +5,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
@ -88,41 +89,6 @@ func ParseNode(s string) (node Node, err error) {
return
}
// MarkDead marks the node fail status.
func (node *Node) MarkDead() {
atomic.AddUint32(&node.failCount, 1)
atomic.StoreInt64(&node.failTime, time.Now().Unix())
if node.group == nil {
return
}
for i := range node.group.nodes {
if node.group.nodes[i].ID == node.ID {
atomic.AddUint32(&node.group.nodes[i].failCount, 1)
atomic.StoreInt64(&node.group.nodes[i].failTime, time.Now().Unix())
break
}
}
}
// ResetDead resets the node fail status.
func (node *Node) ResetDead() {
atomic.StoreUint32(&node.failCount, 0)
atomic.StoreInt64(&node.failTime, 0)
if node.group == nil {
return
}
for i := range node.group.nodes {
if node.group.nodes[i].ID == node.ID {
atomic.StoreUint32(&node.group.nodes[i].failCount, 0)
atomic.StoreInt64(&node.group.nodes[i].failTime, 0)
break
}
}
}
// Clone clones the node, it will prevent data race.
func (node *Node) Clone() Node {
return Node{
@ -167,10 +133,11 @@ func (node *Node) String() string {
// NodeGroup is a group of nodes.
type NodeGroup struct {
ID int
nodes []Node
Options []SelectOption
Selector NodeSelector
ID int
nodes []Node
selectorOptions []SelectOption
selector NodeSelector
mux sync.RWMutex
}
// NewNodeGroup creates a node group
@ -185,11 +152,21 @@ func (group *NodeGroup) AddNode(node ...Node) {
if group == nil {
return
}
group.mux.Lock()
defer group.mux.Unlock()
group.nodes = append(group.nodes, node...)
}
// SetNodes replaces the group nodes to the specified nodes.
func (group *NodeGroup) SetNodes(nodes ...Node) {
if group == nil {
return
}
group.mux.Lock()
defer group.mux.Unlock()
group.nodes = nodes
}
@ -198,27 +175,100 @@ func (group *NodeGroup) SetSelector(selector NodeSelector, opts ...SelectOption)
if group == nil {
return
}
group.Selector = selector
group.Options = opts
group.mux.Lock()
defer group.mux.Unlock()
group.selector = selector
group.selectorOptions = opts
}
// Nodes returns node list in the group
// Nodes returns the node list in the group
func (group *NodeGroup) Nodes() []Node {
if group == nil {
return nil
}
group.mux.RLock()
defer group.mux.RUnlock()
return group.nodes
}
// Next selects the next node from group.
func (group *NodeGroup) copyNodes() []Node {
group.mux.RLock()
defer group.mux.RUnlock()
var nodes []Node
for i := range group.nodes {
nodes = append(nodes, group.nodes[i])
}
return nodes
}
// GetNode returns a copy of the node specified by index in the group.
func (group *NodeGroup) GetNode(i int) Node {
group.mux.RLock()
defer group.mux.RUnlock()
if i < 0 || group == nil || len(group.nodes) <= i {
return Node{}
}
return group.nodes[i].Clone()
}
// MarkDeadNode marks the node with ID nid status to dead.
func (group *NodeGroup) MarkDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.AddUint32(&group.nodes[i].failCount, 1)
atomic.StoreInt64(&group.nodes[i].failTime, time.Now().Unix())
break
}
}
}
// ResetDeadNode resets the node with ID nid status.
func (group *NodeGroup) ResetDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.StoreUint32(&group.nodes[i].failCount, 0)
atomic.StoreInt64(&group.nodes[i].failTime, 0)
break
}
}
}
// Next selects a node from group.
// It also selects IP if the IP list exists.
func (group *NodeGroup) Next() (node Node, err error) {
selector := group.Selector
if group == nil {
return
}
group.mux.RLock()
defer group.mux.RUnlock()
selector := group.selector
if selector == nil {
selector = &defaultSelector{}
}
// select node from node group
node, err = selector.Select(group.Nodes(), group.Options...)
node, err = selector.Select(group.nodes, group.selectorOptions...)
if err != nil {
return
}

View File

@ -26,6 +26,7 @@ func PeriodReload(r Reloader, configFile string) error {
finfo, err := f.Stat()
if err != nil {
f.Close()
return err
}
mt := finfo.ModTime()

View File

@ -68,6 +68,7 @@ type resolver struct {
TTL time.Duration
period time.Duration
domain string
mux sync.RWMutex
}
// NewResolver create a new Resolver with the given name servers and resolution timeout.
@ -78,17 +79,23 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
TTL: ttl,
mCache: &sync.Map{},
}
r.init()
return r
}
func (r *resolver) init() {
if r.Timeout <= 0 {
r.Timeout = DefaultResolverTimeout
}
if r.TTL == 0 {
r.TTL = DefaultResolverTTL
}
return r
}
func (r *resolver) copyServers() []NameServer {
var servers []NameServer
for i := range r.Servers {
servers = append(servers, r.Servers[i])
}
return servers
}
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
@ -96,14 +103,24 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return
}
var domain string
var timeout, ttl time.Duration
var servers []NameServer
r.mux.RLock()
domain = r.domain
timeout = r.Timeout
servers = r.copyServers()
r.mux.RUnlock()
if ip := net.ParseIP(host); ip != nil {
return []net.IP{ip}, nil
}
if !strings.Contains(host, ".") && r.domain != "" {
host = host + "." + r.domain
if !strings.Contains(host, ".") && domain != "" {
host = host + "." + domain
}
ips = r.loadCache(host)
ips = r.loadCache(host, ttl)
if len(ips) > 0 {
if Debug {
log.Logf("[resolver] cache hit %s: %v", host, ips)
@ -111,8 +128,8 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return
}
for _, ns := range r.Servers {
ips, err = r.resolve(ns, host)
for _, ns := range servers {
ips, err = r.resolve(ns, host, timeout)
if err != nil {
log.Logf("[resolver] %s via %s : %s", host, ns, err)
continue
@ -130,14 +147,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return
}
func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error) {
func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) {
addr := ns.Addr
if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "53")
}
client := dns.Client{
Timeout: r.Timeout,
Timeout: timeout,
}
switch strings.ToLower(ns.Protocol) {
case "tcp":
@ -171,8 +188,7 @@ func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error)
return
}
func (r *resolver) loadCache(name string) []net.IP {
ttl := r.TTL
func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
if ttl < 0 {
return nil
}
@ -189,8 +205,7 @@ func (r *resolver) loadCache(name string) []net.IP {
}
func (r *resolver) storeCache(name string, ips []net.IP) {
ttl := r.TTL
if ttl < 0 || name == "" || len(ips) == 0 {
if name == "" || len(ips) == 0 {
return
}
r.mCache.Store(name, &resolverCacheItem{
@ -200,6 +215,8 @@ func (r *resolver) storeCache(name string, ips []net.IP) {
}
func (r *resolver) Reload(rd io.Reader) error {
var ttl, timeout, period time.Duration
var domain string
var nss []NameServer
split := func(line string) []string {
@ -232,19 +249,19 @@ func (r *resolver) Reload(rd io.Reader) error {
switch ss[0] {
case "timeout": // timeout option
if len(ss) > 1 {
r.Timeout, _ = time.ParseDuration(ss[1])
timeout, _ = time.ParseDuration(ss[1])
}
case "ttl": // ttl option
if len(ss) > 1 {
r.TTL, _ = time.ParseDuration(ss[1])
ttl, _ = time.ParseDuration(ss[1])
}
case "reload": // reload option
if len(ss) > 1 {
r.period, _ = time.ParseDuration(ss[1])
period, _ = time.ParseDuration(ss[1])
}
case "domain":
if len(ss) > 1 {
r.domain = ss[1]
domain = ss[1]
}
case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf
case "nameserver": // nameserver option, compatible with /etc/resolv.conf
@ -276,11 +293,21 @@ func (r *resolver) Reload(rd io.Reader) error {
return err
}
r.mux.Lock()
r.Timeout = timeout
r.TTL = ttl
r.domain = domain
r.period = period
r.Servers = nss
r.mux.Unlock()
return nil
}
func (r *resolver) Period() time.Duration {
r.mux.RLock()
defer r.mux.RUnlock()
return r.period
}
@ -289,6 +316,9 @@ func (r *resolver) String() string {
return ""
}
r.mux.RLock()
defer r.mux.RUnlock()
b := &bytes.Buffer{}
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
fmt.Fprintf(b, "TTL %v\n", r.TTL)

View File

@ -94,6 +94,7 @@ type RandomStrategy struct {
Seed int64
rand *rand.Rand
once sync.Once
mux sync.Mutex
}
// Apply applies the random strategy for the nodes.
@ -109,7 +110,11 @@ func (s *RandomStrategy) Apply(nodes []Node) Node {
return Node{}
}
return nodes[s.rand.Int()%len(nodes)]
s.mux.Lock()
r := s.rand.Int()
s.mux.Unlock()
return nodes[r%len(nodes)]
}
func (s *RandomStrategy) String() string {

View File

@ -72,11 +72,13 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error {
}
tempDelay = 0
if s.options.Bypass.Contains(conn.RemoteAddr().String()) {
log.Log("[bypass]", conn.RemoteAddr())
conn.Close()
continue
}
/*
if s.options.Bypass.Contains(conn.RemoteAddr().String()) {
log.Log("[bypass]", conn.RemoteAddr())
conn.Close()
continue
}
*/
go h.Handle(conn)
}
@ -90,12 +92,14 @@ type ServerOptions struct {
// ServerOption allows a common way to set server options.
type ServerOption func(opts *ServerOptions)
/*
// BypassServerOption sets the bypass option of ServerOptions.
func BypassServerOption(bypass *Bypass) ServerOption {
return func(opts *ServerOptions) {
opts.Bypass = bypass
}
}
*/
// Listener is a proxy server listener, just like a net.Listener.
type Listener interface {