From 5e0e08d5b0918ad0258776c2d852d84f0c8ebe68 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 27 Nov 2018 10:08:18 +0800 Subject: [PATCH] live reloading for base config --- bypass.go | 2 +- cmd/gost/cfg.go | 67 +++++-- cmd/gost/main.go | 444 +++------------------------------------------- cmd/gost/peer.go | 27 +-- cmd/gost/route.go | 412 ++++++++++++++++++++++++++++++++++++++++++ reload.go | 9 +- 6 files changed, 516 insertions(+), 445 deletions(-) create mode 100644 cmd/gost/route.go diff --git a/bypass.go b/bypass.go index 5b98cfd..acd9be9 100644 --- a/bypass.go +++ b/bypass.go @@ -122,8 +122,8 @@ func (m *domainMatcher) String() string { // It contains a list of matchers. type Bypass struct { matchers []Matcher - reversed bool period time.Duration // the period for live reloading + reversed bool mux sync.RWMutex } diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index 4eb9dfc..e05e900 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/ioutil" "net/url" "os" @@ -51,35 +52,60 @@ func loadCA(caFile string) (cp *x509.CertPool, err error) { return } -func loadConfigureFile(configureFile string) error { - if configureFile == "" { - return nil - } - content, err := ioutil.ReadFile(configureFile) +type baseConfig struct { + route + Routes []route + ReloadPeriod string + Debug bool +} + +func parseBaseConfig(s string) (*baseConfig, error) { + file, err := os.Open(s) if err != nil { - return err + return nil, err } - var cfg struct { - route - Routes []route + defer file.Close() + + if err := json.NewDecoder(file).Decode(baseCfg); err != nil { + return nil, err } - if err := json.Unmarshal(content, &cfg); err != nil { + + return baseCfg, nil +} + +func (cfg *baseConfig) IsValid() bool { + return len(cfg.route.ServeNodes) > 0 +} + +func (cfg *baseConfig) Reload(r io.Reader) error { + c := baseConfig{} + if err := json.NewDecoder(r).Decode(&c); err != nil { return err } - if len(cfg.route.ServeNodes) > 0 { - routes = append(routes, cfg.route) - } - for _, route := range cfg.Routes { - if len(route.ServeNodes) > 0 { - routes = append(routes, route) - } + cfg.route.Close() + for _, r := range cfg.Routes { + r.Close() } + *cfg = c gost.Debug = cfg.Debug + if err := cfg.route.serve(); err != nil { + return err + } + for _, route := range cfg.Routes { + if err := route.serve(); err != nil { + return err + } + } return nil } +func (cfg *baseConfig) Period() time.Duration { + d, _ := time.ParseDuration(cfg.ReloadPeriod) + return d +} + type stringList []string func (l *stringList) String() string { @@ -240,3 +266,10 @@ func parseResolver(cfg string) gost.Resolver { return resolver } + +func parseHosts(s string) *gost.Hosts { + hosts := gost.NewHosts() + go gost.PeriodReload(hosts, s) + + return hosts +} diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 529f80c..8a585fa 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -1,69 +1,63 @@ package main import ( - "crypto/sha256" "crypto/tls" "flag" "fmt" - "net" - - // _ "net/http/pprof" "os" "runtime" - "time" + + // _ "net/http/pprof" "github.com/ginuerzh/gost" "github.com/go-log/log" ) var ( - options route - routes []route + configureFile string + baseCfg = &baseConfig{} ) func init() { gost.SetLogger(&gost.LogLogger{}) var ( - configureFile string - printVersion bool + printVersion bool ) - flag.Var(&options.ChainNodes, "F", "forward address, can make a forward chain") - flag.Var(&options.ServeNodes, "L", "listen address, can listen on multiple ports") + flag.Var(&baseCfg.route.ChainNodes, "F", "forward address, can make a forward chain") + flag.Var(&baseCfg.route.ServeNodes, "L", "listen address, can listen on multiple ports") flag.StringVar(&configureFile, "C", "", "configure file") - flag.BoolVar(&options.Debug, "D", false, "enable debug log") + flag.BoolVar(&baseCfg.Debug, "D", false, "enable debug log") flag.BoolVar(&printVersion, "V", false, "print version") flag.Parse() if printVersion { - fmt.Fprintf(os.Stderr, "gost %s (%s)\n", gost.Version, runtime.Version()) + fmt.Fprintf(os.Stderr, "gost %s (%s %s/%s)\n", + gost.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH) os.Exit(0) } - if len(options.ServeNodes) > 0 { - routes = append(routes, options) + if configureFile != "" { + _, err := parseBaseConfig(configureFile) + if err != nil { + log.Log(err) + os.Exit(1) + } } - gost.Debug = options.Debug - - if err := loadConfigureFile(configureFile); err != nil { - log.Log(err) - os.Exit(1) - } - - if flag.NFlag() == 0 || len(routes) == 0 { + if flag.NFlag() == 0 || !baseCfg.IsValid() { flag.PrintDefaults() os.Exit(0) } - } func main() { // go func() { // log.Log(http.ListenAndServe("localhost:6060", nil)) // }() + // NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate. - config, err := tlsConfig(defaultCertFile, defaultKeyFile) + tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile) if err != nil { // generate random self-signed certificate. cert, err := gost.GenCertificate() @@ -71,410 +65,30 @@ func main() { log.Log(err) os.Exit(1) } - config = &tls.Config{ + tlsConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, } } - gost.DefaultTLSConfig = config + gost.DefaultTLSConfig = tlsConfig - for _, route := range routes { - if err := route.serve(); err != nil { - log.Log(err) - os.Exit(1) - } - } + start() select {} } -type route struct { - ChainNodes, ServeNodes stringList - Retries int - Debug bool -} +func start() error { + gost.Debug = baseCfg.Debug -func (r *route) initChain() (*gost.Chain, error) { - chain := gost.NewChain() - chain.Retries = r.Retries - gid := 1 // group ID - - for _, ns := range r.ChainNodes { - ngroup := gost.NewNodeGroup() - ngroup.ID = gid - gid++ - - // parse the base nodes - nodes, err := parseChainNode(ns) - if err != nil { - return nil, err - } - - nid := 1 // node ID - for i := range nodes { - nodes[i].ID = nid - nid++ - } - ngroup.AddNode(nodes...) - - go gost.PeriodReload(&peerConfig{ - group: ngroup, - baseNodes: nodes, - }, nodes[0].Get("peer")) - - chain.AddNodeGroup(ngroup) - } - - return chain, nil -} - -func parseChainNode(ns string) (nodes []gost.Node, err error) { - node, err := gost.ParseNode(ns) - if err != nil { - return - } - - users, err := parseUsers(node.Get("secrets")) - if err != nil { - return - } - if node.User == nil && len(users) > 0 { - node.User = users[0] - } - serverName, sport, _ := net.SplitHostPort(node.Addr) - if serverName == "" { - serverName = "localhost" // default server name - } - - rootCAs, err := loadCA(node.Get("ca")) - if err != nil { - return - } - tlsCfg := &tls.Config{ - ServerName: serverName, - InsecureSkipVerify: !node.GetBool("secure"), - RootCAs: rootCAs, - } - wsOpts := &gost.WSOptions{} - wsOpts.EnableCompression = node.GetBool("compression") - wsOpts.ReadBufferSize = node.GetInt("rbuf") - wsOpts.WriteBufferSize = node.GetInt("wbuf") - wsOpts.UserAgent = node.Get("agent") - - var tr gost.Transporter - switch node.Transport { - case "tls": - tr = gost.TLSTransporter() - case "mtls": - tr = gost.MTLSTransporter() - case "ws": - tr = gost.WSTransporter(wsOpts) - case "mws": - tr = gost.MWSTransporter(wsOpts) - case "wss": - tr = gost.WSSTransporter(wsOpts) - case "mwss": - tr = gost.MWSSTransporter(wsOpts) - case "kcp": - config, err := parseKCPConfig(node.Get("c")) - if err != nil { - return nil, err - } - tr = gost.KCPTransporter(config) - case "ssh": - if node.Protocol == "direct" || node.Protocol == "remote" { - tr = gost.SSHForwardTransporter() - } else { - tr = gost.SSHTunnelTransporter() - } - case "quic": - config := &gost.QUICConfig{ - TLSConfig: tlsCfg, - KeepAlive: node.GetBool("keepalive"), - Timeout: time.Duration(node.GetInt("timeout")) * time.Second, - IdleTimeout: time.Duration(node.GetInt("idle")) * time.Second, - } - - if cipher := node.Get("cipher"); cipher != "" { - sum := sha256.Sum256([]byte(cipher)) - config.Key = sum[:] - } - - tr = gost.QUICTransporter(config) - case "http2": - tr = gost.HTTP2Transporter(tlsCfg) - case "h2": - tr = gost.H2Transporter(tlsCfg) - case "h2c": - tr = gost.H2CTransporter() - - case "obfs4": - tr = gost.Obfs4Transporter() - case "ohttp": - tr = gost.ObfsHTTPTransporter() - default: - tr = gost.TCPTransporter() - } - - var connector gost.Connector - switch node.Protocol { - case "http2": - connector = gost.HTTP2Connector(node.User) - case "socks", "socks5": - connector = gost.SOCKS5Connector(node.User) - case "socks4": - connector = gost.SOCKS4Connector() - case "socks4a": - connector = gost.SOCKS4AConnector() - case "ss": - connector = gost.ShadowConnector(node.User) - case "direct": - connector = gost.SSHDirectForwardConnector() - case "remote": - connector = gost.SSHRemoteForwardConnector() - case "forward": - connector = gost.ForwardConnector() - case "sni": - connector = gost.SNIConnector(node.Get("host")) - case "http": - fallthrough - default: - node.Protocol = "http" // default protocol is HTTP - connector = gost.HTTPConnector(node.User) - } - - timeout := node.GetInt("timeout") - node.DialOptions = append(node.DialOptions, - gost.TimeoutDialOption(time.Duration(timeout)*time.Second), - ) - - handshakeOptions := []gost.HandshakeOption{ - gost.AddrHandshakeOption(node.Addr), - gost.HostHandshakeOption(node.Host), - gost.UserHandshakeOption(node.User), - gost.TLSConfigHandshakeOption(tlsCfg), - gost.IntervalHandshakeOption(time.Duration(node.GetInt("ping")) * time.Second), - gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second), - gost.RetryHandshakeOption(node.GetInt("retry")), - } - node.Client = &gost.Client{ - Connector: connector, - Transporter: tr, - } - - node.Bypass = parseBypass(node.Get("bypass")) - - ips := parseIP(node.Get("ip"), sport) - for _, ip := range ips { - node.Addr = ip - // override the default node address - node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip)) - // One node per IP - nodes = append(nodes, node) - } - if len(ips) == 0 { - node.HandshakeOptions = handshakeOptions - nodes = []gost.Node{node} - } - - if node.Transport == "obfs4" { - for i := range nodes { - if err := gost.Obfs4Init(nodes[i], false); err != nil { - return nil, err - } - } - } - - return -} - -func (r *route) serve() error { - chain, err := r.initChain() - if err != nil { + if err := baseCfg.route.serve(); err != nil { return err } - - for _, ns := range r.ServeNodes { - node, err := gost.ParseNode(ns) - if err != nil { + for _, route := range baseCfg.Routes { + if err := route.serve(); err != nil { return err } - users, err := parseUsers(node.Get("secrets")) - if err != nil { - return err - } - if node.User != nil { - users = append(users, node.User) - } - certFile, keyFile := node.Get("cert"), node.Get("key") - tlsCfg, err := tlsConfig(certFile, keyFile) - if err != nil && certFile != "" && keyFile != "" { - return err - } - - wsOpts := &gost.WSOptions{} - wsOpts.EnableCompression = node.GetBool("compression") - wsOpts.ReadBufferSize = node.GetInt("rbuf") - wsOpts.WriteBufferSize = node.GetInt("wbuf") - - var ln gost.Listener - switch node.Transport { - case "tls": - ln, err = gost.TLSListener(node.Addr, tlsCfg) - case "mtls": - ln, err = gost.MTLSListener(node.Addr, tlsCfg) - case "ws": - wsOpts.WriteBufferSize = node.GetInt("wbuf") - ln, err = gost.WSListener(node.Addr, wsOpts) - case "mws": - ln, err = gost.MWSListener(node.Addr, wsOpts) - case "wss": - ln, err = gost.WSSListener(node.Addr, tlsCfg, wsOpts) - case "mwss": - ln, err = gost.MWSSListener(node.Addr, tlsCfg, wsOpts) - case "kcp": - config, er := parseKCPConfig(node.Get("c")) - if er != nil { - return er - } - ln, err = gost.KCPListener(node.Addr, config) - case "ssh": - config := &gost.SSHConfig{ - Users: users, - TLSConfig: tlsCfg, - } - if node.Protocol == "forward" { - ln, err = gost.TCPListener(node.Addr) - } else { - ln, err = gost.SSHTunnelListener(node.Addr, config) - } - case "quic": - config := &gost.QUICConfig{ - TLSConfig: tlsCfg, - KeepAlive: node.GetBool("keepalive"), - Timeout: time.Duration(node.GetInt("timeout")) * time.Second, - IdleTimeout: time.Duration(node.GetInt("idle")) * time.Second, - } - if cipher := node.Get("cipher"); cipher != "" { - sum := sha256.Sum256([]byte(cipher)) - config.Key = sum[:] - } - - ln, err = gost.QUICListener(node.Addr, config) - case "http2": - ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) - case "h2": - ln, err = gost.H2Listener(node.Addr, tlsCfg) - case "h2c": - ln, err = gost.H2CListener(node.Addr) - case "tcp": - // Directly use SSH port forwarding if the last chain node is forward+ssh - if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { - chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHDirectForwardConnector() - chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() - } - ln, err = gost.TCPListener(node.Addr) - case "rtcp": - // Directly use SSH port forwarding if the last chain node is forward+ssh - if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { - chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHRemoteForwardConnector() - chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() - } - ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) - case "udp": - ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(node.GetInt("ttl"))*time.Second) - case "rudp": - ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(node.GetInt("ttl"))*time.Second) - case "ssu": - ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(node.GetInt("ttl"))*time.Second) - case "obfs4": - if err = gost.Obfs4Init(node, true); err != nil { - return err - } - ln, err = gost.Obfs4Listener(node.Addr) - case "ohttp": - ln, err = gost.ObfsHTTPListener(node.Addr) - default: - ln, err = gost.TCPListener(node.Addr) - } - if err != nil { - return err - } - - var handler gost.Handler - switch node.Protocol { - case "http2": - handler = gost.HTTP2Handler() - case "socks", "socks5": - handler = gost.SOCKS5Handler() - case "socks4", "socks4a": - handler = gost.SOCKS4Handler() - case "ss": - handler = gost.ShadowHandler() - case "http": - handler = gost.HTTPHandler() - case "tcp": - handler = gost.TCPDirectForwardHandler(node.Remote) - case "rtcp": - handler = gost.TCPRemoteForwardHandler(node.Remote) - case "udp": - handler = gost.UDPDirectForwardHandler(node.Remote) - case "rudp": - handler = gost.UDPRemoteForwardHandler(node.Remote) - case "forward": - handler = gost.SSHForwardHandler() - case "redirect": - handler = gost.TCPRedirectHandler() - case "ssu": - handler = gost.ShadowUDPdHandler() - case "sni": - handler = gost.SNIHandler() - default: - // start from 2.5, if remote is not empty, then we assume that it is a forward tunnel. - if node.Remote != "" { - handler = gost.TCPDirectForwardHandler(node.Remote) - } else { - handler = gost.AutoHandler() - } - } - - var whitelist, blacklist *gost.Permissions - if node.Values.Get("whitelist") != "" { - if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil { - return err - } - } - if node.Values.Get("blacklist") != "" { - if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil { - return err - } - } - - var hosts *gost.Hosts - if f, _ := os.Open(node.Get("hosts")); f != nil { - f.Close() - hosts = gost.NewHosts() - go gost.PeriodReload(hosts, node.Get("hosts")) - } - - handler.Init( - gost.AddrHandlerOption(node.Addr), - gost.ChainHandlerOption(chain), - gost.UsersHandlerOption(users...), - gost.TLSConfigHandlerOption(tlsCfg), - gost.WhitelistHandlerOption(whitelist), - gost.BlacklistHandlerOption(blacklist), - gost.BypassHandlerOption(parseBypass(node.Get("bypass"))), - gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))), - gost.ResolverHandlerOption(parseResolver(node.Get("dns"))), - gost.HostsHandlerOption(hosts), - gost.RetryHandlerOption(node.GetInt("retry")), - gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second), - gost.ProbeResistHandlerOption(node.Get("probe_resist")), - ) - - srv := &gost.Server{Listener: ln} - go srv.Serve(handler) } + go gost.PeriodReload(baseCfg, configureFile) + return nil } diff --git a/cmd/gost/peer.go b/cmd/gost/peer.go index a0e43a7..f615d73 100644 --- a/cmd/gost/peer.go +++ b/cmd/gost/peer.go @@ -13,6 +13,11 @@ import ( "github.com/ginuerzh/gost" ) +const ( + defaultMaxFails = 1 + defaultFailTimeout = 30 * time.Second +) + type peerConfig struct { Strategy string `json:"strategy"` MaxFails int `json:"max_fails"` @@ -39,10 +44,10 @@ func parsePeerConfig(cfg string, group *gost.NodeGroup, baseNodes []gost.Node) * func (cfg *peerConfig) Validate() { if cfg.MaxFails <= 0 { - cfg.MaxFails = 1 + cfg.MaxFails = defaultMaxFails } if cfg.FailTimeout <= 0 { - cfg.FailTimeout = 30 // seconds + cfg.FailTimeout = defaultFailTimeout // seconds } } @@ -53,20 +58,22 @@ func (cfg *peerConfig) Reload(r io.Reader) error { cfg.Validate() group := cfg.group - strategy := cfg.Strategy - if len(cfg.baseNodes) > 0 { - // overwrite the strategry in the peer config if `strategy` param exists. - if s := cfg.baseNodes[0].Get("strategy"); s != "" { - strategy = s + /* + strategy := cfg.Strategy + if len(cfg.baseNodes) > 0 { + // overwrite the strategry in the peer config if `strategy` param exists. + if s := cfg.baseNodes[0].Get("strategy"); s != "" { + strategy = s + } } - } + */ group.SetSelector( nil, gost.WithFilter(&gost.FailFilter{ MaxFails: cfg.MaxFails, - FailTimeout: time.Duration(cfg.FailTimeout) * time.Second, + FailTimeout: cfg.FailTimeout, }), - gost.WithStrategy(parseStrategy(strategy)), + gost.WithStrategy(parseStrategy(cfg.Strategy)), ) gNodes := cfg.baseNodes diff --git a/cmd/gost/route.go b/cmd/gost/route.go new file mode 100644 index 0000000..28318e9 --- /dev/null +++ b/cmd/gost/route.go @@ -0,0 +1,412 @@ +package main + +import ( + "crypto/sha256" + "crypto/tls" + "net" + "time" + + "github.com/ginuerzh/gost" +) + +type route struct { + ServeNodes stringList + ChainNodes stringList + Retries int + server *gost.Server +} + +func (r *route) initChain() (*gost.Chain, error) { + chain := gost.NewChain() + chain.Retries = r.Retries + gid := 1 // group ID + + for _, ns := range r.ChainNodes { + ngroup := gost.NewNodeGroup() + ngroup.ID = gid + gid++ + + // parse the base nodes + nodes, err := parseChainNode(ns) + if err != nil { + return nil, err + } + + nid := 1 // node ID + for i := range nodes { + nodes[i].ID = nid + nid++ + } + ngroup.AddNode(nodes...) + + ngroup.SetSelector(nil, + gost.WithFilter(&gost.FailFilter{ + MaxFails: defaultMaxFails, + FailTimeout: defaultFailTimeout, + }), + gost.WithStrategy(parseStrategy(nodes[0].Get("strategy"))), + ) + + go gost.PeriodReload(&peerConfig{ + group: ngroup, + baseNodes: nodes, + }, nodes[0].Get("peer")) + + chain.AddNodeGroup(ngroup) + } + + return chain, nil +} + +func parseChainNode(ns string) (nodes []gost.Node, err error) { + node, err := gost.ParseNode(ns) + if err != nil { + return + } + + users, err := parseUsers(node.Get("secrets")) + if err != nil { + return + } + if node.User == nil && len(users) > 0 { + node.User = users[0] + } + serverName, sport, _ := net.SplitHostPort(node.Addr) + if serverName == "" { + serverName = "localhost" // default server name + } + + rootCAs, err := loadCA(node.Get("ca")) + if err != nil { + return + } + tlsCfg := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: !node.GetBool("secure"), + RootCAs: rootCAs, + } + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = node.GetBool("compression") + wsOpts.ReadBufferSize = node.GetInt("rbuf") + wsOpts.WriteBufferSize = node.GetInt("wbuf") + wsOpts.UserAgent = node.Get("agent") + + var tr gost.Transporter + switch node.Transport { + case "tls": + tr = gost.TLSTransporter() + case "mtls": + tr = gost.MTLSTransporter() + case "ws": + tr = gost.WSTransporter(wsOpts) + case "mws": + tr = gost.MWSTransporter(wsOpts) + case "wss": + tr = gost.WSSTransporter(wsOpts) + case "mwss": + tr = gost.MWSSTransporter(wsOpts) + case "kcp": + config, err := parseKCPConfig(node.Get("c")) + if err != nil { + return nil, err + } + tr = gost.KCPTransporter(config) + case "ssh": + if node.Protocol == "direct" || node.Protocol == "remote" { + tr = gost.SSHForwardTransporter() + } else { + tr = gost.SSHTunnelTransporter() + } + case "quic": + config := &gost.QUICConfig{ + TLSConfig: tlsCfg, + KeepAlive: node.GetBool("keepalive"), + Timeout: time.Duration(node.GetInt("timeout")) * time.Second, + IdleTimeout: time.Duration(node.GetInt("idle")) * time.Second, + } + + if cipher := node.Get("cipher"); cipher != "" { + sum := sha256.Sum256([]byte(cipher)) + config.Key = sum[:] + } + + tr = gost.QUICTransporter(config) + case "http2": + tr = gost.HTTP2Transporter(tlsCfg) + case "h2": + tr = gost.H2Transporter(tlsCfg) + case "h2c": + tr = gost.H2CTransporter() + + case "obfs4": + tr = gost.Obfs4Transporter() + case "ohttp": + tr = gost.ObfsHTTPTransporter() + default: + tr = gost.TCPTransporter() + } + + var connector gost.Connector + switch node.Protocol { + case "http2": + connector = gost.HTTP2Connector(node.User) + case "socks", "socks5": + connector = gost.SOCKS5Connector(node.User) + case "socks4": + connector = gost.SOCKS4Connector() + case "socks4a": + connector = gost.SOCKS4AConnector() + case "ss": + connector = gost.ShadowConnector(node.User) + case "direct": + connector = gost.SSHDirectForwardConnector() + case "remote": + connector = gost.SSHRemoteForwardConnector() + case "forward": + connector = gost.ForwardConnector() + case "sni": + connector = gost.SNIConnector(node.Get("host")) + case "http": + fallthrough + default: + node.Protocol = "http" // default protocol is HTTP + connector = gost.HTTPConnector(node.User) + } + + timeout := node.GetInt("timeout") + node.DialOptions = append(node.DialOptions, + gost.TimeoutDialOption(time.Duration(timeout)*time.Second), + ) + + handshakeOptions := []gost.HandshakeOption{ + gost.AddrHandshakeOption(node.Addr), + gost.HostHandshakeOption(node.Host), + gost.UserHandshakeOption(node.User), + gost.TLSConfigHandshakeOption(tlsCfg), + gost.IntervalHandshakeOption(time.Duration(node.GetInt("ping")) * time.Second), + gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second), + gost.RetryHandshakeOption(node.GetInt("retry")), + } + node.Client = &gost.Client{ + Connector: connector, + Transporter: tr, + } + + node.Bypass = parseBypass(node.Get("bypass")) + + ips := parseIP(node.Get("ip"), sport) + for _, ip := range ips { + nd := node.Clone() + nd.Addr = ip + // override the default node address + nd.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip)) + // One node per IP + nodes = append(nodes, nd) + } + if len(ips) == 0 { + node.HandshakeOptions = handshakeOptions + nodes = []gost.Node{node} + } + + if node.Transport == "obfs4" { + for i := range nodes { + if err := gost.Obfs4Init(nodes[i], false); err != nil { + return nil, err + } + } + } + + return +} + +func (r *route) serve() error { + chain, err := r.initChain() + if err != nil { + return err + } + + for _, ns := range r.ServeNodes { + node, err := gost.ParseNode(ns) + if err != nil { + return err + } + users, err := parseUsers(node.Get("secrets")) + if err != nil { + return err + } + if node.User != nil { + users = append(users, node.User) + } + certFile, keyFile := node.Get("cert"), node.Get("key") + tlsCfg, err := tlsConfig(certFile, keyFile) + if err != nil && certFile != "" && keyFile != "" { + return err + } + + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = node.GetBool("compression") + wsOpts.ReadBufferSize = node.GetInt("rbuf") + wsOpts.WriteBufferSize = node.GetInt("wbuf") + + var ln gost.Listener + switch node.Transport { + case "tls": + ln, err = gost.TLSListener(node.Addr, tlsCfg) + case "mtls": + ln, err = gost.MTLSListener(node.Addr, tlsCfg) + case "ws": + wsOpts.WriteBufferSize = node.GetInt("wbuf") + ln, err = gost.WSListener(node.Addr, wsOpts) + case "mws": + ln, err = gost.MWSListener(node.Addr, wsOpts) + case "wss": + ln, err = gost.WSSListener(node.Addr, tlsCfg, wsOpts) + case "mwss": + ln, err = gost.MWSSListener(node.Addr, tlsCfg, wsOpts) + case "kcp": + config, er := parseKCPConfig(node.Get("c")) + if er != nil { + return er + } + ln, err = gost.KCPListener(node.Addr, config) + case "ssh": + config := &gost.SSHConfig{ + Users: users, + TLSConfig: tlsCfg, + } + if node.Protocol == "forward" { + ln, err = gost.TCPListener(node.Addr) + } else { + ln, err = gost.SSHTunnelListener(node.Addr, config) + } + case "quic": + config := &gost.QUICConfig{ + TLSConfig: tlsCfg, + KeepAlive: node.GetBool("keepalive"), + Timeout: time.Duration(node.GetInt("timeout")) * time.Second, + IdleTimeout: time.Duration(node.GetInt("idle")) * time.Second, + } + if cipher := node.Get("cipher"); cipher != "" { + sum := sha256.Sum256([]byte(cipher)) + config.Key = sum[:] + } + + ln, err = gost.QUICListener(node.Addr, config) + case "http2": + ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) + case "h2": + ln, err = gost.H2Listener(node.Addr, tlsCfg) + case "h2c": + ln, err = gost.H2CListener(node.Addr) + case "tcp": + // Directly use SSH port forwarding if the last chain node is forward+ssh + if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { + chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHDirectForwardConnector() + chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() + } + ln, err = gost.TCPListener(node.Addr) + case "rtcp": + // Directly use SSH port forwarding if the last chain node is forward+ssh + if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { + chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHRemoteForwardConnector() + chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() + } + ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) + case "udp": + ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(node.GetInt("ttl"))*time.Second) + case "rudp": + ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(node.GetInt("ttl"))*time.Second) + case "ssu": + ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(node.GetInt("ttl"))*time.Second) + case "obfs4": + if err = gost.Obfs4Init(node, true); err != nil { + return err + } + ln, err = gost.Obfs4Listener(node.Addr) + case "ohttp": + ln, err = gost.ObfsHTTPListener(node.Addr) + default: + ln, err = gost.TCPListener(node.Addr) + } + if err != nil { + return err + } + + var handler gost.Handler + switch node.Protocol { + case "http2": + handler = gost.HTTP2Handler() + case "socks", "socks5": + handler = gost.SOCKS5Handler() + case "socks4", "socks4a": + handler = gost.SOCKS4Handler() + case "ss": + handler = gost.ShadowHandler() + case "http": + handler = gost.HTTPHandler() + case "tcp": + handler = gost.TCPDirectForwardHandler(node.Remote) + case "rtcp": + handler = gost.TCPRemoteForwardHandler(node.Remote) + case "udp": + handler = gost.UDPDirectForwardHandler(node.Remote) + case "rudp": + handler = gost.UDPRemoteForwardHandler(node.Remote) + case "forward": + handler = gost.SSHForwardHandler() + case "redirect": + handler = gost.TCPRedirectHandler() + case "ssu": + handler = gost.ShadowUDPdHandler() + case "sni": + handler = gost.SNIHandler() + default: + // start from 2.5, if remote is not empty, then we assume that it is a forward tunnel. + if node.Remote != "" { + handler = gost.TCPDirectForwardHandler(node.Remote) + } else { + handler = gost.AutoHandler() + } + } + + var whitelist, blacklist *gost.Permissions + if node.Values.Get("whitelist") != "" { + if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil { + return err + } + } + if node.Values.Get("blacklist") != "" { + if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil { + return err + } + } + + handler.Init( + gost.AddrHandlerOption(node.Addr), + gost.ChainHandlerOption(chain), + gost.UsersHandlerOption(users...), + gost.TLSConfigHandlerOption(tlsCfg), + gost.WhitelistHandlerOption(whitelist), + gost.BlacklistHandlerOption(blacklist), + gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))), + gost.BypassHandlerOption(parseBypass(node.Get("bypass"))), + gost.ResolverHandlerOption(parseResolver(node.Get("dns"))), + gost.HostsHandlerOption(parseHosts(node.Get("hosts"))), + gost.RetryHandlerOption(node.GetInt("retry")), + gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second), + gost.ProbeResistHandlerOption(node.Get("probe_resist")), + ) + + r.server = &gost.Server{Listener: ln} + go r.server.Serve(handler) + } + + return nil +} + +func (r *route) Close() error { + if r == nil || r.server == nil { + return nil + } + return r.server.Close() +} diff --git a/reload.go b/reload.go index e5ced63..db79b8c 100644 --- a/reload.go +++ b/reload.go @@ -16,8 +16,11 @@ type Reloader interface { // PeriodReload reloads the config periodically according to the period of the reloader. func PeriodReload(r Reloader, configFile string) error { - var lastMod time.Time + if configFile == "" { + return nil + } + var lastMod time.Time for { f, err := os.Open(configFile) if err != nil { @@ -32,7 +35,9 @@ func PeriodReload(r Reloader, configFile string) error { mt := finfo.ModTime() if !mt.Equal(lastMod) { log.Log("[reload]", configFile) - r.Reload(f) + if err := r.Reload(f); err != nil { + log.Logf("[reload] %s: %s", configFile, err) + } lastMod = mt } f.Close()