add static hosts support
This commit is contained in:
parent
e56fbfa809
commit
644d22d7c3
35
chain.go
35
chain.go
@ -18,6 +18,7 @@ var (
|
||||
type Chain struct {
|
||||
isRoute bool
|
||||
Retries int
|
||||
Hosts *Hosts
|
||||
Resolver Resolver
|
||||
nodeGroups []*NodeGroup
|
||||
}
|
||||
@ -124,18 +125,7 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c != nil && c.Resolver != nil {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err == nil {
|
||||
addrs, er := c.Resolver.Resolve(host)
|
||||
if er != nil {
|
||||
log.Logf("[resolver] %s: %v", host, er)
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
addr = net.JoinHostPort(addrs[0].IP.String(), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
addr = c.resolve(addr)
|
||||
|
||||
if route.IsEmpty() {
|
||||
return net.DialTimeout("tcp", addr, DialTimeout)
|
||||
@ -154,6 +144,27 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (c *Chain) resolve(addr string) string {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
if ip := c.Hosts.Lookup(host); ip != nil {
|
||||
return net.JoinHostPort(ip.String(), port)
|
||||
}
|
||||
if c.Resolver != nil {
|
||||
ips, err := c.Resolver.Resolve(host)
|
||||
if err != nil {
|
||||
log.Logf("[resolver] %s: %v", host, err)
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
return net.JoinHostPort(ips[0].String(), port)
|
||||
}
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// Conn obtains a handshaked connection to the last node of the chain.
|
||||
// If the chain is empty, it returns an ErrEmptyChain error.
|
||||
func (c *Chain) Conn() (conn net.Conn, err error) {
|
||||
|
14
cmd/gost/hosts.txt
Normal file
14
cmd/gost/hosts.txt
Normal file
@ -0,0 +1,14 @@
|
||||
# The following lines are desirable for IPv4 capable hosts
|
||||
127.0.0.1 localhost
|
||||
|
||||
# 127.0.1.1 is often used for the FQDN of the machine
|
||||
127.0.1.1 thishost.mydomain.org thishost
|
||||
192.168.1.10 foo.mydomain.org foo
|
||||
192.168.1.13 bar.mydomain.org bar
|
||||
146.82.138.7 master.debian.org master
|
||||
209.237.226.90 www.opensource.org
|
||||
|
||||
# The following lines are desirable for IPv6 capable hosts
|
||||
::1 localhost ip6-localhost ip6-loopback
|
||||
ff02::1 ip6-allnodes
|
||||
ff02::2 ip6-allrouters
|
@ -6,6 +6,8 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
@ -57,6 +59,9 @@ func init() {
|
||||
}
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
log.Log(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
// NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate.
|
||||
config, err := tlsConfig(defaultCertFile, defaultKeyFile)
|
||||
if err != nil {
|
||||
@ -336,12 +341,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
|
||||
}
|
||||
|
||||
func (r *route) serve() error {
|
||||
chain, err := r.initChain()
|
||||
baseChain, err := r.initChain()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, ns := range r.ServeNodes {
|
||||
chain := &gost.Chain{}
|
||||
*chain = *baseChain
|
||||
|
||||
node, err := gost.ParseNode(ns)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -462,7 +470,6 @@ func (r *route) serve() error {
|
||||
}
|
||||
|
||||
var handlerOptions []gost.HandlerOption
|
||||
|
||||
handlerOptions = append(handlerOptions,
|
||||
gost.AddrHandlerOption(node.Addr),
|
||||
gost.ChainHandlerOption(chain),
|
||||
@ -516,6 +523,14 @@ func (r *route) serve() error {
|
||||
if gost.Debug {
|
||||
log.Logf("[resolver]\n%v", chain.Resolver)
|
||||
}
|
||||
|
||||
if f, _ := os.Open(node.Get("hosts")); f != nil {
|
||||
chain.Hosts, err = gost.ParseHosts(f)
|
||||
if err != nil {
|
||||
log.Logf("[hosts] %s: %v", f.Name(), err)
|
||||
}
|
||||
}
|
||||
|
||||
go srv.Serve(handler)
|
||||
}
|
||||
|
||||
|
104
hosts.go
Normal file
104
hosts.go
Normal file
@ -0,0 +1,104 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/go-log/log"
|
||||
)
|
||||
|
||||
// Host is a static mapping from hostname to IP.
|
||||
type Host struct {
|
||||
IP net.IP
|
||||
Hostname string
|
||||
Aliases []string
|
||||
}
|
||||
|
||||
// Hosts is a static table lookup for hostnames.
|
||||
type Hosts struct {
|
||||
hosts []Host
|
||||
}
|
||||
|
||||
// NewHosts creates a Hosts with optional list of host
|
||||
func NewHosts(hosts ...Host) *Hosts {
|
||||
return &Hosts{
|
||||
hosts: hosts,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseHosts parses host table from r.
|
||||
// For each host a single line should be present with the following information:
|
||||
// IP_address canonical_hostname [aliases...]
|
||||
// Fields of the entry are separated by any number of blanks and/or tab characters.
|
||||
// Text from a "#" character until the end of the line is a comment, and is ignored.
|
||||
func ParseHosts(r io.Reader) (*Hosts, error) {
|
||||
hosts := NewHosts()
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if n := strings.IndexByte(line, '#'); n >= 0 {
|
||||
line = line[:n]
|
||||
}
|
||||
line = strings.Replace(line, "\t", " ", -1)
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var ss []string
|
||||
for _, s := range strings.Split(line, " ") {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
ss = append(ss, s)
|
||||
}
|
||||
}
|
||||
if len(ss) < 2 {
|
||||
continue // invalid lines are ignored
|
||||
}
|
||||
ip := net.ParseIP(ss[0])
|
||||
if ip == nil {
|
||||
continue // invalid IP addresses are ignored
|
||||
}
|
||||
host := Host{
|
||||
IP: ip,
|
||||
Hostname: ss[1],
|
||||
}
|
||||
if len(ss) > 2 {
|
||||
host.Aliases = ss[2:]
|
||||
}
|
||||
hosts.AddHost(host)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
// AddHost adds host(s) to the host table.
|
||||
func (h *Hosts) AddHost(host ...Host) {
|
||||
h.hosts = append(h.hosts, host...)
|
||||
}
|
||||
|
||||
// Lookup searches the IP address corresponds to the given host from the host table.
|
||||
func (h *Hosts) Lookup(host string) (ip net.IP) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
for _, h := range h.hosts {
|
||||
if h.Hostname == host {
|
||||
ip = h.IP
|
||||
break
|
||||
}
|
||||
for _, alias := range h.Aliases {
|
||||
if alias == host {
|
||||
ip = h.IP
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if ip != nil && Debug {
|
||||
log.Logf("[hosts] hit: %s %s", host, ip.String())
|
||||
}
|
||||
return
|
||||
}
|
52
resolver.go
52
resolver.go
@ -24,7 +24,7 @@ var (
|
||||
// It contains a list of name servers.
|
||||
type Resolver interface {
|
||||
// Resolve returns a slice of that host's IPv4 and IPv6 addresses.
|
||||
Resolve(host string) ([]net.IPAddr, error)
|
||||
Resolve(host string) ([]net.IP, error)
|
||||
}
|
||||
|
||||
// NameServer is a name server.
|
||||
@ -39,8 +39,8 @@ func (ns NameServer) String() string {
|
||||
addr := ns.Addr
|
||||
prot := ns.Protocol
|
||||
host := ns.Hostname
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr += ":53"
|
||||
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
||||
addr = net.JoinHostPort(addr, "53")
|
||||
}
|
||||
if prot == "" {
|
||||
prot = "udp"
|
||||
@ -49,7 +49,7 @@ func (ns NameServer) String() string {
|
||||
}
|
||||
|
||||
type resolverCacheItem struct {
|
||||
IPAddrs []net.IPAddr
|
||||
IPs []net.IP
|
||||
ts int64
|
||||
}
|
||||
|
||||
@ -100,8 +100,8 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
|
||||
addr := ns.Addr
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr += ":53"
|
||||
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
||||
addr = net.JoinHostPort(addr, "53")
|
||||
}
|
||||
switch strings.ToLower(ns.Protocol) {
|
||||
case "tcp":
|
||||
@ -125,28 +125,38 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) Resolve(name string) (addrs []net.IPAddr, err error) {
|
||||
func (r *resolver) Resolve(name string) (ips []net.IP, err error) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
timeout := r.Timeout
|
||||
|
||||
addrs = r.loadCache(name)
|
||||
if len(addrs) > 0 {
|
||||
if ip := net.ParseIP(name); ip != nil {
|
||||
return []net.IP{ip}, nil
|
||||
}
|
||||
|
||||
ips = r.loadCache(name)
|
||||
if len(ips) > 0 {
|
||||
if Debug {
|
||||
log.Logf("[resolver] cache hit: %s %v", name, addrs)
|
||||
log.Logf("[resolver] cache hit: %s %v", name, ips)
|
||||
}
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
addrs, err = r.Resolver.LookupIPAddr(ctx, name)
|
||||
r.storeCache(name, addrs)
|
||||
if len(addrs) > 0 && Debug {
|
||||
log.Logf("[resolver] %s %v", name, addrs)
|
||||
addrs, err := r.Resolver.LookupIPAddr(ctx, name)
|
||||
for _, addr := range addrs {
|
||||
ips = append(ips, addr.IP)
|
||||
}
|
||||
r.storeCache(name, ips)
|
||||
if len(ips) > 0 && Debug {
|
||||
log.Logf("[resolver] %s %v", name, ips)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *resolver) loadCache(name string) []net.IPAddr {
|
||||
func (r *resolver) loadCache(name string) []net.IP {
|
||||
ttl := r.TTL
|
||||
if ttl < 0 {
|
||||
return nil
|
||||
@ -157,19 +167,19 @@ func (r *resolver) loadCache(name string) []net.IPAddr {
|
||||
if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl {
|
||||
return nil
|
||||
}
|
||||
return item.IPAddrs
|
||||
return item.IPs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resolver) storeCache(name string, addrs []net.IPAddr) {
|
||||
func (r *resolver) storeCache(name string, ips []net.IP) {
|
||||
ttl := r.TTL
|
||||
if ttl < 0 || name == "" || len(addrs) == 0 {
|
||||
if ttl < 0 || name == "" || len(ips) == 0 {
|
||||
return
|
||||
}
|
||||
r.mCache.Store(name, &resolverCacheItem{
|
||||
IPAddrs: addrs,
|
||||
IPs: ips,
|
||||
ts: time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
@ -180,8 +190,8 @@ func (r *resolver) String() string {
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
fmt.Fprintf(b, "timeout %v\n", r.Timeout)
|
||||
fmt.Fprintf(b, "ttl %v\n", r.TTL)
|
||||
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
|
||||
fmt.Fprintf(b, "TTL %v\n", r.TTL)
|
||||
for i := range r.Servers {
|
||||
fmt.Fprintln(b, r.Servers[i])
|
||||
}
|
||||
|
19
server.go
19
server.go
@ -3,6 +3,7 @@ package gost
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-log/log"
|
||||
@ -132,15 +133,29 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
var (
|
||||
trPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 32*1024)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func transport(rw1, rw2 io.ReadWriter) error {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := io.Copy(rw1, rw2)
|
||||
buf := trPool.Get().([]byte)
|
||||
defer trPool.Put(buf)
|
||||
|
||||
_, err := io.CopyBuffer(rw1, rw2, buf)
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(rw2, rw1)
|
||||
buf := trPool.Get().([]byte)
|
||||
defer trPool.Put(buf)
|
||||
|
||||
_, err := io.CopyBuffer(rw2, rw1, buf)
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user