add stop for live reloading

This commit is contained in:
ginuerzh 2018-11-29 22:09:10 +08:00
parent 5e0e08d5b0
commit dc4c78ca44
18 changed files with 337 additions and 203 deletions

View File

@ -124,6 +124,7 @@ type Bypass struct {
matchers []Matcher
period time.Duration // the period for live reloading
reversed bool
stopped chan struct{}
mux sync.RWMutex
}
@ -133,6 +134,7 @@ func NewBypass(reversed bool, matchers ...Matcher) *Bypass {
return &Bypass{
matchers: matchers,
reversed: reversed,
stopped: make(chan struct{}),
}
}
@ -207,6 +209,10 @@ func (bp *Bypass) Reload(r io.Reader) error {
var period time.Duration
var reversed bool
if bp.Stopped() {
return nil
}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
@ -264,14 +270,37 @@ func (bp *Bypass) Reload(r io.Reader) error {
return nil
}
// Period returns the reload period
// Period returns the reload period.
func (bp *Bypass) Period() time.Duration {
if bp.Stopped() {
return -1
}
bp.mux.RLock()
defer bp.mux.RUnlock()
return bp.period
}
// Stop stops reloading.
func (bp *Bypass) Stop() {
select {
case <-bp.stopped:
default:
close(bp.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (bp *Bypass) Stopped() bool {
select {
case <-bp.stopped:
return true
default:
return false
}
}
func (bp *Bypass) String() string {
bp.mux.RLock()
defer bp.mux.RUnlock()

View File

@ -15,7 +15,7 @@ var (
ErrEmptyChain = errors.New("empty chain")
)
// Chain is a proxy chain that holds a list of proxy nodes.
// Chain is a proxy chain that holds a list of proxy node groups.
type Chain struct {
isRoute bool
Retries int
@ -23,6 +23,7 @@ type Chain struct {
}
// NewChain creates a proxy chain with a list of proxy nodes.
// It creates the node groups automatically, one group per node.
func NewChain(nodes ...Node) *Chain {
chain := &Chain{}
for _, node := range nodes {
@ -31,6 +32,8 @@ func NewChain(nodes ...Node) *Chain {
return chain
}
// newRoute creates a chain route.
// a chain route is the final route after node selection.
func newRoute(nodes ...Node) *Chain {
chain := NewChain(nodes...)
chain.isRoute = true

View File

@ -6,8 +6,6 @@ import (
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/url"
"os"
@ -17,6 +15,34 @@ import (
"github.com/ginuerzh/gost"
)
var (
routers []router
)
type baseConfig struct {
route
Routes []route
Debug bool
}
func parseBaseConfig(s string) (*baseConfig, error) {
file, err := os.Open(s)
if err != nil {
return nil, err
}
defer file.Close()
if err := json.NewDecoder(file).Decode(baseCfg); err != nil {
return nil, err
}
return baseCfg, nil
}
func (cfg *baseConfig) IsValid() bool {
return len(cfg.route.ServeNodes) > 0
}
var (
defaultCertFile = "cert.pem"
defaultKeyFile = "key.pem"
@ -52,70 +78,6 @@ func loadCA(caFile string) (cp *x509.CertPool, err error) {
return
}
type baseConfig struct {
route
Routes []route
ReloadPeriod string
Debug bool
}
func parseBaseConfig(s string) (*baseConfig, error) {
file, err := os.Open(s)
if err != nil {
return nil, err
}
defer file.Close()
if err := json.NewDecoder(file).Decode(baseCfg); err != nil {
return nil, err
}
return baseCfg, nil
}
func (cfg *baseConfig) IsValid() bool {
return len(cfg.route.ServeNodes) > 0
}
func (cfg *baseConfig) Reload(r io.Reader) error {
c := baseConfig{}
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
cfg.route.Close()
for _, r := range cfg.Routes {
r.Close()
}
*cfg = c
gost.Debug = cfg.Debug
if err := cfg.route.serve(); err != nil {
return err
}
for _, route := range cfg.Routes {
if err := route.serve(); err != nil {
return err
}
}
return nil
}
func (cfg *baseConfig) Period() time.Duration {
d, _ := time.ParseDuration(cfg.ReloadPeriod)
return d
}
type stringList []string
func (l *stringList) String() string {
return fmt.Sprintf("%s", *l)
}
func (l *stringList) Set(value string) error {
*l = append(*l, value)
return nil
}
func parseKCPConfig(configFile string) (*gost.KCPConfig, error) {
if configFile == "" {
return nil, nil
@ -221,9 +183,10 @@ func parseBypass(s string) *gost.Bypass {
}
return gost.NewBypass(reversed, matchers...)
}
f.Close()
defer f.Close()
bp := gost.NewBypass(reversed)
bp.Reload(f)
go gost.PeriodReload(bp, s)
return bp
@ -259,16 +222,26 @@ func parseResolver(cfg string) gost.Resolver {
}
return gost.NewResolver(timeout, ttl, nss...)
}
f.Close()
defer f.Close()
resolver := gost.NewResolver(timeout, ttl)
resolver.Reload(f)
go gost.PeriodReload(resolver, cfg)
return resolver
}
func parseHosts(s string) *gost.Hosts {
f, err := os.Open(s)
if err != nil {
return nil
}
defer f.Close()
hosts := gost.NewHosts()
hosts.Reload(f)
go gost.PeriodReload(hosts, s)
return hosts

View File

@ -1,30 +0,0 @@
{
"Debug": false,
"Retries": 1,
"ServeNodes": [
":8080",
"ss://chacha20:12345678@:8338"
],
"ChainNodes": [
"http://192.168.1.1:8080",
"https://10.0.2.1:443"
],
"Routes": [
{
"Retries": 1,
"ServeNodes": [
"ws://:1443"
],
"ChainNodes": [
"socks://:192.168.1.1:1080"
]
},
{
"Retries": 3,
"ServeNodes": [
"quic://:443"
]
}
]
}

View File

@ -71,7 +71,10 @@ func main() {
}
gost.DefaultTLSConfig = tlsConfig
start()
if err := start(); err != nil {
log.Log(err)
os.Exit(1)
}
select {}
}
@ -79,16 +82,24 @@ func main() {
func start() error {
gost.Debug = baseCfg.Debug
if err := baseCfg.route.serve(); err != nil {
var routers []router
rts, err := baseCfg.route.GenRouters()
if err != nil {
return err
}
routers = append(routers, rts...)
for _, route := range baseCfg.Routes {
if err := route.serve(); err != nil {
rts, err := route.GenRouters()
if err != nil {
return err
}
routers = append(routers, rts...)
}
go gost.PeriodReload(baseCfg, configureFile)
for i := range routers {
go routers[i].Serve()
}
return nil
}

View File

@ -26,20 +26,13 @@ type peerConfig struct {
Nodes []string `json:"nodes"`
group *gost.NodeGroup
baseNodes []gost.Node
stopped chan struct{}
}
type bypass struct {
Reverse bool `json:"reverse"`
Patterns []string `json:"patterns"`
}
func parsePeerConfig(cfg string, group *gost.NodeGroup, baseNodes []gost.Node) *peerConfig {
pc := &peerConfig{
group: group,
baseNodes: baseNodes,
func newPeerConfig() *peerConfig {
return &peerConfig{
stopped: make(chan struct{}),
}
go gost.PeriodReload(pc, cfg)
return pc
}
func (cfg *peerConfig) Validate() {
@ -52,28 +45,23 @@ func (cfg *peerConfig) Validate() {
}
func (cfg *peerConfig) Reload(r io.Reader) error {
if cfg.Stopped() {
return nil
}
if err := cfg.parse(r); err != nil {
return err
}
cfg.Validate()
group := cfg.group
/*
strategy := cfg.Strategy
if len(cfg.baseNodes) > 0 {
// overwrite the strategry in the peer config if `strategy` param exists.
if s := cfg.baseNodes[0].Get("strategy"); s != "" {
strategy = s
}
}
*/
group.SetSelector(
nil,
gost.WithFilter(&gost.FailFilter{
MaxFails: cfg.MaxFails,
FailTimeout: cfg.FailTimeout,
}),
gost.WithStrategy(parseStrategy(cfg.Strategy)),
gost.WithStrategy(gost.NewStrategy(cfg.Strategy)),
)
gNodes := cfg.baseNodes
@ -92,7 +80,12 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
gNodes = append(gNodes, nodes...)
}
group.SetNodes(gNodes...)
nodes := group.SetNodes(gNodes...)
for _, node := range nodes[len(cfg.baseNodes):] {
if node.Bypass != nil {
node.Bypass.Stop() // clear the old nodes
}
}
return nil
}
@ -154,18 +147,27 @@ func (cfg *peerConfig) parse(r io.Reader) error {
}
func (cfg *peerConfig) Period() time.Duration {
if cfg.Stopped() {
return -1
}
return cfg.period
}
func parseStrategy(s string) gost.Strategy {
switch s {
case "random":
return &gost.RandomStrategy{}
case "fifo":
return &gost.FIFOStrategy{}
case "round":
fallthrough
// Stop stops reloading.
func (cfg *peerConfig) Stop() {
select {
case <-cfg.stopped:
default:
return &gost.RoundStrategy{}
close(cfg.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (cfg *peerConfig) Stopped() bool {
select {
case <-cfg.stopped:
return true
default:
return false
}
}

View File

@ -3,20 +3,32 @@ package main
import (
"crypto/sha256"
"crypto/tls"
"fmt"
"net"
"os"
"time"
"github.com/ginuerzh/gost"
"github.com/go-log/log"
)
type stringList []string
func (l *stringList) String() string {
return fmt.Sprintf("%s", *l)
}
func (l *stringList) Set(value string) error {
*l = append(*l, value)
return nil
}
type route struct {
ServeNodes stringList
ChainNodes stringList
Retries int
server *gost.Server
}
func (r *route) initChain() (*gost.Chain, error) {
func (r *route) parseChain() (*gost.Chain, error) {
chain := gost.NewChain()
chain.Retries = r.Retries
gid := 1 // group ID
@ -44,13 +56,20 @@ func (r *route) initChain() (*gost.Chain, error) {
MaxFails: defaultMaxFails,
FailTimeout: defaultFailTimeout,
}),
gost.WithStrategy(parseStrategy(nodes[0].Get("strategy"))),
gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))),
)
go gost.PeriodReload(&peerConfig{
group: ngroup,
baseNodes: nodes,
}, nodes[0].Get("peer"))
cfg := nodes[0].Get("peer")
f, err := os.Open(cfg)
if err == nil {
peerCfg := newPeerConfig()
peerCfg.group = ngroup
peerCfg.baseNodes = nodes
peerCfg.Reload(f)
f.Close()
go gost.PeriodReload(peerCfg, cfg)
}
chain.AddNodeGroup(ngroup)
}
@ -219,20 +238,22 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
return
}
func (r *route) serve() error {
chain, err := r.initChain()
func (r *route) GenRouters() ([]router, error) {
chain, err := r.parseChain()
if err != nil {
return err
return nil, err
}
var rts []router
for _, ns := range r.ServeNodes {
node, err := gost.ParseNode(ns)
if err != nil {
return err
return nil, err
}
users, err := parseUsers(node.Get("secrets"))
if err != nil {
return err
return nil, err
}
if node.User != nil {
users = append(users, node.User)
@ -240,7 +261,7 @@ func (r *route) serve() error {
certFile, keyFile := node.Get("cert"), node.Get("key")
tlsCfg, err := tlsConfig(certFile, keyFile)
if err != nil && certFile != "" && keyFile != "" {
return err
return nil, err
}
wsOpts := &gost.WSOptions{}
@ -266,7 +287,7 @@ func (r *route) serve() error {
case "kcp":
config, er := parseKCPConfig(node.Get("c"))
if er != nil {
return er
return nil, er
}
ln, err = gost.KCPListener(node.Addr, config)
case "ssh":
@ -320,7 +341,7 @@ func (r *route) serve() error {
ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(node.GetInt("ttl"))*time.Second)
case "obfs4":
if err = gost.Obfs4Init(node, true); err != nil {
return err
return nil, err
}
ln, err = gost.Obfs4Listener(node.Addr)
case "ohttp":
@ -329,7 +350,7 @@ func (r *route) serve() error {
ln, err = gost.TCPListener(node.Addr)
}
if err != nil {
return err
return nil, err
}
var handler gost.Handler
@ -372,15 +393,19 @@ func (r *route) serve() error {
var whitelist, blacklist *gost.Permissions
if node.Values.Get("whitelist") != "" {
if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil {
return err
return nil, err
}
}
if node.Values.Get("blacklist") != "" {
if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil {
return err
return nil, err
}
}
node.Bypass = parseBypass(node.Get("bypass"))
resolver := parseResolver(node.Get("dns"))
hosts := parseHosts(node.Get("hosts"))
handler.Init(
gost.AddrHandlerOption(node.Addr),
gost.ChainHandlerOption(chain),
@ -388,23 +413,44 @@ func (r *route) serve() error {
gost.TLSConfigHandlerOption(tlsCfg),
gost.WhitelistHandlerOption(whitelist),
gost.BlacklistHandlerOption(blacklist),
gost.StrategyHandlerOption(parseStrategy(node.Get("strategy"))),
gost.BypassHandlerOption(parseBypass(node.Get("bypass"))),
gost.ResolverHandlerOption(parseResolver(node.Get("dns"))),
gost.HostsHandlerOption(parseHosts(node.Get("hosts"))),
gost.StrategyHandlerOption(gost.NewStrategy(node.Get("strategy"))),
gost.BypassHandlerOption(node.Bypass),
gost.ResolverHandlerOption(resolver),
gost.HostsHandlerOption(hosts),
gost.RetryHandlerOption(node.GetInt("retry")),
gost.TimeoutHandlerOption(time.Duration(node.GetInt("timeout"))*time.Second),
gost.ProbeResistHandlerOption(node.Get("probe_resist")),
)
r.server = &gost.Server{Listener: ln}
go r.server.Serve(handler)
rt := router{
node: node,
server: &gost.Server{Listener: ln},
handler: handler,
chain: chain,
resolver: resolver,
hosts: hosts,
}
rts = append(rts, rt)
}
return nil
return rts, nil
}
func (r *route) Close() error {
type router struct {
node gost.Node
server *gost.Server
handler gost.Handler
chain *gost.Chain
resolver gost.Resolver
hosts *gost.Hosts
}
func (r *router) Serve() error {
log.Logf("[route] start %s on %s", r.node.String(), r.server.Addr())
return r.server.Serve(r.handler)
}
func (r *router) Close() error {
if r == nil || r.server == nil {
return nil
}

View File

@ -24,15 +24,17 @@ type Host struct {
// Fields of the entry are separated by any number of blanks and/or tab characters.
// Text from a "#" character until the end of the line is a comment, and is ignored.
type Hosts struct {
hosts []Host
period time.Duration
mux sync.RWMutex
hosts []Host
period time.Duration
stopped chan struct{}
mux sync.RWMutex
}
// NewHosts creates a Hosts with optional list of host
func NewHosts(hosts ...Host) *Hosts {
return &Hosts{
hosts: hosts,
hosts: hosts,
stopped: make(chan struct{}),
}
}
@ -76,6 +78,10 @@ func (h *Hosts) Reload(r io.Reader) error {
var period time.Duration
var hosts []Host
if h.Stopped() {
return nil
}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
@ -130,8 +136,31 @@ func (h *Hosts) Reload(r io.Reader) error {
// Period returns the reload period
func (h *Hosts) Period() time.Duration {
if h.Stopped() {
return -1
}
h.mux.RLock()
defer h.mux.RUnlock()
return h.period
}
// Stop stops reloading.
func (h *Hosts) Stop() {
select {
case <-h.stopped:
default:
close(h.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (h *Hosts) Stopped() bool {
select {
case <-h.stopped:
return true
default:
return false
}
}

View File

@ -263,7 +263,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
if err == nil {
return
}
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err)
// log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err)
continue
}

View File

@ -468,6 +468,7 @@ func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response)
type http2Listener struct {
server *http.Server
connChan chan *http2ServerConn
addr net.Addr
errChan chan error
}
@ -494,6 +495,8 @@ func HTTP2Listener(addr string, config *tls.Config) (Listener, error) {
if err != nil {
return nil, err
}
l.addr = ln.Addr()
go func() {
err := server.Serve(ln)
if err != nil {
@ -532,8 +535,7 @@ func (l *http2Listener) Accept() (conn net.Conn, err error) {
}
func (l *http2Listener) Addr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", l.server.Addr)
return addr
return l.addr
}
func (l *http2Listener) Close() (err error) {

17
node.go
View File

@ -2,7 +2,6 @@ package gost
import (
"errors"
"fmt"
"net/url"
"strconv"
"strings"
@ -22,6 +21,7 @@ type Node struct {
Protocol string
Transport string
Remote string // remote address, used by tcp/udp port forwarding
url string // raw url
User *url.Userinfo
Values url.Values
DialOptions []DialOption
@ -57,6 +57,9 @@ func ParseNode(s string) (node Node, err error) {
marker: &failMarker{},
}
u.RawQuery = ""
node.url = u.String()
schemes := strings.Split(u.Scheme, "+")
if len(schemes) == 1 {
node.Protocol = schemes[0]
@ -136,8 +139,7 @@ func (node *Node) GetInt(key string) int {
}
func (node Node) String() string {
return fmt.Sprintf("%d@%s+%s://%s",
node.ID, node.Protocol, node.Transport, node.Addr)
return node.url
}
// NodeGroup is a group of nodes.
@ -167,16 +169,19 @@ func (group *NodeGroup) AddNode(node ...Node) {
group.nodes = append(group.nodes, node...)
}
// SetNodes replaces the group nodes to the specified nodes.
func (group *NodeGroup) SetNodes(nodes ...Node) {
// SetNodes replaces the group nodes to the specified nodes,
// and returns the previous nodes.
func (group *NodeGroup) SetNodes(nodes ...Node) []Node {
if group == nil {
return
return nil
}
group.mux.Lock()
defer group.mux.Unlock()
old := group.nodes
group.nodes = nodes
return old
}
// SetSelector sets node selector with options for the group.

View File

@ -8,7 +8,7 @@ var nodeTests = []struct {
out Node
hasError bool
}{
{"", Node{}, false},
{"", Node{}, true},
{"://", Node{}, true},
{"localhost", Node{Addr: "localhost", Transport: "tcp"}, false},
{":", Node{Addr: ":", Transport: "tcp"}, false},

View File

@ -14,43 +14,71 @@ type Reloader interface {
Period() time.Duration
}
// PeriodReload reloads the config periodically according to the period of the reloader.
// Stoppable is the interface that indicates a Reloader can be stopped.
type Stoppable interface {
Stop()
}
//StopReloader is the interface that adds Stop method to the Reloader.
type StopReloader interface {
Reloader
Stoppable
}
type nopStoppable struct {
Reloader
}
func (nopStoppable) Stop() {
return
}
// NopStoppable returns a StopReloader with a no-op Stop method,
// wrapping the provided Reloader r.
func NopStoppable(r Reloader) StopReloader {
return nopStoppable{r}
}
// PeriodReload reloads the config configFile periodically according to the period of the Reloader r.
func PeriodReload(r Reloader, configFile string) error {
if configFile == "" {
if r == nil || configFile == "" {
return nil
}
var lastMod time.Time
for {
if r.Period() < 0 {
log.Log("[reload] stopped:", configFile)
return nil
}
f, err := os.Open(configFile)
if err != nil {
return err
}
finfo, err := f.Stat()
if err != nil {
f.Close()
return err
mt := lastMod
if finfo, err := f.Stat(); err == nil {
mt = finfo.ModTime()
}
mt := finfo.ModTime()
if !mt.Equal(lastMod) {
if !lastMod.IsZero() && !mt.Equal(lastMod) {
log.Log("[reload]", configFile)
if err := r.Reload(f); err != nil {
log.Logf("[reload] %s: %s", configFile, err)
}
lastMod = mt
}
f.Close()
lastMod = mt
period := r.Period()
if period <= 0 {
if period == 0 {
log.Log("[reload] disabled:", configFile)
return nil
}
if period < time.Second {
period = time.Second
}
<-time.After(period)
}
}

View File

@ -29,10 +29,11 @@ type Resolver interface {
Resolve(host string) ([]net.IP, error)
}
// ReloadResolver is resolover that support live reloading
// ReloadResolver is resolover that support live reloading.
type ReloadResolver interface {
Resolver
Reloader
Stoppable
}
// NameServer is a name server.
@ -68,6 +69,7 @@ type resolver struct {
TTL time.Duration
period time.Duration
domain string
stopped chan struct{}
mux sync.RWMutex
}
@ -78,6 +80,7 @@ func NewResolver(timeout, ttl time.Duration, servers ...NameServer) ReloadResolv
Timeout: timeout,
TTL: ttl,
mCache: &sync.Map{},
stopped: make(chan struct{}),
}
if r.Timeout <= 0 {
@ -110,6 +113,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
r.mux.RLock()
domain = r.domain
timeout = r.Timeout
ttl = r.TTL
servers = r.copyServers()
r.mux.RUnlock()
@ -219,6 +223,10 @@ func (r *resolver) Reload(rd io.Reader) error {
var domain string
var nss []NameServer
if r.Stopped() {
return nil
}
split := func(line string) []string {
if line == "" {
return nil
@ -305,12 +313,35 @@ func (r *resolver) Reload(rd io.Reader) error {
}
func (r *resolver) Period() time.Duration {
if r.Stopped() {
return -1
}
r.mux.RLock()
defer r.mux.RUnlock()
return r.period
}
// Stop stops reloading.
func (r *resolver) Stop() {
select {
case <-r.stopped:
default:
close(r.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (r *resolver) Stopped() bool {
select {
case <-r.stopped:
return true
default:
return false
}
}
func (r *resolver) String() string {
if r == nil {
return ""

View File

@ -68,6 +68,20 @@ type Strategy interface {
String() string
}
// NewStrategy creates a Strategy by the name s.
func NewStrategy(s string) Strategy {
switch s {
case "random":
return &RandomStrategy{}
case "fifo":
return &FIFOStrategy{}
case "round":
fallthrough
default:
return &RoundStrategy{}
}
}
// RoundStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm.
type RoundStrategy struct {

View File

@ -86,21 +86,11 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error {
// ServerOptions holds the options for Server.
type ServerOptions struct {
Bypass *Bypass
}
// ServerOption allows a common way to set server options.
type ServerOption func(opts *ServerOptions)
/*
// BypassServerOption sets the bypass option of ServerOptions.
func BypassServerOption(bypass *Bypass) ServerOption {
return func(opts *ServerOptions) {
opts.Bypass = bypass
}
}
*/
// Listener is a proxy server listener, just like a net.Listener.
type Listener interface {
net.Listener

5
ss.go
View File

@ -84,8 +84,9 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...Connect
return nil, err
}
sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher)
if err != nil {
sc := ss.NewConn(conn, cipher)
// sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher)
if _, err := sc.Write(rawaddr); err != nil {
return nil, err
}
return &shadowConn{conn: sc}, nil

8
ws.go
View File

@ -384,7 +384,6 @@ func WSListener(addr string, options *WSOptions) (Listener, error) {
options = &WSOptions{}
}
l := &wsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize,
@ -403,6 +402,7 @@ func WSListener(addr string, options *WSOptions) (Listener, error) {
if err != nil {
return nil, err
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(tcpKeepAliveListener{ln})
@ -473,7 +473,6 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) {
options = &WSOptions{}
}
l := &mwsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize,
@ -492,6 +491,7 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) {
if err != nil {
return nil, err
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(tcpKeepAliveListener{ln})
@ -584,7 +584,6 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen
}
l := &wssListener{
wsListener: &wsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize,
@ -612,6 +611,7 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen
if err != nil {
return nil, err
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig))
@ -644,7 +644,6 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste
}
l := &mwssListener{
mwsListener: &mwsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize,
@ -672,6 +671,7 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste
if err != nil {
return nil, err
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig))