add stop for live reloading

This commit is contained in:
ginuerzh 2018-11-29 22:09:10 +08:00
parent 5e0e08d5b0
commit dc4c78ca44
18 changed files with 337 additions and 203 deletions

View File

@ -124,6 +124,7 @@ type Bypass struct {
matchers []Matcher matchers []Matcher
period time.Duration // the period for live reloading period time.Duration // the period for live reloading
reversed bool reversed bool
stopped chan struct{}
mux sync.RWMutex mux sync.RWMutex
} }
@ -133,6 +134,7 @@ func NewBypass(reversed bool, matchers ...Matcher) *Bypass {
return &Bypass{ return &Bypass{
matchers: matchers, matchers: matchers,
reversed: reversed, reversed: reversed,
stopped: make(chan struct{}),
} }
} }
@ -207,6 +209,10 @@ func (bp *Bypass) Reload(r io.Reader) error {
var period time.Duration var period time.Duration
var reversed bool var reversed bool
if bp.Stopped() {
return nil
}
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@ -264,14 +270,37 @@ func (bp *Bypass) Reload(r io.Reader) error {
return nil return nil
} }
// Period returns the reload period // Period returns the reload period.
func (bp *Bypass) Period() time.Duration { func (bp *Bypass) Period() time.Duration {
if bp.Stopped() {
return -1
}
bp.mux.RLock() bp.mux.RLock()
defer bp.mux.RUnlock() defer bp.mux.RUnlock()
return bp.period return bp.period
} }
// Stop stops reloading.
func (bp *Bypass) Stop() {
select {
case <-bp.stopped:
default:
close(bp.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (bp *Bypass) Stopped() bool {
select {
case <-bp.stopped:
return true
default:
return false
}
}
func (bp *Bypass) String() string { func (bp *Bypass) String() string {
bp.mux.RLock() bp.mux.RLock()
defer bp.mux.RUnlock() defer bp.mux.RUnlock()

View File

@ -15,7 +15,7 @@ var (
ErrEmptyChain = errors.New("empty chain") ErrEmptyChain = errors.New("empty chain")
) )
// Chain is a proxy chain that holds a list of proxy nodes. // Chain is a proxy chain that holds a list of proxy node groups.
type Chain struct { type Chain struct {
isRoute bool isRoute bool
Retries int Retries int
@ -23,6 +23,7 @@ type Chain struct {
} }
// NewChain creates a proxy chain with a list of proxy nodes. // NewChain creates a proxy chain with a list of proxy nodes.
// It creates the node groups automatically, one group per node.
func NewChain(nodes ...Node) *Chain { func NewChain(nodes ...Node) *Chain {
chain := &Chain{} chain := &Chain{}
for _, node := range nodes { for _, node := range nodes {
@ -31,6 +32,8 @@ func NewChain(nodes ...Node) *Chain {
return chain return chain
} }
// newRoute creates a chain route.
// a chain route is the final route after node selection.
func newRoute(nodes ...Node) *Chain { func newRoute(nodes ...Node) *Chain {
chain := NewChain(nodes...) chain := NewChain(nodes...)
chain.isRoute = true chain.isRoute = true

View File

@ -6,8 +6,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io"
"io/ioutil" "io/ioutil"
"net/url" "net/url"
"os" "os"
@ -17,6 +15,34 @@ import (
"github.com/ginuerzh/gost" "github.com/ginuerzh/gost"
) )
var (
routers []router
)
type baseConfig struct {
route
Routes []route
Debug bool
}
func parseBaseConfig(s string) (*baseConfig, error) {
file, err := os.Open(s)
if err != nil {
return nil, err
}
defer file.Close()
if err := json.NewDecoder(file).Decode(baseCfg); err != nil {
return nil, err
}
return baseCfg, nil
}
func (cfg *baseConfig) IsValid() bool {
return len(cfg.route.ServeNodes) > 0
}
var ( var (
defaultCertFile = "cert.pem" defaultCertFile = "cert.pem"
defaultKeyFile = "key.pem" defaultKeyFile = "key.pem"
@ -52,70 +78,6 @@ func loadCA(caFile string) (cp *x509.CertPool, err error) {
return return
} }
type baseConfig struct {
route
Routes []route
ReloadPeriod string
Debug bool
}
func parseBaseConfig(s string) (*baseConfig, error) {
file, err := os.Open(s)
if err != nil {
return nil, err
}
defer file.Close()
if err := json.NewDecoder(file).Decode(baseCfg); err != nil {
return nil, err
}
return baseCfg, nil
}
func (cfg *baseConfig) IsValid() bool {
return len(cfg.route.ServeNodes) > 0
}
func (cfg *baseConfig) Reload(r io.Reader) error {
c := baseConfig{}
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
cfg.route.Close()
for _, r := range cfg.Routes {
r.Close()
}
*cfg = c
gost.Debug = cfg.Debug
if err := cfg.route.serve(); err != nil {
return err
}
for _, route := range cfg.Routes {
if err := route.serve(); err != nil {
return err
}
}
return nil
}
func (cfg *baseConfig) Period() time.Duration {
d, _ := time.ParseDuration(cfg.ReloadPeriod)
return d
}
type stringList []string
func (l *stringList) String() string {
return fmt.Sprintf("%s", *l)
}
func (l *stringList) Set(value string) error {
*l = append(*l, value)
return nil
}
func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { func parseKCPConfig(configFile string) (*gost.KCPConfig, error) {
if configFile == "" { if configFile == "" {
return nil, nil return nil, nil
@ -221,9 +183,10 @@ func parseBypass(s string) *gost.Bypass {
} }
return gost.NewBypass(reversed, matchers...) return gost.NewBypass(reversed, matchers...)
} }
f.Close() defer f.Close()
bp := gost.NewBypass(reversed) bp := gost.NewBypass(reversed)
bp.Reload(f)
go gost.PeriodReload(bp, s) go gost.PeriodReload(bp, s)
return bp return bp
@ -259,16 +222,26 @@ func parseResolver(cfg string) gost.Resolver {
} }
return gost.NewResolver(timeout, ttl, nss...) return gost.NewResolver(timeout, ttl, nss...)
} }
f.Close() defer f.Close()
resolver := gost.NewResolver(timeout, ttl) resolver := gost.NewResolver(timeout, ttl)
resolver.Reload(f)
go gost.PeriodReload(resolver, cfg) go gost.PeriodReload(resolver, cfg)
return resolver return resolver
} }
func parseHosts(s string) *gost.Hosts { func parseHosts(s string) *gost.Hosts {
f, err := os.Open(s)
if err != nil {
return nil
}
defer f.Close()
hosts := gost.NewHosts() hosts := gost.NewHosts()
hosts.Reload(f)
go gost.PeriodReload(hosts, s) go gost.PeriodReload(hosts, s)
return hosts return hosts

View File

@ -1,30 +0,0 @@
{
"Debug": false,
"Retries": 1,
"ServeNodes": [
":8080",
"ss://chacha20:12345678@:8338"
],
"ChainNodes": [
"http://192.168.1.1:8080",
"https://10.0.2.1:443"
],
"Routes": [
{
"Retries": 1,
"ServeNodes": [
"ws://:1443"
],
"ChainNodes": [
"socks://:192.168.1.1:1080"
]
},
{
"Retries": 3,
"ServeNodes": [
"quic://:443"
]
}
]
}

View File

@ -71,7 +71,10 @@ func main() {
} }
gost.DefaultTLSConfig = tlsConfig gost.DefaultTLSConfig = tlsConfig
start() if err := start(); err != nil {
log.Log(err)
os.Exit(1)
}
select {} select {}
} }
@ -79,16 +82,24 @@ func main() {
func start() error { func start() error {
gost.Debug = baseCfg.Debug gost.Debug = baseCfg.Debug
if err := baseCfg.route.serve(); err != nil { var routers []router
rts, err := baseCfg.route.GenRouters()
if err != nil {
return err return err
} }
routers = append(routers, rts...)
for _, route := range baseCfg.Routes { for _, route := range baseCfg.Routes {
if err := route.serve(); err != nil { rts, err := route.GenRouters()
if err != nil {
return err return err
} }
routers = append(routers, rts...)
} }
go gost.PeriodReload(baseCfg, configureFile) for i := range routers {
go routers[i].Serve()
}
return nil return nil
} }

View File

@ -26,20 +26,13 @@ type peerConfig struct {
Nodes []string `json:"nodes"` Nodes []string `json:"nodes"`
group *gost.NodeGroup group *gost.NodeGroup
baseNodes []gost.Node baseNodes []gost.Node
stopped chan struct{}
} }
type bypass struct { func newPeerConfig() *peerConfig {
Reverse bool `json:"reverse"` return &peerConfig{
Patterns []string `json:"patterns"` stopped: make(chan struct{}),
}
func parsePeerConfig(cfg string, group *gost.NodeGroup, baseNodes []gost.Node) *peerConfig {
pc := &peerConfig{
group: group,
baseNodes: baseNodes,
} }
go gost.PeriodReload(pc, cfg)
return pc
} }
func (cfg *peerConfig) Validate() { func (cfg *peerConfig) Validate() {
@ -52,28 +45,23 @@ func (cfg *peerConfig) Validate() {
} }
func (cfg *peerConfig) Reload(r io.Reader) error { func (cfg *peerConfig) Reload(r io.Reader) error {
if cfg.Stopped() {
return nil
}
if err := cfg.parse(r); err != nil { if err := cfg.parse(r); err != nil {
return err return err
} }
cfg.Validate() cfg.Validate()
group := cfg.group group := cfg.group
/*
strategy := cfg.Strategy
if len(cfg.baseNodes) > 0 {
// overwrite the strategry in the peer config if `strategy` param exists.
if s := cfg.baseNodes[0].Get("strategy"); s != "" {
strategy = s
}
}
*/
group.SetSelector( group.SetSelector(
nil, nil,
gost.WithFilter(&gost.FailFilter{ gost.WithFilter(&gost.FailFilter{
MaxFails: cfg.MaxFails, MaxFails: cfg.MaxFails,
FailTimeout: cfg.FailTimeout, FailTimeout: cfg.FailTimeout,
}), }),
gost.WithStrategy(parseStrategy(cfg.Strategy)), gost.WithStrategy(gost.NewStrategy(cfg.Strategy)),
) )
gNodes := cfg.baseNodes gNodes := cfg.baseNodes
@ -92,7 +80,12 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
gNodes = append(gNodes, nodes...) gNodes = append(gNodes, nodes...)
} }
group.SetNodes(gNodes...) nodes := group.SetNodes(gNodes...)
for _, node := range nodes[len(cfg.baseNodes):] {
if node.Bypass != nil {
node.Bypass.Stop() // clear the old nodes
}
}
return nil return nil
} }
@ -154,18 +147,27 @@ func (cfg *peerConfig) parse(r io.Reader) error {
} }
func (cfg *peerConfig) Period() time.Duration { func (cfg *peerConfig) Period() time.Duration {
if cfg.Stopped() {
return -1
}
return cfg.period return cfg.period
} }
func parseStrategy(s string) gost.Strategy { // Stop stops reloading.
switch s { func (cfg *peerConfig) Stop() {
case "random": select {
return &gost.RandomStrategy{} case <-cfg.stopped:
case "fifo":
return &gost.FIFOStrategy{}
case "round":
fallthrough
default: default:
return &gost.RoundStrategy{} close(cfg.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (cfg *peerConfig) Stopped() bool {
select {
case <-cfg.stopped:
return true
default:
return false
} }
} }

View File

@ -3,20 +3,32 @@ package main
import ( import (
"crypto/sha256" "crypto/sha256"
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"os"
"time" "time"
"github.com/ginuerzh/gost" "github.com/ginuerzh/gost"
"github.com/go-log/log"
) )
type stringList []string
func (l *stringList) String() string {
return fmt.Sprintf("%s", *l)
}
func (l *stringList) Set(value string) error {
*l = append(*l, value)
return nil
}
type route struct { type route struct {
ServeNodes stringList ServeNodes stringList
ChainNodes stringList ChainNodes stringList
Retries int Retries int
server *gost.Server
} }
func (r *route) initChain() (*gost.Chain, error) { func (r *route) parseChain() (*gost.Chain, error) {
chain := gost.NewChain() chain := gost.NewChain()
chain.Retries = r.Retries chain.Retries = r.Retries
gid := 1 // group ID gid := 1 // group ID
@ -44,13 +56,20 @@ func (r *route) initChain() (*gost.Chain, error) {
MaxFails: defaultMaxFails, MaxFails: defaultMaxFails,
FailTimeout: defaultFailTimeout, FailTimeout: defaultFailTimeout,
}), }),
gost.WithStrategy(parseStrategy(nodes[0].Get("strategy"))), gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))),
) )
go gost.PeriodReload(&peerConfig{ cfg := nodes[0].Get("peer")
group: ngroup, f, err := os.Open(cfg)
baseNodes: nodes, if err == nil {
}, nodes[0].Get("peer")) peerCfg := newPeerConfig()
peerCfg.group = ngroup
peerCfg.baseNodes = nodes
peerCfg.Reload(f)
f.Close()
go gost.PeriodReload(peerCfg, cfg)
}
chain.AddNodeGroup(ngroup) chain.AddNodeGroup(ngroup)
} }
@ -219,20 +238,22 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
return return
} }
func (r *route) serve() error { func (r *route) GenRouters() ([]router, error) {
chain, err := r.initChain() chain, err := r.parseChain()
if err != nil { if err != nil {
return err return nil, err
} }
var rts []router
for _, ns := range r.ServeNodes { for _, ns := range r.ServeNodes {
node, err := gost.ParseNode(ns) node, err := gost.ParseNode(ns)
if err != nil { if err != nil {
return err return nil, err
} }
users, err := parseUsers(node.Get("secrets")) users, err := parseUsers(node.Get("secrets"))
if err != nil { if err != nil {
return err return nil, err
} }
if node.User != nil { if node.User != nil {
users = append(users, node.User) users = append(users, node.User)
@ -240,7 +261,7 @@ func (r *route) serve() error {
certFile, keyFile := node.Get("cert"), node.Get("key") certFile, keyFile := node.Get("cert"), node.Get("key")
tlsCfg, err := tlsConfig(certFile, keyFile) tlsCfg, err := tlsConfig(certFile, keyFile)
if err != nil && certFile != "" && keyFile != "" { if err != nil && certFile != "" && keyFile != "" {
return err return nil, err
} }
wsOpts := &gost.WSOptions{} wsOpts := &gost.WSOptions{}
@ -266,7 +287,7 @@ func (r *route) serve() error {
case "kcp": case "kcp":
config, er := parseKCPConfig(node.Get("c")) config, er := parseKCPConfig(node.Get("c"))
if er != nil { if er != nil {
return er return nil, er
} }
ln, err = gost.KCPListener(node.Addr, config) ln, err = gost.KCPListener(node.Addr, config)
case "ssh": case "ssh":
@ -320,7 +341,7 @@ func (r *route) serve() error {
ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(node.GetInt("ttl"))*time.Second) ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(node.GetInt("ttl"))*time.Second)
case "obfs4": case "obfs4":
if err = gost.Obfs4Init(node, true); err != nil { if err = gost.Obfs4Init(node, true); err != nil {
return err return nil, err
} }
ln, err = gost.Obfs4Listener(node.Addr) ln, err = gost.Obfs4Listener(node.Addr)
case "ohttp": case "ohttp":
@ -329,7 +350,7 @@ func (r *route) serve() error {
ln, err = gost.TCPListener(node.Addr) ln, err = gost.TCPListener(node.Addr)
} }
if err != nil { if err != nil {
return err return nil, err
} }
var handler gost.Handler var handler gost.Handler
@ -372,15 +393,19 @@ func (r *route) serve() error {
var whitelist, blacklist *gost.Permissions var whitelist, blacklist *gost.Permissions
if node.Values.Get("whitelist") != "" { if node.Values.Get("whitelist") != "" {
if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil { if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil {
return err return nil, err
} }
} }
if node.Values.Get("blacklist") != "" { if node.Values.Get("blacklist") != "" {
if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil { if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil {
return err return nil, err
} }
} }
node.Bypass = parseBypass(node.Get("bypass"))
resolver := parseResolver(node.Get("dns"))
hosts := parseHosts(node.Get("hosts"))
handler.Init( handler.Init(
gost.AddrHandlerOption(node.Addr), gost.AddrHandlerOption(node.Addr),
gost.ChainHandlerOption(chain), gost.ChainHandlerOption(chain),
@ -388,23 +413,44 @@ func (r *route) serve() error {
gost.TLSConfigHandlerOption(tlsCfg), gost.TLSConfigHandlerOption(tlsCfg),
gost.WhitelistHandlerOption(whitelist), gost.WhitelistHandlerOption(whitelist),
gost.BlacklistHandlerOption(blacklist), gost.BlacklistHandlerOption(blacklist),
gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))), gost.StrategyHandlerOption(gost.NewStrategy(node.Get("strategy"))),
gost.BypassHandlerOption(parseBypass(node.Get("bypass"))), gost.BypassHandlerOption(node.Bypass),
gost.ResolverHandlerOption(parseResolver(node.Get("dns"))), gost.ResolverHandlerOption(resolver),
gost.HostsHandlerOption(parseHosts(node.Get("hosts"))), gost.HostsHandlerOption(hosts),
gost.RetryHandlerOption(node.GetInt("retry")), gost.RetryHandlerOption(node.GetInt("retry")),
gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second), gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second),
gost.ProbeResistHandlerOption(node.Get("probe_resist")), gost.ProbeResistHandlerOption(node.Get("probe_resist")),
) )
r.server = &gost.Server{Listener: ln} rt := router{
go r.server.Serve(handler) node: node,
server: &gost.Server{Listener: ln},
handler: handler,
chain: chain,
resolver: resolver,
hosts: hosts,
}
rts = append(rts, rt)
} }
return nil return rts, nil
} }
func (r *route) Close() error { type router struct {
node gost.Node
server *gost.Server
handler gost.Handler
chain *gost.Chain
resolver gost.Resolver
hosts *gost.Hosts
}
func (r *router) Serve() error {
log.Logf("[route] start %s on %s", r.node.String(), r.server.Addr())
return r.server.Serve(r.handler)
}
func (r *router) Close() error {
if r == nil || r.server == nil { if r == nil || r.server == nil {
return nil return nil
} }

View File

@ -24,15 +24,17 @@ type Host struct {
// Fields of the entry are separated by any number of blanks and/or tab characters. // Fields of the entry are separated by any number of blanks and/or tab characters.
// Text from a "#" character until the end of the line is a comment, and is ignored. // Text from a "#" character until the end of the line is a comment, and is ignored.
type Hosts struct { type Hosts struct {
hosts []Host hosts []Host
period time.Duration period time.Duration
mux sync.RWMutex stopped chan struct{}
mux sync.RWMutex
} }
// NewHosts creates a Hosts with optional list of host // NewHosts creates a Hosts with optional list of host
func NewHosts(hosts ...Host) *Hosts { func NewHosts(hosts ...Host) *Hosts {
return &Hosts{ return &Hosts{
hosts: hosts, hosts: hosts,
stopped: make(chan struct{}),
} }
} }
@ -76,6 +78,10 @@ func (h *Hosts) Reload(r io.Reader) error {
var period time.Duration var period time.Duration
var hosts []Host var hosts []Host
if h.Stopped() {
return nil
}
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@ -130,8 +136,31 @@ func (h *Hosts) Reload(r io.Reader) error {
// Period returns the reload period // Period returns the reload period
func (h *Hosts) Period() time.Duration { func (h *Hosts) Period() time.Duration {
if h.Stopped() {
return -1
}
h.mux.RLock() h.mux.RLock()
defer h.mux.RUnlock() defer h.mux.RUnlock()
return h.period return h.period
} }
// Stop stops reloading.
func (h *Hosts) Stop() {
select {
case <-h.stopped:
default:
close(h.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (h *Hosts) Stopped() bool {
select {
case <-h.stopped:
return true
default:
return false
}
}

View File

@ -263,7 +263,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
if err == nil { if err == nil {
return return
} }
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) // log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err)
continue continue
} }

View File

@ -468,6 +468,7 @@ func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response)
type http2Listener struct { type http2Listener struct {
server *http.Server server *http.Server
connChan chan *http2ServerConn connChan chan *http2ServerConn
addr net.Addr
errChan chan error errChan chan error
} }
@ -494,6 +495,8 @@ func HTTP2Listener(addr string, config *tls.Config) (Listener, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
l.addr = ln.Addr()
go func() { go func() {
err := server.Serve(ln) err := server.Serve(ln)
if err != nil { if err != nil {
@ -532,8 +535,7 @@ func (l *http2Listener) Accept() (conn net.Conn, err error) {
} }
func (l *http2Listener) Addr() net.Addr { func (l *http2Listener) Addr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", l.server.Addr) return l.addr
return addr
} }
func (l *http2Listener) Close() (err error) { func (l *http2Listener) Close() (err error) {

17
node.go
View File

@ -2,7 +2,6 @@ package gost
import ( import (
"errors" "errors"
"fmt"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
@ -22,6 +21,7 @@ type Node struct {
Protocol string Protocol string
Transport string Transport string
Remote string // remote address, used by tcp/udp port forwarding Remote string // remote address, used by tcp/udp port forwarding
url string // raw url
User *url.Userinfo User *url.Userinfo
Values url.Values Values url.Values
DialOptions []DialOption DialOptions []DialOption
@ -57,6 +57,9 @@ func ParseNode(s string) (node Node, err error) {
marker: &failMarker{}, marker: &failMarker{},
} }
u.RawQuery = ""
node.url = u.String()
schemes := strings.Split(u.Scheme, "+") schemes := strings.Split(u.Scheme, "+")
if len(schemes) == 1 { if len(schemes) == 1 {
node.Protocol = schemes[0] node.Protocol = schemes[0]
@ -136,8 +139,7 @@ func (node *Node) GetInt(key string) int {
} }
func (node Node) String() string { func (node Node) String() string {
return fmt.Sprintf("%d@%s+%s://%s", return node.url
node.ID, node.Protocol, node.Transport, node.Addr)
} }
// NodeGroup is a group of nodes. // NodeGroup is a group of nodes.
@ -167,16 +169,19 @@ func (group *NodeGroup) AddNode(node ...Node) {
group.nodes = append(group.nodes, node...) group.nodes = append(group.nodes, node...)
} }
// SetNodes replaces the group nodes to the specified nodes. // SetNodes replaces the group nodes to the specified nodes,
func (group *NodeGroup) SetNodes(nodes ...Node) { // and returns the previous nodes.
func (group *NodeGroup) SetNodes(nodes ...Node) []Node {
if group == nil { if group == nil {
return return nil
} }
group.mux.Lock() group.mux.Lock()
defer group.mux.Unlock() defer group.mux.Unlock()
old := group.nodes
group.nodes = nodes group.nodes = nodes
return old
} }
// SetSelector sets node selector with options for the group. // SetSelector sets node selector with options for the group.

View File

@ -8,7 +8,7 @@ var nodeTests = []struct {
out Node out Node
hasError bool hasError bool
}{ }{
{"", Node{}, false}, {"", Node{}, true},
{"://", Node{}, true}, {"://", Node{}, true},
{"localhost", Node{Addr: "localhost", Transport: "tcp"}, false}, {"localhost", Node{Addr: "localhost", Transport: "tcp"}, false},
{":", Node{Addr: ":", Transport: "tcp"}, false}, {":", Node{Addr: ":", Transport: "tcp"}, false},

View File

@ -14,43 +14,71 @@ type Reloader interface {
Period() time.Duration Period() time.Duration
} }
// PeriodReload reloads the config periodically according to the period of the reloader. // Stoppable is the interface that indicates a Reloader can be stopped.
type Stoppable interface {
Stop()
}
//StopReloader is the interface that adds Stop method to the Reloader.
type StopReloader interface {
Reloader
Stoppable
}
type nopStoppable struct {
Reloader
}
func (nopStoppable) Stop() {
return
}
// NopStoppable returns a StopReloader with a no-op Stop method,
// wrapping the provided Reloader r.
func NopStoppable(r Reloader) StopReloader {
return nopStoppable{r}
}
// PeriodReload reloads the config configFile periodically according to the period of the Reloader r.
func PeriodReload(r Reloader, configFile string) error { func PeriodReload(r Reloader, configFile string) error {
if configFile == "" { if r == nil || configFile == "" {
return nil return nil
} }
var lastMod time.Time var lastMod time.Time
for { for {
if r.Period() < 0 {
log.Log("[reload] stopped:", configFile)
return nil
}
f, err := os.Open(configFile) f, err := os.Open(configFile)
if err != nil { if err != nil {
return err return err
} }
finfo, err := f.Stat() mt := lastMod
if err != nil { if finfo, err := f.Stat(); err == nil {
f.Close() mt = finfo.ModTime()
return err
} }
mt := finfo.ModTime()
if !mt.Equal(lastMod) { if !lastMod.IsZero() && !mt.Equal(lastMod) {
log.Log("[reload]", configFile) log.Log("[reload]", configFile)
if err := r.Reload(f); err != nil { if err := r.Reload(f); err != nil {
log.Logf("[reload] %s: %s", configFile, err) log.Logf("[reload] %s: %s", configFile, err)
} }
lastMod = mt
} }
f.Close() f.Close()
lastMod = mt
period := r.Period() period := r.Period()
if period <= 0 { if period == 0 {
log.Log("[reload] disabled:", configFile) log.Log("[reload] disabled:", configFile)
return nil return nil
} }
if period < time.Second { if period < time.Second {
period = time.Second period = time.Second
} }
<-time.After(period) <-time.After(period)
} }
} }

View File

@ -29,10 +29,11 @@ type Resolver interface {
Resolve(host string) ([]net.IP, error) Resolve(host string) ([]net.IP, error)
} }
// ReloadResolver is resolover that support live reloading // ReloadResolver is resolover that support live reloading.
type ReloadResolver interface { type ReloadResolver interface {
Resolver Resolver
Reloader Reloader
Stoppable
} }
// NameServer is a name server. // NameServer is a name server.
@ -68,6 +69,7 @@ type resolver struct {
TTL time.Duration TTL time.Duration
period time.Duration period time.Duration
domain string domain string
stopped chan struct{}
mux sync.RWMutex mux sync.RWMutex
} }
@ -78,6 +80,7 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
Timeout: timeout, Timeout: timeout,
TTL: ttl, TTL: ttl,
mCache: &sync.Map{}, mCache: &sync.Map{},
stopped: make(chan struct{}),
} }
if r.Timeout <= 0 { if r.Timeout <= 0 {
@ -110,6 +113,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
r.mux.RLock() r.mux.RLock()
domain = r.domain domain = r.domain
timeout = r.Timeout timeout = r.Timeout
ttl = r.TTL
servers = r.copyServers() servers = r.copyServers()
r.mux.RUnlock() r.mux.RUnlock()
@ -219,6 +223,10 @@ func (r *resolver) Reload(rd io.Reader) error {
var domain string var domain string
var nss []NameServer var nss []NameServer
if r.Stopped() {
return nil
}
split := func(line string) []string { split := func(line string) []string {
if line == "" { if line == "" {
return nil return nil
@ -305,12 +313,35 @@ func (r *resolver) Reload(rd io.Reader) error {
} }
func (r *resolver) Period() time.Duration { func (r *resolver) Period() time.Duration {
if r.Stopped() {
return -1
}
r.mux.RLock() r.mux.RLock()
defer r.mux.RUnlock() defer r.mux.RUnlock()
return r.period return r.period
} }
// Stop stops reloading.
func (r *resolver) Stop() {
select {
case <-r.stopped:
default:
close(r.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (r *resolver) Stopped() bool {
select {
case <-r.stopped:
return true
default:
return false
}
}
func (r *resolver) String() string { func (r *resolver) String() string {
if r == nil { if r == nil {
return "" return ""

View File

@ -68,6 +68,20 @@ type Strategy interface {
String() string String() string
} }
// NewStrategy creates a Strategy by the name s.
func NewStrategy(s string) Strategy {
switch s {
case "random":
return &RandomStrategy{}
case "fifo":
return &FIFOStrategy{}
case "round":
fallthrough
default:
return &RoundStrategy{}
}
}
// RoundStrategy is a strategy for node selector. // RoundStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm. // The node will be selected by round-robin algorithm.
type RoundStrategy struct { type RoundStrategy struct {

View File

@ -86,21 +86,11 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error {
// ServerOptions holds the options for Server. // ServerOptions holds the options for Server.
type ServerOptions struct { type ServerOptions struct {
Bypass *Bypass
} }
// ServerOption allows a common way to set server options. // ServerOption allows a common way to set server options.
type ServerOption func(opts *ServerOptions) type ServerOption func(opts *ServerOptions)
/*
// BypassServerOption sets the bypass option of ServerOptions.
func BypassServerOption(bypass *Bypass) ServerOption {
return func(opts *ServerOptions) {
opts.Bypass = bypass
}
}
*/
// Listener is a proxy server listener, just like a net.Listener. // Listener is a proxy server listener, just like a net.Listener.
type Listener interface { type Listener interface {
net.Listener net.Listener

5
ss.go
View File

@ -84,8 +84,9 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...Connect
return nil, err return nil, err
} }
sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher) sc := ss.NewConn(conn, cipher)
if err != nil { // sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher)
if _, err := sc.Write(rawaddr); err != nil {
return nil, err return nil, err
} }
return &shadowConn{conn: sc}, nil return &shadowConn{conn: sc}, nil

8
ws.go
View File

@ -384,7 +384,6 @@ func WSListener(addr string, options *WSOptions) (Listener, error) {
options = &WSOptions{} options = &WSOptions{}
} }
l := &wsListener{ l := &wsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize, ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize, WriteBufferSize: options.WriteBufferSize,
@ -403,6 +402,7 @@ func WSListener(addr string, options *WSOptions) (Listener, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
l.addr = ln.Addr()
go func() { go func() {
err := l.srv.Serve(tcpKeepAliveListener{ln}) err := l.srv.Serve(tcpKeepAliveListener{ln})
@ -473,7 +473,6 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) {
options = &WSOptions{} options = &WSOptions{}
} }
l := &mwsListener{ l := &mwsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize, ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize, WriteBufferSize: options.WriteBufferSize,
@ -492,6 +491,7 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
l.addr = ln.Addr()
go func() { go func() {
err := l.srv.Serve(tcpKeepAliveListener{ln}) err := l.srv.Serve(tcpKeepAliveListener{ln})
@ -584,7 +584,6 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen
} }
l := &wssListener{ l := &wssListener{
wsListener: &wsListener{ wsListener: &wsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize, ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize, WriteBufferSize: options.WriteBufferSize,
@ -612,6 +611,7 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen
if err != nil { if err != nil {
return nil, err return nil, err
} }
l.addr = ln.Addr()
go func() { go func() {
err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig)) err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig))
@ -644,7 +644,6 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste
} }
l := &mwssListener{ l := &mwssListener{
mwsListener: &mwsListener{ mwsListener: &mwsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize, ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize, WriteBufferSize: options.WriteBufferSize,
@ -672,6 +671,7 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste
if err != nil { if err != nil {
return nil, err return nil, err
} }
l.addr = ln.Addr()
go func() { go func() {
err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig)) err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig))