From dc4c78ca4449b84f72eb428df72e3100142115ea Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 29 Nov 2018 22:09:10 +0800 Subject: [PATCH] add stop for live reloading --- bypass.go | 31 ++++++++++++- chain.go | 5 ++- cmd/gost/cfg.go | 109 +++++++++++++++++---------------------------- cmd/gost/gost.json | 30 ------------- cmd/gost/main.go | 19 ++++++-- cmd/gost/peer.go | 64 +++++++++++++------------- cmd/gost/route.go | 98 +++++++++++++++++++++++++++++----------- hosts.go | 37 +++++++++++++-- http.go | 2 +- http2.go | 6 ++- node.go | 17 ++++--- node_test.go | 2 +- reload.go | 50 ++++++++++++++++----- resolver.go | 33 +++++++++++++- selector.go | 14 ++++++ server.go | 10 ----- ss.go | 5 ++- ws.go | 8 ++-- 18 files changed, 337 insertions(+), 203 deletions(-) delete mode 100644 cmd/gost/gost.json diff --git a/bypass.go b/bypass.go index acd9be9..a64a5fc 100644 --- a/bypass.go +++ b/bypass.go @@ -124,6 +124,7 @@ type Bypass struct { matchers []Matcher period time.Duration // the period for live reloading reversed bool + stopped chan struct{} mux sync.RWMutex } @@ -133,6 +134,7 @@ func NewBypass(reversed bool, matchers ...Matcher) *Bypass { return &Bypass{ matchers: matchers, reversed: reversed, + stopped: make(chan struct{}), } } @@ -207,6 +209,10 @@ func (bp *Bypass) Reload(r io.Reader) error { var period time.Duration var reversed bool + if bp.Stopped() { + return nil + } + scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() @@ -264,14 +270,37 @@ func (bp *Bypass) Reload(r io.Reader) error { return nil } -// Period returns the reload period +// Period returns the reload period. func (bp *Bypass) Period() time.Duration { + if bp.Stopped() { + return -1 + } + bp.mux.RLock() defer bp.mux.RUnlock() return bp.period } +// Stop stops reloading. +func (bp *Bypass) Stop() { + select { + case <-bp.stopped: + default: + close(bp.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (bp *Bypass) Stopped() bool { + select { + case <-bp.stopped: + return true + default: + return false + } +} + func (bp *Bypass) String() string { bp.mux.RLock() defer bp.mux.RUnlock() diff --git a/chain.go b/chain.go index 49fe439..dde827c 100644 --- a/chain.go +++ b/chain.go @@ -15,7 +15,7 @@ var ( ErrEmptyChain = errors.New("empty chain") ) -// Chain is a proxy chain that holds a list of proxy nodes. +// Chain is a proxy chain that holds a list of proxy node groups. type Chain struct { isRoute bool Retries int @@ -23,6 +23,7 @@ type Chain struct { } // NewChain creates a proxy chain with a list of proxy nodes. +// It creates the node groups automatically, one group per node. func NewChain(nodes ...Node) *Chain { chain := &Chain{} for _, node := range nodes { @@ -31,6 +32,8 @@ func NewChain(nodes ...Node) *Chain { return chain } +// newRoute creates a chain route. +// a chain route is the final route after node selection. func newRoute(nodes ...Node) *Chain { chain := NewChain(nodes...) chain.isRoute = true diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index e05e900..d78700a 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -6,8 +6,6 @@ import ( "crypto/x509" "encoding/json" "errors" - "fmt" - "io" "io/ioutil" "net/url" "os" @@ -17,6 +15,34 @@ import ( "github.com/ginuerzh/gost" ) +var ( + routers []router +) + +type baseConfig struct { + route + Routes []route + Debug bool +} + +func parseBaseConfig(s string) (*baseConfig, error) { + file, err := os.Open(s) + if err != nil { + return nil, err + } + defer file.Close() + + if err := json.NewDecoder(file).Decode(baseCfg); err != nil { + return nil, err + } + + return baseCfg, nil +} + +func (cfg *baseConfig) IsValid() bool { + return len(cfg.route.ServeNodes) > 0 +} + var ( defaultCertFile = "cert.pem" defaultKeyFile = "key.pem" @@ -52,70 +78,6 @@ func loadCA(caFile string) (cp *x509.CertPool, err error) { return } -type baseConfig struct { - route - Routes []route - ReloadPeriod string - Debug bool -} - -func parseBaseConfig(s string) (*baseConfig, error) { - file, err := os.Open(s) - if err != nil { - return nil, err - } - defer file.Close() - - if err := json.NewDecoder(file).Decode(baseCfg); err != nil { - return nil, err - } - - return baseCfg, nil -} - -func (cfg *baseConfig) IsValid() bool { - return len(cfg.route.ServeNodes) > 0 -} - -func (cfg *baseConfig) Reload(r io.Reader) error { - c := baseConfig{} - if err := json.NewDecoder(r).Decode(&c); err != nil { - return err - } - - cfg.route.Close() - for _, r := range cfg.Routes { - r.Close() - } - *cfg = c - gost.Debug = cfg.Debug - - if err := cfg.route.serve(); err != nil { - return err - } - for _, route := range cfg.Routes { - if err := route.serve(); err != nil { - return err - } - } - return nil -} - -func (cfg *baseConfig) Period() time.Duration { - d, _ := time.ParseDuration(cfg.ReloadPeriod) - return d -} - -type stringList []string - -func (l *stringList) String() string { - return fmt.Sprintf("%s", *l) -} -func (l *stringList) Set(value string) error { - *l = append(*l, value) - return nil -} - func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { if configFile == "" { return nil, nil @@ -221,9 +183,10 @@ func parseBypass(s string) *gost.Bypass { } return gost.NewBypass(reversed, matchers...) } - f.Close() + defer f.Close() bp := gost.NewBypass(reversed) + bp.Reload(f) go gost.PeriodReload(bp, s) return bp @@ -259,16 +222,26 @@ func parseResolver(cfg string) gost.Resolver { } return gost.NewResolver(timeout, ttl, nss...) } - f.Close() + defer f.Close() resolver := gost.NewResolver(timeout, ttl) + resolver.Reload(f) + go gost.PeriodReload(resolver, cfg) return resolver } func parseHosts(s string) *gost.Hosts { + f, err := os.Open(s) + if err != nil { + return nil + } + defer f.Close() + hosts := gost.NewHosts() + hosts.Reload(f) + go gost.PeriodReload(hosts, s) return hosts diff --git a/cmd/gost/gost.json b/cmd/gost/gost.json deleted file mode 100644 index e7217c9..0000000 --- a/cmd/gost/gost.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "Debug": false, - "Retries": 1, - "ServeNodes": [ - ":8080", - "ss://chacha20:12345678@:8338" - ], - "ChainNodes": [ - "http://192.168.1.1:8080", - "https://10.0.2.1:443" - ], - - "Routes": [ - { - "Retries": 1, - "ServeNodes": [ - "ws://:1443" - ], - "ChainNodes": [ - "socks://:192.168.1.1:1080" - ] - }, - { - "Retries": 3, - "ServeNodes": [ - "quic://:443" - ] - } - ] -} \ No newline at end of file diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 8a585fa..32abfaa 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -71,7 +71,10 @@ func main() { } gost.DefaultTLSConfig = tlsConfig - start() + if err := start(); err != nil { + log.Log(err) + os.Exit(1) + } select {} } @@ -79,16 +82,24 @@ func main() { func start() error { gost.Debug = baseCfg.Debug - if err := baseCfg.route.serve(); err != nil { + var routers []router + rts, err := baseCfg.route.GenRouters() + if err != nil { return err } + routers = append(routers, rts...) + for _, route := range baseCfg.Routes { - if err := route.serve(); err != nil { + rts, err := route.GenRouters() + if err != nil { return err } + routers = append(routers, rts...) } - go gost.PeriodReload(baseCfg, configureFile) + for i := range routers { + go routers[i].Serve() + } return nil } diff --git a/cmd/gost/peer.go b/cmd/gost/peer.go index f615d73..cec083b 100644 --- a/cmd/gost/peer.go +++ b/cmd/gost/peer.go @@ -26,20 +26,13 @@ type peerConfig struct { Nodes []string `json:"nodes"` group *gost.NodeGroup baseNodes []gost.Node + stopped chan struct{} } -type bypass struct { - Reverse bool `json:"reverse"` - Patterns []string `json:"patterns"` -} - -func parsePeerConfig(cfg string, group *gost.NodeGroup, baseNodes []gost.Node) *peerConfig { - pc := &peerConfig{ - group: group, - baseNodes: baseNodes, +func newPeerConfig() *peerConfig { + return &peerConfig{ + stopped: make(chan struct{}), } - go gost.PeriodReload(pc, cfg) - return pc } func (cfg *peerConfig) Validate() { @@ -52,28 +45,23 @@ func (cfg *peerConfig) Validate() { } func (cfg *peerConfig) Reload(r io.Reader) error { + if cfg.Stopped() { + return nil + } + if err := cfg.parse(r); err != nil { return err } cfg.Validate() group := cfg.group - /* - strategy := cfg.Strategy - if len(cfg.baseNodes) > 0 { - // overwrite the strategry in the peer config if `strategy` param exists. - if s := cfg.baseNodes[0].Get("strategy"); s != "" { - strategy = s - } - } - */ group.SetSelector( nil, gost.WithFilter(&gost.FailFilter{ MaxFails: cfg.MaxFails, FailTimeout: cfg.FailTimeout, }), - gost.WithStrategy(parseStrategy(cfg.Strategy)), + gost.WithStrategy(gost.NewStrategy(cfg.Strategy)), ) gNodes := cfg.baseNodes @@ -92,7 +80,12 @@ func (cfg *peerConfig) Reload(r io.Reader) error { gNodes = append(gNodes, nodes...) } - group.SetNodes(gNodes...) + nodes := group.SetNodes(gNodes...) + for _, node := range nodes[len(cfg.baseNodes):] { + if node.Bypass != nil { + node.Bypass.Stop() // clear the old nodes + } + } return nil } @@ -154,18 +147,27 @@ func (cfg *peerConfig) parse(r io.Reader) error { } func (cfg *peerConfig) Period() time.Duration { + if cfg.Stopped() { + return -1 + } return cfg.period } -func parseStrategy(s string) gost.Strategy { - switch s { - case "random": - return &gost.RandomStrategy{} - case "fifo": - return &gost.FIFOStrategy{} - case "round": - fallthrough +// Stop stops reloading. +func (cfg *peerConfig) Stop() { + select { + case <-cfg.stopped: default: - return &gost.RoundStrategy{} + close(cfg.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (cfg *peerConfig) Stopped() bool { + select { + case <-cfg.stopped: + return true + default: + return false } } diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 28318e9..ecb7e5f 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -3,20 +3,32 @@ package main import ( "crypto/sha256" "crypto/tls" + "fmt" "net" + "os" "time" "github.com/ginuerzh/gost" + "github.com/go-log/log" ) +type stringList []string + +func (l *stringList) String() string { + return fmt.Sprintf("%s", *l) +} +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} + type route struct { ServeNodes stringList ChainNodes stringList Retries int - server *gost.Server } -func (r *route) initChain() (*gost.Chain, error) { +func (r *route) parseChain() (*gost.Chain, error) { chain := gost.NewChain() chain.Retries = r.Retries gid := 1 // group ID @@ -44,13 +56,20 @@ func (r *route) initChain() (*gost.Chain, error) { MaxFails: defaultMaxFails, FailTimeout: defaultFailTimeout, }), - gost.WithStrategy(parseStrategy(nodes[0].Get("strategy"))), + gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))), ) - go gost.PeriodReload(&peerConfig{ - group: ngroup, - baseNodes: nodes, - }, nodes[0].Get("peer")) + cfg := nodes[0].Get("peer") + f, err := os.Open(cfg) + if err == nil { + peerCfg := newPeerConfig() + peerCfg.group = ngroup + peerCfg.baseNodes = nodes + peerCfg.Reload(f) + f.Close() + + go gost.PeriodReload(peerCfg, cfg) + } chain.AddNodeGroup(ngroup) } @@ -219,20 +238,22 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { return } -func (r *route) serve() error { - chain, err := r.initChain() +func (r *route) GenRouters() ([]router, error) { + chain, err := r.parseChain() if err != nil { - return err + return nil, err } + var rts []router + for _, ns := range r.ServeNodes { node, err := gost.ParseNode(ns) if err != nil { - return err + return nil, err } users, err := parseUsers(node.Get("secrets")) if err != nil { - return err + return nil, err } if node.User != nil { users = append(users, node.User) @@ -240,7 +261,7 @@ func (r *route) serve() error { certFile, keyFile := node.Get("cert"), node.Get("key") tlsCfg, err := tlsConfig(certFile, keyFile) if err != nil && certFile != "" && keyFile != "" { - return err + return nil, err } wsOpts := &gost.WSOptions{} @@ -266,7 +287,7 @@ func (r *route) serve() error { case "kcp": config, er := parseKCPConfig(node.Get("c")) if er != nil { - return er + return nil, er } ln, err = gost.KCPListener(node.Addr, config) case "ssh": @@ -320,7 +341,7 @@ func (r *route) serve() error { ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(node.GetInt("ttl"))*time.Second) case "obfs4": if err = gost.Obfs4Init(node, true); err != nil { - return err + return nil, err } ln, err = gost.Obfs4Listener(node.Addr) case "ohttp": @@ -329,7 +350,7 @@ func (r *route) serve() error { ln, err = gost.TCPListener(node.Addr) } if err != nil { - return err + return nil, err } var handler gost.Handler @@ -372,15 +393,19 @@ func (r *route) serve() error { var whitelist, blacklist *gost.Permissions if node.Values.Get("whitelist") != "" { if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil { - return err + return nil, err } } if node.Values.Get("blacklist") != "" { if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil { - return err + return nil, err } } + node.Bypass = parseBypass(node.Get("bypass")) + resolver := parseResolver(node.Get("dns")) + hosts := parseHosts(node.Get("hosts")) + handler.Init( gost.AddrHandlerOption(node.Addr), gost.ChainHandlerOption(chain), @@ -388,23 +413,44 @@ func (r *route) serve() error { gost.TLSConfigHandlerOption(tlsCfg), gost.WhitelistHandlerOption(whitelist), gost.BlacklistHandlerOption(blacklist), - gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))), - gost.BypassHandlerOption(parseBypass(node.Get("bypass"))), - gost.ResolverHandlerOption(parseResolver(node.Get("dns"))), - gost.HostsHandlerOption(parseHosts(node.Get("hosts"))), + gost.StrategyHandlerOption(gost.NewStrategy(node.Get("strategy"))), + gost.BypassHandlerOption(node.Bypass), + gost.ResolverHandlerOption(resolver), + gost.HostsHandlerOption(hosts), gost.RetryHandlerOption(node.GetInt("retry")), gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second), gost.ProbeResistHandlerOption(node.Get("probe_resist")), ) - r.server = &gost.Server{Listener: ln} - go r.server.Serve(handler) + rt := router{ + node: node, + server: &gost.Server{Listener: ln}, + handler: handler, + chain: chain, + resolver: resolver, + hosts: hosts, + } + rts = append(rts, rt) } - return nil + return rts, nil } -func (r *route) Close() error { +type router struct { + node gost.Node + server *gost.Server + handler gost.Handler + chain *gost.Chain + resolver gost.Resolver + hosts *gost.Hosts +} + +func (r *router) Serve() error { + log.Logf("[route] start %s on %s", r.node.String(), r.server.Addr()) + return r.server.Serve(r.handler) +} + +func (r *router) Close() error { if r == nil || r.server == nil { return nil } diff --git a/hosts.go b/hosts.go index b7744d5..4c492fb 100644 --- a/hosts.go +++ b/hosts.go @@ -24,15 +24,17 @@ type Host struct { // Fields of the entry are separated by any number of blanks and/or tab characters. // Text from a "#" character until the end of the line is a comment, and is ignored. type Hosts struct { - hosts []Host - period time.Duration - mux sync.RWMutex + hosts []Host + period time.Duration + stopped chan struct{} + mux sync.RWMutex } // NewHosts creates a Hosts with optional list of host func NewHosts(hosts ...Host) *Hosts { return &Hosts{ - hosts: hosts, + hosts: hosts, + stopped: make(chan struct{}), } } @@ -76,6 +78,10 @@ func (h *Hosts) Reload(r io.Reader) error { var period time.Duration var hosts []Host + if h.Stopped() { + return nil + } + scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() @@ -130,8 +136,31 @@ func (h *Hosts) Reload(r io.Reader) error { // Period returns the reload period func (h *Hosts) Period() time.Duration { + if h.Stopped() { + return -1 + } + h.mux.RLock() defer h.mux.RUnlock() return h.period } + +// Stop stops reloading. +func (h *Hosts) Stop() { + select { + case <-h.stopped: + default: + close(h.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (h *Hosts) Stopped() bool { + select { + case <-h.stopped: + return true + default: + return false + } +} diff --git a/http.go b/http.go index f54b2de..e436f60 100644 --- a/http.go +++ b/http.go @@ -263,7 +263,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { if err == nil { return } - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) + // log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) continue } diff --git a/http2.go b/http2.go index 922b02a..a33b57c 100644 --- a/http2.go +++ b/http2.go @@ -468,6 +468,7 @@ func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response) type http2Listener struct { server *http.Server connChan chan *http2ServerConn + addr net.Addr errChan chan error } @@ -494,6 +495,8 @@ func HTTP2Listener(addr string, config *tls.Config) (Listener, error) { if err != nil { return nil, err } + l.addr = ln.Addr() + go func() { err := server.Serve(ln) if err != nil { @@ -532,8 +535,7 @@ func (l *http2Listener) Accept() (conn net.Conn, err error) { } func (l *http2Listener) Addr() net.Addr { - addr, _ := net.ResolveTCPAddr("tcp", l.server.Addr) - return addr + return l.addr } func (l *http2Listener) Close() (err error) { diff --git a/node.go b/node.go index fbe929e..6eb47b6 100644 --- a/node.go +++ b/node.go @@ -2,7 +2,6 @@ package gost import ( "errors" - "fmt" "net/url" "strconv" "strings" @@ -22,6 +21,7 @@ type Node struct { Protocol string Transport string Remote string // remote address, used by tcp/udp port forwarding + url string // raw url User *url.Userinfo Values url.Values DialOptions []DialOption @@ -57,6 +57,9 @@ func ParseNode(s string) (node Node, err error) { marker: &failMarker{}, } + u.RawQuery = "" + node.url = u.String() + schemes := strings.Split(u.Scheme, "+") if len(schemes) == 1 { node.Protocol = schemes[0] @@ -136,8 +139,7 @@ func (node *Node) GetInt(key string) int { } func (node Node) String() string { - return fmt.Sprintf("%d@%s+%s://%s", - node.ID, node.Protocol, node.Transport, node.Addr) + return node.url } // NodeGroup is a group of nodes. @@ -167,16 +169,19 @@ func (group *NodeGroup) AddNode(node ...Node) { group.nodes = append(group.nodes, node...) } -// SetNodes replaces the group nodes to the specified nodes. -func (group *NodeGroup) SetNodes(nodes ...Node) { +// SetNodes replaces the group nodes to the specified nodes, +// and returns the previous nodes. +func (group *NodeGroup) SetNodes(nodes ...Node) []Node { if group == nil { - return + return nil } group.mux.Lock() defer group.mux.Unlock() + old := group.nodes group.nodes = nodes + return old } // SetSelector sets node selector with options for the group. diff --git a/node_test.go b/node_test.go index 48b8ec6..31718f4 100644 --- a/node_test.go +++ b/node_test.go @@ -8,7 +8,7 @@ var nodeTests = []struct { out Node hasError bool }{ - {"", Node{}, false}, + {"", Node{}, true}, {"://", Node{}, true}, {"localhost", Node{Addr: "localhost", Transport: "tcp"}, false}, {":", Node{Addr: ":", Transport: "tcp"}, false}, diff --git a/reload.go b/reload.go index db79b8c..e6d5648 100644 --- a/reload.go +++ b/reload.go @@ -14,43 +14,71 @@ type Reloader interface { Period() time.Duration } -// PeriodReload reloads the config periodically according to the period of the reloader. +// Stoppable is the interface that indicates a Reloader can be stopped. +type Stoppable interface { + Stop() +} + +//StopReloader is the interface that adds Stop method to the Reloader. +type StopReloader interface { + Reloader + Stoppable +} + +type nopStoppable struct { + Reloader +} + +func (nopStoppable) Stop() { + return +} + +// NopStoppable returns a StopReloader with a no-op Stop method, +// wrapping the provided Reloader r. +func NopStoppable(r Reloader) StopReloader { + return nopStoppable{r} +} + +// PeriodReload reloads the config configFile periodically according to the period of the Reloader r. func PeriodReload(r Reloader, configFile string) error { - if configFile == "" { + if r == nil || configFile == "" { return nil } var lastMod time.Time for { + if r.Period() < 0 { + log.Log("[reload] stopped:", configFile) + return nil + } + f, err := os.Open(configFile) if err != nil { return err } - finfo, err := f.Stat() - if err != nil { - f.Close() - return err + mt := lastMod + if finfo, err := f.Stat(); err == nil { + mt = finfo.ModTime() } - mt := finfo.ModTime() - if !mt.Equal(lastMod) { + + if !lastMod.IsZero() && !mt.Equal(lastMod) { log.Log("[reload]", configFile) if err := r.Reload(f); err != nil { log.Logf("[reload] %s: %s", configFile, err) } - lastMod = mt } f.Close() + lastMod = mt period := r.Period() - if period <= 0 { + if period == 0 { log.Log("[reload] disabled:", configFile) return nil } if period < time.Second { period = time.Second } - <-time.After(period) } } diff --git a/resolver.go b/resolver.go index 5d25fc7..2296817 100644 --- a/resolver.go +++ b/resolver.go @@ -29,10 +29,11 @@ type Resolver interface { Resolve(host string) ([]net.IP, error) } -// ReloadResolver is resolover that support live reloading +// ReloadResolver is resolover that support live reloading. type ReloadResolver interface { Resolver Reloader + Stoppable } // NameServer is a name server. @@ -68,6 +69,7 @@ type resolver struct { TTL time.Duration period time.Duration domain string + stopped chan struct{} mux sync.RWMutex } @@ -78,6 +80,7 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv Timeout: timeout, TTL: ttl, mCache: &sync.Map{}, + stopped: make(chan struct{}), } if r.Timeout <= 0 { @@ -110,6 +113,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { r.mux.RLock() domain = r.domain timeout = r.Timeout + ttl = r.TTL servers = r.copyServers() r.mux.RUnlock() @@ -219,6 +223,10 @@ func (r *resolver) Reload(rd io.Reader) error { var domain string var nss []NameServer + if r.Stopped() { + return nil + } + split := func(line string) []string { if line == "" { return nil @@ -305,12 +313,35 @@ func (r *resolver) Reload(rd io.Reader) error { } func (r *resolver) Period() time.Duration { + if r.Stopped() { + return -1 + } + r.mux.RLock() defer r.mux.RUnlock() return r.period } +// Stop stops reloading. +func (r *resolver) Stop() { + select { + case <-r.stopped: + default: + close(r.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (r *resolver) Stopped() bool { + select { + case <-r.stopped: + return true + default: + return false + } +} + func (r *resolver) String() string { if r == nil { return "" diff --git a/selector.go b/selector.go index f4120c5..d4dcc7e 100644 --- a/selector.go +++ b/selector.go @@ -68,6 +68,20 @@ type Strategy interface { String() string } +// NewStrategy creates a Strategy by the name s. +func NewStrategy(s string) Strategy { + switch s { + case "random": + return &RandomStrategy{} + case "fifo": + return &FIFOStrategy{} + case "round": + fallthrough + default: + return &RoundStrategy{} + } +} + // RoundStrategy is a strategy for node selector. // The node will be selected by round-robin algorithm. type RoundStrategy struct { diff --git a/server.go b/server.go index e4cdfce..d3804cf 100644 --- a/server.go +++ b/server.go @@ -86,21 +86,11 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error { // ServerOptions holds the options for Server. type ServerOptions struct { - Bypass *Bypass } // ServerOption allows a common way to set server options. type ServerOption func(opts *ServerOptions) -/* -// BypassServerOption sets the bypass option of ServerOptions. -func BypassServerOption(bypass *Bypass) ServerOption { - return func(opts *ServerOptions) { - opts.Bypass = bypass - } -} -*/ - // Listener is a proxy server listener, just like a net.Listener. type Listener interface { net.Listener diff --git a/ss.go b/ss.go index b44d726..5138881 100644 --- a/ss.go +++ b/ss.go @@ -84,8 +84,9 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...Connect return nil, err } - sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher) - if err != nil { + sc := ss.NewConn(conn, cipher) + // sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher) + if _, err := sc.Write(rawaddr); err != nil { return nil, err } return &shadowConn{conn: sc}, nil diff --git a/ws.go b/ws.go index 0b9c61f..fdc97b4 100644 --- a/ws.go +++ b/ws.go @@ -384,7 +384,6 @@ func WSListener(addr string, options *WSOptions) (Listener, error) { options = &WSOptions{} } l := &wsListener{ - addr: tcpAddr, upgrader: &websocket.Upgrader{ ReadBufferSize: options.ReadBufferSize, WriteBufferSize: options.WriteBufferSize, @@ -403,6 +402,7 @@ func WSListener(addr string, options *WSOptions) (Listener, error) { if err != nil { return nil, err } + l.addr = ln.Addr() go func() { err := l.srv.Serve(tcpKeepAliveListener{ln}) @@ -473,7 +473,6 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) { options = &WSOptions{} } l := &mwsListener{ - addr: tcpAddr, upgrader: &websocket.Upgrader{ ReadBufferSize: options.ReadBufferSize, WriteBufferSize: options.WriteBufferSize, @@ -492,6 +491,7 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) { if err != nil { return nil, err } + l.addr = ln.Addr() go func() { err := l.srv.Serve(tcpKeepAliveListener{ln}) @@ -584,7 +584,6 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen } l := &wssListener{ wsListener: &wsListener{ - addr: tcpAddr, upgrader: &websocket.Upgrader{ ReadBufferSize: options.ReadBufferSize, WriteBufferSize: options.WriteBufferSize, @@ -612,6 +611,7 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen if err != nil { return nil, err } + l.addr = ln.Addr() go func() { err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig)) @@ -644,7 +644,6 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste } l := &mwssListener{ mwsListener: &mwsListener{ - addr: tcpAddr, upgrader: &websocket.Upgrader{ ReadBufferSize: options.ReadBufferSize, WriteBufferSize: options.WriteBufferSize, @@ -672,6 +671,7 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste if err != nil { return nil, err } + l.addr = ln.Addr() go func() { err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig))