fix data race
This commit is contained in:
parent
e9b872c4cf
commit
a020c7bc33
34
bypass.go
34
bypass.go
@ -124,7 +124,7 @@ type Bypass struct {
|
||||
matchers []Matcher
|
||||
reversed bool
|
||||
period time.Duration // the period for live reloading
|
||||
mux sync.Mutex
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBypass creates and initializes a new Bypass using matchers as its match rules.
|
||||
@ -160,8 +160,8 @@ func (bp *Bypass) Contains(addr string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
bp.mux.Lock()
|
||||
defer bp.mux.Unlock()
|
||||
bp.mux.RLock()
|
||||
defer bp.mux.RUnlock()
|
||||
|
||||
var matched bool
|
||||
for _, matcher := range bp.matchers {
|
||||
@ -179,22 +179,33 @@ func (bp *Bypass) Contains(addr string) bool {
|
||||
|
||||
// AddMatchers appends matchers to the bypass matcher list.
|
||||
func (bp *Bypass) AddMatchers(matchers ...Matcher) {
|
||||
bp.mux.Lock()
|
||||
defer bp.mux.Unlock()
|
||||
|
||||
bp.matchers = append(bp.matchers, matchers...)
|
||||
}
|
||||
|
||||
// Matchers return the bypass matcher list.
|
||||
func (bp *Bypass) Matchers() []Matcher {
|
||||
bp.mux.RLock()
|
||||
defer bp.mux.RUnlock()
|
||||
|
||||
return bp.matchers
|
||||
}
|
||||
|
||||
// Reversed reports whether the rules of the bypass are reversed.
|
||||
func (bp *Bypass) Reversed() bool {
|
||||
bp.mux.RLock()
|
||||
defer bp.mux.RUnlock()
|
||||
|
||||
return bp.reversed
|
||||
}
|
||||
|
||||
// Reload parses config from r, then live reloads the bypass.
|
||||
func (bp *Bypass) Reload(r io.Reader) error {
|
||||
var matchers []Matcher
|
||||
var period time.Duration
|
||||
var reversed bool
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
@ -217,7 +228,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
|
||||
}
|
||||
}
|
||||
if len(ss) == 2 {
|
||||
bp.period, _ = time.ParseDuration(ss[1])
|
||||
period, _ = time.ParseDuration(ss[1])
|
||||
continue
|
||||
}
|
||||
}
|
||||
@ -231,7 +242,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
|
||||
}
|
||||
}
|
||||
if len(ss) == 2 {
|
||||
bp.reversed, _ = strconv.ParseBool(ss[1])
|
||||
reversed, _ = strconv.ParseBool(ss[1])
|
||||
continue
|
||||
}
|
||||
}
|
||||
@ -247,19 +258,28 @@ func (bp *Bypass) Reload(r io.Reader) error {
|
||||
defer bp.mux.Unlock()
|
||||
|
||||
bp.matchers = matchers
|
||||
bp.period = period
|
||||
bp.reversed = reversed
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Period returns the reload period
|
||||
func (bp *Bypass) Period() time.Duration {
|
||||
bp.mux.RLock()
|
||||
defer bp.mux.RUnlock()
|
||||
|
||||
return bp.period
|
||||
}
|
||||
|
||||
func (bp *Bypass) String() string {
|
||||
bp.mux.RLock()
|
||||
defer bp.mux.RUnlock()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
fmt.Fprintf(b, "reversed: %v\n", bp.Reversed())
|
||||
for _, m := range bp.Matchers() {
|
||||
fmt.Fprintf(b, "reversed: %v\n", bp.reversed)
|
||||
fmt.Fprintf(b, "reload: %v\n", bp.period)
|
||||
for _, m := range bp.matchers {
|
||||
b.WriteString(m.String())
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
18
chain.go
18
chain.go
@ -38,7 +38,7 @@ func newRoute(nodes ...Node) *Chain {
|
||||
}
|
||||
|
||||
// Nodes returns the proxy nodes that the chain holds.
|
||||
// If a node is a node group, the first node in the group will be returned.
|
||||
// The first node in each group will be returned.
|
||||
func (c *Chain) Nodes() (nodes []Node) {
|
||||
for _, group := range c.nodeGroups {
|
||||
if ns := group.Nodes(); len(ns) > 0 {
|
||||
@ -61,7 +61,7 @@ func (c *Chain) LastNode() Node {
|
||||
return Node{}
|
||||
}
|
||||
group := c.nodeGroups[len(c.nodeGroups)-1]
|
||||
return group.nodes[0].Clone()
|
||||
return group.GetNode(0)
|
||||
}
|
||||
|
||||
// LastNodeGroup returns the last group of the group list.
|
||||
@ -173,7 +173,6 @@ func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
|
||||
}
|
||||
|
||||
// Conn obtains a handshaked connection to the last node of the chain.
|
||||
// If the chain is empty, it returns an ErrEmptyChain error.
|
||||
func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
|
||||
options := &ChainOptions{}
|
||||
for _, opt := range opts {
|
||||
@ -206,6 +205,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
|
||||
}
|
||||
|
||||
// getConn obtains a connection to the last node of the chain.
|
||||
// It does not handshake with the last node.
|
||||
func (c *Chain) getConn() (conn net.Conn, err error) {
|
||||
if c.IsEmpty() {
|
||||
err = ErrEmptyChain
|
||||
@ -216,16 +216,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
|
||||
|
||||
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
||||
if err != nil {
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
return
|
||||
}
|
||||
|
||||
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
|
||||
if err != nil {
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
return
|
||||
}
|
||||
node.ResetDead()
|
||||
node.group.ResetDeadNode(node.ID)
|
||||
|
||||
preNode := node
|
||||
for _, node := range nodes[1:] {
|
||||
@ -233,16 +233,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
|
||||
cc, err = preNode.Client.Connect(cn, node.Addr)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
return
|
||||
}
|
||||
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
return
|
||||
}
|
||||
node.ResetDead()
|
||||
node.group.ResetDeadNode(node.ID)
|
||||
|
||||
cn = cc
|
||||
preNode = node
|
||||
|
@ -60,7 +60,8 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
|
||||
strategy = s
|
||||
}
|
||||
}
|
||||
group.Options = append([]gost.SelectOption{},
|
||||
group.SetSelector(
|
||||
nil,
|
||||
gost.WithFilter(&gost.FailFilter{
|
||||
MaxFails: cfg.MaxFails,
|
||||
FailTimeout: time.Duration(cfg.FailTimeout) * time.Second,
|
||||
|
20
forward.go
20
forward.go
@ -107,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
|
||||
)
|
||||
if err != nil {
|
||||
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
@ -116,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
node.ResetDead()
|
||||
node.group.ResetDeadNode(node.ID)
|
||||
defer cc.Close()
|
||||
|
||||
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
|
||||
@ -191,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
|
||||
if h.options.Chain.IsEmpty() {
|
||||
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
|
||||
if err != nil {
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
|
||||
return
|
||||
}
|
||||
cc, err = net.DialUDP("udp", nil, raddr)
|
||||
if err != nil {
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
|
||||
return
|
||||
}
|
||||
@ -212,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
|
||||
}
|
||||
|
||||
defer cc.Close()
|
||||
node.ResetDead()
|
||||
node.group.ResetDeadNode(node.ID)
|
||||
|
||||
log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
|
||||
transport(conn, cc)
|
||||
@ -291,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
|
||||
cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout)
|
||||
if err != nil {
|
||||
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
@ -301,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
|
||||
}
|
||||
|
||||
defer cc.Close()
|
||||
node.ResetDead()
|
||||
node.group.ResetDeadNode(node.ID)
|
||||
|
||||
log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
|
||||
transport(cc, conn)
|
||||
@ -369,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
|
||||
|
||||
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
|
||||
if err != nil {
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
|
||||
return
|
||||
}
|
||||
cc, err := net.DialUDP("udp", nil, raddr)
|
||||
if err != nil {
|
||||
node.MarkDead()
|
||||
node.group.MarkDeadNode(node.ID)
|
||||
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
|
||||
return
|
||||
}
|
||||
defer cc.Close()
|
||||
node.ResetDead()
|
||||
node.group.ResetDeadNode(node.ID)
|
||||
|
||||
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
|
||||
transport(conn, cc)
|
||||
|
19
hosts.go
19
hosts.go
@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-log/log"
|
||||
@ -25,6 +26,7 @@ type Host struct {
|
||||
type Hosts struct {
|
||||
hosts []Host
|
||||
period time.Duration
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHosts creates a Hosts with optional list of host
|
||||
@ -36,6 +38,9 @@ func NewHosts(hosts ...Host) *Hosts {
|
||||
|
||||
// AddHost adds host(s) to the host table.
|
||||
func (h *Hosts) AddHost(host ...Host) {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
h.hosts = append(h.hosts, host...)
|
||||
}
|
||||
|
||||
@ -44,6 +49,10 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.mux.RLock()
|
||||
defer h.mux.RUnlock()
|
||||
|
||||
for _, h := range h.hosts {
|
||||
if h.Hostname == host {
|
||||
ip = h.IP
|
||||
@ -64,6 +73,7 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
|
||||
|
||||
// Reload parses config from r, then live reloads the hosts.
|
||||
func (h *Hosts) Reload(r io.Reader) error {
|
||||
var period time.Duration
|
||||
var hosts []Host
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
@ -89,7 +99,7 @@ func (h *Hosts) Reload(r io.Reader) error {
|
||||
|
||||
// reload option
|
||||
if strings.ToLower(ss[0]) == "reload" {
|
||||
h.period, _ = time.ParseDuration(ss[1])
|
||||
period, _ = time.ParseDuration(ss[1])
|
||||
continue
|
||||
}
|
||||
|
||||
@ -110,11 +120,18 @@ func (h *Hosts) Reload(r io.Reader) error {
|
||||
return err
|
||||
}
|
||||
|
||||
h.mux.Lock()
|
||||
h.period = period
|
||||
h.hosts = hosts
|
||||
h.mux.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Period returns the reload period
|
||||
func (h *Hosts) Period() time.Duration {
|
||||
h.mux.RLock()
|
||||
defer h.mux.RUnlock()
|
||||
|
||||
return h.period
|
||||
}
|
||||
|
2
http2.go
2
http2.go
@ -321,7 +321,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
|
||||
if Debug && (u != "" || p != "") {
|
||||
log.Logf("[http] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
|
||||
log.Logf("[http2] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
|
||||
}
|
||||
if !authenticate(u, p, h.options.Users...) {
|
||||
// probing resistance is enabled
|
||||
|
140
node.go
140
node.go
@ -5,6 +5,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
@ -88,41 +89,6 @@ func ParseNode(s string) (node Node, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// MarkDead marks the node fail status.
|
||||
func (node *Node) MarkDead() {
|
||||
atomic.AddUint32(&node.failCount, 1)
|
||||
atomic.StoreInt64(&node.failTime, time.Now().Unix())
|
||||
|
||||
if node.group == nil {
|
||||
return
|
||||
}
|
||||
for i := range node.group.nodes {
|
||||
if node.group.nodes[i].ID == node.ID {
|
||||
atomic.AddUint32(&node.group.nodes[i].failCount, 1)
|
||||
atomic.StoreInt64(&node.group.nodes[i].failTime, time.Now().Unix())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ResetDead resets the node fail status.
|
||||
func (node *Node) ResetDead() {
|
||||
atomic.StoreUint32(&node.failCount, 0)
|
||||
atomic.StoreInt64(&node.failTime, 0)
|
||||
|
||||
if node.group == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range node.group.nodes {
|
||||
if node.group.nodes[i].ID == node.ID {
|
||||
atomic.StoreUint32(&node.group.nodes[i].failCount, 0)
|
||||
atomic.StoreInt64(&node.group.nodes[i].failTime, 0)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clone clones the node, it will prevent data race.
|
||||
func (node *Node) Clone() Node {
|
||||
return Node{
|
||||
@ -167,10 +133,11 @@ func (node *Node) String() string {
|
||||
|
||||
// NodeGroup is a group of nodes.
|
||||
type NodeGroup struct {
|
||||
ID int
|
||||
nodes []Node
|
||||
Options []SelectOption
|
||||
Selector NodeSelector
|
||||
ID int
|
||||
nodes []Node
|
||||
selectorOptions []SelectOption
|
||||
selector NodeSelector
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewNodeGroup creates a node group
|
||||
@ -185,11 +152,21 @@ func (group *NodeGroup) AddNode(node ...Node) {
|
||||
if group == nil {
|
||||
return
|
||||
}
|
||||
group.mux.Lock()
|
||||
defer group.mux.Unlock()
|
||||
|
||||
group.nodes = append(group.nodes, node...)
|
||||
}
|
||||
|
||||
// SetNodes replaces the group nodes to the specified nodes.
|
||||
func (group *NodeGroup) SetNodes(nodes ...Node) {
|
||||
if group == nil {
|
||||
return
|
||||
}
|
||||
|
||||
group.mux.Lock()
|
||||
defer group.mux.Unlock()
|
||||
|
||||
group.nodes = nodes
|
||||
}
|
||||
|
||||
@ -198,27 +175,100 @@ func (group *NodeGroup) SetSelector(selector NodeSelector, opts ...SelectOption)
|
||||
if group == nil {
|
||||
return
|
||||
}
|
||||
group.Selector = selector
|
||||
group.Options = opts
|
||||
group.mux.Lock()
|
||||
defer group.mux.Unlock()
|
||||
|
||||
group.selector = selector
|
||||
group.selectorOptions = opts
|
||||
}
|
||||
|
||||
// Nodes returns node list in the group
|
||||
// Nodes returns the node list in the group
|
||||
func (group *NodeGroup) Nodes() []Node {
|
||||
if group == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
group.mux.RLock()
|
||||
defer group.mux.RUnlock()
|
||||
|
||||
return group.nodes
|
||||
}
|
||||
|
||||
// Next selects the next node from group.
|
||||
func (group *NodeGroup) copyNodes() []Node {
|
||||
group.mux.RLock()
|
||||
defer group.mux.RUnlock()
|
||||
|
||||
var nodes []Node
|
||||
for i := range group.nodes {
|
||||
nodes = append(nodes, group.nodes[i])
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// GetNode returns a copy of the node specified by index in the group.
|
||||
func (group *NodeGroup) GetNode(i int) Node {
|
||||
group.mux.RLock()
|
||||
defer group.mux.RUnlock()
|
||||
|
||||
if i < 0 || group == nil || len(group.nodes) <= i {
|
||||
return Node{}
|
||||
}
|
||||
return group.nodes[i].Clone()
|
||||
}
|
||||
|
||||
// MarkDeadNode marks the node with ID nid status to dead.
|
||||
func (group *NodeGroup) MarkDeadNode(nid int) {
|
||||
group.mux.RLock()
|
||||
defer group.mux.RUnlock()
|
||||
|
||||
if group == nil || nid <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range group.nodes {
|
||||
if group.nodes[i].ID == nid {
|
||||
atomic.AddUint32(&group.nodes[i].failCount, 1)
|
||||
atomic.StoreInt64(&group.nodes[i].failTime, time.Now().Unix())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ResetDeadNode resets the node with ID nid status.
|
||||
func (group *NodeGroup) ResetDeadNode(nid int) {
|
||||
group.mux.RLock()
|
||||
defer group.mux.RUnlock()
|
||||
|
||||
if group == nil || nid <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range group.nodes {
|
||||
if group.nodes[i].ID == nid {
|
||||
atomic.StoreUint32(&group.nodes[i].failCount, 0)
|
||||
atomic.StoreInt64(&group.nodes[i].failTime, 0)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Next selects a node from group.
|
||||
// It also selects IP if the IP list exists.
|
||||
func (group *NodeGroup) Next() (node Node, err error) {
|
||||
selector := group.Selector
|
||||
if group == nil {
|
||||
return
|
||||
}
|
||||
|
||||
group.mux.RLock()
|
||||
defer group.mux.RUnlock()
|
||||
|
||||
selector := group.selector
|
||||
if selector == nil {
|
||||
selector = &defaultSelector{}
|
||||
}
|
||||
|
||||
// select node from node group
|
||||
node, err = selector.Select(group.Nodes(), group.Options...)
|
||||
node, err = selector.Select(group.nodes, group.selectorOptions...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ func PeriodReload(r Reloader, configFile string) error {
|
||||
|
||||
finfo, err := f.Stat()
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return err
|
||||
}
|
||||
mt := finfo.ModTime()
|
||||
|
68
resolver.go
68
resolver.go
@ -68,6 +68,7 @@ type resolver struct {
|
||||
TTL time.Duration
|
||||
period time.Duration
|
||||
domain string
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewResolver create a new Resolver with the given name servers and resolution timeout.
|
||||
@ -78,17 +79,23 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
|
||||
TTL: ttl,
|
||||
mCache: &sync.Map{},
|
||||
}
|
||||
r.init()
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *resolver) init() {
|
||||
if r.Timeout <= 0 {
|
||||
r.Timeout = DefaultResolverTimeout
|
||||
}
|
||||
if r.TTL == 0 {
|
||||
r.TTL = DefaultResolverTTL
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *resolver) copyServers() []NameServer {
|
||||
var servers []NameServer
|
||||
for i := range r.Servers {
|
||||
servers = append(servers, r.Servers[i])
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
@ -96,14 +103,24 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
var domain string
|
||||
var timeout, ttl time.Duration
|
||||
var servers []NameServer
|
||||
|
||||
r.mux.RLock()
|
||||
domain = r.domain
|
||||
timeout = r.Timeout
|
||||
servers = r.copyServers()
|
||||
r.mux.RUnlock()
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return []net.IP{ip}, nil
|
||||
}
|
||||
|
||||
if !strings.Contains(host, ".") && r.domain != "" {
|
||||
host = host + "." + r.domain
|
||||
if !strings.Contains(host, ".") && domain != "" {
|
||||
host = host + "." + domain
|
||||
}
|
||||
ips = r.loadCache(host)
|
||||
ips = r.loadCache(host, ttl)
|
||||
if len(ips) > 0 {
|
||||
if Debug {
|
||||
log.Logf("[resolver] cache hit %s: %v", host, ips)
|
||||
@ -111,8 +128,8 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ns := range r.Servers {
|
||||
ips, err = r.resolve(ns, host)
|
||||
for _, ns := range servers {
|
||||
ips, err = r.resolve(ns, host, timeout)
|
||||
if err != nil {
|
||||
log.Logf("[resolver] %s via %s : %s", host, ns, err)
|
||||
continue
|
||||
@ -130,14 +147,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error) {
|
||||
func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) {
|
||||
addr := ns.Addr
|
||||
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
||||
addr = net.JoinHostPort(addr, "53")
|
||||
}
|
||||
|
||||
client := dns.Client{
|
||||
Timeout: r.Timeout,
|
||||
Timeout: timeout,
|
||||
}
|
||||
switch strings.ToLower(ns.Protocol) {
|
||||
case "tcp":
|
||||
@ -171,8 +188,7 @@ func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *resolver) loadCache(name string) []net.IP {
|
||||
ttl := r.TTL
|
||||
func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
|
||||
if ttl < 0 {
|
||||
return nil
|
||||
}
|
||||
@ -189,8 +205,7 @@ func (r *resolver) loadCache(name string) []net.IP {
|
||||
}
|
||||
|
||||
func (r *resolver) storeCache(name string, ips []net.IP) {
|
||||
ttl := r.TTL
|
||||
if ttl < 0 || name == "" || len(ips) == 0 {
|
||||
if name == "" || len(ips) == 0 {
|
||||
return
|
||||
}
|
||||
r.mCache.Store(name, &resolverCacheItem{
|
||||
@ -200,6 +215,8 @@ func (r *resolver) storeCache(name string, ips []net.IP) {
|
||||
}
|
||||
|
||||
func (r *resolver) Reload(rd io.Reader) error {
|
||||
var ttl, timeout, period time.Duration
|
||||
var domain string
|
||||
var nss []NameServer
|
||||
|
||||
split := func(line string) []string {
|
||||
@ -232,19 +249,19 @@ func (r *resolver) Reload(rd io.Reader) error {
|
||||
switch ss[0] {
|
||||
case "timeout": // timeout option
|
||||
if len(ss) > 1 {
|
||||
r.Timeout, _ = time.ParseDuration(ss[1])
|
||||
timeout, _ = time.ParseDuration(ss[1])
|
||||
}
|
||||
case "ttl": // ttl option
|
||||
if len(ss) > 1 {
|
||||
r.TTL, _ = time.ParseDuration(ss[1])
|
||||
ttl, _ = time.ParseDuration(ss[1])
|
||||
}
|
||||
case "reload": // reload option
|
||||
if len(ss) > 1 {
|
||||
r.period, _ = time.ParseDuration(ss[1])
|
||||
period, _ = time.ParseDuration(ss[1])
|
||||
}
|
||||
case "domain":
|
||||
if len(ss) > 1 {
|
||||
r.domain = ss[1]
|
||||
domain = ss[1]
|
||||
}
|
||||
case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf
|
||||
case "nameserver": // nameserver option, compatible with /etc/resolv.conf
|
||||
@ -276,11 +293,21 @@ func (r *resolver) Reload(rd io.Reader) error {
|
||||
return err
|
||||
}
|
||||
|
||||
r.mux.Lock()
|
||||
r.Timeout = timeout
|
||||
r.TTL = ttl
|
||||
r.domain = domain
|
||||
r.period = period
|
||||
r.Servers = nss
|
||||
r.mux.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolver) Period() time.Duration {
|
||||
r.mux.RLock()
|
||||
defer r.mux.RUnlock()
|
||||
|
||||
return r.period
|
||||
}
|
||||
|
||||
@ -289,6 +316,9 @@ func (r *resolver) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
r.mux.RLock()
|
||||
defer r.mux.RUnlock()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
|
||||
fmt.Fprintf(b, "TTL %v\n", r.TTL)
|
||||
|
@ -94,6 +94,7 @@ type RandomStrategy struct {
|
||||
Seed int64
|
||||
rand *rand.Rand
|
||||
once sync.Once
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// Apply applies the random strategy for the nodes.
|
||||
@ -109,7 +110,11 @@ func (s *RandomStrategy) Apply(nodes []Node) Node {
|
||||
return Node{}
|
||||
}
|
||||
|
||||
return nodes[s.rand.Int()%len(nodes)]
|
||||
s.mux.Lock()
|
||||
r := s.rand.Int()
|
||||
s.mux.Unlock()
|
||||
|
||||
return nodes[r%len(nodes)]
|
||||
}
|
||||
|
||||
func (s *RandomStrategy) String() string {
|
||||
|
14
server.go
14
server.go
@ -72,11 +72,13 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error {
|
||||
}
|
||||
tempDelay = 0
|
||||
|
||||
if s.options.Bypass.Contains(conn.RemoteAddr().String()) {
|
||||
log.Log("[bypass]", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
/*
|
||||
if s.options.Bypass.Contains(conn.RemoteAddr().String()) {
|
||||
log.Log("[bypass]", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
*/
|
||||
|
||||
go h.Handle(conn)
|
||||
}
|
||||
@ -90,12 +92,14 @@ type ServerOptions struct {
|
||||
// ServerOption allows a common way to set server options.
|
||||
type ServerOption func(opts *ServerOptions)
|
||||
|
||||
/*
|
||||
// BypassServerOption sets the bypass option of ServerOptions.
|
||||
func BypassServerOption(bypass *Bypass) ServerOption {
|
||||
return func(opts *ServerOptions) {
|
||||
opts.Bypass = bypass
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// Listener is a proxy server listener, just like a net.Listener.
|
||||
type Listener interface {
|
||||
|
Loading…
Reference in New Issue
Block a user