add live reloading support for bypass,resolver and hosts

This commit is contained in:
ginuerzh 2018-11-03 10:56:55 +08:00
parent ca3853e8cb
commit 38827782e1
9 changed files with 313 additions and 134 deletions

View File

@ -1,11 +1,15 @@
package gost
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
glob "github.com/gobwas/glob"
)
@ -118,28 +122,30 @@ func (m *domainMatcher) String() string {
// It contains a list of matchers.
type Bypass struct {
matchers []Matcher
reverse bool
reversed bool
period time.Duration // the period for live reloading
mux sync.Mutex
}
// NewBypass creates and initializes a new Bypass using matchers as its match rules.
// The rules will be reversed if the reversed is true.
func NewBypass(matchers []Matcher, reverse bool) *Bypass {
func NewBypass(reversed bool, matchers ...Matcher) *Bypass {
return &Bypass{
matchers: matchers,
reverse: reverse,
reversed: reversed,
}
}
// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules.
// The rules will be reversed if the reverse is true.
func NewBypassPatterns(patterns []string, reverse bool) *Bypass {
func NewBypassPatterns(reversed bool, patterns ...string) *Bypass {
var matchers []Matcher
for _, pattern := range patterns {
if pattern != "" {
matchers = append(matchers, NewMatcher(pattern))
}
}
return NewBypass(matchers, reverse)
return NewBypass(reversed, matchers...)
}
// Contains reports whether the bypass includes addr.
@ -153,6 +159,10 @@ func (bp *Bypass) Contains(addr string) bool {
addr = host
}
}
bp.mux.Lock()
defer bp.mux.Unlock()
var matched bool
for _, matcher := range bp.matchers {
if matcher == nil {
@ -163,8 +173,8 @@ func (bp *Bypass) Contains(addr string) bool {
break
}
}
return !bp.reverse && matched ||
bp.reverse && !matched
return !bp.reversed && matched ||
bp.reversed && !matched
}
// AddMatchers appends matchers to the bypass matcher list.
@ -179,7 +189,71 @@ func (bp *Bypass) Matchers() []Matcher {
// Reversed reports whether the rules of the bypass are reversed.
func (bp *Bypass) Reversed() bool {
return bp.reverse
return bp.reversed
}
// Reload parses config from r, then live reloads the bypass.
func (bp *Bypass) Reload(r io.Reader) error {
var matchers []Matcher
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if line == "" {
continue
}
// reload option
if strings.HasPrefix(line, "reload ") {
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) == 2 {
bp.period, _ = time.ParseDuration(ss[1])
continue
}
}
// reverse option
if strings.HasPrefix(line, "reverse ") {
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) == 2 {
bp.reversed, _ = strconv.ParseBool(ss[1])
continue
}
}
matchers = append(matchers, NewMatcher(line))
}
if err := scanner.Err(); err != nil {
return err
}
bp.mux.Lock()
defer bp.mux.Unlock()
bp.matchers = matchers
return nil
}
// Period returns the reload period
func (bp *Bypass) Period() time.Duration {
return bp.period
}
func (bp *Bypass) String() string {

View File

@ -1,5 +1,12 @@
# period for live reloading
reload 10s
# matcher reversed
reverse true
10.0.0.1
192.168.0.0/24
172.1.0.0/16
192.168.100.190/32
*.example.com
.example.org

View File

@ -10,7 +10,6 @@ import (
"io/ioutil"
"net/url"
"os"
"strconv"
"strings"
"time"
@ -243,22 +242,14 @@ func parseBypass(s string) *gost.Bypass {
}
matchers = append(matchers, gost.NewMatcher(s))
}
return gost.NewBypass(matchers, reversed)
return gost.NewBypass(reversed, matchers...)
}
f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
matchers = append(matchers, gost.NewMatcher(line))
}
return gost.NewBypass(matchers, reversed)
bp := gost.NewBypass(reversed)
go gost.PeriodReload(bp, s)
return bp
}
func parseResolver(cfg string) gost.Resolver {
@ -289,59 +280,12 @@ func parseResolver(cfg string) gost.Resolver {
})
}
}
return gost.NewResolver(nss, timeout, ttl)
return gost.NewResolver(timeout, ttl, nss...)
}
f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
resolver := gost.NewResolver(timeout, ttl)
go gost.PeriodReload(resolver, cfg)
if len(ss) == 0 {
continue
}
if strings.ToLower(ss[0]) == "timeout" {
if len(ss) >= 2 {
if n, _ := strconv.Atoi(ss[1]); n > 0 {
timeout = time.Second * time.Duration(n)
}
}
continue
}
if strings.ToLower(ss[0]) == "ttl" {
if len(ss) >= 2 {
n, _ := strconv.Atoi(ss[1])
ttl = time.Second * time.Duration(n)
}
continue
}
var ns gost.NameServer
switch len(ss) {
case 1:
ns.Addr = ss[0]
case 2:
ns.Addr = ss[0]
ns.Protocol = ss[1]
default:
ns.Addr = ss[0]
ns.Protocol = ss[1]
ns.Hostname = ss[2]
}
nss = append(nss, ns)
}
return gost.NewResolver(nss, timeout, ttl)
return resolver
}

