fix data race

This commit is contained in:
ginuerzh 2018-11-21 23:42:45 +08:00
parent e9b872c4cf
commit a020c7bc33
11 changed files with 227 additions and 99 deletions

View File

@ -124,7 +124,7 @@ type Bypass struct {
matchers []Matcher matchers []Matcher
reversed bool reversed bool
period time.Duration // the period for live reloading period time.Duration // the period for live reloading
mux sync.Mutex mux sync.RWMutex
} }
// NewBypass creates and initializes a new Bypass using matchers as its match rules. // NewBypass creates and initializes a new Bypass using matchers as its match rules.
@ -160,8 +160,8 @@ func (bp *Bypass) Contains(addr string) bool {
} }
} }
bp.mux.Lock() bp.mux.RLock()
defer bp.mux.Unlock() defer bp.mux.RUnlock()
var matched bool var matched bool
for _, matcher := range bp.matchers { for _, matcher := range bp.matchers {
@ -179,22 +179,33 @@ func (bp *Bypass) Contains(addr string) bool {
// AddMatchers appends matchers to the bypass matcher list. // AddMatchers appends matchers to the bypass matcher list.
func (bp *Bypass) AddMatchers(matchers ...Matcher) { func (bp *Bypass) AddMatchers(matchers ...Matcher) {
bp.mux.Lock()
defer bp.mux.Unlock()
bp.matchers = append(bp.matchers, matchers...) bp.matchers = append(bp.matchers, matchers...)
} }
// Matchers return the bypass matcher list. // Matchers return the bypass matcher list.
func (bp *Bypass) Matchers() []Matcher { func (bp *Bypass) Matchers() []Matcher {
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.matchers return bp.matchers
} }
// Reversed reports whether the rules of the bypass are reversed. // Reversed reports whether the rules of the bypass are reversed.
func (bp *Bypass) Reversed() bool { func (bp *Bypass) Reversed() bool {
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.reversed return bp.reversed
} }
// Reload parses config from r, then live reloads the bypass. // Reload parses config from r, then live reloads the bypass.
func (bp *Bypass) Reload(r io.Reader) error { func (bp *Bypass) Reload(r io.Reader) error {
var matchers []Matcher var matchers []Matcher
var period time.Duration
var reversed bool
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
@ -217,7 +228,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
} }
} }
if len(ss) == 2 { if len(ss) == 2 {
bp.period, _ = time.ParseDuration(ss[1]) period, _ = time.ParseDuration(ss[1])
continue continue
} }
} }
@ -231,7 +242,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
} }
} }
if len(ss) == 2 { if len(ss) == 2 {
bp.reversed, _ = strconv.ParseBool(ss[1]) reversed, _ = strconv.ParseBool(ss[1])
continue continue
} }
} }
@ -247,19 +258,28 @@ func (bp *Bypass) Reload(r io.Reader) error {
defer bp.mux.Unlock() defer bp.mux.Unlock()
bp.matchers = matchers bp.matchers = matchers
bp.period = period
bp.reversed = reversed
return nil return nil
} }
// Period returns the reload period // Period returns the reload period
func (bp *Bypass) Period() time.Duration { func (bp *Bypass) Period() time.Duration {
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.period return bp.period
} }
func (bp *Bypass) String() string { func (bp *Bypass) String() string {
bp.mux.RLock()
defer bp.mux.RUnlock()
b := &bytes.Buffer{} b := &bytes.Buffer{}
fmt.Fprintf(b, "reversed: %v\n", bp.Reversed()) fmt.Fprintf(b, "reversed: %v\n", bp.reversed)
for _, m := range bp.Matchers() { fmt.Fprintf(b, "reload: %v\n", bp.period)
for _, m := range bp.matchers {
b.WriteString(m.String()) b.WriteString(m.String())
b.WriteByte('\n') b.WriteByte('\n')
} }

View File

@ -38,7 +38,7 @@ func newRoute(nodes ...Node) *Chain {
} }
// Nodes returns the proxy nodes that the chain holds. // Nodes returns the proxy nodes that the chain holds.
// If a node is a node group, the first node in the group will be returned. // The first node in each group will be returned.
func (c *Chain) Nodes() (nodes []Node) { func (c *Chain) Nodes() (nodes []Node) {
for _, group := range c.nodeGroups { for _, group := range c.nodeGroups {
if ns := group.Nodes(); len(ns) > 0 { if ns := group.Nodes(); len(ns) > 0 {
@ -61,7 +61,7 @@ func (c *Chain) LastNode() Node {
return Node{} return Node{}
} }
group := c.nodeGroups[len(c.nodeGroups)-1] group := c.nodeGroups[len(c.nodeGroups)-1]
return group.nodes[0].Clone() return group.GetNode(0)
} }
// LastNodeGroup returns the last group of the group list. // LastNodeGroup returns the last group of the group list.
@ -173,7 +173,6 @@ func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
} }
// Conn obtains a handshaked connection to the last node of the chain. // Conn obtains a handshaked connection to the last node of the chain.
// If the chain is empty, it returns an ErrEmptyChain error.
func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
options := &ChainOptions{} options := &ChainOptions{}
for _, opt := range opts { for _, opt := range opts {
@ -206,6 +205,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
} }
// getConn obtains a connection to the last node of the chain. // getConn obtains a connection to the last node of the chain.
// It does not handshake with the last node.
func (c *Chain) getConn() (conn net.Conn, err error) { func (c *Chain) getConn() (conn net.Conn, err error) {
if c.IsEmpty() { if c.IsEmpty() {
err = ErrEmptyChain err = ErrEmptyChain
@ -216,16 +216,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cn, err := node.Client.Dial(node.Addr, node.DialOptions...) cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil { if err != nil {
node.MarkDead() node.group.MarkDeadNode(node.ID)
return return
} }
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...) cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
if err != nil { if err != nil {
node.MarkDead() node.group.MarkDeadNode(node.ID)
return return
} }
node.ResetDead() node.group.ResetDeadNode(node.ID)
preNode := node preNode := node
for _, node := range nodes[1:] { for _, node := range nodes[1:] {
@ -233,16 +233,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
cc, err = preNode.Client.Connect(cn, node.Addr) cc, err = preNode.Client.Connect(cn, node.Addr)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.MarkDead() node.group.MarkDeadNode(node.ID)
return return
} }
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...) cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.MarkDead() node.group.MarkDeadNode(node.ID)
return return
} }
node.ResetDead() node.group.ResetDeadNode(node.ID)
cn = cc cn = cc
preNode = node preNode = node

View File

@ -60,7 +60,8 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
strategy = s strategy = s
} }
} }
group.Options = append([]gost.SelectOption{}, group.SetSelector(
nil,
gost.WithFilter(&gost.FailFilter{ gost.WithFilter(&gost.FailFilter{
MaxFails: cfg.MaxFails, MaxFails: cfg.MaxFails,
FailTimeout: time.Duration(cfg.FailTimeout) * time.Second, FailTimeout: time.Duration(cfg.FailTimeout) * time.Second,

View File

@ -107,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
) )
if err != nil { if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
node.MarkDead() node.group.MarkDeadNode(node.ID)
} else { } else {
break break
} }
@ -116,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
return return
} }
node.ResetDead() node.group.ResetDeadNode(node.ID)
defer cc.Close() defer cc.Close()
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr) log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
@ -191,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
if h.options.Chain.IsEmpty() { if h.options.Chain.IsEmpty() {
raddr, err := net.ResolveUDPAddr("udp", node.Addr) raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil { if err != nil {
node.MarkDead() node.group.MarkDeadNode(node.ID)
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return return
} }
cc, err = net.DialUDP("udp", nil, raddr) cc, err = net.DialUDP("udp", nil, raddr)
if err != nil { if err != nil {
node.MarkDead() node.group.MarkDeadNode(node.ID)
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return return
} }
@ -212,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
} }
defer cc.Close() defer cc.Close()
node.ResetDead() node.group.ResetDeadNode(node.ID)
log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr) log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc) transport(conn, cc)
@ -291,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout) cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout)
if err != nil { if err != nil {
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err) log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
node.MarkDead() node.group.MarkDeadNode(node.ID)
} else { } else {
break break
} }
@ -301,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
} }
defer cc.Close() defer cc.Close()
node.ResetDead() node.group.ResetDeadNode(node.ID)
log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr) log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
transport(cc, conn) transport(cc, conn)
@ -369,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
raddr, err := net.ResolveUDPAddr("udp", node.Addr) raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil { if err != nil {
node.MarkDead() node.group.MarkDeadNode(node.ID)
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return return
} }
cc, err := net.DialUDP("udp", nil, raddr) cc, err := net.DialUDP("udp", nil, raddr)
if err != nil { if err != nil {
node.MarkDead() node.group.MarkDeadNode(node.ID)
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return return
} }
defer cc.Close() defer cc.Close()
node.ResetDead() node.group.ResetDeadNode(node.ID)
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr) log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc) transport(conn, cc)

