fix data race
This commit is contained in:
parent
e9b872c4cf
commit
a020c7bc33
34
bypass.go
34
bypass.go
@ -124,7 +124,7 @@ type Bypass struct {
|
|||||||
matchers []Matcher
|
matchers []Matcher
|
||||||
reversed bool
|
reversed bool
|
||||||
period time.Duration // the period for live reloading
|
period time.Duration // the period for live reloading
|
||||||
mux sync.Mutex
|
mux sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBypass creates and initializes a new Bypass using matchers as its match rules.
|
// NewBypass creates and initializes a new Bypass using matchers as its match rules.
|
||||||
@ -160,8 +160,8 @@ func (bp *Bypass) Contains(addr string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bp.mux.Lock()
|
bp.mux.RLock()
|
||||||
defer bp.mux.Unlock()
|
defer bp.mux.RUnlock()
|
||||||
|
|
||||||
var matched bool
|
var matched bool
|
||||||
for _, matcher := range bp.matchers {
|
for _, matcher := range bp.matchers {
|
||||||
@ -179,22 +179,33 @@ func (bp *Bypass) Contains(addr string) bool {
|
|||||||
|
|
||||||
// AddMatchers appends matchers to the bypass matcher list.
|
// AddMatchers appends matchers to the bypass matcher list.
|
||||||
func (bp *Bypass) AddMatchers(matchers ...Matcher) {
|
func (bp *Bypass) AddMatchers(matchers ...Matcher) {
|
||||||
|
bp.mux.Lock()
|
||||||
|
defer bp.mux.Unlock()
|
||||||
|
|
||||||
bp.matchers = append(bp.matchers, matchers...)
|
bp.matchers = append(bp.matchers, matchers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matchers return the bypass matcher list.
|
// Matchers return the bypass matcher list.
|
||||||
func (bp *Bypass) Matchers() []Matcher {
|
func (bp *Bypass) Matchers() []Matcher {
|
||||||
|
bp.mux.RLock()
|
||||||
|
defer bp.mux.RUnlock()
|
||||||
|
|
||||||
return bp.matchers
|
return bp.matchers
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reversed reports whether the rules of the bypass are reversed.
|
// Reversed reports whether the rules of the bypass are reversed.
|
||||||
func (bp *Bypass) Reversed() bool {
|
func (bp *Bypass) Reversed() bool {
|
||||||
|
bp.mux.RLock()
|
||||||
|
defer bp.mux.RUnlock()
|
||||||
|
|
||||||
return bp.reversed
|
return bp.reversed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reload parses config from r, then live reloads the bypass.
|
// Reload parses config from r, then live reloads the bypass.
|
||||||
func (bp *Bypass) Reload(r io.Reader) error {
|
func (bp *Bypass) Reload(r io.Reader) error {
|
||||||
var matchers []Matcher
|
var matchers []Matcher
|
||||||
|
var period time.Duration
|
||||||
|
var reversed bool
|
||||||
|
|
||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
@ -217,7 +228,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(ss) == 2 {
|
if len(ss) == 2 {
|
||||||
bp.period, _ = time.ParseDuration(ss[1])
|
period, _ = time.ParseDuration(ss[1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -231,7 +242,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(ss) == 2 {
|
if len(ss) == 2 {
|
||||||
bp.reversed, _ = strconv.ParseBool(ss[1])
|
reversed, _ = strconv.ParseBool(ss[1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -247,19 +258,28 @@ func (bp *Bypass) Reload(r io.Reader) error {
|
|||||||
defer bp.mux.Unlock()
|
defer bp.mux.Unlock()
|
||||||
|
|
||||||
bp.matchers = matchers
|
bp.matchers = matchers
|
||||||
|
bp.period = period
|
||||||
|
bp.reversed = reversed
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Period returns the reload period
|
// Period returns the reload period
|
||||||
func (bp *Bypass) Period() time.Duration {
|
func (bp *Bypass) Period() time.Duration {
|
||||||
|
bp.mux.RLock()
|
||||||
|
defer bp.mux.RUnlock()
|
||||||
|
|
||||||
return bp.period
|
return bp.period
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bp *Bypass) String() string {
|
func (bp *Bypass) String() string {
|
||||||
|
bp.mux.RLock()
|
||||||
|
defer bp.mux.RUnlock()
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
fmt.Fprintf(b, "reversed: %v\n", bp.Reversed())
|
fmt.Fprintf(b, "reversed: %v\n", bp.reversed)
|
||||||
for _, m := range bp.Matchers() {
|
fmt.Fprintf(b, "reload: %v\n", bp.period)
|
||||||
|
for _, m := range bp.matchers {
|
||||||
b.WriteString(m.String())
|
b.WriteString(m.String())
|
||||||
b.WriteByte('\n')
|
b.WriteByte('\n')
|
||||||
}
|
}
|
||||||
|
18
chain.go
18
chain.go
@ -38,7 +38,7 @@ func newRoute(nodes ...Node) *Chain {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Nodes returns the proxy nodes that the chain holds.
|
// Nodes returns the proxy nodes that the chain holds.
|
||||||
// If a node is a node group, the first node in the group will be returned.
|
// The first node in each group will be returned.
|
||||||
func (c *Chain) Nodes() (nodes []Node) {
|
func (c *Chain) Nodes() (nodes []Node) {
|
||||||
for _, group := range c.nodeGroups {
|
for _, group := range c.nodeGroups {
|
||||||
if ns := group.Nodes(); len(ns) > 0 {
|
if ns := group.Nodes(); len(ns) > 0 {
|
||||||
@ -61,7 +61,7 @@ func (c *Chain) LastNode() Node {
|
|||||||
return Node{}
|
return Node{}
|
||||||
}
|
}
|
||||||
group := c.nodeGroups[len(c.nodeGroups)-1]
|
group := c.nodeGroups[len(c.nodeGroups)-1]
|
||||||
return group.nodes[0].Clone()
|
return group.GetNode(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LastNodeGroup returns the last group of the group list.
|
// LastNodeGroup returns the last group of the group list.
|
||||||
@ -173,7 +173,6 @@ func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Conn obtains a handshaked connection to the last node of the chain.
|
// Conn obtains a handshaked connection to the last node of the chain.
|
||||||
// If the chain is empty, it returns an ErrEmptyChain error.
|
|
||||||
func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
|
func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
|
||||||
options := &ChainOptions{}
|
options := &ChainOptions{}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
@ -206,6 +205,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getConn obtains a connection to the last node of the chain.
|
// getConn obtains a connection to the last node of the chain.
|
||||||
|
// It does not handshake with the last node.
|
||||||
func (c *Chain) getConn() (conn net.Conn, err error) {
|
func (c *Chain) getConn() (conn net.Conn, err error) {
|
||||||
if c.IsEmpty() {
|
if c.IsEmpty() {
|
||||||
err = ErrEmptyChain
|
err = ErrEmptyChain
|
||||||
@ -216,16 +216,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
|
|||||||
|
|
||||||
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
|
cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
node.ResetDead()
|
node.group.ResetDeadNode(node.ID)
|
||||||
|
|
||||||
preNode := node
|
preNode := node
|
||||||
for _, node := range nodes[1:] {
|
for _, node := range nodes[1:] {
|
||||||
@ -233,16 +233,16 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
|
|||||||
cc, err = preNode.Client.Connect(cn, node.Addr)
|
cc, err = preNode.Client.Connect(cn, node.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cn.Close()
|
cn.Close()
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
|
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cn.Close()
|
cn.Close()
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
node.ResetDead()
|
node.group.ResetDeadNode(node.ID)
|
||||||
|
|
||||||
cn = cc
|
cn = cc
|
||||||
preNode = node
|
preNode = node
|
||||||
|
@ -60,7 +60,8 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
|
|||||||
strategy = s
|
strategy = s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
group.Options = append([]gost.SelectOption{},
|
group.SetSelector(
|
||||||
|
nil,
|
||||||
gost.WithFilter(&gost.FailFilter{
|
gost.WithFilter(&gost.FailFilter{
|
||||||
MaxFails: cfg.MaxFails,
|
MaxFails: cfg.MaxFails,
|
||||||
FailTimeout: time.Duration(cfg.FailTimeout) * time.Second,
|
FailTimeout: time.Duration(cfg.FailTimeout) * time.Second,
|
||||||
|
20
forward.go
20
forward.go
@ -107,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
|
|||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
|
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
} else {
|
} else {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -116,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
node.ResetDead()
|
node.group.ResetDeadNode(node.ID)
|
||||||
defer cc.Close()
|
defer cc.Close()
|
||||||
|
|
||||||
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
|
log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
|
||||||
@ -191,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
|
|||||||
if h.options.Chain.IsEmpty() {
|
if h.options.Chain.IsEmpty() {
|
||||||
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
|
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
|
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cc, err = net.DialUDP("udp", nil, raddr)
|
cc, err = net.DialUDP("udp", nil, raddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
|
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -212,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer cc.Close()
|
defer cc.Close()
|
||||||
node.ResetDead()
|
node.group.ResetDeadNode(node.ID)
|
||||||
|
|
||||||
log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
|
log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
|
||||||
transport(conn, cc)
|
transport(conn, cc)
|
||||||
@ -291,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
|
|||||||
cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout)
|
cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
|
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
} else {
|
} else {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -301,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer cc.Close()
|
defer cc.Close()
|
||||||
node.ResetDead()
|
node.group.ResetDeadNode(node.ID)
|
||||||
|
|
||||||
log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
|
log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
|
||||||
transport(cc, conn)
|
transport(cc, conn)
|
||||||
@ -369,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
|
|||||||
|
|
||||||
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
|
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
|
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cc, err := net.DialUDP("udp", nil, raddr)
|
cc, err := net.DialUDP("udp", nil, raddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
node.MarkDead()
|
node.group.MarkDeadNode(node.ID)
|
||||||
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
|
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer cc.Close()
|
defer cc.Close()
|
||||||
node.ResetDead()
|
node.group.ResetDeadNode(node.ID)
|
||||||
|
|
||||||
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
|
log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
|
||||||
transport(conn, cc)
|
transport(conn, cc)
|
||||||
|
19
hosts.go
19
hosts.go
@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-log/log"
|
"github.com/go-log/log"
|
||||||
@ -25,6 +26,7 @@ type Host struct {
|
|||||||
type Hosts struct {
|
type Hosts struct {
|
||||||
hosts []Host
|
hosts []Host
|
||||||
period time.Duration
|
period time.Duration
|
||||||
|
mux sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHosts creates a Hosts with optional list of host
|
// NewHosts creates a Hosts with optional list of host
|
||||||
@ -36,6 +38,9 @@ func NewHosts(hosts ...Host) *Hosts {
|
|||||||
|
|
||||||
// AddHost adds host(s) to the host table.
|
// AddHost adds host(s) to the host table.
|
||||||
func (h *Hosts) AddHost(host ...Host) {
|
func (h *Hosts) AddHost(host ...Host) {
|
||||||
|
h.mux.Lock()
|
||||||
|
defer h.mux.Unlock()
|
||||||
|
|
||||||
h.hosts = append(h.hosts, host...)
|
h.hosts = append(h.hosts, host...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,6 +49,10 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
|
|||||||
if h == nil {
|
if h == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.mux.RLock()
|
||||||
|
defer h.mux.RUnlock()
|
||||||
|
|
||||||
for _, h := range h.hosts {
|
for _, h := range h.hosts {
|
||||||
if h.Hostname == host {
|
if h.Hostname == host {
|
||||||
ip = h.IP
|
ip = h.IP
|
||||||
@ -64,6 +73,7 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
|
|||||||
|
|
||||||
// Reload parses config from r, then live reloads the hosts.
|
// Reload parses config from r, then live reloads the hosts.
|
||||||
func (h *Hosts) Reload(r io.Reader) error {
|
func (h *Hosts) Reload(r io.Reader) error {
|
||||||
|
var period time.Duration
|
||||||
var hosts []Host
|
var hosts []Host
|
||||||
|
|
||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
@ -89,7 +99,7 @@ func (h *Hosts) Reload(r io.Reader) error {
|
|||||||
|
|
||||||
// reload option
|
// reload option
|
||||||
if strings.ToLower(ss[0]) == "reload" {
|
if strings.ToLower(ss[0]) == "reload" {
|
||||||
h.period, _ = time.ParseDuration(ss[1])
|
period, _ = time.ParseDuration(ss[1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,11 +120,18 @@ func (h *Hosts) Reload(r io.Reader) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.mux.Lock()
|
||||||
|
h.period = period
|
||||||
h.hosts = hosts
|
h.hosts = hosts
|
||||||
|
h.mux.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Period returns the reload period
|
// Period returns the reload period
|
||||||
func (h *Hosts) Period() time.Duration {
|
func (h *Hosts) Period() time.Duration {
|
||||||
|
h.mux.RLock()
|
||||||
|
defer h.mux.RUnlock()
|
||||||
|
|
||||||
return h.period
|
return h.period
|
||||||
}
|
}
|
||||||
|
2
http2.go
2
http2.go
@ -321,7 +321,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
|
u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
|
||||||
if Debug && (u != "" || p != "") {
|
if Debug && (u != "" || p != "") {
|
||||||
log.Logf("[http] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
|
log.Logf("[http2] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
|
||||||
}
|
}
|
||||||
if !authenticate(u, p, h.options.Users...) {
|
if !authenticate(u, p, h.options.Users...) {
|
||||||
// probing resistance is enabled
|
// probing resistance is enabled
|
||||||
|
140
node.go
140
node.go
@ -5,6 +5,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -88,41 +89,6 @@ func ParseNode(s string) (node Node, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkDead marks the node fail status.
|
|
||||||
func (node *Node) MarkDead() {
|
|
||||||
atomic.AddUint32(&node.failCount, 1)
|
|
||||||
atomic.StoreInt64(&node.failTime, time.Now().Unix())
|
|
||||||
|
|
||||||
if node.group == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for i := range node.group.nodes {
|
|
||||||
if node.group.nodes[i].ID == node.ID {
|
|
||||||
atomic.AddUint32(&node.group.nodes[i].failCount, 1)
|
|
||||||
atomic.StoreInt64(&node.group.nodes[i].failTime, time.Now().Unix())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetDead resets the node fail status.
|
|
||||||
func (node *Node) ResetDead() {
|
|
||||||
atomic.StoreUint32(&node.failCount, 0)
|
|
||||||
atomic.StoreInt64(&node.failTime, 0)
|
|
||||||
|
|
||||||
if node.group == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range node.group.nodes {
|
|
||||||
if node.group.nodes[i].ID == node.ID {
|
|
||||||
atomic.StoreUint32(&node.group.nodes[i].failCount, 0)
|
|
||||||
atomic.StoreInt64(&node.group.nodes[i].failTime, 0)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clone clones the node, it will prevent data race.
|
// Clone clones the node, it will prevent data race.
|
||||||
func (node *Node) Clone() Node {
|
func (node *Node) Clone() Node {
|
||||||
return Node{
|
return Node{
|
||||||
@ -167,10 +133,11 @@ func (node *Node) String() string {
|
|||||||
|
|
||||||
// NodeGroup is a group of nodes.
|
// NodeGroup is a group of nodes.
|
||||||
type NodeGroup struct {
|
type NodeGroup struct {
|
||||||
ID int
|
ID int
|
||||||
nodes []Node
|
nodes []Node
|
||||||
Options []SelectOption
|
selectorOptions []SelectOption
|
||||||
Selector NodeSelector
|
selector NodeSelector
|
||||||
|
mux sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNodeGroup creates a node group
|
// NewNodeGroup creates a node group
|
||||||
@ -185,11 +152,21 @@ func (group *NodeGroup) AddNode(node ...Node) {
|
|||||||
if group == nil {
|
if group == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
group.mux.Lock()
|
||||||
|
defer group.mux.Unlock()
|
||||||
|
|
||||||
group.nodes = append(group.nodes, node...)
|
group.nodes = append(group.nodes, node...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNodes replaces the group nodes to the specified nodes.
|
// SetNodes replaces the group nodes to the specified nodes.
|
||||||
func (group *NodeGroup) SetNodes(nodes ...Node) {
|
func (group *NodeGroup) SetNodes(nodes ...Node) {
|
||||||
|
if group == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
group.mux.Lock()
|
||||||
|
defer group.mux.Unlock()
|
||||||
|
|
||||||
group.nodes = nodes
|
group.nodes = nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,27 +175,100 @@ func (group *NodeGroup) SetSelector(selector NodeSelector, opts ...SelectOption)
|
|||||||
if group == nil {
|
if group == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
group.Selector = selector
|
group.mux.Lock()
|
||||||
group.Options = opts
|
defer group.mux.Unlock()
|
||||||
|
|
||||||
|
group.selector = selector
|
||||||
|
group.selectorOptions = opts
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nodes returns node list in the group
|
// Nodes returns the node list in the group
|
||||||
func (group *NodeGroup) Nodes() []Node {
|
func (group *NodeGroup) Nodes() []Node {
|
||||||
if group == nil {
|
if group == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
group.mux.RLock()
|
||||||
|
defer group.mux.RUnlock()
|
||||||
|
|
||||||
return group.nodes
|
return group.nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next selects the next node from group.
|
func (group *NodeGroup) copyNodes() []Node {
|
||||||
|
group.mux.RLock()
|
||||||
|
defer group.mux.RUnlock()
|
||||||
|
|
||||||
|
var nodes []Node
|
||||||
|
for i := range group.nodes {
|
||||||
|
nodes = append(nodes, group.nodes[i])
|
||||||
|
}
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNode returns a copy of the node specified by index in the group.
|
||||||
|
func (group *NodeGroup) GetNode(i int) Node {
|
||||||
|
group.mux.RLock()
|
||||||
|
defer group.mux.RUnlock()
|
||||||
|
|
||||||
|
if i < 0 || group == nil || len(group.nodes) <= i {
|
||||||
|
return Node{}
|
||||||
|
}
|
||||||
|
return group.nodes[i].Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkDeadNode marks the node with ID nid status to dead.
|
||||||
|
func (group *NodeGroup) MarkDeadNode(nid int) {
|
||||||
|
group.mux.RLock()
|
||||||
|
defer group.mux.RUnlock()
|
||||||
|
|
||||||
|
if group == nil || nid <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range group.nodes {
|
||||||
|
if group.nodes[i].ID == nid {
|
||||||
|
atomic.AddUint32(&group.nodes[i].failCount, 1)
|
||||||
|
atomic.StoreInt64(&group.nodes[i].failTime, time.Now().Unix())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDeadNode resets the node with ID nid status.
|
||||||
|
func (group *NodeGroup) ResetDeadNode(nid int) {
|
||||||
|
group.mux.RLock()
|
||||||
|
defer group.mux.RUnlock()
|
||||||
|
|
||||||
|
if group == nil || nid <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range group.nodes {
|
||||||
|
if group.nodes[i].ID == nid {
|
||||||
|
atomic.StoreUint32(&group.nodes[i].failCount, 0)
|
||||||
|
atomic.StoreInt64(&group.nodes[i].failTime, 0)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next selects a node from group.
|
||||||
// It also selects IP if the IP list exists.
|
// It also selects IP if the IP list exists.
|
||||||
func (group *NodeGroup) Next() (node Node, err error) {
|
func (group *NodeGroup) Next() (node Node, err error) {
|
||||||
selector := group.Selector
|
if group == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
group.mux.RLock()
|
||||||
|
defer group.mux.RUnlock()
|
||||||
|
|
||||||
|
selector := group.selector
|
||||||
if selector == nil {
|
if selector == nil {
|
||||||
selector = &defaultSelector{}
|
selector = &defaultSelector{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// select node from node group
|
// select node from node group
|
||||||
node, err = selector.Select(group.Nodes(), group.Options...)
|
node, err = selector.Select(group.nodes, group.selectorOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,7 @@ func PeriodReload(r Reloader, configFile string) error {
|
|||||||
|
|
||||||
finfo, err := f.Stat()
|
finfo, err := f.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mt := finfo.ModTime()
|
mt := finfo.ModTime()
|
||||||
|
68
resolver.go
68
resolver.go
@ -68,6 +68,7 @@ type resolver struct {
|
|||||||
TTL time.Duration
|
TTL time.Duration
|
||||||
period time.Duration
|
period time.Duration
|
||||||
domain string
|
domain string
|
||||||
|
mux sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver create a new Resolver with the given name servers and resolution timeout.
|
// NewResolver create a new Resolver with the given name servers and resolution timeout.
|
||||||
@ -78,17 +79,23 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
|
|||||||
TTL: ttl,
|
TTL: ttl,
|
||||||
mCache: &sync.Map{},
|
mCache: &sync.Map{},
|
||||||
}
|
}
|
||||||
r.init()
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *resolver) init() {
|
|
||||||
if r.Timeout <= 0 {
|
if r.Timeout <= 0 {
|
||||||
r.Timeout = DefaultResolverTimeout
|
r.Timeout = DefaultResolverTimeout
|
||||||
}
|
}
|
||||||
if r.TTL == 0 {
|
if r.TTL == 0 {
|
||||||
r.TTL = DefaultResolverTTL
|
r.TTL = DefaultResolverTTL
|
||||||
}
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resolver) copyServers() []NameServer {
|
||||||
|
var servers []NameServer
|
||||||
|
for i := range r.Servers {
|
||||||
|
servers = append(servers, r.Servers[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return servers
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
||||||
@ -96,14 +103,24 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var domain string
|
||||||
|
var timeout, ttl time.Duration
|
||||||
|
var servers []NameServer
|
||||||
|
|
||||||
|
r.mux.RLock()
|
||||||
|
domain = r.domain
|
||||||
|
timeout = r.Timeout
|
||||||
|
servers = r.copyServers()
|
||||||
|
r.mux.RUnlock()
|
||||||
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
return []net.IP{ip}, nil
|
return []net.IP{ip}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(host, ".") && r.domain != "" {
|
if !strings.Contains(host, ".") && domain != "" {
|
||||||
host = host + "." + r.domain
|
host = host + "." + domain
|
||||||
}
|
}
|
||||||
ips = r.loadCache(host)
|
ips = r.loadCache(host, ttl)
|
||||||
if len(ips) > 0 {
|
if len(ips) > 0 {
|
||||||
if Debug {
|
if Debug {
|
||||||
log.Logf("[resolver] cache hit %s: %v", host, ips)
|
log.Logf("[resolver] cache hit %s: %v", host, ips)
|
||||||
@ -111,8 +128,8 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ns := range r.Servers {
|
for _, ns := range servers {
|
||||||
ips, err = r.resolve(ns, host)
|
ips, err = r.resolve(ns, host, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Logf("[resolver] %s via %s : %s", host, ns, err)
|
log.Logf("[resolver] %s via %s : %s", host, ns, err)
|
||||||
continue
|
continue
|
||||||
@ -130,14 +147,14 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error) {
|
func (*resolver) resolve(ns NameServer, host string, timeout time.Duration) (ips []net.IP, err error) {
|
||||||
addr := ns.Addr
|
addr := ns.Addr
|
||||||
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
||||||
addr = net.JoinHostPort(addr, "53")
|
addr = net.JoinHostPort(addr, "53")
|
||||||
}
|
}
|
||||||
|
|
||||||
client := dns.Client{
|
client := dns.Client{
|
||||||
Timeout: r.Timeout,
|
Timeout: timeout,
|
||||||
}
|
}
|
||||||
switch strings.ToLower(ns.Protocol) {
|
switch strings.ToLower(ns.Protocol) {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
@ -171,8 +188,7 @@ func (r *resolver) resolve(ns NameServer, host string) (ips []net.IP, err error)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) loadCache(name string) []net.IP {
|
func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
|
||||||
ttl := r.TTL
|
|
||||||
if ttl < 0 {
|
if ttl < 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -189,8 +205,7 @@ func (r *resolver) loadCache(name string) []net.IP {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) storeCache(name string, ips []net.IP) {
|
func (r *resolver) storeCache(name string, ips []net.IP) {
|
||||||
ttl := r.TTL
|
if name == "" || len(ips) == 0 {
|
||||||
if ttl < 0 || name == "" || len(ips) == 0 {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.mCache.Store(name, &resolverCacheItem{
|
r.mCache.Store(name, &resolverCacheItem{
|
||||||
@ -200,6 +215,8 @@ func (r *resolver) storeCache(name string, ips []net.IP) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) Reload(rd io.Reader) error {
|
func (r *resolver) Reload(rd io.Reader) error {
|
||||||
|
var ttl, timeout, period time.Duration
|
||||||
|
var domain string
|
||||||
var nss []NameServer
|
var nss []NameServer
|
||||||
|
|
||||||
split := func(line string) []string {
|
split := func(line string) []string {
|
||||||
@ -232,19 +249,19 @@ func (r *resolver) Reload(rd io.Reader) error {
|
|||||||
switch ss[0] {
|
switch ss[0] {
|
||||||
case "timeout": // timeout option
|
case "timeout": // timeout option
|
||||||
if len(ss) > 1 {
|
if len(ss) > 1 {
|
||||||
r.Timeout, _ = time.ParseDuration(ss[1])
|
timeout, _ = time.ParseDuration(ss[1])
|
||||||
}
|
}
|
||||||
case "ttl": // ttl option
|
case "ttl": // ttl option
|
||||||
if len(ss) > 1 {
|
if len(ss) > 1 {
|
||||||
r.TTL, _ = time.ParseDuration(ss[1])
|
ttl, _ = time.ParseDuration(ss[1])
|
||||||
}
|
}
|
||||||
case "reload": // reload option
|
case "reload": // reload option
|
||||||
if len(ss) > 1 {
|
if len(ss) > 1 {
|
||||||
r.period, _ = time.ParseDuration(ss[1])
|
period, _ = time.ParseDuration(ss[1])
|
||||||
}
|
}
|
||||||
case "domain":
|
case "domain":
|
||||||
if len(ss) > 1 {
|
if len(ss) > 1 {
|
||||||
r.domain = ss[1]
|
domain = ss[1]
|
||||||
}
|
}
|
||||||
case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf
|
case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf
|
||||||
case "nameserver": // nameserver option, compatible with /etc/resolv.conf
|
case "nameserver": // nameserver option, compatible with /etc/resolv.conf
|
||||||
@ -276,11 +293,21 @@ func (r *resolver) Reload(rd io.Reader) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.mux.Lock()
|
||||||
|
r.Timeout = timeout
|
||||||
|
r.TTL = ttl
|
||||||
|
r.domain = domain
|
||||||
|
r.period = period
|
||||||
r.Servers = nss
|
r.Servers = nss
|
||||||
|
r.mux.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) Period() time.Duration {
|
func (r *resolver) Period() time.Duration {
|
||||||
|
r.mux.RLock()
|
||||||
|
defer r.mux.RUnlock()
|
||||||
|
|
||||||
return r.period
|
return r.period
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,6 +316,9 @@ func (r *resolver) String() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.mux.RLock()
|
||||||
|
defer r.mux.RUnlock()
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
|
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
|
||||||
fmt.Fprintf(b, "TTL %v\n", r.TTL)
|
fmt.Fprintf(b, "TTL %v\n", r.TTL)
|
||||||
|
@ -94,6 +94,7 @@ type RandomStrategy struct {
|
|||||||
Seed int64
|
Seed int64
|
||||||
rand *rand.Rand
|
rand *rand.Rand
|
||||||
once sync.Once
|
once sync.Once
|
||||||
|
mux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply applies the random strategy for the nodes.
|
// Apply applies the random strategy for the nodes.
|
||||||
@ -109,7 +110,11 @@ func (s *RandomStrategy) Apply(nodes []Node) Node {
|
|||||||
return Node{}
|
return Node{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes[s.rand.Int()%len(nodes)]
|
s.mux.Lock()
|
||||||
|
r := s.rand.Int()
|
||||||
|
s.mux.Unlock()
|
||||||
|
|
||||||
|
return nodes[r%len(nodes)]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RandomStrategy) String() string {
|
func (s *RandomStrategy) String() string {
|
||||||
|
14
server.go
14
server.go
@ -72,11 +72,13 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error {
|
|||||||
}
|
}
|
||||||
tempDelay = 0
|
tempDelay = 0
|
||||||
|
|
||||||
if s.options.Bypass.Contains(conn.RemoteAddr().String()) {
|
/*
|
||||||
log.Log("[bypass]", conn.RemoteAddr())
|
if s.options.Bypass.Contains(conn.RemoteAddr().String()) {
|
||||||
conn.Close()
|
log.Log("[bypass]", conn.RemoteAddr())
|
||||||
continue
|
conn.Close()
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
go h.Handle(conn)
|
go h.Handle(conn)
|
||||||
}
|
}
|
||||||
@ -90,12 +92,14 @@ type ServerOptions struct {
|
|||||||
// ServerOption allows a common way to set server options.
|
// ServerOption allows a common way to set server options.
|
||||||
type ServerOption func(opts *ServerOptions)
|
type ServerOption func(opts *ServerOptions)
|
||||||
|
|
||||||
|
/*
|
||||||
// BypassServerOption sets the bypass option of ServerOptions.
|
// BypassServerOption sets the bypass option of ServerOptions.
|
||||||
func BypassServerOption(bypass *Bypass) ServerOption {
|
func BypassServerOption(bypass *Bypass) ServerOption {
|
||||||
return func(opts *ServerOptions) {
|
return func(opts *ServerOptions) {
|
||||||
opts.Bypass = bypass
|
opts.Bypass = bypass
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
// Listener is a proxy server listener, just like a net.Listener.
|
// Listener is a proxy server listener, just like a net.Listener.
|
||||||
type Listener interface {
|
type Listener interface {
|
||||||
|
Loading…
Reference in New Issue
Block a user