View File

@ -1,8 +1,11 @@
# resolver timeout, default 30s.
timeout 10
timeout 10s
# resolver cache TTL, default 60s, minus value means that cache is disabled.
ttl 300
ttl 300s
# period for live reloading
reload 10s
# ip[:port] [protocol] [hostname]

View File

@ -1,3 +1,6 @@
# period for live reloading
reload 10s
# The following lines are desirable for IPv4 capable hosts
127.0.0.1 localhost

View File

@ -153,7 +153,7 @@ func (r *route) initChain() (*gost.Chain, error) {
var bypass *gost.Bypass
// global bypass
if peerCfg.Bypass != nil {
bypass = gost.NewBypassPatterns(peerCfg.Bypass.Patterns, peerCfg.Bypass.Reverse)
bypass = gost.NewBypassPatterns(peerCfg.Bypass.Reverse, peerCfg.Bypass.Patterns...)
}
nodes = ngroup.Nodes()
for i := range nodes {
@ -492,10 +492,9 @@ func (r *route) serve() error {
var hosts *gost.Hosts
if f, _ := os.Open(node.Get("hosts")); f != nil {
hosts, err = gost.ParseHosts(f)
if err != nil {
log.Logf("[hosts] %s: %v", f.Name(), err)
}
f.Close()
hosts = gost.NewHosts()
go gost.PeriodReload(hosts, node.Get("hosts"))
}
handler.Init(

110
hosts.go
View File

@ -5,6 +5,7 @@ import (
"io"
"net"
"strings"
"time"
"github.com/go-log/log"
)
@ -17,8 +18,13 @@ type Host struct {
}
// Hosts is a static table lookup for hostnames.
// For each host a single line should be present with the following information:
// IP_address canonical_hostname [aliases...]
// Fields of the entry are separated by any number of blanks and/or tab characters.
// Text from a "#" character until the end of the line is a comment, and is ignored.
type Hosts struct {
hosts []Host
period time.Duration
}
// NewHosts creates a Hosts with optional list of host
@ -28,53 +34,6 @@ func NewHosts(hosts ...Host) *Hosts {
}
}
// ParseHosts parses host table from r.
// For each host a single line should be present with the following information:
// IP_address canonical_hostname [aliases...]
// Fields of the entry are separated by any number of blanks and/or tab characters.
// Text from a "#" character until the end of the line is a comment, and is ignored.
func ParseHosts(r io.Reader) (*Hosts, error) {
hosts := NewHosts()
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if line == "" {
continue
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) < 2 {
continue // invalid lines are ignored
}
ip := net.ParseIP(ss[0])
if ip == nil {
continue // invalid IP addresses are ignored
}
host := Host{
IP: ip,
Hostname: ss[1],
}
if len(ss) > 2 {
host.Aliases = ss[2:]
}
hosts.AddHost(host)
}
if err := scanner.Err(); err != nil {
return nil, err
}
return hosts, nil
}
// AddHost adds host(s) to the host table.
func (h *Hosts) AddHost(host ...Host) {
h.hosts = append(h.hosts, host...)
@ -102,3 +61,60 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
}
return
}
// Reload parses config from r, then live reloads the hosts.
func (h *Hosts) Reload(r io.Reader) error {
var hosts []Host
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if line == "" {
continue
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) < 2 {
continue // invalid lines are ignored
}
// reload option
if strings.ToLower(ss[0]) == "reload" {
h.period, _ = time.ParseDuration(ss[1])
continue
}
ip := net.ParseIP(ss[0])
if ip == nil {
continue // invalid IP addresses are ignored
}
host := Host{
IP: ip,
Hostname: ss[1],
}
if len(ss) > 2 {
host.Aliases = ss[2:]
}
hosts = append(hosts, host)
}
if err := scanner.Err(); err != nil {
return err
}
h.hosts = hosts
return nil
}
// Period returns the reload period
func (h *Hosts) Period() time.Duration {
return h.period
}

