ws supports user defined url path

This commit is contained in:
ginuerzh 2019-01-12 14:25:59 +08:00
parent 88efafca06
commit 7c7a51ec02
2 changed files with 66 additions and 9 deletions

View File

@ -109,6 +109,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
wsOpts.ReadBufferSize = node.GetInt("rbuf") wsOpts.ReadBufferSize = node.GetInt("rbuf")
wsOpts.WriteBufferSize = node.GetInt("wbuf") wsOpts.WriteBufferSize = node.GetInt("wbuf")
wsOpts.UserAgent = node.Get("agent") wsOpts.UserAgent = node.Get("agent")
wsOpts.Path = node.Get("path")
var host string var host string
@ -276,6 +277,7 @@ func (r *route) GenRouters() ([]router, error) {
wsOpts.EnableCompression = node.GetBool("compression") wsOpts.EnableCompression = node.GetBool("compression")
wsOpts.ReadBufferSize = node.GetInt("rbuf") wsOpts.ReadBufferSize = node.GetInt("rbuf")
wsOpts.WriteBufferSize = node.GetInt("wbuf") wsOpts.WriteBufferSize = node.GetInt("wbuf")
wsOpts.Path = node.Get("path")
var ln gost.Listener var ln gost.Listener
switch node.Transport { switch node.Transport {
@ -284,7 +286,6 @@ func (r *route) GenRouters() ([]router, error) {
case "mtls": case "mtls":
ln, err = gost.MTLSListener(node.Addr, tlsCfg) ln, err = gost.MTLSListener(node.Addr, tlsCfg)
case "ws": case "ws":
wsOpts.WriteBufferSize = node.GetInt("wbuf")
ln, err = gost.WSListener(node.Addr, wsOpts) ln, err = gost.WSListener(node.Addr, wsOpts)
case "mws": case "mws":
ln, err = gost.MWSListener(node.Addr, wsOpts) ln, err = gost.MWSListener(node.Addr, wsOpts)

72
ws.go
View File

@ -19,6 +19,10 @@ import (
smux "gopkg.in/xtaci/smux.v1" smux "gopkg.in/xtaci/smux.v1"
) )
const (
defaultWSPath = "/ws"
)
// WSOptions describes the options for websocket. // WSOptions describes the options for websocket.
type WSOptions struct { type WSOptions struct {
ReadBufferSize int ReadBufferSize int
@ -26,6 +30,7 @@ type WSOptions struct {
HandshakeTimeout time.Duration HandshakeTimeout time.Duration
EnableCompression bool EnableCompression bool
UserAgent string UserAgent string
Path string
} }
type wsTransporter struct { type wsTransporter struct {
@ -49,7 +54,15 @@ func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (n
if opts.WSOptions != nil { if opts.WSOptions != nil {
wsOptions = opts.WSOptions wsOptions = opts.WSOptions
} }
url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} if wsOptions == nil {
wsOptions = &WSOptions{}
}
path := wsOptions.Path
if path == "" {
path = defaultWSPath
}
url := url.URL{Scheme: "ws", Host: opts.Host, Path: path}
return websocketClientConn(url.String(), conn, nil, wsOptions) return websocketClientConn(url.String(), conn, nil, wsOptions)
} }
@ -148,7 +161,15 @@ func (tr *mwsTransporter) initSession(addr string, conn net.Conn, opts *Handshak
if opts.WSOptions != nil { if opts.WSOptions != nil {
wsOptions = opts.WSOptions wsOptions = opts.WSOptions
} }
url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} if wsOptions == nil {
wsOptions = &WSOptions{}
}
path := wsOptions.Path
if path == "" {
path = defaultWSPath
}
url := url.URL{Scheme: "ws", Host: opts.Host, Path: path}
conn, err := websocketClientConn(url.String(), conn, nil, wsOptions) conn, err := websocketClientConn(url.String(), conn, nil, wsOptions)
if err != nil { if err != nil {
return nil, err return nil, err
@ -187,10 +208,18 @@ func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
if opts.WSOptions != nil { if opts.WSOptions != nil {
wsOptions = opts.WSOptions wsOptions = opts.WSOptions
} }
if wsOptions == nil {
wsOptions = &WSOptions{}
}
if opts.TLSConfig == nil { if opts.TLSConfig == nil {
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
} }
url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} path := wsOptions.Path
if path == "" {
path = defaultWSPath
}
url := url.URL{Scheme: "wss", Host: opts.Host, Path: path}
return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions)
} }
@ -288,11 +317,19 @@ func (tr *mwssTransporter) initSession(addr string, conn net.Conn, opts *Handsha
if opts.WSOptions != nil { if opts.WSOptions != nil {
wsOptions = opts.WSOptions wsOptions = opts.WSOptions
} }
if wsOptions == nil {
wsOptions = &WSOptions{}
}
tlsConfig := opts.TLSConfig tlsConfig := opts.TLSConfig
if tlsConfig == nil { if tlsConfig == nil {
tlsConfig = &tls.Config{InsecureSkipVerify: true} tlsConfig = &tls.Config{InsecureSkipVerify: true}
} }
url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} path := wsOptions.Path
if path == "" {
path = defaultWSPath
}
url := url.URL{Scheme: "wss", Host: opts.Host, Path: path}
conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions) conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions)
if err != nil { if err != nil {
return nil, err return nil, err
@ -338,8 +375,12 @@ func WSListener(addr string, options *WSOptions) (Listener, error) {
errChan: make(chan error, 1), errChan: make(chan error, 1),
} }
path := options.Path
if path == "" {
path = defaultWSPath
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade)) mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: addr, Addr: addr,
Handler: mux, Handler: mux,
@ -431,8 +472,13 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) {
errChan: make(chan error, 1), errChan: make(chan error, 1),
} }
path := options.Path
if path == "" {
path = defaultWSPath
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade)) mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: addr, Addr: addr,
Handler: mux, Handler: mux,
@ -551,8 +597,13 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen
tlsConfig = DefaultTLSConfig tlsConfig = DefaultTLSConfig
} }
path := options.Path
if path == "" {
path = defaultWSPath
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade)) mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: addr, Addr: addr,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
@ -612,8 +663,13 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste
tlsConfig = DefaultTLSConfig tlsConfig = DefaultTLSConfig
} }
path := options.Path
if path == "" {
path = defaultWSPath
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade)) mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: addr, Addr: addr,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,