View File

@ -5,6 +5,7 @@ import (
"io" "io"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/go-log/log" "github.com/go-log/log"
@ -25,6 +26,7 @@ type Host struct {
type Hosts struct { type Hosts struct {
hosts []Host hosts []Host
period time.Duration period time.Duration
mux sync.RWMutex
} }
// NewHosts creates a Hosts with optional list of host // NewHosts creates a Hosts with optional list of host
@ -36,6 +38,9 @@ func NewHosts(hosts ...Host) *Hosts {
// AddHost adds host(s) to the host table. // AddHost adds host(s) to the host table.
func (h *Hosts) AddHost(host ...Host) { func (h *Hosts) AddHost(host ...Host) {
h.mux.Lock()
defer h.mux.Unlock()
h.hosts = append(h.hosts, host...) h.hosts = append(h.hosts, host...)
} }
@ -44,6 +49,10 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
if h == nil { if h == nil {
return return
} }
h.mux.RLock()
defer h.mux.RUnlock()
for _, h := range h.hosts { for _, h := range h.hosts {
if h.Hostname == host { if h.Hostname == host {
ip = h.IP ip = h.IP
@ -64,6 +73,7 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
// Reload parses config from r, then live reloads the hosts. // Reload parses config from r, then live reloads the hosts.
func (h *Hosts) Reload(r io.Reader) error { func (h *Hosts) Reload(r io.Reader) error {
var period time.Duration
var hosts []Host var hosts []Host
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
@ -89,7 +99,7 @@ func (h *Hosts) Reload(r io.Reader) error {
// reload option // reload option
if strings.ToLower(ss[0]) == "reload" { if strings.ToLower(ss[0]) == "reload" {
h.period, _ = time.ParseDuration(ss[1]) period, _ = time.ParseDuration(ss[1])
continue continue
} }
@ -110,11 +120,18 @@ func (h *Hosts) Reload(r io.Reader) error {
return err return err
} }
h.mux.Lock()
h.period = period
h.hosts = hosts h.hosts = hosts
h.mux.Unlock()
return nil return nil
} }
// Period returns the reload period // Period returns the reload period
func (h *Hosts) Period() time.Duration { func (h *Hosts) Period() time.Duration {
h.mux.RLock()
defer h.mux.RUnlock()
return h.period return h.period
} }

View File

@ -321,7 +321,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
if Debug && (u != "" || p != "") { if Debug && (u != "" || p != "") {
log.Logf("[http] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p) log.Logf("[http2] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
} }
if !authenticate(u, p, h.options.Users...) { if !authenticate(u, p, h.options.Users...) {
// probing resistance is enabled // probing resistance is enabled

140
node.go
View File

@ -5,6 +5,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
@ -88,41 +89,6 @@ func ParseNode(s string) (node Node, err error) {
return return
} }
// MarkDead marks the node fail status.
func (node *Node) MarkDead() {
atomic.AddUint32(&node.failCount, 1)
atomic.StoreInt64(&node.failTime, time.Now().Unix())
if node.group == nil {
return
}
for i := range node.group.nodes {
if node.group.nodes[i].ID == node.ID {
atomic.AddUint32(&node.group.nodes[i].failCount, 1)
atomic.StoreInt64(&node.group.nodes[i].failTime, time.Now().Unix())
break
}
}
}
// ResetDead resets the node fail status.
func (node *Node) ResetDead() {
atomic.StoreUint32(&node.failCount, 0)
atomic.StoreInt64(&node.failTime, 0)
if node.group == nil {
return
}
for i := range node.group.nodes {
if node.group.nodes[i].ID == node.ID {
atomic.StoreUint32(&node.group.nodes[i].failCount, 0)
atomic.StoreInt64(&node.group.nodes[i].failTime, 0)
break
}
}
}
// Clone clones the node, it will prevent data race. // Clone clones the node, it will prevent data race.
func (node *Node) Clone() Node { func (node *Node) Clone() Node {
return Node{ return Node{
@ -167,10 +133,11 @@ func (node *Node) String() string {
// NodeGroup is a group of nodes. // NodeGroup is a group of nodes.
type NodeGroup struct { type NodeGroup struct {
ID int ID int
nodes []Node nodes []Node
Options []SelectOption selectorOptions []SelectOption
Selector NodeSelector selector NodeSelector
mux sync.RWMutex
} }
// NewNodeGroup creates a node group // NewNodeGroup creates a node group
@ -185,11 +152,21 @@ func (group *NodeGroup) AddNode(node ...Node) {
if group == nil { if group == nil {
return return
} }
group.mux.Lock()
defer group.mux.Unlock()
group.nodes = append(group.nodes, node...) group.nodes = append(group.nodes, node...)
} }
// SetNodes replaces the group nodes to the specified nodes. // SetNodes replaces the group nodes to the specified nodes.
func (group *NodeGroup) SetNodes(nodes ...Node) { func (group *NodeGroup) SetNodes(nodes ...Node) {
if group == nil {
return
}
group.mux.Lock()
defer group.mux.Unlock()
group.nodes = nodes group.nodes = nodes
} }
@ -198,27 +175,100 @@ func (group *NodeGroup) SetSelector(selector NodeSelector, opts ...SelectOption)
if group == nil { if group == nil {
return return
} }
group.Selector = selector group.mux.Lock()
group.Options = opts defer group.mux.Unlock()
group.selector = selector
group.selectorOptions = opts
} }
// Nodes returns node list in the group // Nodes returns the node list in the group
func (group *NodeGroup) Nodes() []Node { func (group *NodeGroup) Nodes() []Node {
if group == nil { if group == nil {
return nil return nil
} }
group.mux.RLock()
defer group.mux.RUnlock()
return group.nodes return group.nodes
} }
// Next selects the next node from group. func (group *NodeGroup) copyNodes() []Node {
group.mux.RLock()
defer group.mux.RUnlock()
var nodes []Node
for i := range group.nodes {
nodes = append(nodes, group.nodes[i])
}
return nodes
}
// GetNode returns a copy of the node specified by index in the group.
func (group *NodeGroup) GetNode(i int) Node {
group.mux.RLock()
defer group.mux.RUnlock()
if i < 0 || group == nil || len(group.nodes) <= i {
return Node{}
}
return group.nodes[i].Clone()
}
// MarkDeadNode marks the node with ID nid status to dead.
func (group *NodeGroup) MarkDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.AddUint32(&group.nodes[i].failCount, 1)
atomic.StoreInt64(&group.nodes[i].failTime, time.Now().Unix())
break
}
}
}
// ResetDeadNode resets the node with ID nid status.
func (group *NodeGroup) ResetDeadNode(nid int) {
group.mux.RLock()
defer group.mux.RUnlock()
if group == nil || nid <= 0 {
return
}
for i := range group.nodes {
if group.nodes[i].ID == nid {
atomic.StoreUint32(&group.nodes[i].failCount, 0)
atomic.StoreInt64(&group.nodes[i].failTime, 0)
break
}
}
}
// Next selects a node from group.
// It also selects IP if the IP list exists. // It also selects IP if the IP list exists.
func (group *NodeGroup) Next() (node Node, err error) { func (group *NodeGroup) Next() (node Node, err error) {
selector := group.Selector if group == nil {
return
}
group.mux.RLock()
defer group.mux.RUnlock()
selector := group.selector
if selector == nil { if selector == nil {
selector = &defaultSelector{} selector = &defaultSelector{}
} }
// select node from node group // select node from node group
node, err = selector.Select(group.Nodes(), group.Options...) node, err = selector.Select(group.nodes, group.selectorOptions...)
if err != nil { if err != nil {
return return
} }

View File

@ -26,6 +26,7 @@ func PeriodReload(r Reloader, configFile string) error {
finfo, err := f.Stat() finfo, err := f.Stat()
if err != nil { if err != nil {
f.Close()
return err return err
} }
mt := finfo.ModTime() mt := finfo.ModTime()

View File

@ -68,6 +68,7 @@ type resolver struct {
TTL time.Duration TTL time.Duration
period time.Duration period time.Duration
domain string domain string
mux sync.RWMutex
} }
// NewResolver create a new Resolver with the given name servers and resolution timeout. // NewResolver create a new Resolver with the given name servers and resolution timeout.
@ -78,17 +79,23 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
TTL: ttl, TTL: ttl,
mCache: &sync.Map{}, mCache: &sync.Map{},
} }
r.init()
return r
}
func (r *resolver) init() {
if r.Timeout <= 0 { if r.Timeout <= 0 {
r.Timeout = DefaultResolverTimeout r.Timeout = DefaultResolverTimeout
} }
if r.TTL == 0 { if r.TTL == 0 {
r.TTL = DefaultResolverTTL r.TTL = DefaultResolverTTL
} }
return r
}
func (r *resolver) copyServers() []NameServer {
var servers []NameServer
for i := range r.Servers {
servers = append(servers, r.Servers[i])
}
return servers
} }
func (r *resolver) Resolve(host string) (ips []net.IP, err error) { func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
@ -96,14 +103,24 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return return
} }
var domain string
var timeout, ttl time.Duration
var servers []NameServer
r.mux.RLock()
domain = r.domain
timeout = r.Timeout
servers = r.copyServers()
r.mux.RUnlock()
if ip := net.ParseIP(host); ip != nil { if ip := net.ParseIP(host); ip != nil {
return []net.IP{ip}, nil return []net.IP{ip}, nil
} }
if !strings.Contains(host, ".") && r.domain != "" { if !strings.Contains(host, ".") && domain != "" {
host = host + "." + r.domain host = host + "." + domain
} }
ips = r.loadCache(host) ips = r.loadCache(host, ttl)
if len(ips) > 0 { if len(ips) > 0 {
if Debug { if Debug {
log.Logf("[resolver] cache hit %s: %v", host, ips) log.Logf("[resolver] cache hit %s: %v", host, ips)
@ -111,8 +128,8 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return return
} }
for _, ns := range r.Servers { for _, ns := range servers {
ips, err = r.resolve(ns, host) ips, err = r.resolve(ns, host, timeout)
if err != nil { if err != nil {
log.Logf("[resolver] %s via %s : %s", host, ns, err) log.Logf("[resolver] %s via %s : %s", host, ns, err)
continue continue
@ -130,14 +147,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return return
} }
func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error) { func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) {
addr := ns.Addr addr := ns.Addr
if _, port, _ := net.SplitHostPort(addr); port == "" { if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "53") addr = net.JoinHostPort(addr, "53")
} }
client := dns.Client{ client := dns.Client{
Timeout: r.Timeout, Timeout: timeout,
} }
switch strings.ToLower(ns.Protocol) { switch strings.ToLower(ns.Protocol) {
case "tcp": case "tcp":
@ -171,8 +188,7 @@ func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error)
return return
} }
func (r *resolver) loadCache(name string) []net.IP { func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
ttl := r.TTL
if ttl < 0 { if ttl < 0 {
return nil return nil
} }
@ -189,8 +205,7 @@ func (r *resolver) loadCache(name string) []net.IP {
} }
func (r *resolver) storeCache(name string, ips []net.IP) { func (r *resolver) storeCache(name string, ips []net.IP) {
ttl := r.TTL if name == "" || len(ips) == 0 {
if ttl < 0 || name == "" || len(ips) == 0 {
return return
} }
r.mCache.Store(name, &resolverCacheItem{ r.mCache.Store(name, &resolverCacheItem{
@ -200,6 +215,8 @@ func (r *resolver) storeCache(name string, ips []net.IP) {
} }
func (r *resolver) Reload(rd io.Reader) error { func (r *resolver) Reload(rd io.Reader) error {
var ttl, timeout, period time.Duration
var domain string
var nss []NameServer var nss []NameServer
split := func(line string) []string { split := func(line string) []string {
@ -232,19 +249,19 @@ func (r *resolver) Reload(rd io.Reader) error {
switch ss[0] { switch ss[0] {
case "timeout": // timeout option case "timeout": // timeout option
if len(ss) > 1 { if len(ss) > 1 {
r.Timeout, _ = time.ParseDuration(ss[1]) timeout, _ = time.ParseDuration(ss[1])
} }
case "ttl": // ttl option case "ttl": // ttl option
if len(ss) > 1 { if len(ss) > 1 {
r.TTL, _ = time.ParseDuration(ss[1]) ttl, _ = time.ParseDuration(ss[1])
} }
case "reload": // reload option case "reload": // reload option
if len(ss) > 1 { if len(ss) > 1 {
r.period, _ = time.ParseDuration(ss[1]) period, _ = time.ParseDuration(ss[1])
} }
case "domain": case "domain":
if len(ss) > 1 { if len(ss) > 1 {
r.domain = ss[1] domain = ss[1]
} }
case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf
case "nameserver": // nameserver option, compatible with /etc/resolv.conf case "nameserver": // nameserver option, compatible with /etc/resolv.conf
@ -276,11 +293,21 @@ func (r *resolver) Reload(rd io.Reader) error {
return err return err
} }
r.mux.Lock()
r.Timeout = timeout
r.TTL = ttl
r.domain = domain
r.period = period
r.Servers = nss r.Servers = nss
r.mux.Unlock()
return nil return nil
} }
func (r *resolver) Period() time.Duration { func (r *resolver) Period() time.Duration {
r.mux.RLock()
defer r.mux.RUnlock()
return r.period return r.period
} }
@ -289,6 +316,9 @@ func (r *resolver) String() string {
return "" return ""
} }
r.mux.RLock()
defer r.mux.RUnlock()
b := &bytes.Buffer{} b := &bytes.Buffer{}
fmt.Fprintf(b, "Timeout %v\n", r.Timeout) fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
fmt.Fprintf(b, "TTL %v\n", r.TTL) fmt.Fprintf(b, "TTL %v\n", r.TTL)

View File

@ -94,6 +94,7 @@ type RandomStrategy struct {
Seed int64 Seed int64
rand *rand.Rand rand *rand.Rand
once sync.Once once sync.Once
mux sync.Mutex
} }
// Apply applies the random strategy for the nodes. // Apply applies the random strategy for the nodes.
@ -109,7 +110,11 @@ func (s *RandomStrategy) Apply(nodes []Node) Node {
return Node{} return Node{}
} }
return nodes[s.rand.Int()%len(nodes)] s.mux.Lock()
r := s.rand.Int()
s.mux.Unlock()
return nodes[r%len(nodes)]
} }
func (s *RandomStrategy) String() string { func (s *RandomStrategy) String() string {

View File

@ -72,11 +72,13 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error {
} }
tempDelay = 0 tempDelay = 0
if s.options.Bypass.Contains(conn.RemoteAddr().String()) { /*
log.Log("[bypass]", conn.RemoteAddr()) if s.options.Bypass.Contains(conn.RemoteAddr().String()) {
conn.Close() log.Log("[bypass]", conn.RemoteAddr())
continue conn.Close()
} continue
}
*/
go h.Handle(conn) go h.Handle(conn)
} }
@ -90,12 +92,14 @@ type ServerOptions struct {
// ServerOption allows a common way to set server options. // ServerOption allows a common way to set server options.
type ServerOption func(opts *ServerOptions) type ServerOption func(opts *ServerOptions)
/*
// BypassServerOption sets the bypass option of ServerOptions. // BypassServerOption sets the bypass option of ServerOptions.
func BypassServerOption(bypass *Bypass) ServerOption { func BypassServerOption(bypass *Bypass) ServerOption {
return func(opts *ServerOptions) { return func(opts *ServerOptions) {
opts.Bypass = bypass opts.Bypass = bypass
} }
} }
*/
// Listener is a proxy server listener, just like a net.Listener. // Listener is a proxy server listener, just like a net.Listener.
type Listener interface { type Listener interface {