add static hosts support

This commit is contained in:
ginuerzh 2018-07-04 19:34:22 +08:00
parent e56fbfa809
commit 644d22d7c3
6 changed files with 208 additions and 39 deletions

View File

@ -18,6 +18,7 @@ var (
type Chain struct { type Chain struct {
isRoute bool isRoute bool
Retries int Retries int
Hosts *Hosts
Resolver Resolver Resolver Resolver
nodeGroups []*NodeGroup nodeGroups []*NodeGroup
} }
@ -124,18 +125,7 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
return nil, err return nil, err
} }
if c != nil && c.Resolver != nil { addr = c.resolve(addr)
host, port, err := net.SplitHostPort(addr)
if err == nil {
addrs, er := c.Resolver.Resolve(host)
if er != nil {
log.Logf("[resolver] %s: %v", host, er)
}
if len(addrs) > 0 {
addr = net.JoinHostPort(addrs[0].IP.String(), port)
}
}
}
if route.IsEmpty() { if route.IsEmpty() {
return net.DialTimeout("tcp", addr, DialTimeout) return net.DialTimeout("tcp", addr, DialTimeout)
@ -154,6 +144,27 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
return cc, nil return cc, nil
} }
func (c *Chain) resolve(addr string) string {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
if ip := c.Hosts.Lookup(host); ip != nil {
return net.JoinHostPort(ip.String(), port)
}
if c.Resolver != nil {
ips, err := c.Resolver.Resolve(host)
if err != nil {
log.Logf("[resolver] %s: %v", host, err)
}
if len(ips) > 0 {
return net.JoinHostPort(ips[0].String(), port)
}
}
return addr
}
// Conn obtains a handshaked connection to the last node of the chain. // Conn obtains a handshaked connection to the last node of the chain.
// If the chain is empty, it returns an ErrEmptyChain error. // If the chain is empty, it returns an ErrEmptyChain error.
func (c *Chain) Conn() (conn net.Conn, err error) { func (c *Chain) Conn() (conn net.Conn, err error) {

14
cmd/gost/hosts.txt Normal file
View File

@ -0,0 +1,14 @@
# The following lines are desirable for IPv4 capable hosts
127.0.0.1 localhost
# 127.0.1.1 is often used for the FQDN of the machine
127.0.1.1 thishost.mydomain.org thishost
192.168.1.10 foo.mydomain.org foo
192.168.1.13 bar.mydomain.org bar
146.82.138.7 master.debian.org master
209.237.226.90 www.opensource.org
# The following lines are desirable for IPv6 capable hosts
::1 localhost ip6-localhost ip6-loopback
ff02::1 ip6-allnodes
ff02::2 ip6-allrouters

View File

@ -6,6 +6,8 @@ import (
"flag" "flag"
"fmt" "fmt"
"net" "net"
"net/http"
_ "net/http/pprof"
"os" "os"
"runtime" "runtime"
"time" "time"
@ -57,6 +59,9 @@ func init() {
} }
func main() { func main() {
go func() {
log.Log(http.ListenAndServe("localhost:6060", nil))
}()
// NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate. // NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate.
config, err := tlsConfig(defaultCertFile, defaultKeyFile) config, err := tlsConfig(defaultCertFile, defaultKeyFile)
if err != nil { if err != nil {
@ -336,12 +341,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
} }
func (r *route) serve() error { func (r *route) serve() error {
chain, err := r.initChain() baseChain, err := r.initChain()
if err != nil { if err != nil {
return err return err
} }
for _, ns := range r.ServeNodes { for _, ns := range r.ServeNodes {
chain := &gost.Chain{}
*chain = *baseChain
node, err := gost.ParseNode(ns) node, err := gost.ParseNode(ns)
if err != nil { if err != nil {
return err return err
@ -462,7 +470,6 @@ func (r *route) serve() error {
} }
var handlerOptions []gost.HandlerOption var handlerOptions []gost.HandlerOption
handlerOptions = append(handlerOptions, handlerOptions = append(handlerOptions,
gost.AddrHandlerOption(node.Addr), gost.AddrHandlerOption(node.Addr),
gost.ChainHandlerOption(chain), gost.ChainHandlerOption(chain),
@ -516,6 +523,14 @@ func (r *route) serve() error {
if gost.Debug { if gost.Debug {
log.Logf("[resolver]\n%v", chain.Resolver) log.Logf("[resolver]\n%v", chain.Resolver)
} }
if f, _ := os.Open(node.Get("hosts")); f != nil {
chain.Hosts, err = gost.ParseHosts(f)
if err != nil {
log.Logf("[hosts] %s: %v", f.Name(), err)
}
}
go srv.Serve(handler) go srv.Serve(handler)
} }

104
hosts.go Normal file
View File

@ -0,0 +1,104 @@
package gost
import (
"bufio"
"io"
"net"
"strings"
"github.com/go-log/log"
)
// Host is a static mapping from hostname to IP.
type Host struct {
IP net.IP
Hostname string
Aliases []string
}
// Hosts is a static table lookup for hostnames.
type Hosts struct {
hosts []Host
}
// NewHosts creates a Hosts with optional list of host
func NewHosts(hosts ...Host) *Hosts {
return &Hosts{
hosts: hosts,
}
}
// ParseHosts parses host table from r.
// For each host a single line should be present with the following information:
// IP_address canonical_hostname [aliases...]
// Fields of the entry are separated by any number of blanks and/or tab characters.
// Text from a "#" character until the end of the line is a comment, and is ignored.
func ParseHosts(r io.Reader) (*Hosts, error) {
hosts := NewHosts()
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if line == "" {
continue
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) < 2 {
continue // invalid lines are ignored
}
ip := net.ParseIP(ss[0])
if ip == nil {
continue // invalid IP addresses are ignored
}
host := Host{
IP: ip,
Hostname: ss[1],
}
if len(ss) > 2 {
host.Aliases = ss[2:]
}
hosts.AddHost(host)
}
if err := scanner.Err(); err != nil {
return nil, err
}
return hosts, nil
}
// AddHost adds host(s) to the host table.
func (h *Hosts) AddHost(host ...Host) {
h.hosts = append(h.hosts, host...)
}
// Lookup searches the IP address corresponds to the given host from the host table.
func (h *Hosts) Lookup(host string) (ip net.IP) {
if h == nil {
return
}
for _, h := range h.hosts {
if h.Hostname == host {
ip = h.IP
break
}
for _, alias := range h.Aliases {
if alias == host {
ip = h.IP
break
}
}
}
if ip != nil && Debug {
log.Logf("[hosts] hit: %s %s", host, ip.String())
}
return
}

View File

@ -24,7 +24,7 @@ var (
// It contains a list of name servers. // It contains a list of name servers.
type Resolver interface { type Resolver interface {
// Resolve returns a slice of that host's IPv4 and IPv6 addresses. // Resolve returns a slice of that host's IPv4 and IPv6 addresses.
Resolve(host string) ([]net.IPAddr, error) Resolve(host string) ([]net.IP, error)
} }
// NameServer is a name server. // NameServer is a name server.
@ -39,8 +39,8 @@ func (ns NameServer) String() string {
addr := ns.Addr addr := ns.Addr
prot := ns.Protocol prot := ns.Protocol
host := ns.Hostname host := ns.Hostname
if !strings.Contains(addr, ":") { if _, port, _ := net.SplitHostPort(addr); port == "" {
addr += ":53" addr = net.JoinHostPort(addr, "53")
} }
if prot == "" { if prot == "" {
prot = "udp" prot = "udp"
@ -49,8 +49,8 @@ func (ns NameServer) String() string {
} }
type resolverCacheItem struct { type resolverCacheItem struct {
IPAddrs []net.IPAddr IPs []net.IP
ts int64 ts int64
} }
type resolver struct { type resolver struct {
@ -100,8 +100,8 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
var d net.Dialer var d net.Dialer
addr := ns.Addr addr := ns.Addr
if !strings.Contains(addr, ":") { if _, port, _ := net.SplitHostPort(addr); port == "" {
addr += ":53" addr = net.JoinHostPort(addr, "53")
} }
switch strings.ToLower(ns.Protocol) { switch strings.ToLower(ns.Protocol) {
case "tcp": case "tcp":
@ -125,28 +125,38 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
} }
} }
func (r *resolver) Resolve(name string) (addrs []net.IPAddr, err error) { func (r *resolver) Resolve(name string) (ips []net.IP, err error) {
if r == nil {
return
}
timeout := r.Timeout timeout := r.Timeout
addrs = r.loadCache(name) if ip := net.ParseIP(name); ip != nil {
if len(addrs) > 0 { return []net.IP{ip}, nil
}
ips = r.loadCache(name)
if len(ips) > 0 {
if Debug { if Debug {
log.Logf("[resolver] cache hit: %s %v", name, addrs) log.Logf("[resolver] cache hit: %s %v", name, ips)
} }
return return
} }
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
addrs, err = r.Resolver.LookupIPAddr(ctx, name) addrs, err := r.Resolver.LookupIPAddr(ctx, name)
r.storeCache(name, addrs) for _, addr := range addrs {
if len(addrs) > 0 && Debug { ips = append(ips, addr.IP)
log.Logf("[resolver] %s %v", name, addrs) }
r.storeCache(name, ips)
if len(ips) > 0 && Debug {
log.Logf("[resolver] %s %v", name, ips)
} }
return return
} }
func (r *resolver) loadCache(name string) []net.IPAddr { func (r *resolver) loadCache(name string) []net.IP {
ttl := r.TTL ttl := r.TTL
if ttl < 0 { if ttl < 0 {
return nil return nil
@ -157,20 +167,20 @@ func (r *resolver) loadCache(name string) []net.IPAddr {
if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl { if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl {
return nil return nil
} }
return item.IPAddrs return item.IPs
} }
return nil return nil
} }
func (r *resolver) storeCache(name string, addrs []net.IPAddr) { func (r *resolver) storeCache(name string, ips []net.IP) {
ttl := r.TTL ttl := r.TTL
if ttl < 0 || name == "" || len(addrs) == 0 { if ttl < 0 || name == "" || len(ips) == 0 {
return return
} }
r.mCache.Store(name, &resolverCacheItem{ r.mCache.Store(name, &resolverCacheItem{
IPAddrs: addrs, IPs: ips,
ts: time.Now().Unix(), ts: time.Now().Unix(),
}) })
} }
@ -180,8 +190,8 @@ func (r *resolver) String() string {
} }
b := &bytes.Buffer{} b := &bytes.Buffer{}
fmt.Fprintf(b, "timeout %v\n", r.Timeout) fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
fmt.Fprintf(b, "ttl %v\n", r.TTL) fmt.Fprintf(b, "TTL %v\n", r.TTL)
for i := range r.Servers { for i := range r.Servers {
fmt.Fprintln(b, r.Servers[i]) fmt.Fprintln(b, r.Servers[i])
} }

View File

@ -3,6 +3,7 @@ package gost
import ( import (
"io" "io"
"net" "net"
"sync"
"time" "time"
"github.com/go-log/log" "github.com/go-log/log"
@ -132,15 +133,29 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
return tc, nil return tc, nil
} }
var (
trPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
)
func transport(rw1, rw2 io.ReadWriter) error { func transport(rw1, rw2 io.ReadWriter) error {
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
_, err := io.Copy(rw1, rw2) buf := trPool.Get().([]byte)
defer trPool.Put(buf)
_, err := io.CopyBuffer(rw1, rw2, buf)
errc <- err errc <- err
}() }()
go func() { go func() {
_, err := io.Copy(rw2, rw1) buf := trPool.Get().([]byte)
defer trPool.Put(buf)
_, err := io.CopyBuffer(rw2, rw1, buf)
errc <- err errc <- err
}() }()