add support custom dns resolver

This commit is contained in:
ginuerzh 2018-05-19 11:12:18 +08:00
parent d39e211f3e
commit 0695bb5e9a
6 changed files with 477 additions and 252 deletions

2
.gitignore vendored
View File

@ -26,5 +26,5 @@ _testmain.go
*.bak
cmd/gost
cmd/gost/gost
snap

View File

@ -19,6 +19,7 @@ var (
type Chain struct {
isRoute bool
Retries int
Resolver Resolver
nodeGroups []*NodeGroup
}
@ -101,7 +102,15 @@ func (c *Chain) IsEmpty() bool {
// Dial connects to the target address addr through the chain.
// If the chain is empty, it will use the net.Dial directly.
func (c *Chain) Dial(addr string) (conn net.Conn, err error) {
for i := 0; i < c.Retries; i++ {
var retries int
if c != nil {
retries = c.Retries
}
if retries == 0 {
retries = 1
}
for i := 0; i < retries; i++ {
conn, err = c.dial(addr)
if err == nil {
break
@ -115,6 +124,18 @@ func (c *Chain) dial(addr string) (net.Conn, error) {
if err != nil {
return nil, err
}
if c != nil && c.Resolver != nil {
host, port, err := net.SplitHostPort(addr)
if err == nil {
addrs, _ := c.Resolver.Resolve(host)
log.Log(addr, addrs)
if len(addrs) > 0 {
addr = net.JoinHostPort(addrs[0].IP.String(), port)
}
}
}
if route.IsEmpty() {
return net.DialTimeout("tcp", addr, DialTimeout)
}
@ -204,7 +225,6 @@ func (c *Chain) selectRoute() (route *Chain, err error) {
buf := bytes.Buffer{}
route = newRoute()
route.Retries = c.Retries
for _, group := range c.nodeGroups {
node, err := group.Next()
@ -218,11 +238,13 @@ func (c *Chain) selectRoute() (route *Chain, err error) {
ChainDialOption(route),
)
route = newRoute() // cutoff the chain for multiplex.
route.Retries = c.Retries
}
route.AddNode(node)
}
route.Retries = c.Retries
route.Resolver = c.Resolver
if Debug {
log.Log("select route:", buf.String())
}
@ -237,7 +259,6 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
buf := bytes.Buffer{}
route = newRoute()
route.Retries = c.Retries
for _, group := range c.nodeGroups {
var node Node
@ -265,11 +286,14 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) {
ChainDialOption(route),
)
route = newRoute() // cutoff the chain for multiplex.
route.Retries = c.Retries
}
route.AddNode(node)
}
route.Retries = c.Retries
route.Resolver = c.Resolver
if Debug {
log.Log("select route:", buf.String())
}

323
cmd/gost/cfg.go Normal file
View File

@ -0,0 +1,323 @@
package main
import (
"bufio"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/ginuerzh/gost"
)
// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid.
func tlsConfig(certFile, keyFile string) (*tls.Config, error) {
if certFile == "" {
certFile = "cert.pem"
}
if keyFile == "" {
keyFile = "key.pem"
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
return &tls.Config{Certificates: []tls.Certificate{cert}}, nil
}
func loadCA(caFile string) (cp *x509.CertPool, err error) {
if caFile == "" {
return
}
cp = x509.NewCertPool()
data, err := ioutil.ReadFile(caFile)
if err != nil {
return nil, err
}
if !cp.AppendCertsFromPEM(data) {
return nil, errors.New("AppendCertsFromPEM failed")
}
return
}
func loadConfigureFile(configureFile string) error {
if configureFile == "" {
return nil
}
content, err := ioutil.ReadFile(configureFile)
if err != nil {
return err
}
var cfg struct {
route
Routes []route
}
if err := json.Unmarshal(content, &cfg); err != nil {
return err
}
if len(cfg.route.ServeNodes) > 0 {
routes = append(routes, cfg.route)
}
for _, route := range cfg.Routes {
if len(route.ServeNodes) > 0 {
routes = append(routes, route)
}
}
gost.Debug = cfg.Debug
return nil
}
type stringList []string
func (l *stringList) String() string {
return fmt.Sprintf("%s", *l)
}
func (l *stringList) Set(value string) error {
*l = append(*l, value)
return nil
}
func parseKCPConfig(configFile string) (*gost.KCPConfig, error) {
if configFile == "" {
return nil, nil
}
file, err := os.Open(configFile)
if err != nil {
return nil, err
}
defer file.Close()
config := &gost.KCPConfig{}
if err = json.NewDecoder(file).Decode(config); err != nil {
return nil, err
}
return config, nil
}
func parseUsers(authFile string) (users []*url.Userinfo, err error) {
if authFile == "" {
return
}
file, err := os.Open(authFile)
if err != nil {
return
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
s := strings.SplitN(line, " ", 2)
if len(s) == 1 {
users = append(users, url.User(strings.TrimSpace(s[0])))
} else if len(s) == 2 {
users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1])))
}
}
err = scanner.Err()
return
}
func parseIP(s string, port string) (ips []string) {
if s == "" {
return
}
if port == "" {
port = "8080" // default port
}
file, err := os.Open(s)
if err != nil {
ss := strings.Split(s, ",")
for _, s := range ss {
s = strings.TrimSpace(s)
if s != "" {
if !strings.Contains(s, ":") {
s = s + ":" + port
}
ips = append(ips, s)
}
}
return
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if !strings.Contains(line, ":") {
line = line + ":" + port
}
ips = append(ips, line)
}
return
}
type peerConfig struct {
Strategy string `json:"strategy"`
Filters []string `json:"filters"`
MaxFails int `json:"max_fails"`
FailTimeout int `json:"fail_timeout"`
Nodes []string `json:"nodes"`
Bypass *bypass `json:"bypass"` // global bypass
}
type bypass struct {
Reverse bool `json:"reverse"`
Patterns []string `json:"patterns"`
}
func loadPeerConfig(peer string) (config peerConfig, err error) {
if peer == "" {
return
}
content, err := ioutil.ReadFile(peer)
if err != nil {
return
}
err = json.Unmarshal(content, &config)
return
}
func (cfg *peerConfig) Validate() {
if cfg.MaxFails <= 0 {
cfg.MaxFails = 1
}
if cfg.FailTimeout <= 0 {
cfg.FailTimeout = 30 // seconds
}
}
func parseStrategy(s string) gost.Strategy {
switch s {
case "random":
return &gost.RandomStrategy{}
case "fifo":
return &gost.FIFOStrategy{}
case "round":
fallthrough
default:
return &gost.RoundStrategy{}
}
}
func parseBypass(s string) *gost.Bypass {
if s == "" {
return nil
}
var matchers []gost.Matcher
var reversed bool
if strings.HasPrefix(s, "~") {
reversed = true
s = strings.TrimLeft(s, "~")
}
f, err := os.Open(s)
if err != nil {
for _, s := range strings.Split(s, ",") {
s = strings.TrimSpace(s)
if s == "" {
continue
}
matchers = append(matchers, gost.NewMatcher(s))
}
return gost.NewBypass(matchers, reversed)
}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
matchers = append(matchers, gost.NewMatcher(line))
}
return gost.NewBypass(matchers, reversed)
}
func parseResolver(cfg string) gost.Resolver {
if cfg == "" {
return nil
}
f, err := os.Open(cfg)
if err != nil {
for _, s := range strings.Split(cfg, ",") {
s = strings.TrimSpace(s)
if s == "" {
continue
}
}
// return gost.NewBypass(matchers, reversed)
}
timeout := 30 * time.Second
var nss []gost.NameServer
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
if len(ss) == 0 {
continue
}
if ss[0] == "timeout" {
if len(ss) >= 2 {
if n, _ := strconv.Atoi(ss[1]); n > 0 {
timeout = time.Second * time.Duration(n)
}
}
continue
}
var ns gost.NameServer
if len(ss) == 1 {
ns.Addr = ss[0]
}
if len(ss) == 2 {
ns.Addr = ss[0]
ns.Protocol = ss[1]
}
if len(ss) == 3 {
ns.Addr = ss[0]
ns.Protocol = ss[1]
ns.Hostname = ss[2]
}
nss = append(nss, ns)
}
return gost.NewResolver(nss, timeout)
}

10
cmd/gost/dns.txt Normal file
View File

@ -0,0 +1,10 @@
# ip[:port] [protocol] [hostname]
# resolver timeout
timeout 10
1.1.1.1:853 tls cloudflare-dns.com
8.8.8.8
8.8.8.8 tcp
1.1.1.1 udp
1.1.1.1:53 tcp

View File

@ -1,20 +1,13 @@
package main
import (
"bufio"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"flag"
"fmt"
"io/ioutil"
"net"
"net/url"
"os"
"runtime"
"strings"
"time"
"github.com/ginuerzh/gost"
@ -506,247 +499,12 @@ func (r *route) serve() error {
srv.Init(
gost.BypassServerOption(parseBypass(node.Get("bypass"))),
)
chain.Resolver = parseResolver(node.Get("dns"))
log.Log(chain.Resolver)
go srv.Serve(handler)
}
return nil
}
// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid.
func tlsConfig(certFile, keyFile string) (*tls.Config, error) {
if certFile == "" {
certFile = "cert.pem"
}
if keyFile == "" {
keyFile = "key.pem"
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
return &tls.Config{Certificates: []tls.Certificate{cert}}, nil
}
func loadCA(caFile string) (cp *x509.CertPool, err error) {
if caFile == "" {
return
}
cp = x509.NewCertPool()
data, err := ioutil.ReadFile(caFile)
if err != nil {
return nil, err
}
if !cp.AppendCertsFromPEM(data) {
return nil, errors.New("AppendCertsFromPEM failed")
}
return
}
func loadConfigureFile(configureFile string) error {
if configureFile == "" {
return nil
}
content, err := ioutil.ReadFile(configureFile)
if err != nil {
return err
}
var cfg struct {
route
Routes []route
}
if err := json.Unmarshal(content, &cfg); err != nil {
return err
}
if len(cfg.route.ServeNodes) > 0 {
routes = append(routes, cfg.route)
}
for _, route := range cfg.Routes {
if len(route.ServeNodes) > 0 {
routes = append(routes, route)
}
}
gost.Debug = cfg.Debug
return nil
}
type stringList []string
func (l *stringList) String() string {
return fmt.Sprintf("%s", *l)
}
func (l *stringList) Set(value string) error {
*l = append(*l, value)
return nil
}
func parseKCPConfig(configFile string) (*gost.KCPConfig, error) {
if configFile == "" {
return nil, nil
}
file, err := os.Open(configFile)
if err != nil {
return nil, err
}
defer file.Close()
config := &gost.KCPConfig{}
if err = json.NewDecoder(file).Decode(config); err != nil {
return nil, err
}
return config, nil
}
func parseUsers(authFile string) (users []*url.Userinfo, err error) {
if authFile == "" {
return
}
file, err := os.Open(authFile)
if err != nil {
return
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
s := strings.SplitN(line, " ", 2)
if len(s) == 1 {
users = append(users, url.User(strings.TrimSpace(s[0])))
} else if len(s) == 2 {
users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1])))
}
}
err = scanner.Err()
return
}
func parseIP(s string, port string) (ips []string) {
if s == "" {
return
}
if port == "" {
port = "8080" // default port
}
file, err := os.Open(s)
if err != nil {
ss := strings.Split(s, ",")
for _, s := range ss {
s = strings.TrimSpace(s)
if s != "" {
if !strings.Contains(s, ":") {
s = s + ":" + port
}
ips = append(ips, s)
}
}
return
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if !strings.Contains(line, ":") {
line = line + ":" + port
}
ips = append(ips, line)
}
return
}
type peerConfig struct {
Strategy string `json:"strategy"`
Filters []string `json:"filters"`
MaxFails int `json:"max_fails"`
FailTimeout int `json:"fail_timeout"`
Nodes []string `json:"nodes"`
Bypass *bypass `json:"bypass"` // global bypass
}
type bypass struct {
Reverse bool `json:"reverse"`
Patterns []string `json:"patterns"`
}
func loadPeerConfig(peer string) (config peerConfig, err error) {
if peer == "" {
return
}
content, err := ioutil.ReadFile(peer)
if err != nil {
return
}
err = json.Unmarshal(content, &config)
return
}
func (cfg *peerConfig) Validate() {
if cfg.MaxFails <= 0 {
cfg.MaxFails = 1
}
if cfg.FailTimeout <= 0 {
cfg.FailTimeout = 30 // seconds
}
}
func parseStrategy(s string) gost.Strategy {
switch s {
case "random":
return &gost.RandomStrategy{}
case "fifo":
return &gost.FIFOStrategy{}
case "round":
fallthrough
default:
return &gost.RoundStrategy{}
}
}
func parseBypass(s string) *gost.Bypass {
if s == "" {
return nil
}
var matchers []gost.Matcher
var reversed bool
if strings.HasPrefix(s, "~") {
reversed = true
s = strings.TrimLeft(s, "~")
}
f, err := os.Open(s)
if err != nil {
for _, s := range strings.Split(s, ",") {
s = strings.TrimSpace(s)
if s == "" {
continue
}
matchers = append(matchers, gost.NewMatcher(s))
}
return gost.NewBypass(matchers, reversed)
}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if n := strings.IndexByte(line, '#'); n >= 0 {
line = line[:n]
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
matchers = append(matchers, gost.NewMatcher(line))
}
return gost.NewBypass(matchers, reversed)
}

110
resolver.go Normal file
View File

@ -0,0 +1,110 @@
package gost
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"net"
"time"
)
var (
// DefaultResolverTimeout is the default timeout for name resolution.
DefaultResolverTimeout = 30 * time.Second
)
// Resolver is a name resolver for domain name.
// It contains a list of name servers.
type Resolver interface {
// Resolve returns a slice of that host's IPv4 and IPv6 addresses.
Resolve(host string) ([]net.IPAddr, error)
}
// NameServer is a name server.
// Currently supported protocol: TCP, UDP and TLS.
type NameServer struct {
Addr string
Protocol string
Hostname string // for TLS handshake verification
}
type resolver struct {
Resolver *net.Resolver
Servers []NameServer
Timeout time.Duration
}
// NewResolver create a new Resolver with the given name servers and resolution timeout.
func NewResolver(servers []NameServer, timeout time.Duration) Resolver {
r := &resolver{
Servers: servers,
Timeout: timeout,
}
r.init()
return r
}
func (r *resolver) init() {
r.Resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) {
for _, ns := range r.Servers {
conn, err = r.dial(ctx, ns)
if err == nil {
break
}
}
return
},
}
}
func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) {
var d net.Dialer
switch ns.Protocol {
case "tcp":
return d.DialContext(ctx, "tcp", ns.Addr)
case "tls":
conn, err := d.DialContext(ctx, "tcp", ns.Addr)
if err != nil {
return nil, err
}
cfg := &tls.Config{
ServerName: ns.Hostname,
}
if cfg.ServerName == "" {
cfg.InsecureSkipVerify = true
}
return tls.Client(conn, cfg), nil
case "udp":
fallthrough
default:
return d.DialContext(ctx, "udp", ns.Addr)
}
}
func (r *resolver) Resolve(name string) ([]net.IPAddr, error) {
timeout := r.Timeout
if timeout <= 0 {
timeout = DefaultResolverTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return r.Resolver.LookupIPAddr(ctx, name)
}
func (r *resolver) String() string {
if r == nil {
return ""
}
b := &bytes.Buffer{}
fmt.Fprintf(b, "timeout %v\n", r.Timeout)
for i := range r.Servers {
fmt.Fprintf(b, "%s/%s %s\n", r.Servers[i].Addr, r.Servers[i].Protocol, r.Servers[i].Hostname)
}
return b.String()
}