add static hosts support

This commit is contained in:
ginuerzh 2018-07-04 19:34:22 +08:00
parent e56fbfa809
commit 644d22d7c3
6 changed files with 208 additions and 39 deletions

View File

@ -18,6 +18,7 @@ var (
type Chain struct {
isRoute bool
Retries int
Hosts *Hosts
Resolver Resolver
nodeGroups []*NodeGroup
}
@ -124,18 +125,7 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
return nil, err
}
if c != nil && c.Resolver != nil {
host, port, err := net.SplitHostPort(addr)
if err == nil {
addrs, er := c.Resolver.Resolve(host)
if er != nil {
log.Logf("[resolver] %s: %v", host, er)
}
if len(addrs) > 0 {
addr = net.JoinHostPort(addrs[0].IP.String(), port)
}
}
}
addr = c.resolve(addr)
if route.IsEmpty() {
return net.DialTimeout("tcp", addr, DialTimeout)
@ -154,6 +144,27 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
return cc, nil
}
func (c *Chain) resolve(addr string) string {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
if ip := c.Hosts.Lookup(host); ip != nil {
return net.JoinHostPort(ip.String(), port)
}
if c.Resolver != nil {
ips, err := c.Resolver.Resolve(host)
if err != nil {
log.Logf("[resolver] %s: %v", host, err)
}
if len(ips) > 0 {
return net.JoinHostPort(ips[0].String(), port)
}
}
return addr
}
// Conn obtains a handshaked connection to the last node of the chain.
// If the chain is empty, it returns an ErrEmptyChain error.
func (c *Chain) Conn() (conn net.Conn, err error) {

14
cmd/gost/hosts.txt Normal file
View File

@ -0,0 +1,14 @@
# The following lines are desirable for IPv4 capable hosts
127.0.0.1 localhost
# 127.0.1.1 is often used for the FQDN of the machine
127.0.1.1 thishost.mydomain.org thishost
192.168.1.10 foo.mydomain.org foo
192.168.1.13 bar.mydomain.org bar
146.82.138.7 master.debian.org master
209.237.226.90 www.opensource.org
# The following lines are desirable for IPv6 capable hosts
::1 localhost ip6-localhost ip6-loopback
ff02::1 ip6-allnodes
ff02::2 ip6-allrouters

View File

@ -6,6 +6,8 @@ import (
"flag"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
"os"
"runtime"
"time"
@ -57,6 +59,9 @@ func init() {
}
func main() {
go func() {
log.Log(http.ListenAndServe("localhost:6060", nil))
}()
// NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate.
config, err := tlsConfig(defaultCertFile, defaultKeyFile)
if err != nil {
@ -336,12 +341,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
}
func (r *route) serve() error {
chain, err := r.initChain()
baseChain, err := r.initChain()
if err != nil {
return err
}
for _, ns := range r.ServeNodes {
chain := &gost.Chain{}
*chain = *baseChain
node, err := gost.ParseNode(ns)
if err != nil {
return err
@ -462,7 +470,6 @@ func (r *route) serve() error {
}
var handlerOptions []gost.HandlerOption
handlerOptions = append(handlerOptions,
gost.AddrHandlerOption(node.Addr),
gost.ChainHandlerOption(chain),
@ -516,6 +523,14 @@ func (r *route) serve() error {
if gost.Debug {
log.Logf("[resolver]\n%v", chain.Resolver)
}
if f, _ := os.Open(node.Get("hosts")); f != nil {
chain.Hosts, err = gost.ParseHosts(f)
if err != nil {
log.Logf("[hosts] %s: %v", f.Name(), err)
}
}
go srv.Serve(handler)
}

104
hosts.go Normal file
View File

@ -0,0 +1,104 @@
package gost
import (
"bufio"
"io"
"net"
"strings"
"github.com/go-log/log"
)
// Host is a static mapping from hostname to IP.
type Host struct {
IP net.IP
Hostname string
Aliases []string
}
// Hosts is a static table lookup for hostnames.
type Hosts struct {
hosts []Host
}
// NewHosts creates a Hosts with optional list of host
func NewHosts(hosts ...Host) *Hosts {
return &Hosts{
hosts: hosts,
}
}
// ParseHosts parses host table from r.
// For each host a single line should be present with the following information:
// IP_address canonical_hostname [aliases...]
// Fields of the entry are separated by any number of blanks and/or tab characters.
// Text from a "#" character until the end of the line is a comment, and is ignored.
func ParseHosts(r io.Reader) (*Hosts, error) {
hosts := NewHosts()
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if line == "" {
continue
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) < 2 {
continue // invalid lines are ignored
}
ip := net.ParseIP(ss[0])
if ip == nil {
continue // invalid IP addresses are ignored
}
host := Host{
IP: ip,
Hostname: ss[1],
}
if len(ss) > 2 {
host.Aliases = ss[2:]
}
hosts.AddHost(host)
}
if err := scanner.Err(); err != nil {
return nil, err
}
return hosts, nil
}
// AddHost adds host(s) to the host table.
func (h *Hosts) AddHost(host ...Host) {
h.hosts = append(h.hosts, host...)
}
// Lookup searches the IP address corresponds to the given host from the host table.
func (h *Hosts) Lookup(host string) (ip net.IP) {
if h == nil {
return
}
for _, h := range h.hosts {
if h.Hostname == host {
ip = h.IP
break
}
for _, alias := range h.Aliases {
if alias == host {
ip = h.IP
break
}
}
}
if ip != nil && Debug {
log.Logf("[hosts] hit: %s %s", host, ip.String())
}
return
}

View File

@ -24,7 +24,7 @@ var (
// It contains a list of name servers.
type Resolver interface {
// Resolve returns a slice of that host's IPv4 and IPv6 addresses.
Resolve(host string) ([]net.IPAddr, error)
Resolve(host string) ([]net.IP, error)
}
// NameServer is a name server.
@ -39,8 +39,8 @@ func (ns NameServer) String() string {
addr := ns.Addr
prot := ns.Protocol
host := ns.Hostname
if !strings.Contains(addr, ":") {
addr += ":53"
if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "53")
}
if prot == "" {
prot = "udp"
@ -49,8 +49,8 @@ func (ns NameServer) String() string {
}
type resolverCacheItem struct {
IPAddrs []net.IPAddr
ts int64
IPs []net.IP
ts int64
}
type resolver struct {
@ -100,8 +100,8 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
var d net.Dialer
addr := ns.Addr
if !strings.Contains(addr, ":") {
addr += ":53"
if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "53")
}
switch strings.ToLower(ns.Protocol) {
case "tcp":
@ -125,28 +125,38 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
}
}
func (r *resolver) Resolve(name string) (addrs []net.IPAddr, err error) {
func (r *resolver) Resolve(name string) (ips []net.IP, err error) {
if r == nil {
return
}
timeout := r.Timeout
addrs = r.loadCache(name)
if len(addrs) > 0 {
if ip := net.ParseIP(name); ip != nil {
return []net.IP{ip}, nil
}
ips = r.loadCache(name)
if len(ips) > 0 {
if Debug {
log.Logf("[resolver] cache hit: %s %v", name, addrs)
log.Logf("[resolver] cache hit: %s %v", name, ips)
}
return
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
addrs, err = r.Resolver.LookupIPAddr(ctx, name)
r.storeCache(name, addrs)
if len(addrs) > 0 && Debug {
log.Logf("[resolver] %s %v", name, addrs)
addrs, err := r.Resolver.LookupIPAddr(ctx, name)
for _, addr := range addrs {
ips = append(ips, addr.IP)
}
r.storeCache(name, ips)
if len(ips) > 0 && Debug {
log.Logf("[resolver] %s %v", name, ips)
}
return
}
func (r *resolver) loadCache(name string) []net.IPAddr {
func (r *resolver) loadCache(name string) []net.IP {
ttl := r.TTL
if ttl < 0 {
return nil
@ -157,20 +167,20 @@ func (r *resolver) loadCache(name string) []net.IPAddr {
if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl {
return nil
}
return item.IPAddrs
return item.IPs
}
return nil
}
func (r *resolver) storeCache(name string, addrs []net.IPAddr) {
func (r *resolver) storeCache(name string, ips []net.IP) {
ttl := r.TTL
if ttl < 0 || name == "" || len(addrs) == 0 {
if ttl < 0 || name == "" || len(ips) == 0 {
return
}
r.mCache.Store(name, &resolverCacheItem{
IPAddrs: addrs,
ts: time.Now().Unix(),
IPs: ips,
ts: time.Now().Unix(),
})
}
@ -180,8 +190,8 @@ func (r *resolver) String() string {
}
b := &bytes.Buffer{}
fmt.Fprintf(b, "timeout %v\n", r.Timeout)
fmt.Fprintf(b, "ttl %v\n", r.TTL)
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
fmt.Fprintf(b, "TTL %v\n", r.TTL)
for i := range r.Servers {
fmt.Fprintln(b, r.Servers[i])
}

View File

@ -3,6 +3,7 @@ package gost
import (
"io"
"net"
"sync"
"time"
"github.com/go-log/log"
@ -132,15 +133,29 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
return tc, nil
}
var (
trPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
)
func transport(rw1, rw2 io.ReadWriter) error {
errc := make(chan error, 1)
go func() {
_, err := io.Copy(rw1, rw2)
buf := trPool.Get().([]byte)
defer trPool.Put(buf)
_, err := io.CopyBuffer(rw1, rw2, buf)
errc <- err
}()
go func() {
_, err := io.Copy(rw2, rw1)
buf := trPool.Get().([]byte)
defer trPool.Put(buf)
_, err := io.CopyBuffer(rw2, rw1, buf)
errc <- err
}()