diff --git a/.gitignore b/.gitignore index 2f239f7..fc7960b 100644 --- a/.gitignore +++ b/.gitignore @@ -26,5 +26,5 @@ _testmain.go *.bak -cmd/gost +cmd/gost/gost snap diff --git a/chain.go b/chain.go index b12857a..af72047 100644 --- a/chain.go +++ b/chain.go @@ -19,6 +19,7 @@ var ( type Chain struct { isRoute bool Retries int + Resolver Resolver nodeGroups []*NodeGroup } @@ -101,7 +102,15 @@ func (c *Chain) IsEmpty() bool { // Dial connects to the target address addr through the chain. // If the chain is empty, it will use the net.Dial directly. func (c *Chain) Dial(addr string) (conn net.Conn, err error) { - for i := 0; i < c.Retries; i++ { + var retries int + if c != nil { + retries = c.Retries + } + if retries == 0 { + retries = 1 + } + + for i := 0; i < retries; i++ { conn, err = c.dial(addr) if err == nil { break @@ -115,6 +124,18 @@ func (c *Chain) dial(addr string) (net.Conn, error) { if err != nil { return nil, err } + + if c != nil && c.Resolver != nil { + host, port, err := net.SplitHostPort(addr) + if err == nil { + addrs, _ := c.Resolver.Resolve(host) + log.Log(addr, addrs) + if len(addrs) > 0 { + addr = net.JoinHostPort(addrs[0].IP.String(), port) + } + } + } + if route.IsEmpty() { return net.DialTimeout("tcp", addr, DialTimeout) } @@ -204,7 +225,6 @@ func (c *Chain) selectRoute() (route *Chain, err error) { buf := bytes.Buffer{} route = newRoute() - route.Retries = c.Retries for _, group := range c.nodeGroups { node, err := group.Next() @@ -218,11 +238,13 @@ func (c *Chain) selectRoute() (route *Chain, err error) { ChainDialOption(route), ) route = newRoute() // cutoff the chain for multiplex. - route.Retries = c.Retries } route.AddNode(node) } + route.Retries = c.Retries + route.Resolver = c.Resolver + if Debug { log.Log("select route:", buf.String()) } @@ -237,7 +259,6 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { buf := bytes.Buffer{} route = newRoute() - route.Retries = c.Retries for _, group := range c.nodeGroups { var node Node @@ -265,11 +286,14 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { ChainDialOption(route), ) route = newRoute() // cutoff the chain for multiplex. - route.Retries = c.Retries } route.AddNode(node) } + + route.Retries = c.Retries + route.Resolver = c.Resolver + if Debug { log.Log("select route:", buf.String()) } diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go new file mode 100644 index 0000000..9031775 --- /dev/null +++ b/cmd/gost/cfg.go @@ -0,0 +1,323 @@ +package main + +import ( + "bufio" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/ginuerzh/gost" +) + +// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid. +func tlsConfig(certFile, keyFile string) (*tls.Config, error) { + if certFile == "" { + certFile = "cert.pem" + } + if keyFile == "" { + keyFile = "key.pem" + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + return &tls.Config{Certificates: []tls.Certificate{cert}}, nil +} + +func loadCA(caFile string) (cp *x509.CertPool, err error) { + if caFile == "" { + return + } + cp = x509.NewCertPool() + data, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + if !cp.AppendCertsFromPEM(data) { + return nil, errors.New("AppendCertsFromPEM failed") + } + return +} + +func loadConfigureFile(configureFile string) error { + if configureFile == "" { + return nil + } + content, err := ioutil.ReadFile(configureFile) + if err != nil { + return err + } + var cfg struct { + route + Routes []route + } + if err := json.Unmarshal(content, &cfg); err != nil { + return err + } + + if len(cfg.route.ServeNodes) > 0 { + routes = append(routes, cfg.route) + } + for _, route := range cfg.Routes { + if len(route.ServeNodes) > 0 { + routes = append(routes, route) + } + } + gost.Debug = cfg.Debug + + return nil +} + +type stringList []string + +func (l *stringList) String() string { + return fmt.Sprintf("%s", *l) +} +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} + +func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { + if configFile == "" { + return nil, nil + } + file, err := os.Open(configFile) + if err != nil { + return nil, err + } + defer file.Close() + + config := &gost.KCPConfig{} + if err = json.NewDecoder(file).Decode(config); err != nil { + return nil, err + } + return config, nil +} + +func parseUsers(authFile string) (users []*url.Userinfo, err error) { + if authFile == "" { + return + } + + file, err := os.Open(authFile) + if err != nil { + return + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + s := strings.SplitN(line, " ", 2) + if len(s) == 1 { + users = append(users, url.User(strings.TrimSpace(s[0]))) + } else if len(s) == 2 { + users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) + } + } + + err = scanner.Err() + return +} + +func parseIP(s string, port string) (ips []string) { + if s == "" { + return + } + if port == "" { + port = "8080" // default port + } + + file, err := os.Open(s) + if err != nil { + ss := strings.Split(s, ",") + for _, s := range ss { + s = strings.TrimSpace(s) + if s != "" { + if !strings.Contains(s, ":") { + s = s + ":" + port + } + ips = append(ips, s) + } + + } + return + } + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + if !strings.Contains(line, ":") { + line = line + ":" + port + } + ips = append(ips, line) + } + return +} + +type peerConfig struct { + Strategy string `json:"strategy"` + Filters []string `json:"filters"` + MaxFails int `json:"max_fails"` + FailTimeout int `json:"fail_timeout"` + Nodes []string `json:"nodes"` + Bypass *bypass `json:"bypass"` // global bypass +} + +type bypass struct { + Reverse bool `json:"reverse"` + Patterns []string `json:"patterns"` +} + +func loadPeerConfig(peer string) (config peerConfig, err error) { + if peer == "" { + return + } + content, err := ioutil.ReadFile(peer) + if err != nil { + return + } + err = json.Unmarshal(content, &config) + return +} + +func (cfg *peerConfig) Validate() { + if cfg.MaxFails <= 0 { + cfg.MaxFails = 1 + } + if cfg.FailTimeout <= 0 { + cfg.FailTimeout = 30 // seconds + } +} + +func parseStrategy(s string) gost.Strategy { + switch s { + case "random": + return &gost.RandomStrategy{} + case "fifo": + return &gost.FIFOStrategy{} + case "round": + fallthrough + default: + return &gost.RoundStrategy{} + + } +} + +func parseBypass(s string) *gost.Bypass { + if s == "" { + return nil + } + var matchers []gost.Matcher + var reversed bool + if strings.HasPrefix(s, "~") { + reversed = true + s = strings.TrimLeft(s, "~") + } + + f, err := os.Open(s) + if err != nil { + for _, s := range strings.Split(s, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + matchers = append(matchers, gost.NewMatcher(s)) + } + return gost.NewBypass(matchers, reversed) + } + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.TrimSpace(line) + if line == "" { + continue + } + matchers = append(matchers, gost.NewMatcher(line)) + } + return gost.NewBypass(matchers, reversed) +} + +func parseResolver(cfg string) gost.Resolver { + if cfg == "" { + return nil + } + f, err := os.Open(cfg) + if err != nil { + for _, s := range strings.Split(cfg, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + } + // return gost.NewBypass(matchers, reversed) + } + + timeout := 30 * time.Second + + var nss []gost.NameServer + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.TrimSpace(line) + if line == "" { + continue + } + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + + if len(ss) == 0 { + continue + } + + if ss[0] == "timeout" { + if len(ss) >= 2 { + if n, _ := strconv.Atoi(ss[1]); n > 0 { + timeout = time.Second * time.Duration(n) + } + } + continue + } + + var ns gost.NameServer + if len(ss) == 1 { + ns.Addr = ss[0] + } + if len(ss) == 2 { + ns.Addr = ss[0] + ns.Protocol = ss[1] + } + if len(ss) == 3 { + ns.Addr = ss[0] + ns.Protocol = ss[1] + ns.Hostname = ss[2] + } + nss = append(nss, ns) + } + return gost.NewResolver(nss, timeout) +} diff --git a/cmd/gost/dns.txt b/cmd/gost/dns.txt new file mode 100644 index 0000000..8d9d55b --- /dev/null +++ b/cmd/gost/dns.txt @@ -0,0 +1,10 @@ +# ip[:port] [protocol] [hostname] + +# resolver timeout +timeout 10 + +1.1.1.1:853 tls cloudflare-dns.com +8.8.8.8 +8.8.8.8 tcp +1.1.1.1 udp +1.1.1.1:53 tcp \ No newline at end of file diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 8abf2f3..dd99707 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -1,20 +1,13 @@ package main import ( - "bufio" "crypto/sha256" "crypto/tls" - "crypto/x509" - "encoding/json" - "errors" "flag" "fmt" - "io/ioutil" "net" - "net/url" "os" "runtime" - "strings" "time" "github.com/ginuerzh/gost" @@ -506,247 +499,12 @@ func (r *route) serve() error { srv.Init( gost.BypassServerOption(parseBypass(node.Get("bypass"))), ) + + chain.Resolver = parseResolver(node.Get("dns")) + log.Log(chain.Resolver) + go srv.Serve(handler) } return nil } - -// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid. -func tlsConfig(certFile, keyFile string) (*tls.Config, error) { - if certFile == "" { - certFile = "cert.pem" - } - if keyFile == "" { - keyFile = "key.pem" - } - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, err - } - return &tls.Config{Certificates: []tls.Certificate{cert}}, nil -} - -func loadCA(caFile string) (cp *x509.CertPool, err error) { - if caFile == "" { - return - } - cp = x509.NewCertPool() - data, err := ioutil.ReadFile(caFile) - if err != nil { - return nil, err - } - if !cp.AppendCertsFromPEM(data) { - return nil, errors.New("AppendCertsFromPEM failed") - } - return -} - -func loadConfigureFile(configureFile string) error { - if configureFile == "" { - return nil - } - content, err := ioutil.ReadFile(configureFile) - if err != nil { - return err - } - var cfg struct { - route - Routes []route - } - if err := json.Unmarshal(content, &cfg); err != nil { - return err - } - - if len(cfg.route.ServeNodes) > 0 { - routes = append(routes, cfg.route) - } - for _, route := range cfg.Routes { - if len(route.ServeNodes) > 0 { - routes = append(routes, route) - } - } - gost.Debug = cfg.Debug - - return nil -} - -type stringList []string - -func (l *stringList) String() string { - return fmt.Sprintf("%s", *l) -} -func (l *stringList) Set(value string) error { - *l = append(*l, value) - return nil -} - -func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { - if configFile == "" { - return nil, nil - } - file, err := os.Open(configFile) - if err != nil { - return nil, err - } - defer file.Close() - - config := &gost.KCPConfig{} - if err = json.NewDecoder(file).Decode(config); err != nil { - return nil, err - } - return config, nil -} - -func parseUsers(authFile string) (users []*url.Userinfo, err error) { - if authFile == "" { - return - } - - file, err := os.Open(authFile) - if err != nil { - return - } - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - s := strings.SplitN(line, " ", 2) - if len(s) == 1 { - users = append(users, url.User(strings.TrimSpace(s[0]))) - } else if len(s) == 2 { - users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) - } - } - - err = scanner.Err() - return -} - -func parseIP(s string, port string) (ips []string) { - if s == "" { - return - } - if port == "" { - port = "8080" // default port - } - - file, err := os.Open(s) - if err != nil { - ss := strings.Split(s, ",") - for _, s := range ss { - s = strings.TrimSpace(s) - if s != "" { - if !strings.Contains(s, ":") { - s = s + ":" + port - } - ips = append(ips, s) - } - - } - return - } - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - if !strings.Contains(line, ":") { - line = line + ":" + port - } - ips = append(ips, line) - } - return -} - -type peerConfig struct { - Strategy string `json:"strategy"` - Filters []string `json:"filters"` - MaxFails int `json:"max_fails"` - FailTimeout int `json:"fail_timeout"` - Nodes []string `json:"nodes"` - Bypass *bypass `json:"bypass"` // global bypass -} - -type bypass struct { - Reverse bool `json:"reverse"` - Patterns []string `json:"patterns"` -} - -func loadPeerConfig(peer string) (config peerConfig, err error) { - if peer == "" { - return - } - content, err := ioutil.ReadFile(peer) - if err != nil { - return - } - err = json.Unmarshal(content, &config) - return -} - -func (cfg *peerConfig) Validate() { - if cfg.MaxFails <= 0 { - cfg.MaxFails = 1 - } - if cfg.FailTimeout <= 0 { - cfg.FailTimeout = 30 // seconds - } -} - -func parseStrategy(s string) gost.Strategy { - switch s { - case "random": - return &gost.RandomStrategy{} - case "fifo": - return &gost.FIFOStrategy{} - case "round": - fallthrough - default: - return &gost.RoundStrategy{} - - } -} - -func parseBypass(s string) *gost.Bypass { - if s == "" { - return nil - } - var matchers []gost.Matcher - var reversed bool - if strings.HasPrefix(s, "~") { - reversed = true - s = strings.TrimLeft(s, "~") - } - - f, err := os.Open(s) - if err != nil { - for _, s := range strings.Split(s, ",") { - s = strings.TrimSpace(s) - if s == "" { - continue - } - matchers = append(matchers, gost.NewMatcher(s)) - } - return gost.NewBypass(matchers, reversed) - } - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Text() - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - line = strings.TrimSpace(line) - if line == "" { - continue - } - matchers = append(matchers, gost.NewMatcher(line)) - } - return gost.NewBypass(matchers, reversed) -} diff --git a/resolver.go b/resolver.go new file mode 100644 index 0000000..db06823 --- /dev/null +++ b/resolver.go @@ -0,0 +1,110 @@ +package gost + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "net" + "time" +) + +var ( + // DefaultResolverTimeout is the default timeout for name resolution. + DefaultResolverTimeout = 30 * time.Second +) + +// Resolver is a name resolver for domain name. +// It contains a list of name servers. +type Resolver interface { + // Resolve returns a slice of that host's IPv4 and IPv6 addresses. + Resolve(host string) ([]net.IPAddr, error) +} + +// NameServer is a name server. +// Currently supported protocol: TCP, UDP and TLS. +type NameServer struct { + Addr string + Protocol string + Hostname string // for TLS handshake verification +} + +type resolver struct { + Resolver *net.Resolver + Servers []NameServer + Timeout time.Duration +} + +// NewResolver create a new Resolver with the given name servers and resolution timeout. +func NewResolver(servers []NameServer, timeout time.Duration) Resolver { + r := &resolver{ + Servers: servers, + Timeout: timeout, + } + r.init() + return r +} + +func (r *resolver) init() { + r.Resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) { + for _, ns := range r.Servers { + conn, err = r.dial(ctx, ns) + if err == nil { + break + } + } + return + }, + } +} + +func (r *resolver) dial(ctx context.Context, ns NameServer) (net.Conn, error) { + var d net.Dialer + + switch ns.Protocol { + case "tcp": + return d.DialContext(ctx, "tcp", ns.Addr) + case "tls": + conn, err := d.DialContext(ctx, "tcp", ns.Addr) + if err != nil { + return nil, err + } + cfg := &tls.Config{ + ServerName: ns.Hostname, + } + if cfg.ServerName == "" { + cfg.InsecureSkipVerify = true + } + return tls.Client(conn, cfg), nil + case "udp": + fallthrough + default: + return d.DialContext(ctx, "udp", ns.Addr) + } +} + +func (r *resolver) Resolve(name string) ([]net.IPAddr, error) { + timeout := r.Timeout + if timeout <= 0 { + timeout = DefaultResolverTimeout + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + return r.Resolver.LookupIPAddr(ctx, name) +} + +func (r *resolver) String() string { + if r == nil { + return "" + } + + b := &bytes.Buffer{} + fmt.Fprintf(b, "timeout %v\n", r.Timeout) + for i := range r.Servers { + fmt.Fprintf(b, "%s/%s %s\n", r.Servers[i].Addr, r.Servers[i].Protocol, r.Servers[i].Hostname) + } + return b.String() +}