add static hosts support
This commit is contained in:
parent
e56fbfa809
commit
644d22d7c3
35
chain.go
35
chain.go
@ -18,6 +18,7 @@ var (
|
|||||||
type Chain struct {
|
type Chain struct {
|
||||||
isRoute bool
|
isRoute bool
|
||||||
Retries int
|
Retries int
|
||||||
|
Hosts *Hosts
|
||||||
Resolver Resolver
|
Resolver Resolver
|
||||||
nodeGroups []*NodeGroup
|
nodeGroups []*NodeGroup
|
||||||
}
|
}
|
||||||
@ -124,18 +125,7 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if c != nil && c.Resolver != nil {
|
addr = c.resolve(addr)
|
||||||
host, port, err := net.SplitHostPort(addr)
|
|
||||||
if err == nil {
|
|
||||||
addrs, er := c.Resolver.Resolve(host)
|
|
||||||
if er != nil {
|
|
||||||
log.Logf("[resolver] %s: %v", host, er)
|
|
||||||
}
|
|
||||||
if len(addrs) > 0 {
|
|
||||||
addr = net.JoinHostPort(addrs[0].IP.String(), port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if route.IsEmpty() {
|
if route.IsEmpty() {
|
||||||
return net.DialTimeout("tcp", addr, DialTimeout)
|
return net.DialTimeout("tcp", addr, DialTimeout)
|
||||||
@ -154,6 +144,27 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
|
|||||||
return cc, nil
|
return cc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Chain) resolve(addr string) string {
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip := c.Hosts.Lookup(host); ip != nil {
|
||||||
|
return net.JoinHostPort(ip.String(), port)
|
||||||
|
}
|
||||||
|
if c.Resolver != nil {
|
||||||
|
ips, err := c.Resolver.Resolve(host)
|
||||||
|
if err != nil {
|
||||||
|
log.Logf("[resolver] %s: %v", host, err)
|
||||||
|
}
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return net.JoinHostPort(ips[0].String(), port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
// Conn obtains a handshaked connection to the last node of the chain.
|
// Conn obtains a handshaked connection to the last node of the chain.
|
||||||
// If the chain is empty, it returns an ErrEmptyChain error.
|
// If the chain is empty, it returns an ErrEmptyChain error.
|
||||||
func (c *Chain) Conn() (conn net.Conn, err error) {
|
func (c *Chain) Conn() (conn net.Conn, err error) {
|
||||||
|
14
cmd/gost/hosts.txt
Normal file
14
cmd/gost/hosts.txt
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# The following lines are desirable for IPv4 capable hosts
|
||||||
|
127.0.0.1 localhost
|
||||||
|
|
||||||
|
# 127.0.1.1 is often used for the FQDN of the machine
|
||||||
|
127.0.1.1 thishost.mydomain.org thishost
|
||||||
|
192.168.1.10 foo.mydomain.org foo
|
||||||
|
192.168.1.13 bar.mydomain.org bar
|
||||||
|
146.82.138.7 master.debian.org master
|
||||||
|
209.237.226.90 www.opensource.org
|
||||||
|
|
||||||
|
# The following lines are desirable for IPv6 capable hosts
|
||||||
|
::1 localhost ip6-localhost ip6-loopback
|
||||||
|
ff02::1 ip6-allnodes
|
||||||
|
ff02::2 ip6-allrouters
|
@ -6,6 +6,8 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
@ -57,6 +59,9 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
go func() {
|
||||||
|
log.Log(http.ListenAndServe("localhost:6060", nil))
|
||||||
|
}()
|
||||||
// NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate.
|
// NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate.
|
||||||
config, err := tlsConfig(defaultCertFile, defaultKeyFile)
|
config, err := tlsConfig(defaultCertFile, defaultKeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -336,12 +341,15 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *route) serve() error {
|
func (r *route) serve() error {
|
||||||
chain, err := r.initChain()
|
baseChain, err := r.initChain()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ns := range r.ServeNodes {
|
for _, ns := range r.ServeNodes {
|
||||||
|
chain := &gost.Chain{}
|
||||||
|
*chain = *baseChain
|
||||||
|
|
||||||
node, err := gost.ParseNode(ns)
|
node, err := gost.ParseNode(ns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -462,7 +470,6 @@ func (r *route) serve() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var handlerOptions []gost.HandlerOption
|
var handlerOptions []gost.HandlerOption
|
||||||
|
|
||||||
handlerOptions = append(handlerOptions,
|
handlerOptions = append(handlerOptions,
|
||||||
gost.AddrHandlerOption(node.Addr),
|
gost.AddrHandlerOption(node.Addr),
|
||||||
gost.ChainHandlerOption(chain),
|
gost.ChainHandlerOption(chain),
|
||||||
@ -516,6 +523,14 @@ func (r *route) serve() error {
|
|||||||
if gost.Debug {
|
if gost.Debug {
|
||||||
log.Logf("[resolver]\n%v", chain.Resolver)
|
log.Logf("[resolver]\n%v", chain.Resolver)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if f, _ := os.Open(node.Get("hosts")); f != nil {
|
||||||
|
chain.Hosts, err = gost.ParseHosts(f)
|
||||||
|
if err != nil {
|
||||||
|
log.Logf("[hosts] %s: %v", f.Name(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
go srv.Serve(handler)
|
go srv.Serve(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
104
hosts.go
Normal file
104
hosts.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package gost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-log/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Host is a static mapping from hostname to IP.
|
||||||
|
type Host struct {
|
||||||
|
IP net.IP
|
||||||
|
Hostname string
|
||||||
|
Aliases []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hosts is a static table lookup for hostnames.
|
||||||
|
type Hosts struct {
|
||||||
|
hosts []Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHosts creates a Hosts with optional list of host
|
||||||
|
func NewHosts(hosts ...Host) *Hosts {
|
||||||
|
return &Hosts{
|
||||||
|
hosts: hosts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseHosts parses host table from r.
|
||||||
|
// For each host a single line should be present with the following information:
|
||||||
|
// IP_address canonical_hostname [aliases...]
|
||||||
|
// Fields of the entry are separated by any number of blanks and/or tab characters.
|
||||||
|
// Text from a "#" character until the end of the line is a comment, and is ignored.
|
||||||
|
func ParseHosts(r io.Reader) (*Hosts, error) {
|
||||||
|
hosts := NewHosts()
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if n := strings.IndexByte(line, '#'); n >= 0 {
|
||||||
|
line = line[:n]
|
||||||
|
}
|
||||||
|
line = strings.Replace(line, "\t", " ", -1)
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var ss []string
|
||||||
|
for _, s := range strings.Split(line, " ") {
|
||||||
|
if s = strings.TrimSpace(s); s != "" {
|
||||||
|
ss = append(ss, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(ss) < 2 {
|
||||||
|
continue // invalid lines are ignored
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(ss[0])
|
||||||
|
if ip == nil {
|
||||||
|
continue // invalid IP addresses are ignored
|
||||||
|
}
|
||||||
|
host := Host{
|
||||||
|
IP: ip,
|
||||||
|
Hostname: ss[1],
|
||||||
|
}
|
||||||
|
if len(ss) > 2 {
|
||||||
|
host.Aliases = ss[2:]
|
||||||
|
}
|
||||||
|
hosts.AddHost(host)
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return hosts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHost adds host(s) to the host table.
|
||||||
|
func (h *Hosts) AddHost(host ...Host) {
|
||||||
|
h.hosts = append(h.hosts, host...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup searches the IP address corresponds to the given host from the host table.
|
||||||
|
func (h *Hosts) Lookup(host string) (ip net.IP) {
|
||||||
|
if h == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, h := range h.hosts {
|
||||||
|
if h.Hostname == host {
|
||||||
|
ip = h.IP
|
||||||
|
break
|
||||||
|
}
|
||||||
|
for _, alias := range h.Aliases {
|
||||||
|
if alias == host {
|
||||||
|
ip = h.IP
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ip != nil && Debug {
|
||||||
|
log.Logf("[hosts] hit: %s %s", host, ip.String())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
52
resolver.go
52
resolver.go
@ -24,7 +24,7 @@ var (
|
|||||||
// It contains a list of name servers.
|
// It contains a list of name servers.
|
||||||
type Resolver interface {
|
type Resolver interface {
|
||||||
// Resolve returns a slice of that host's IPv4 and IPv6 addresses.
|
// Resolve returns a slice of that host's IPv4 and IPv6 addresses.
|
||||||
Resolve(host string) ([]net.IPAddr, error)
|
Resolve(host string) ([]net.IP, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NameServer is a name server.
|
// NameServer is a name server.
|
||||||
@ -39,8 +39,8 @@ func (ns NameServer) String() string {
|
|||||||
addr := ns.Addr
|
addr := ns.Addr
|
||||||
prot := ns.Protocol
|
prot := ns.Protocol
|
||||||
host := ns.Hostname
|
host := ns.Hostname
|
||||||
if !strings.Contains(addr, ":") {
|
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
||||||
addr += ":53"
|
addr = net.JoinHostPort(addr, "53")
|
||||||
}
|
}
|
||||||
if prot == "" {
|
if prot == "" {
|
||||||
prot = "udp"
|
prot = "udp"
|
||||||
@ -49,7 +49,7 @@ func (ns NameServer) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type resolverCacheItem struct {
|
type resolverCacheItem struct {
|
||||||
IPAddrs []net.IPAddr
|
IPs []net.IP
|
||||||
ts int64
|
ts int64
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,8 +100,8 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
|
|||||||
var d net.Dialer
|
var d net.Dialer
|
||||||
|
|
||||||
addr := ns.Addr
|
addr := ns.Addr
|
||||||
if !strings.Contains(addr, ":") {
|
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
||||||
addr += ":53"
|
addr = net.JoinHostPort(addr, "53")
|
||||||
}
|
}
|
||||||
switch strings.ToLower(ns.Protocol) {
|
switch strings.ToLower(ns.Protocol) {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
@ -125,28 +125,38 @@ func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) Resolve(name string) (addrs []net.IPAddr, err error) {
|
func (r *resolver) Resolve(name string) (ips []net.IP, err error) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
timeout := r.Timeout
|
timeout := r.Timeout
|
||||||
|
|
||||||
addrs = r.loadCache(name)
|
if ip := net.ParseIP(name); ip != nil {
|
||||||
if len(addrs) > 0 {
|
return []net.IP{ip}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ips = r.loadCache(name)
|
||||||
|
if len(ips) > 0 {
|
||||||
if Debug {
|
if Debug {
|
||||||
log.Logf("[resolver] cache hit: %s %v", name, addrs)
|
log.Logf("[resolver] cache hit: %s %v", name, ips)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
addrs, err = r.Resolver.LookupIPAddr(ctx, name)
|
addrs, err := r.Resolver.LookupIPAddr(ctx, name)
|
||||||
r.storeCache(name, addrs)
|
for _, addr := range addrs {
|
||||||
if len(addrs) > 0 && Debug {
|
ips = append(ips, addr.IP)
|
||||||
log.Logf("[resolver] %s %v", name, addrs)
|
}
|
||||||
|
r.storeCache(name, ips)
|
||||||
|
if len(ips) > 0 && Debug {
|
||||||
|
log.Logf("[resolver] %s %v", name, ips)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) loadCache(name string) []net.IPAddr {
|
func (r *resolver) loadCache(name string) []net.IP {
|
||||||
ttl := r.TTL
|
ttl := r.TTL
|
||||||
if ttl < 0 {
|
if ttl < 0 {
|
||||||
return nil
|
return nil
|
||||||
@ -157,19 +167,19 @@ func (r *resolver) loadCache(name string) []net.IPAddr {
|
|||||||
if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl {
|
if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return item.IPAddrs
|
return item.IPs
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolver) storeCache(name string, addrs []net.IPAddr) {
|
func (r *resolver) storeCache(name string, ips []net.IP) {
|
||||||
ttl := r.TTL
|
ttl := r.TTL
|
||||||
if ttl < 0 || name == "" || len(addrs) == 0 {
|
if ttl < 0 || name == "" || len(ips) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.mCache.Store(name, &resolverCacheItem{
|
r.mCache.Store(name, &resolverCacheItem{
|
||||||
IPAddrs: addrs,
|
IPs: ips,
|
||||||
ts: time.Now().Unix(),
|
ts: time.Now().Unix(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -180,8 +190,8 @@ func (r *resolver) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
fmt.Fprintf(b, "timeout %v\n", r.Timeout)
|
fmt.Fprintf(b, "Timeout %v\n", r.Timeout)
|
||||||
fmt.Fprintf(b, "ttl %v\n", r.TTL)
|
fmt.Fprintf(b, "TTL %v\n", r.TTL)
|
||||||
for i := range r.Servers {
|
for i := range r.Servers {
|
||||||
fmt.Fprintln(b, r.Servers[i])
|
fmt.Fprintln(b, r.Servers[i])
|
||||||
}
|
}
|
||||||
|
19
server.go
19
server.go
@ -3,6 +3,7 @@ package gost
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-log/log"
|
"github.com/go-log/log"
|
||||||
@ -132,15 +133,29 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
|||||||
return tc, nil
|
return tc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
trPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return make([]byte, 32*1024)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
func transport(rw1, rw2 io.ReadWriter) error {
|
func transport(rw1, rw2 io.ReadWriter) error {
|
||||||
errc := make(chan error, 1)
|
errc := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(rw1, rw2)
|
buf := trPool.Get().([]byte)
|
||||||
|
defer trPool.Put(buf)
|
||||||
|
|
||||||
|
_, err := io.CopyBuffer(rw1, rw2, buf)
|
||||||
errc <- err
|
errc <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(rw2, rw1)
|
buf := trPool.Get().([]byte)
|
||||||
|
defer trPool.Put(buf)
|
||||||
|
|
||||||
|
_, err := io.CopyBuffer(rw2, rw1, buf)
|
||||||
errc <- err
|
errc <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user