From 644d22d7c38cf6a8892dfadd7409d27ad930e4aa Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 4 Jul 2018 19:34:22 +0800 Subject: [PATCH] add static hosts support --- chain.go | 35 +++++++++------ cmd/gost/hosts.txt | 14 ++++++ cmd/gost/main.go | 19 ++++++++- hosts.go | 104 +++++++++++++++++++++++++++++++++++++++++++++ resolver.go | 56 ++++++++++++++---------- server.go | 19 ++++++++- 6 files changed, 208 insertions(+), 39 deletions(-) create mode 100644 cmd/gost/hosts.txt create mode 100644 hosts.go diff --git a/chain.go b/chain.go index 45831e3..fc70a82 100644 --- a/chain.go +++ b/chain.go @@ -18,6 +18,7 @@ var ( type Chain struct { isRoute bool Retries int + Hosts *Hosts Resolver Resolver nodeGroups []*NodeGroup } @@ -124,18 +125,7 @@ func (c *Chain) dial(addr string) (net.Conn, error) { return nil, err } - if c != nil && c.Resolver != nil { - host, port, err := net.SplitHostPort(addr) - if err == nil { - addrs, er := c.Resolver.Resolve(host) - if er != nil { - log.Logf("[resolver] %s: %v", host, er) - } - if len(addrs) > 0 { - addr = net.JoinHostPort(addrs[0].IP.String(), port) - } - } - } + addr = c.resolve(addr) if route.IsEmpty() { return net.DialTimeout("tcp", addr, DialTimeout) @@ -154,6 +144,27 @@ func (c *Chain) dial(addr string) (net.Conn, error) { return cc, nil } +func (c *Chain) resolve(addr string) string { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return addr + } + + if ip := c.Hosts.Lookup(host); ip != nil { + return net.JoinHostPort(ip.String(), port) + } + if c.Resolver != nil { + ips, err := c.Resolver.Resolve(host) + if err != nil { + log.Logf("[resolver] %s: %v", host, err) + } + if len(ips) > 0 { + return net.JoinHostPort(ips[0].String(), port) + } + } + return addr +} + // Conn obtains a handshaked connection to the last node of the chain. // If the chain is empty, it returns an ErrEmptyChain error. func (c *Chain) Conn() (conn net.Conn, err error) { diff --git a/cmd/gost/hosts.txt b/cmd/gost/hosts.txt new file mode 100644 index 0000000..6944f41 --- /dev/null +++ b/cmd/gost/hosts.txt @@ -0,0 +1,14 @@ +# The following lines are desirable for IPv4 capable hosts +127.0.0.1 localhost + +# 127.0.1.1 is often used for the FQDN of the machine +127.0.1.1 thishost.mydomain.org thishost +192.168.1.10 foo.mydomain.org foo +192.168.1.13 bar.mydomain.org bar +146.82.138.7 master.debian.org master +209.237.226.90 www.opensource.org + +# The following lines are desirable for IPv6 capable hosts +::1 localhost ip6-localhost ip6-loopback +ff02::1 ip6-allnodes +ff02::2 ip6-allrouters \ No newline at end of file diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 067f620..779596c 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -6,6 +6,8 @@ import ( "flag" "fmt" "net" + "net/http" + _ "net/http/pprof" "os" "runtime" "time" @@ -57,6 +59,9 @@ func init() { } func main() { + go func() { + log.Log(http.ListenAndServe("localhost:6060", nil)) + }() // NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate. config, err := tlsConfig(defaultCertFile, defaultKeyFile) if err != nil { @@ -336,12 +341,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { } func (r *route) serve() error { - chain, err := r.initChain() + baseChain, err := r.initChain() if err != nil { return err } for _, ns := range r.ServeNodes { + chain := &gost.Chain{} + *chain = *baseChain + node, err := gost.ParseNode(ns) if err != nil { return err @@ -462,7 +470,6 @@ func (r *route) serve() error { } var handlerOptions []gost.HandlerOption - handlerOptions = append(handlerOptions, gost.AddrHandlerOption(node.Addr), gost.ChainHandlerOption(chain), @@ -516,6 +523,14 @@ func (r *route) serve() error { if gost.Debug { log.Logf("[resolver]\n%v", chain.Resolver) } + + if f, _ := os.Open(node.Get("hosts")); f != nil { + chain.Hosts, err = gost.ParseHosts(f) + if err != nil { + log.Logf("[hosts] %s: %v", f.Name(), err) + } + } + go srv.Serve(handler) } diff --git a/hosts.go b/hosts.go new file mode 100644 index 0000000..81bac00 --- /dev/null +++ b/hosts.go @@ -0,0 +1,104 @@ +package gost + +import ( + "bufio" + "io" + "net" + "strings" + + "github.com/go-log/log" +) + +// Host is a static mapping from hostname to IP. +type Host struct { + IP net.IP + Hostname string + Aliases []string +} + +// Hosts is a static table lookup for hostnames. +type Hosts struct { + hosts []Host +} + +// NewHosts creates a Hosts with optional list of host +func NewHosts(hosts ...Host) *Hosts { + return &Hosts{ + hosts: hosts, + } +} + +// ParseHosts parses host table from r. +// For each host a single line should be present with the following information: +// IP_address canonical_hostname [aliases...] +// Fields of the entry are separated by any number of blanks and/or tab characters. +// Text from a "#" character until the end of the line is a comment, and is ignored. +func ParseHosts(r io.Reader) (*Hosts, error) { + hosts := NewHosts() + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + if line == "" { + continue + } + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + if len(ss) < 2 { + continue // invalid lines are ignored + } + ip := net.ParseIP(ss[0]) + if ip == nil { + continue // invalid IP addresses are ignored + } + host := Host{ + IP: ip, + Hostname: ss[1], + } + if len(ss) > 2 { + host.Aliases = ss[2:] + } + hosts.AddHost(host) + } + if err := scanner.Err(); err != nil { + return nil, err + } + + return hosts, nil +} + +// AddHost adds host(s) to the host table. +func (h *Hosts) AddHost(host ...Host) { + h.hosts = append(h.hosts, host...) +} + +// Lookup searches the IP address corresponds to the given host from the host table. +func (h *Hosts) Lookup(host string) (ip net.IP) { + if h == nil { + return + } + for _, h := range h.hosts { + if h.Hostname == host { + ip = h.IP + break + } + for _, alias := range h.Aliases { + if alias == host { + ip = h.IP + break + } + } + } + if ip != nil && Debug { + log.Logf("[hosts] hit: %s %s", host, ip.String()) + } + return +} diff --git a/resolver.go b/resolver.go index b1f26f8..08c38ff 100644 --- a/resolver.go +++ b/resolver.go @@ -24,7 +24,7 @@ var ( // It contains a list of name servers. type Resolver interface { // Resolve returns a slice of that host's IPv4 and IPv6 addresses. - Resolve(host string) ([]net.IPAddr, error) + Resolve(host string) ([]net.IP, error) } // NameServer is a name server. @@ -39,8 +39,8 @@ func (ns NameServer) String() string { addr := ns.Addr prot := ns.Protocol host := ns.Hostname - if !strings.Contains(addr, ":") { - addr += ":53" + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "53") } if prot == "" { prot = "udp" @@ -49,8 +49,8 @@ func (ns NameServer) String() string { } type resolverCacheItem struct { - IPAddrs []net.IPAddr - ts int64 + IPs []net.IP + ts int64 } type resolver struct { @@ -100,8 +100,8 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) { var d net.Dialer addr := ns.Addr - if !strings.Contains(addr, ":") { - addr += ":53" + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "53") } switch strings.ToLower(ns.Protocol) { case "tcp": @@ -125,28 +125,38 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) { } } -func (r *resolver) Resolve(name string) (addrs []net.IPAddr, err error) { +func (r *resolver) Resolve(name string) (ips []net.IP, err error) { + if r == nil { + return + } timeout := r.Timeout - addrs = r.loadCache(name) - if len(addrs) > 0 { + if ip := net.ParseIP(name); ip != nil { + return []net.IP{ip}, nil + } + + ips = r.loadCache(name) + if len(ips) > 0 { if Debug { - log.Logf("[resolver] cache hit: %s %v", name, addrs) + log.Logf("[resolver] cache hit: %s %v", name, ips) } return } ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - addrs, err = r.Resolver.LookupIPAddr(ctx, name) - r.storeCache(name, addrs) - if len(addrs) > 0 && Debug { - log.Logf("[resolver] %s %v", name, addrs) + addrs, err := r.Resolver.LookupIPAddr(ctx, name) + for _, addr := range addrs { + ips = append(ips, addr.IP) + } + r.storeCache(name, ips) + if len(ips) > 0 && Debug { + log.Logf("[resolver] %s %v", name, ips) } return } -func (r *resolver) loadCache(name string) []net.IPAddr { +func (r *resolver) loadCache(name string) []net.IP { ttl := r.TTL if ttl < 0 { return nil @@ -157,20 +167,20 @@ func (r *resolver) loadCache(name string) []net.IPAddr { if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl { return nil } - return item.IPAddrs + return item.IPs } return nil } -func (r *resolver) storeCache(name string, addrs []net.IPAddr) { +func (r *resolver) storeCache(name string, ips []net.IP) { ttl := r.TTL - if ttl < 0 || name == "" || len(addrs) == 0 { + if ttl < 0 || name == "" || len(ips) == 0 { return } r.mCache.Store(name, &resolverCacheItem{ - IPAddrs: addrs, - ts: time.Now().Unix(), + IPs: ips, + ts: time.Now().Unix(), }) } @@ -180,8 +190,8 @@ func (r *resolver) String() string { } b := &bytes.Buffer{} - fmt.Fprintf(b, "timeout %v\n", r.Timeout) - fmt.Fprintf(b, "ttl %v\n", r.TTL) + fmt.Fprintf(b, "Timeout %v\n", r.Timeout) + fmt.Fprintf(b, "TTL %v\n", r.TTL) for i := range r.Servers { fmt.Fprintln(b, r.Servers[i]) } diff --git a/server.go b/server.go index 29c1212..a9fbd0e 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package gost import ( "io" "net" + "sync" "time" "github.com/go-log/log" @@ -132,15 +133,29 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { return tc, nil } +var ( + trPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 32*1024) + }, + } +) + func transport(rw1, rw2 io.ReadWriter) error { errc := make(chan error, 1) go func() { - _, err := io.Copy(rw1, rw2) + buf := trPool.Get().([]byte) + defer trPool.Put(buf) + + _, err := io.CopyBuffer(rw1, rw2, buf) errc <- err }() go func() { - _, err := io.Copy(rw2, rw1) + buf := trPool.Get().([]byte) + defer trPool.Put(buf) + + _, err := io.CopyBuffer(rw2, rw1, buf) errc <- err }()