diff --git a/chain.go b/chain.go index 1a24e3b..60b6871 100644 --- a/chain.go +++ b/chain.go @@ -6,6 +6,7 @@ import ( "net" "time" + "github.com/go-gost/hosts" "github.com/go-gost/log" ) @@ -179,14 +180,16 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op return cc, nil } -func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { +func (*Chain) resolve(addr string, resolver Resolver, hosts hosts.Hosts) string { host, port, err := net.SplitHostPort(addr) if err != nil { return addr } - if ip := hosts.Lookup(host); ip != nil { - return net.JoinHostPort(ip.String(), port) + if hosts != nil { + if ip := hosts.Lookup(host); ip != nil { + return net.JoinHostPort(ip.String(), port) + } } if resolver != nil { ips, err := resolver.Resolve(host) @@ -326,7 +329,7 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { type ChainOptions struct { Retries int Timeout time.Duration - Hosts *Hosts + Hosts hosts.Hosts Resolver Resolver } @@ -348,7 +351,7 @@ func TimeoutChainOption(timeout time.Duration) ChainOption { } // HostsChainOption specifies the hosts used by Chain.Dial. -func HostsChainOption(hosts *Hosts) ChainOption { +func HostsChainOption(hosts hosts.Hosts) ChainOption { return func(opts *ChainOptions) { opts.Hosts = hosts } diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index ff264b1..5dbfbda 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -14,6 +14,7 @@ import ( "github.com/ginuerzh/gost" "github.com/go-gost/bypass" + "github.com/go-gost/hosts" "github.com/go-gost/reloader" ) @@ -266,19 +267,20 @@ func parseResolver(cfg string) gost.Resolver { return resolver } -func parseHosts(s string) *gost.Hosts { +func parseHosts(s string) hosts.Hosts { f, err := os.Open(s) if err != nil { return nil } defer f.Close() - hosts := gost.NewHosts() - hosts.Reload(f) + hsts := hosts.NewHosts() + if r, ok := hsts.(reloader.Reloader); ok { + r.Reload(f) + go reloader.PeriodReload(r, s) + } - go reloader.PeriodReload(hosts, s) - - return hosts + return hsts } func parseIPRoutes(s string) (routes []gost.IPRoute) { diff --git a/cmd/gost/route.go b/cmd/gost/route.go index b21f165..4247a17 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/ginuerzh/gost" + "github.com/go-gost/hosts" "github.com/go-gost/log" "github.com/go-gost/reloader" ) @@ -640,7 +641,7 @@ type router struct { handler gost.Handler chain *gost.Chain resolver gost.Resolver - hosts *gost.Hosts + hosts hosts.Hosts } func (r *router) Serve() error { diff --git a/go.mod b/go.mod index 54aef86..31b048c 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/ginuerzh/tls-dissector v0.0.2-0.20200224064855-24ab2b3a3796 github.com/go-gost/bpool v1.0.0 github.com/go-gost/bypass v1.0.0 + github.com/go-gost/hosts v1.0.0 github.com/go-gost/log v1.0.0 github.com/go-gost/relay v0.1.0 github.com/go-gost/reloader v1.1.0 diff --git a/go.sum b/go.sum index 29bfea4..ea3026e 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,8 @@ github.com/go-gost/bpool v1.0.0 h1:Og+6SH6SooHqf8CIwuxcPRHq8k0Si4YEfO2mBEi3/Uk= github.com/go-gost/bpool v1.0.0/go.mod h1:y/Pywm22A4OrJqNF/mL4nW7yb9fCdhlO8cxjyparkFI= github.com/go-gost/bypass v1.0.0 h1:ZhLzA3WY9JDxmpyuGwVUH0ubaFOmWWzNpejzwExvhMA= github.com/go-gost/bypass v1.0.0/go.mod h1:r2MYlxn1/fs24NFs+h/m9HiZKfckBrTnONXklxRUwcE= +github.com/go-gost/hosts v1.0.0 h1:KnUBEuIZ6CtMAvGN0n8vYvYahPuMMAu58HGFm2QKEhE= +github.com/go-gost/hosts v1.0.0/go.mod h1:6MrBWeZaRKo/ZwKhtSRK49C+rd3b8Hun4rhuRiYT6bI= github.com/go-gost/log v1.0.0 h1:maSjjMvQqLSQYb0Ta5nJTdlRI+aiLMt6WIBYVxajJgs= github.com/go-gost/log v1.0.0/go.mod h1:FCOaaJQ7moHTlLxYk7dsFewlS68U9A1GG3OR+yXkF6s= github.com/go-gost/relay v0.1.0 h1:UOf2YwAzzaUjY5mdpMuLfSw0vz62iIFYk7oJQkuhlGw= diff --git a/handler.go b/handler.go index 5aa9ac1..fdece30 100644 --- a/handler.go +++ b/handler.go @@ -10,6 +10,7 @@ import ( "github.com/ginuerzh/gosocks4" "github.com/ginuerzh/gosocks5" "github.com/go-gost/bypass" + "github.com/go-gost/hosts" "github.com/go-gost/log" ) @@ -35,7 +36,7 @@ type HandlerOptions struct { Retries int Timeout time.Duration Resolver Resolver - Hosts *Hosts + Hosts hosts.Hosts ProbeResist string KnockingHost string Node Node @@ -157,7 +158,7 @@ func ResolverHandlerOption(resolver Resolver) HandlerOption { } // HostsHandlerOption sets the Hosts option of HandlerOptions. -func HostsHandlerOption(hosts *Hosts) HandlerOption { +func HostsHandlerOption(hosts hosts.Hosts) HandlerOption { return func(opts *HandlerOptions) { opts.Hosts = hosts } diff --git a/hosts.go b/hosts.go deleted file mode 100644 index 19e1260..0000000 --- a/hosts.go +++ /dev/null @@ -1,160 +0,0 @@ -package gost - -import ( - "bufio" - "io" - "net" - "sync" - "time" - - "github.com/go-gost/log" -) - -// Host is a static mapping from hostname to IP. -type Host struct { - IP net.IP - Hostname string - Aliases []string -} - -// NewHost creates a Host. -func NewHost(ip net.IP, hostname string, aliases ...string) Host { - return Host{ - IP: ip, - Hostname: hostname, - Aliases: aliases, - } -} - -// Hosts is a static table lookup for hostnames. -// For each host a single line should be present with the following information: -// IP_address canonical_hostname [aliases...] -// Fields of the entry are separated by any number of blanks and/or tab characters. -// Text from a "#" character until the end of the line is a comment, and is ignored. -type Hosts struct { - hosts []Host - period time.Duration - stopped chan struct{} - mux sync.RWMutex -} - -// NewHosts creates a Hosts with optional list of hosts. -func NewHosts(hosts ...Host) *Hosts { - return &Hosts{ - hosts: hosts, - stopped: make(chan struct{}), - } -} - -// AddHost adds host(s) to the host table. -func (h *Hosts) AddHost(host ...Host) { - h.mux.Lock() - defer h.mux.Unlock() - - h.hosts = append(h.hosts, host...) -} - -// Lookup searches the IP address corresponds to the given host from the host table. -func (h *Hosts) Lookup(host string) (ip net.IP) { - if h == nil || host == "" { - return - } - - h.mux.RLock() - defer h.mux.RUnlock() - - for _, h := range h.hosts { - if h.Hostname == host { - ip = h.IP - break - } - for _, alias := range h.Aliases { - if alias == host { - ip = h.IP - break - } - } - } - if ip != nil && Debug { - log.Logf("[hosts] hit: %s %s", host, ip.String()) - } - return -} - -// Reload parses config from r, then live reloads the hosts. -func (h *Hosts) Reload(r io.Reader) error { - var period time.Duration - var hosts []Host - - if r == nil || h.Stopped() { - return nil - } - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - ss := splitLine(line) - if len(ss) < 2 { - continue // invalid lines are ignored - } - - switch ss[0] { - case "reload": // reload option - period, _ = time.ParseDuration(ss[1]) - default: - ip := net.ParseIP(ss[0]) - if ip == nil { - break // invalid IP addresses are ignored - } - host := Host{ - IP: ip, - Hostname: ss[1], - } - if len(ss) > 2 { - host.Aliases = ss[2:] - } - hosts = append(hosts, host) - } - } - if err := scanner.Err(); err != nil { - return err - } - - h.mux.Lock() - h.period = period - h.hosts = hosts - h.mux.Unlock() - - return nil -} - -// Period returns the reload period -func (h *Hosts) Period() time.Duration { - if h.Stopped() { - return -1 - } - - h.mux.RLock() - defer h.mux.RUnlock() - - return h.period -} - -// Stop stops reloading. -func (h *Hosts) Stop() { - select { - case <-h.stopped: - default: - close(h.stopped) - } -} - -// Stopped checks whether the reloader is stopped. -func (h *Hosts) Stopped() bool { - select { - case <-h.stopped: - return true - default: - return false - } -} diff --git a/hosts_test.go b/hosts_test.go deleted file mode 100644 index 2fbae32..0000000 --- a/hosts_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package gost - -import ( - "bytes" - "io" - "net" - "testing" - "time" -) - -var hostsLookupTests = []struct { - hosts []Host - host string - ip net.IP -}{ - {nil, "", nil}, - {nil, "example.com", nil}, - {[]Host{}, "", nil}, - {[]Host{}, "example.com", nil}, - {[]Host{NewHost(nil, "")}, "", nil}, - {[]Host{NewHost(nil, "example.com")}, "example.com", nil}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "")}, "", nil}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com")}, "example.com", net.IPv4(192, 168, 1, 1)}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com")}, "example", nil}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com", "example", "examples")}, "example", net.IPv4(192, 168, 1, 1)}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com", "example", "examples")}, "examples", net.IPv4(192, 168, 1, 1)}, -} - -func TestHostsLookup(t *testing.T) { - for i, tc := range hostsLookupTests { - hosts := NewHosts() - hosts.AddHost(tc.hosts...) - ip := hosts.Lookup(tc.host) - if !ip.Equal(tc.ip) { - t.Errorf("#%d test failed: lookup should be %s, got %s", i, tc.ip, ip) - } - } -} - -var HostsReloadTests = []struct { - r io.Reader - period time.Duration - host string - ip net.IP - stopped bool -}{ - { - r: nil, - period: 0, - host: "", - ip: nil, - }, - { - r: bytes.NewBufferString(""), - period: 0, - host: "example.com", - ip: nil, - }, - { - r: bytes.NewBufferString("reload 10s"), - period: 10 * time.Second, - host: "example.com", - ip: nil, - }, - { - r: bytes.NewBufferString("#reload 10s\ninvalid.ip.addr example.com"), - period: 0, - ip: nil, - }, - { - r: bytes.NewBufferString("reload 10s\n192.168.1.1"), - period: 10 * time.Second, - host: "", - ip: nil, - }, - { - r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com"), - period: 0, - host: "example.com", - ip: net.IPv4(192, 168, 1, 1), - }, - { - r: bytes.NewBufferString("#reload 10s\n#192.168.1.1 example.com"), - period: 0, - host: "example.com", - ip: nil, - stopped: true, - }, - { - r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com example examples"), - period: 0, - host: "example", - ip: net.IPv4(192, 168, 1, 1), - stopped: true, - }, - { - r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com example examples"), - period: 0, - host: "examples", - ip: net.IPv4(192, 168, 1, 1), - stopped: true, - }, -} - -func TestHostsReload(t *testing.T) { - for i, tc := range HostsReloadTests { - hosts := NewHosts() - if err := hosts.Reload(tc.r); err != nil { - t.Error(err) - } - if hosts.Period() != tc.period { - t.Errorf("#%d test failed: period value should be %v, got %v", - i, tc.period, hosts.Period()) - } - ip := hosts.Lookup(tc.host) - if !ip.Equal(tc.ip) { - t.Errorf("#%d test failed: lookup should be %s, got %s", i, tc.ip, ip) - } - if tc.stopped { - hosts.Stop() - if hosts.Period() >= 0 { - t.Errorf("period of the stopped reloader should be minus value") - } - } - if hosts.Stopped() != tc.stopped { - t.Errorf("#%d test failed: stopped value should be %v, got %v", - i, tc.stopped, hosts.Stopped()) - } - } -}