add live reloading support for bypass,resolver and hosts
This commit is contained in:
parent
ca3853e8cb
commit
38827782e1
90
bypass.go
90
bypass.go
@ -1,11 +1,15 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
glob "github.com/gobwas/glob"
|
||||
)
|
||||
@ -118,28 +122,30 @@ func (m *domainMatcher) String() string {
|
||||
// It contains a list of matchers.
|
||||
type Bypass struct {
|
||||
matchers []Matcher
|
||||
reverse bool
|
||||
reversed bool
|
||||
period time.Duration // the period for live reloading
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// NewBypass creates and initializes a new Bypass using matchers as its match rules.
|
||||
// The rules will be reversed if the reversed is true.
|
||||
func NewBypass(matchers []Matcher, reverse bool) *Bypass {
|
||||
func NewBypass(reversed bool, matchers ...Matcher) *Bypass {
|
||||
return &Bypass{
|
||||
matchers: matchers,
|
||||
reverse: reverse,
|
||||
reversed: reversed,
|
||||
}
|
||||
}
|
||||
|
||||
// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules.
|
||||
// The rules will be reversed if the reverse is true.
|
||||
func NewBypassPatterns(patterns []string, reverse bool) *Bypass {
|
||||
func NewBypassPatterns(reversed bool, patterns ...string) *Bypass {
|
||||
var matchers []Matcher
|
||||
for _, pattern := range patterns {
|
||||
if pattern != "" {
|
||||
matchers = append(matchers, NewMatcher(pattern))
|
||||
}
|
||||
}
|
||||
return NewBypass(matchers, reverse)
|
||||
return NewBypass(reversed, matchers...)
|
||||
}
|
||||
|
||||
// Contains reports whether the bypass includes addr.
|
||||
@ -153,6 +159,10 @@ func (bp *Bypass) Contains(addr string) bool {
|
||||
addr = host
|
||||
}
|
||||
}
|
||||
|
||||
bp.mux.Lock()
|
||||
defer bp.mux.Unlock()
|
||||
|
||||
var matched bool
|
||||
for _, matcher := range bp.matchers {
|
||||
if matcher == nil {
|
||||
@ -163,8 +173,8 @@ func (bp *Bypass) Contains(addr string) bool {
|
||||
break
|
||||
}
|
||||
}
|
||||
return !bp.reverse && matched ||
|
||||
bp.reverse && !matched
|
||||
return !bp.reversed && matched ||
|
||||
bp.reversed && !matched
|
||||
}
|
||||
|
||||
// AddMatchers appends matchers to the bypass matcher list.
|
||||
@ -179,7 +189,71 @@ func (bp *Bypass) Matchers() []Matcher {
|
||||
|
||||
// Reversed reports whether the rules of the bypass are reversed.
|
||||
func (bp *Bypass) Reversed() bool {
|
||||
return bp.reverse
|
||||
return bp.reversed
|
||||
}
|
||||
|
||||
// Reload parses config from r, then live reloads the bypass.
|
||||
func (bp *Bypass) Reload(r io.Reader) error {
|
||||
var matchers []Matcher
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
||||
line = line[:n]
|
||||
}
|
||||
line = strings.Replace(line, "\t", " ", -1)
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// reload option
|
||||
if strings.HasPrefix(line, "reload ") {
|
||||
var ss []string
|
||||
for _, s := range strings.Split(line, " ") {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
ss = append(ss, s)
|
||||
}
|
||||
}
|
||||
if len(ss) == 2 {
|
||||
bp.period, _ = time.ParseDuration(ss[1])
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// reverse option
|
||||
if strings.HasPrefix(line, "reverse ") {
|
||||
var ss []string
|
||||
for _, s := range strings.Split(line, " ") {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
ss = append(ss, s)
|
||||
}
|
||||
}
|
||||
if len(ss) == 2 {
|
||||
bp.reversed, _ = strconv.ParseBool(ss[1])
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
matchers = append(matchers, NewMatcher(line))
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bp.mux.Lock()
|
||||
defer bp.mux.Unlock()
|
||||
|
||||
bp.matchers = matchers
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Period returns the reload period
|
||||
func (bp *Bypass) Period() time.Duration {
|
||||
return bp.period
|
||||
}
|
||||
|
||||
func (bp *Bypass) String() string {
|
||||
|
@ -1,5 +1,12 @@
|
||||
# period for live reloading
|
||||
reload 10s
|
||||
|
||||
# matcher reversed
|
||||
reverse true
|
||||
|
||||
10.0.0.1
|
||||
192.168.0.0/24
|
||||
172.1.0.0/16
|
||||
192.168.100.190/32
|
||||
*.example.com
|
||||
*.example.com
|
||||
.example.org
|
@ -10,7 +10,6 @@ import (
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -243,22 +242,14 @@ func parseBypass(s string) *gost.Bypass {
|
||||
}
|
||||
matchers = append(matchers, gost.NewMatcher(s))
|
||||
}
|
||||
return gost.NewBypass(matchers, reversed)
|
||||
return gost.NewBypass(reversed, matchers...)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
||||
line = line[:n]
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
matchers = append(matchers, gost.NewMatcher(line))
|
||||
}
|
||||
return gost.NewBypass(matchers, reversed)
|
||||
bp := gost.NewBypass(reversed)
|
||||
go gost.PeriodReload(bp, s)
|
||||
|
||||
return bp
|
||||
}
|
||||
|
||||
func parseResolver(cfg string) gost.Resolver {
|
||||
@ -289,59 +280,12 @@ func parseResolver(cfg string) gost.Resolver {
|
||||
})
|
||||
}
|
||||
}
|
||||
return gost.NewResolver(nss, timeout, ttl)
|
||||
return gost.NewResolver(timeout, ttl, nss...)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
||||
line = line[:n]
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var ss []string
|
||||
for _, s := range strings.Split(line, " ") {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
ss = append(ss, s)
|
||||
}
|
||||
}
|
||||
resolver := gost.NewResolver(timeout, ttl)
|
||||
go gost.PeriodReload(resolver, cfg)
|
||||
|
||||
if len(ss) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.ToLower(ss[0]) == "timeout" {
|
||||
if len(ss) >= 2 {
|
||||
if n, _ := strconv.Atoi(ss[1]); n > 0 {
|
||||
timeout = time.Second * time.Duration(n)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(ss[0]) == "ttl" {
|
||||
if len(ss) >= 2 {
|
||||
n, _ := strconv.Atoi(ss[1])
|
||||
ttl = time.Second * time.Duration(n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
var ns gost.NameServer
|
||||
switch len(ss) {
|
||||
case 1:
|
||||
ns.Addr = ss[0]
|
||||
case 2:
|
||||
ns.Addr = ss[0]
|
||||
ns.Protocol = ss[1]
|
||||
default:
|
||||
ns.Addr = ss[0]
|
||||
ns.Protocol = ss[1]
|
||||
ns.Hostname = ss[2]
|
||||
}
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
return gost.NewResolver(nss, timeout, ttl)
|
||||
return resolver
|
||||
}
|
||||
|
@ -1,8 +1,11 @@
|
||||
# resolver timeout, default 30s.
|
||||
timeout 10
|
||||
timeout 10s
|
||||
|
||||
# resolver cache TTL, default 60s, minus value means that cache is disabled.
|
||||
ttl 300
|
||||
ttl 300s
|
||||
|
||||
# period for live reloading
|
||||
reload 10s
|
||||
|
||||
# ip[:port] [protocol] [hostname]
|
||||
|
||||
|
@ -1,3 +1,6 @@
|
||||
# period for live reloading
|
||||
reload 10s
|
||||
|
||||
# The following lines are desirable for IPv4 capable hosts
|
||||
127.0.0.1 localhost
|
||||
|
||||
@ -11,4 +14,4 @@
|
||||
# The following lines are desirable for IPv6 capable hosts
|
||||
::1 localhost ip6-localhost ip6-loopback
|
||||
ff02::1 ip6-allnodes
|
||||
ff02::2 ip6-allrouters
|
||||
ff02::2 ip6-allrouters
|
||||
|
@ -153,7 +153,7 @@ func (r *route) initChain() (*gost.Chain, error) {
|
||||
var bypass *gost.Bypass
|
||||
// global bypass
|
||||
if peerCfg.Bypass != nil {
|
||||
bypass = gost.NewBypassPatterns(peerCfg.Bypass.Patterns, peerCfg.Bypass.Reverse)
|
||||
bypass = gost.NewBypassPatterns(peerCfg.Bypass.Reverse, peerCfg.Bypass.Patterns...)
|
||||
}
|
||||
nodes = ngroup.Nodes()
|
||||
for i := range nodes {
|
||||
@ -492,10 +492,9 @@ func (r *route) serve() error {
|
||||
|
||||
var hosts *gost.Hosts
|
||||
if f, _ := os.Open(node.Get("hosts")); f != nil {
|
||||
hosts, err = gost.ParseHosts(f)
|
||||
if err != nil {
|
||||
log.Logf("[hosts] %s: %v", f.Name(), err)
|
||||
}
|
||||
f.Close()
|
||||
hosts = gost.NewHosts()
|
||||
go gost.PeriodReload(hosts, node.Get("hosts"))
|
||||
}
|
||||
|
||||
handler.Init(
|
||||
|
112
hosts.go
112
hosts.go
@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-log/log"
|
||||
)
|
||||
@ -17,8 +18,13 @@ type Host struct {
|
||||
}
|
||||
|
||||
// Hosts is a static table lookup for hostnames.
|
||||
// For each host a single line should be present with the following information:
|
||||
// IP_address canonical_hostname [aliases...]
|
||||
// Fields of the entry are separated by any number of blanks and/or tab characters.
|
||||
// Text from a "#" character until the end of the line is a comment, and is ignored.
|
||||
type Hosts struct {
|
||||
hosts []Host
|
||||
hosts []Host
|
||||
period time.Duration
|
||||
}
|
||||
|
||||
// NewHosts creates a Hosts with optional list of host
|
||||
@ -28,53 +34,6 @@ func NewHosts(hosts ...Host) *Hosts {
|
||||
}
|
||||
}
|
||||
|
||||
// ParseHosts parses host table from r.
|
||||
// For each host a single line should be present with the following information:
|
||||
// IP_address canonical_hostname [aliases...]
|
||||
// Fields of the entry are separated by any number of blanks and/or tab characters.
|
||||
// Text from a "#" character until the end of the line is a comment, and is ignored.
|
||||
func ParseHosts(r io.Reader) (*Hosts, error) {
|
||||
hosts := NewHosts()
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
||||
line = line[:n]
|
||||
}
|
||||
line = strings.Replace(line, "\t", " ", -1)
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var ss []string
|
||||
for _, s := range strings.Split(line, " ") {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
ss = append(ss, s)
|
||||
}
|
||||
}
|
||||
if len(ss) < 2 {
|
||||
continue // invalid lines are ignored
|
||||
}
|
||||
ip := net.ParseIP(ss[0])
|
||||
if ip == nil {
|
||||
continue // invalid IP addresses are ignored
|
||||
}
|
||||
host := Host{
|
||||
IP: ip,
|
||||
Hostname: ss[1],
|
||||
}
|
||||
if len(ss) > 2 {
|
||||
host.Aliases = ss[2:]
|
||||
}
|
||||
hosts.AddHost(host)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
// AddHost adds host(s) to the host table.
|
||||
func (h *Hosts) AddHost(host ...Host) {
|
||||
h.hosts = append(h.hosts, host...)
|
||||
@ -102,3 +61,60 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Reload parses config from r, then live reloads the hosts.
|
||||
func (h *Hosts) Reload(r io.Reader) error {
|
||||
var hosts []Host
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
||||
line = line[:n]
|
||||
}
|
||||
line = strings.Replace(line, "\t", " ", -1)
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var ss []string
|
||||
for _, s := range strings.Split(line, " ") {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
ss = append(ss, s)
|
||||
}
|
||||
}
|
||||
if len(ss) < 2 {
|
||||
continue // invalid lines are ignored
|
||||
}
|
||||
|
||||
// reload option
|
||||
if strings.ToLower(ss[0]) == "reload" {
|
||||
h.period, _ = time.ParseDuration(ss[1])
|
||||
continue
|
||||
}
|
||||
|
||||
ip := net.ParseIP(ss[0])
|
||||
if ip == nil {
|
||||
continue // invalid IP addresses are ignored
|
||||
}
|
||||
host := Host{
|
||||
IP: ip,
|
||||
Hostname: ss[1],
|
||||
}
|
||||
if len(ss) > 2 {
|
||||
host.Aliases = ss[2:]
|
||||
}
|
||||
hosts = append(hosts, host)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.hosts = hosts
|
||||
return nil
|
||||
}
|
||||
|
||||
// Period returns the reload period
|
||||
func (h *Hosts) Period() time.Duration {
|
||||
return h.period
|
||||
}
|
||||
|
52
reload.go
Normal file
52
reload.go
Normal file
@ -0,0 +1,52 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-log/log"
|
||||
)
|
||||
|
||||
// Reloader is the interface for objects that support live reloading.
|
||||
type Reloader interface {
|
||||
Reload(r io.Reader) error
|
||||
Period() time.Duration
|
||||
}
|
||||
|
||||
// PeriodReload reloads the config periodically according to the period of the reloader.
|
||||
func PeriodReload(r Reloader, configFile string) error {
|
||||
var lastMod time.Time
|
||||
|
||||
for {
|
||||
f, err := os.Open(configFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finfo, err := f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mt := finfo.ModTime()
|
||||
if !mt.Equal(lastMod) {
|
||||
if Debug {
|
||||
log.Log("[reload]", configFile)
|
||||
}
|
||||
r.Reload(f)
|
||||
lastMod = mt
|
||||
}
|
||||
f.Close()
|
||||
|
||||
period := r.Period()
|
||||
if period <= 0 {
|
||||
log.Log("[reload] disabled:", configFile)
|
||||
return nil
|
||||
}
|
||||
if period < time.Second {
|
||||
period = time.Second
|
||||
}
|
||||
|
||||
<-time.After(period)
|
||||
}
|
||||
}
|
85
resolver.go
85
resolver.go
@ -1,10 +1,12 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -27,6 +29,12 @@ type Resolver interface {
|
||||
Resolve(host string) ([]net.IP, error)
|
||||
}
|
||||
|
||||
// ReloadResolver is resolover that support live reloading
|
||||
type ReloadResolver interface {
|
||||
Resolver
|
||||
Reloader
|
||||
}
|
||||
|
||||
// NameServer is a name server.
|
||||
// Currently supported protocol: TCP, UDP and TLS.
|
||||
type NameServer struct {
|
||||
@ -56,13 +64,14 @@ type resolverCacheItem struct {
|
||||
type resolver struct {
|
||||
Resolver *net.Resolver
|
||||
Servers []NameServer
|
||||
mCache *sync.Map
|
||||
Timeout time.Duration
|
||||
TTL time.Duration
|
||||
mCache *sync.Map
|
||||
period time.Duration
|
||||
}
|
||||
|
||||
// NewResolver create a new Resolver with the given name servers and resolution timeout.
|
||||
func NewResolver(servers []NameServer, timeout, ttl time.Duration) Resolver {
|
||||
func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolver {
|
||||
r := &resolver{
|
||||
Servers: servers,
|
||||
Timeout: timeout,
|
||||
@ -184,6 +193,77 @@ func (r *resolver) storeCache(name string, ips []net.IP) {
|
||||
})
|
||||
}
|
||||
|
||||
func (r *resolver) Reload(rd io.Reader) error {
|
||||
var nss []NameServer
|
||||
|
||||
scanner := bufio.NewScanner(rd)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
||||
line = line[:n]
|
||||
}
|
||||
line = strings.Replace(line, "\t", " ", -1)
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var ss []string
|
||||
for _, s := range strings.Split(line, " ") {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
ss = append(ss, s)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ss) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(ss) >= 2 {
|
||||
// timeout option
|
||||
if strings.ToLower(ss[0]) == "timeout" {
|
||||
r.Timeout, _ = time.ParseDuration(ss[1])
|
||||
continue
|
||||
}
|
||||
|
||||
// ttl option
|
||||
if strings.ToLower(ss[0]) == "ttl" {
|
||||
r.TTL, _ = time.ParseDuration(ss[1])
|
||||
continue
|
||||
}
|
||||
|
||||
// reload option
|
||||
if strings.ToLower(ss[0]) == "reload" {
|
||||
r.period, _ = time.ParseDuration(ss[1])
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var ns NameServer
|
||||
switch len(ss) {
|
||||
case 1:
|
||||
ns.Addr = ss[0]
|
||||
case 2:
|
||||
ns.Addr = ss[0]
|
||||
ns.Protocol = ss[1]
|
||||
default:
|
||||
ns.Addr = ss[0]
|
||||
ns.Protocol = ss[1]
|
||||
ns.Hostname = ss[2]
|
||||
}
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Servers = nss
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolver) Period() time.Duration {
|
||||
return r.period
|
||||
}
|
||||
|
||||
func (r *resolver) String() string {
|
||||
if r == nil {
|
||||
return ""
|
||||
@ -192,6 +272,7 @@ func (r *resolver) String() string {
|
||||
b := &bytes.Buffer{}
|
||||
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
|
||||
fmt.Fprintf(b, "TTL %v\n", r.TTL)
|
||||
fmt.Fprintf(b, "Reload %v\n", r.period)
|
||||
for i := range r.Servers {
|
||||
fmt.Fprintln(b, r.Servers[i])
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user