diff --git a/cmd/gost/route.go b/cmd/gost/route.go index cb16e7a..2f0690a 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -109,6 +109,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { wsOpts.ReadBufferSize = node.GetInt("rbuf") wsOpts.WriteBufferSize = node.GetInt("wbuf") wsOpts.UserAgent = node.Get("agent") + wsOpts.Path = node.Get("path") var host string @@ -276,6 +277,7 @@ func (r *route) GenRouters() ([]router, error) { wsOpts.EnableCompression = node.GetBool("compression") wsOpts.ReadBufferSize = node.GetInt("rbuf") wsOpts.WriteBufferSize = node.GetInt("wbuf") + wsOpts.Path = node.Get("path") var ln gost.Listener switch node.Transport { @@ -284,7 +286,6 @@ func (r *route) GenRouters() ([]router, error) { case "mtls": ln, err = gost.MTLSListener(node.Addr, tlsCfg) case "ws": - wsOpts.WriteBufferSize = node.GetInt("wbuf") ln, err = gost.WSListener(node.Addr, wsOpts) case "mws": ln, err = gost.MWSListener(node.Addr, wsOpts) diff --git a/ws.go b/ws.go index a6aad99..477eaaa 100644 --- a/ws.go +++ b/ws.go @@ -19,6 +19,10 @@ import ( smux "gopkg.in/xtaci/smux.v1" ) +const ( + defaultWSPath = "/ws" +) + // WSOptions describes the options for websocket. type WSOptions struct { ReadBufferSize int @@ -26,6 +30,7 @@ type WSOptions struct { HandshakeTimeout time.Duration EnableCompression bool UserAgent string + Path string } type wsTransporter struct { @@ -49,7 +54,15 @@ func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (n if opts.WSOptions != nil { wsOptions = opts.WSOptions } - url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} + if wsOptions == nil { + wsOptions = &WSOptions{} + } + + path := wsOptions.Path + if path == "" { + path = defaultWSPath + } + url := url.URL{Scheme: "ws", Host: opts.Host, Path: path} return websocketClientConn(url.String(), conn, nil, wsOptions) } @@ -148,7 +161,15 @@ func (tr *mwsTransporter) initSession(addr string, conn net.Conn, opts *Handshak if opts.WSOptions != nil { wsOptions = opts.WSOptions } - url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} + if wsOptions == nil { + wsOptions = &WSOptions{} + } + + path := wsOptions.Path + if path == "" { + path = defaultWSPath + } + url := url.URL{Scheme: "ws", Host: opts.Host, Path: path} conn, err := websocketClientConn(url.String(), conn, nil, wsOptions) if err != nil { return nil, err @@ -187,10 +208,18 @@ func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( if opts.WSOptions != nil { wsOptions = opts.WSOptions } + if wsOptions == nil { + wsOptions = &WSOptions{} + } + if opts.TLSConfig == nil { opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} } - url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} + path := wsOptions.Path + if path == "" { + path = defaultWSPath + } + url := url.URL{Scheme: "wss", Host: opts.Host, Path: path} return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) } @@ -288,11 +317,19 @@ func (tr *mwssTransporter) initSession(addr string, conn net.Conn, opts *Handsha if opts.WSOptions != nil { wsOptions = opts.WSOptions } + if wsOptions == nil { + wsOptions = &WSOptions{} + } + tlsConfig := opts.TLSConfig if tlsConfig == nil { tlsConfig = &tls.Config{InsecureSkipVerify: true} } - url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} + path := wsOptions.Path + if path == "" { + path = defaultWSPath + } + url := url.URL{Scheme: "wss", Host: opts.Host, Path: path} conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions) if err != nil { return nil, err @@ -338,8 +375,12 @@ func WSListener(addr string, options *WSOptions) (Listener, error) { errChan: make(chan error, 1), } + path := options.Path + if path == "" { + path = defaultWSPath + } mux := http.NewServeMux() - mux.Handle("/ws", http.HandlerFunc(l.upgrade)) + mux.Handle(path, http.HandlerFunc(l.upgrade)) l.srv = &http.Server{ Addr: addr, Handler: mux, @@ -431,8 +472,13 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) { errChan: make(chan error, 1), } + path := options.Path + if path == "" { + path = defaultWSPath + } + mux := http.NewServeMux() - mux.Handle("/ws", http.HandlerFunc(l.upgrade)) + mux.Handle(path, http.HandlerFunc(l.upgrade)) l.srv = &http.Server{ Addr: addr, Handler: mux, @@ -551,8 +597,13 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen tlsConfig = DefaultTLSConfig } + path := options.Path + if path == "" { + path = defaultWSPath + } + mux := http.NewServeMux() - mux.Handle("/ws", http.HandlerFunc(l.upgrade)) + mux.Handle(path, http.HandlerFunc(l.upgrade)) l.srv = &http.Server{ Addr: addr, TLSConfig: tlsConfig, @@ -612,8 +663,13 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste tlsConfig = DefaultTLSConfig } + path := options.Path + if path == "" { + path = defaultWSPath + } + mux := http.NewServeMux() - mux.Handle("/ws", http.HandlerFunc(l.upgrade)) + mux.Handle(path, http.HandlerFunc(l.upgrade)) l.srv = &http.Server{ Addr: addr, TLSConfig: tlsConfig,