52
reload.go Normal file
View File

@ -0,0 +1,52 @@
package gost
import (
"io"
"os"
"time"
"github.com/go-log/log"
)
// Reloader is the interface for objects that support live reloading.
type Reloader interface {
Reload(r io.Reader) error
Period() time.Duration
}
// PeriodReload reloads the config periodically according to the period of the reloader.
func PeriodReload(r Reloader, configFile string) error {
var lastMod time.Time
for {
f, err := os.Open(configFile)
if err != nil {
return err
}
finfo, err := f.Stat()
if err != nil {
return err
}
mt := finfo.ModTime()
if !mt.Equal(lastMod) {
if Debug {
log.Log("[reload]", configFile)
}
r.Reload(f)
lastMod = mt
}
f.Close()
period := r.Period()
if period <= 0 {
log.Log("[reload] disabled:", configFile)
return nil
}
if period < time.Second {
period = time.Second
}
<-time.After(period)
}
}

View File

@ -1,10 +1,12 @@
package gost
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"strings"
"sync"
@ -27,6 +29,12 @@ type Resolver interface {
Resolve(host string) ([]net.IP, error)
}
// ReloadResolver is resolover that support live reloading
type ReloadResolver interface {
Resolver
Reloader
}
// NameServer is a name server.
// Currently supported protocol: TCP, UDP and TLS.
type NameServer struct {
@ -56,13 +64,14 @@ type resolverCacheItem struct {
type resolver struct {
Resolver *net.Resolver
Servers []NameServer
mCache *sync.Map
Timeout time.Duration
TTL time.Duration
mCache *sync.Map
period time.Duration
}
// NewResolver create a new Resolver with the given name servers and resolution timeout.
func NewResolver(servers []NameServer, timeout, ttl time.Duration) Resolver {
func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolver {
r := &resolver{
Servers: servers,
Timeout: timeout,
@ -184,6 +193,77 @@ func (r *resolver) storeCache(name string, ips []net.IP) {
})
}
func (r *resolver) Reload(rd io.Reader) error {
var nss []NameServer
scanner := bufio.NewScanner(rd)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if line == "" {
continue
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) == 0 {
continue
}
if len(ss) >= 2 {
// timeout option
if strings.ToLower(ss[0]) == "timeout" {
r.Timeout, _ = time.ParseDuration(ss[1])
continue
}
// ttl option
if strings.ToLower(ss[0]) == "ttl" {
r.TTL, _ = time.ParseDuration(ss[1])
continue
}
// reload option
if strings.ToLower(ss[0]) == "reload" {
r.period, _ = time.ParseDuration(ss[1])
continue
}
}
var ns NameServer
switch len(ss) {
case 1:
ns.Addr = ss[0]
case 2:
ns.Addr = ss[0]
ns.Protocol = ss[1]
default:
ns.Addr = ss[0]
ns.Protocol = ss[1]
ns.Hostname = ss[2]
}
nss = append(nss, ns)
}
if err := scanner.Err(); err != nil {
return err
}
r.Servers = nss
return nil
}
func (r *resolver) Period() time.Duration {
return r.period
}
func (r *resolver) String() string {
if r == nil {
return ""
@ -192,6 +272,7 @@ func (r *resolver) String() string {
b := &bytes.Buffer{}
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
fmt.Fprintf(b, "TTL %v\n", r.TTL)
fmt.Fprintf(b, "Reload %v\n", r.period)
for i := range r.Servers {
fmt.Fprintln(b, r.Servers[i])
}