From 38827782e1c19a2c43ca9602856eb4ecb123c39f Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 3 Nov 2018 10:56:55 +0800 Subject: [PATCH] add live reloading support for bypass,resolver and hosts --- bypass.go | 90 +++++++++++++++++++++++++++++++---- cmd/gost/bypass.txt | 9 +++- cmd/gost/cfg.go | 78 +++++------------------------- cmd/gost/dns.txt | 7 ++- cmd/gost/hosts.txt | 5 +- cmd/gost/main.go | 9 ++-- hosts.go | 112 +++++++++++++++++++++++++------------------- reload.go | 52 ++++++++++++++++++++ resolver.go | 85 ++++++++++++++++++++++++++++++++- 9 files changed, 313 insertions(+), 134 deletions(-) create mode 100644 reload.go diff --git a/bypass.go b/bypass.go index 416d045..792c33d 100644 --- a/bypass.go +++ b/bypass.go @@ -1,11 +1,15 @@ package gost import ( + "bufio" "bytes" "fmt" + "io" "net" "strconv" "strings" + "sync" + "time" glob "github.com/gobwas/glob" ) @@ -118,28 +122,30 @@ func (m *domainMatcher) String() string { // It contains a list of matchers. type Bypass struct { matchers []Matcher - reverse bool + reversed bool + period time.Duration // the period for live reloading + mux sync.Mutex } // NewBypass creates and initializes a new Bypass using matchers as its match rules. // The rules will be reversed if the reversed is true. -func NewBypass(matchers []Matcher, reverse bool) *Bypass { +func NewBypass(reversed bool, matchers ...Matcher) *Bypass { return &Bypass{ matchers: matchers, - reverse: reverse, + reversed: reversed, } } // NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules. // The rules will be reversed if the reverse is true. -func NewBypassPatterns(patterns []string, reverse bool) *Bypass { +func NewBypassPatterns(reversed bool, patterns ...string) *Bypass { var matchers []Matcher for _, pattern := range patterns { if pattern != "" { matchers = append(matchers, NewMatcher(pattern)) } } - return NewBypass(matchers, reverse) + return NewBypass(reversed, matchers...) } // Contains reports whether the bypass includes addr. @@ -153,6 +159,10 @@ func (bp *Bypass) Contains(addr string) bool { addr = host } } + + bp.mux.Lock() + defer bp.mux.Unlock() + var matched bool for _, matcher := range bp.matchers { if matcher == nil { @@ -163,8 +173,8 @@ func (bp *Bypass) Contains(addr string) bool { break } } - return !bp.reverse && matched || - bp.reverse && !matched + return !bp.reversed && matched || + bp.reversed && !matched } // AddMatchers appends matchers to the bypass matcher list. @@ -179,7 +189,71 @@ func (bp *Bypass) Matchers() []Matcher { // Reversed reports whether the rules of the bypass are reversed. func (bp *Bypass) Reversed() bool { - return bp.reverse + return bp.reversed +} + +// Reload parses config from r, then live reloads the bypass. +func (bp *Bypass) Reload(r io.Reader) error { + var matchers []Matcher + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // reload option + if strings.HasPrefix(line, "reload ") { + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + if len(ss) == 2 { + bp.period, _ = time.ParseDuration(ss[1]) + continue + } + } + + // reverse option + if strings.HasPrefix(line, "reverse ") { + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + if len(ss) == 2 { + bp.reversed, _ = strconv.ParseBool(ss[1]) + continue + } + } + + matchers = append(matchers, NewMatcher(line)) + } + + if err := scanner.Err(); err != nil { + return err + } + + bp.mux.Lock() + defer bp.mux.Unlock() + + bp.matchers = matchers + + return nil +} + +// Period returns the reload period +func (bp *Bypass) Period() time.Duration { + return bp.period } func (bp *Bypass) String() string { diff --git a/cmd/gost/bypass.txt b/cmd/gost/bypass.txt index 04bf178..5afe8c0 100644 --- a/cmd/gost/bypass.txt +++ b/cmd/gost/bypass.txt @@ -1,5 +1,12 @@ +# period for live reloading +reload 10s + +# matcher reversed + reverse true + 10.0.0.1 192.168.0.0/24 172.1.0.0/16 192.168.100.190/32 -*.example.com \ No newline at end of file +*.example.com +.example.org \ No newline at end of file diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index 1b758dc..30eb321 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -10,7 +10,6 @@ import ( "io/ioutil" "net/url" "os" - "strconv" "strings" "time" @@ -243,22 +242,14 @@ func parseBypass(s string) *gost.Bypass { } matchers = append(matchers, gost.NewMatcher(s)) } - return gost.NewBypass(matchers, reversed) + return gost.NewBypass(reversed, matchers...) } + f.Close() - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Text() - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - line = strings.TrimSpace(line) - if line == "" { - continue - } - matchers = append(matchers, gost.NewMatcher(line)) - } - return gost.NewBypass(matchers, reversed) + bp := gost.NewBypass(reversed) + go gost.PeriodReload(bp, s) + + return bp } func parseResolver(cfg string) gost.Resolver { @@ -289,59 +280,12 @@ func parseResolver(cfg string) gost.Resolver { }) } } - return gost.NewResolver(nss, timeout, ttl) + return gost.NewResolver(timeout, ttl, nss...) } + f.Close() - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Text() - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - line = strings.TrimSpace(line) - if line == "" { - continue - } - var ss []string - for _, s := range strings.Split(line, " ") { - if s = strings.TrimSpace(s); s != "" { - ss = append(ss, s) - } - } + resolver := gost.NewResolver(timeout, ttl) + go gost.PeriodReload(resolver, cfg) - if len(ss) == 0 { - continue - } - - if strings.ToLower(ss[0]) == "timeout" { - if len(ss) >= 2 { - if n, _ := strconv.Atoi(ss[1]); n > 0 { - timeout = time.Second * time.Duration(n) - } - } - continue - } - if strings.ToLower(ss[0]) == "ttl" { - if len(ss) >= 2 { - n, _ := strconv.Atoi(ss[1]) - ttl = time.Second * time.Duration(n) - } - continue - } - - var ns gost.NameServer - switch len(ss) { - case 1: - ns.Addr = ss[0] - case 2: - ns.Addr = ss[0] - ns.Protocol = ss[1] - default: - ns.Addr = ss[0] - ns.Protocol = ss[1] - ns.Hostname = ss[2] - } - nss = append(nss, ns) - } - return gost.NewResolver(nss, timeout, ttl) + return resolver } diff --git a/cmd/gost/dns.txt b/cmd/gost/dns.txt index d2ddc45..303ad2c 100644 --- a/cmd/gost/dns.txt +++ b/cmd/gost/dns.txt @@ -1,8 +1,11 @@ # resolver timeout, default 30s. -timeout 10 +timeout 10s # resolver cache TTL, default 60s, minus value means that cache is disabled. -ttl 300 +ttl 300s + +# period for live reloading +reload 10s # ip[:port] [protocol] [hostname] diff --git a/cmd/gost/hosts.txt b/cmd/gost/hosts.txt index 6944f41..a2bf11d 100644 --- a/cmd/gost/hosts.txt +++ b/cmd/gost/hosts.txt @@ -1,3 +1,6 @@ +# period for live reloading +reload 10s + # The following lines are desirable for IPv4 capable hosts 127.0.0.1 localhost @@ -11,4 +14,4 @@ # The following lines are desirable for IPv6 capable hosts ::1 localhost ip6-localhost ip6-loopback ff02::1 ip6-allnodes -ff02::2 ip6-allrouters \ No newline at end of file +ff02::2 ip6-allrouters diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 331ac1b..7aa26e1 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -153,7 +153,7 @@ func (r *route) initChain() (*gost.Chain, error) { var bypass *gost.Bypass // global bypass if peerCfg.Bypass != nil { - bypass = gost.NewBypassPatterns(peerCfg.Bypass.Patterns, peerCfg.Bypass.Reverse) + bypass = gost.NewBypassPatterns(peerCfg.Bypass.Reverse, peerCfg.Bypass.Patterns...) } nodes = ngroup.Nodes() for i := range nodes { @@ -492,10 +492,9 @@ func (r *route) serve() error { var hosts *gost.Hosts if f, _ := os.Open(node.Get("hosts")); f != nil { - hosts, err = gost.ParseHosts(f) - if err != nil { - log.Logf("[hosts] %s: %v", f.Name(), err) - } + f.Close() + hosts = gost.NewHosts() + go gost.PeriodReload(hosts, node.Get("hosts")) } handler.Init( diff --git a/hosts.go b/hosts.go index 81bac00..de2866f 100644 --- a/hosts.go +++ b/hosts.go @@ -5,6 +5,7 @@ import ( "io" "net" "strings" + "time" "github.com/go-log/log" ) @@ -17,8 +18,13 @@ type Host struct { } // Hosts is a static table lookup for hostnames. +// For each host a single line should be present with the following information: +// IP_address canonical_hostname [aliases...] +// Fields of the entry are separated by any number of blanks and/or tab characters. +// Text from a "#" character until the end of the line is a comment, and is ignored. type Hosts struct { - hosts []Host + hosts []Host + period time.Duration } // NewHosts creates a Hosts with optional list of host @@ -28,53 +34,6 @@ func NewHosts(hosts ...Host) *Hosts { } } -// ParseHosts parses host table from r. -// For each host a single line should be present with the following information: -// IP_address canonical_hostname [aliases...] -// Fields of the entry are separated by any number of blanks and/or tab characters. -// Text from a "#" character until the end of the line is a comment, and is ignored. -func ParseHosts(r io.Reader) (*Hosts, error) { - hosts := NewHosts() - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - line = strings.Replace(line, "\t", " ", -1) - line = strings.TrimSpace(line) - if line == "" { - continue - } - var ss []string - for _, s := range strings.Split(line, " ") { - if s = strings.TrimSpace(s); s != "" { - ss = append(ss, s) - } - } - if len(ss) < 2 { - continue // invalid lines are ignored - } - ip := net.ParseIP(ss[0]) - if ip == nil { - continue // invalid IP addresses are ignored - } - host := Host{ - IP: ip, - Hostname: ss[1], - } - if len(ss) > 2 { - host.Aliases = ss[2:] - } - hosts.AddHost(host) - } - if err := scanner.Err(); err != nil { - return nil, err - } - - return hosts, nil -} - // AddHost adds host(s) to the host table. func (h *Hosts) AddHost(host ...Host) { h.hosts = append(h.hosts, host...) @@ -102,3 +61,60 @@ func (h *Hosts) Lookup(host string) (ip net.IP) { } return } + +// Reload parses config from r, then live reloads the hosts. +func (h *Hosts) Reload(r io.Reader) error { + var hosts []Host + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + if line == "" { + continue + } + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + if len(ss) < 2 { + continue // invalid lines are ignored + } + + // reload option + if strings.ToLower(ss[0]) == "reload" { + h.period, _ = time.ParseDuration(ss[1]) + continue + } + + ip := net.ParseIP(ss[0]) + if ip == nil { + continue // invalid IP addresses are ignored + } + host := Host{ + IP: ip, + Hostname: ss[1], + } + if len(ss) > 2 { + host.Aliases = ss[2:] + } + hosts = append(hosts, host) + } + if err := scanner.Err(); err != nil { + return err + } + + h.hosts = hosts + return nil +} + +// Period returns the reload period +func (h *Hosts) Period() time.Duration { + return h.period +} diff --git a/reload.go b/reload.go new file mode 100644 index 0000000..6e7bfa5 --- /dev/null +++ b/reload.go @@ -0,0 +1,52 @@ +package gost + +import ( + "io" + "os" + "time" + + "github.com/go-log/log" +) + +// Reloader is the interface for objects that support live reloading. +type Reloader interface { + Reload(r io.Reader) error + Period() time.Duration +} + +// PeriodReload reloads the config periodically according to the period of the reloader. +func PeriodReload(r Reloader, configFile string) error { + var lastMod time.Time + + for { + f, err := os.Open(configFile) + if err != nil { + return err + } + + finfo, err := f.Stat() + if err != nil { + return err + } + mt := finfo.ModTime() + if !mt.Equal(lastMod) { + if Debug { + log.Log("[reload]", configFile) + } + r.Reload(f) + lastMod = mt + } + f.Close() + + period := r.Period() + if period <= 0 { + log.Log("[reload] disabled:", configFile) + return nil + } + if period < time.Second { + period = time.Second + } + + <-time.After(period) + } +} diff --git a/resolver.go b/resolver.go index 08c38ff..87e81eb 100644 --- a/resolver.go +++ b/resolver.go @@ -1,10 +1,12 @@ package gost import ( + "bufio" "bytes" "context" "crypto/tls" "fmt" + "io" "net" "strings" "sync" @@ -27,6 +29,12 @@ type Resolver interface { Resolve(host string) ([]net.IP, error) } +// ReloadResolver is resolover that support live reloading +type ReloadResolver interface { + Resolver + Reloader +} + // NameServer is a name server. // Currently supported protocol: TCP, UDP and TLS. type NameServer struct { @@ -56,13 +64,14 @@ type resolverCacheItem struct { type resolver struct { Resolver *net.Resolver Servers []NameServer + mCache *sync.Map Timeout time.Duration TTL time.Duration - mCache *sync.Map + period time.Duration } // NewResolver create a new Resolver with the given name servers and resolution timeout. -func NewResolver(servers []NameServer, timeout, ttl time.Duration) Resolver { +func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolver { r := &resolver{ Servers: servers, Timeout: timeout, @@ -184,6 +193,77 @@ func (r *resolver) storeCache(name string, ips []net.IP) { }) } +func (r *resolver) Reload(rd io.Reader) error { + var nss []NameServer + + scanner := bufio.NewScanner(rd) + for scanner.Scan() { + line := scanner.Text() + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + if line == "" { + continue + } + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + + if len(ss) == 0 { + continue + } + + if len(ss) >= 2 { + // timeout option + if strings.ToLower(ss[0]) == "timeout" { + r.Timeout, _ = time.ParseDuration(ss[1]) + continue + } + + // ttl option + if strings.ToLower(ss[0]) == "ttl" { + r.TTL, _ = time.ParseDuration(ss[1]) + continue + } + + // reload option + if strings.ToLower(ss[0]) == "reload" { + r.period, _ = time.ParseDuration(ss[1]) + continue + } + } + + var ns NameServer + switch len(ss) { + case 1: + ns.Addr = ss[0] + case 2: + ns.Addr = ss[0] + ns.Protocol = ss[1] + default: + ns.Addr = ss[0] + ns.Protocol = ss[1] + ns.Hostname = ss[2] + } + nss = append(nss, ns) + } + if err := scanner.Err(); err != nil { + return err + } + + r.Servers = nss + return nil +} + +func (r *resolver) Period() time.Duration { + return r.period +} + func (r *resolver) String() string { if r == nil { return "" @@ -192,6 +272,7 @@ func (r *resolver) String() string { b := &bytes.Buffer{} fmt.Fprintf(b, "Timeout %v\n", r.Timeout) fmt.Fprintf(b, "TTL %v\n", r.TTL) + fmt.Fprintf(b, "Reload %v\n", r.period) for i := range r.Servers { fmt.Fprintln(b, r.Servers[i]) }