add multiplex support for ws/wss

This commit is contained in:
rui.zheng 2017-11-01 15:33:44 +08:00
parent 31f6d0af35
commit cd842de218
4 changed files with 429 additions and 22 deletions

View File

@ -26,6 +26,7 @@ gost - GO Simple Tunnel
* SSH通道 (2.4+)
* QUIC通道 (2.4+)
* obfs4通道 (2.4+)
* SNI (2.5+)
二进制文件下载https://github.com/ginuerzh/gost/releases
@ -53,7 +54,7 @@ $ sudo snap install gost
```
scheme分为两部分: protocol+transport
protocol: 代理协议类型(http, socks4(a), socks5, ss), transport: 数据传输方式(ws, wss, tls, quic, kcp, ssh, h2, h2c, obfs4), 二者可以任意组合,或单独使用:
protocol: 代理协议类型(http, socks4(a), socks5, ss, sni), transport: 数据传输方式(ws, wss, tls, mtls, quic, kcp, ssh, h2, h2c, obfs4), 二者可以任意组合,或单独使用:
> http - 标准HTTP代理: http://:8080
@ -73,6 +74,8 @@ protocol: 代理协议类型(http, socks4(a), socks5, ss), transport: 数据传
> tls - HTTP/SOCKS5代理使用TLS传输数据: tls://:443
> mtls - HTTP/SOCKS5代理使用TLS以多路复用方式传输数据: mtls://:443
> ss - Shadowsocks代理: ss://chacha20:123456@:8338
> ssu - Shadowsocks UDP relay: ssu://chacha20:123456@:8338
@ -87,6 +90,7 @@ protocol: 代理协议类型(http, socks4(a), socks5, ss), transport: 数据传
> obfs4 - obfs4通道: obfs4://:8080
> sni - SNI代理: sni://:443
#### 端口转发

View File

@ -102,23 +102,26 @@ func initChain() (*gost.Chain, error) {
InsecureSkipVerify: !toBool(node.Values.Get("secure")),
RootCAs: rootCAs,
}
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
wsOpts.UserAgent = node.Values.Get("agent")
var tr gost.Transporter
switch node.Transport {
case "tls":
tr = gost.TLSTransporter()
case "mtls":
tr = gost.MTLSTransporter()
case "ws":
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
wsOpts.UserAgent = node.Values.Get("agent")
tr = gost.WSTransporter(wsOpts)
case "mws":
tr = gost.MWSTransporter(wsOpts)
case "wss":
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
tr = gost.WSSTransporter(wsOpts)
case "mwss":
tr = gost.MWSSTransporter(wsOpts)
case "kcp":
if !chain.IsEmpty() {
return nil, errors.New("KCP must be the first node in the proxy chain")
@ -173,8 +176,6 @@ func initChain() (*gost.Chain, error) {
tr = gost.Obfs4Transporter()
case "ohttp":
tr = gost.ObfsHTTPTransporter()
case "mtls":
tr = gost.MTLSTransporter()
default:
tr = gost.TCPTransporter()
}
@ -248,22 +249,26 @@ func serve(chain *gost.Chain) error {
return err
}
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
var ln gost.Listener
switch node.Transport {
case "tls":
ln, err = gost.TLSListener(node.Addr, tlsCfg)
case "mtls":
ln, err = gost.MTLSListener(node.Addr, tlsCfg)
case "ws":
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
ln, err = gost.WSListener(node.Addr, wsOpts)
case "mws":
ln, err = gost.MWSListener(node.Addr, wsOpts)
case "wss":
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
ln, err = gost.WSSListener(node.Addr, tlsCfg, wsOpts)
case "mwss":
ln, err = gost.MWSSListener(node.Addr, tlsCfg, wsOpts)
case "kcp":
config, er := parseKCPConfig(node.Values.Get("c"))
if er != nil {
@ -319,8 +324,6 @@ func serve(chain *gost.Chain) error {
ln, err = gost.Obfs4Listener(node.Addr)
case "ohttp":
ln, err = gost.ObfsHTTPListener(node.Addr)
case "mtls":
ln, err = gost.MTLSListener(node.Addr, tlsCfg)
default:
ln, err = gost.TCPListener(node.Addr)
}

View File

@ -52,7 +52,7 @@ func ParseNode(s string) (node Node, err error) {
}
switch node.Transport {
case "tls", "ws", "wss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "obfs4", "mtls":
case "tls", "mtls", "ws", "mws", "wss", "mwss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "obfs4":
case "https":
node.Protocol = "http"
node.Transport = "tls"

400
ws.go
View File

@ -5,16 +5,20 @@ import (
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"net"
"net/http"
"net/http/httputil"
"sync"
"sync/atomic"
"time"
"net/url"
"github.com/go-log/log"
"gopkg.in/gorilla/websocket.v1"
smux "gopkg.in/xtaci/smux.v1"
)
// WSOptions describes the options for websocket.
@ -131,6 +135,115 @@ func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (n
return websocketClientConn(url.String(), conn, nil, wsOptions)
}
type mwsTransporter struct {
tcpTransporter
options *WSOptions
sessions map[string]*muxSession
sessionMutex sync.Mutex
}
// MWSTransporter creates a Transporter that is used by multiplex-websocket proxy client.
func MWSTransporter(opts *WSOptions) Transporter {
return &mwsTransporter{
options: opts,
sessions: make(map[string]*muxSession),
}
}
func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) {
opts := &DialOptions{}
for _, option := range options {
option(opts)
}
if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
if !ok {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
} else {
conn, err = opts.Chain.Dial(addr)
}
if err != nil {
return
}
session = &muxSession{conn: conn}
tr.sessions[addr] = session
}
return session.conn, nil
}
func (tr *mwsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
opts := &HandshakeOptions{}
for _, option := range options {
option(opts)
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[opts.Addr]
if session != nil && session.conn != conn {
conn.Close()
return nil, errors.New("mws: unrecognized connection")
}
if !ok || session.session == nil {
s, err := tr.initSession(opts.Addr, conn, opts)
if err != nil {
conn.Close()
delete(tr.sessions, opts.Addr)
return nil, err
}
session = s
tr.sessions[opts.Addr] = session
}
cc, err := session.GetConn()
if err != nil {
session.Close()
delete(tr.sessions, opts.Addr)
return nil, err
}
return cc, nil
}
func (tr *mwsTransporter) initSession(addr string, conn net.Conn, opts *HandshakeOptions) (*muxSession, error) {
if opts == nil {
opts = &HandshakeOptions{}
}
wsOptions := tr.options
if opts.WSOptions != nil {
wsOptions = opts.WSOptions
}
url := url.URL{Scheme: "ws", Host: opts.Addr, Path: "/ws"}
conn, err := websocketClientConn(url.String(), conn, nil, wsOptions)
if err != nil {
return nil, err
}
// stream multiplex
smuxConfig := smux.DefaultConfig()
session, err := smux.Client(conn, smuxConfig)
if err != nil {
return nil, err
}
return &muxSession{conn: conn, session: session}, nil
}
func (tr *mwsTransporter) Multiplex() bool {
return true
}
type wssTransporter struct {
tcpTransporter
options *WSOptions
@ -159,6 +272,119 @@ func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions)
}
type mwssTransporter struct {
tcpTransporter
options *WSOptions
sessions map[string]*muxSession
sessionMutex sync.Mutex
}
// MWSSTransporter creates a Transporter that is used by multiplex-websocket secure proxy client.
func MWSSTransporter(opts *WSOptions) Transporter {
return &mwssTransporter{
options: opts,
sessions: make(map[string]*muxSession),
}
}
func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) {
opts := &DialOptions{}
for _, option := range options {
option(opts)
}
if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
if !ok {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
} else {
conn, err = opts.Chain.Dial(addr)
}
if err != nil {
return
}
session = &muxSession{conn: conn}
tr.sessions[addr] = session
}
return session.conn, nil
}
func (tr *mwssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
opts := &HandshakeOptions{}
for _, option := range options {
option(opts)
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[opts.Addr]
if session != nil && session.conn != conn {
conn.Close()
return nil, errors.New("mws: unrecognized connection")
}
if !ok || session.session == nil {
s, err := tr.initSession(opts.Addr, conn, opts)
if err != nil {
conn.Close()
delete(tr.sessions, opts.Addr)
return nil, err
}
session = s
tr.sessions[opts.Addr] = session
}
cc, err := session.GetConn()
if err != nil {
session.Close()
delete(tr.sessions, opts.Addr)
return nil, err
}
return cc, nil
}
func (tr *mwssTransporter) initSession(addr string, conn net.Conn, opts *HandshakeOptions) (*muxSession, error) {
if opts == nil {
opts = &HandshakeOptions{}
}
wsOptions := tr.options
if opts.WSOptions != nil {
wsOptions = opts.WSOptions
}
tlsConfig := opts.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{InsecureSkipVerify: true}
}
url := url.URL{Scheme: "wss", Host: opts.Addr, Path: "/ws"}
conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions)
if err != nil {
return nil, err
}
// stream multiplex
smuxConfig := smux.DefaultConfig()
session, err := smux.Client(conn, smuxConfig)
if err != nil {
return nil, err
}
return &muxSession{conn: conn, session: session}, nil
}
func (tr *mwssTransporter) Multiplex() bool {
return true
}
type wsListener struct {
addr net.Addr
upgrader *websocket.Upgrader
@ -248,6 +474,120 @@ func (l *wsListener) Addr() net.Addr {
return l.addr
}
type mwsListener struct {
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
connChan chan net.Conn
errChan chan error
}
// MWSListener creates a Listener for multiplex-websocket proxy server.
func MWSListener(addr string, options *WSOptions) (Listener, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
if options == nil {
options = &WSOptions{}
}
l := &mwsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: options.EnableCompression,
},
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1),
}
mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade))
l.srv = &http.Server{Addr: addr, Handler: mux}
ln, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return nil, err
}
go func() {
err := l.srv.Serve(tcpKeepAliveListener{ln})
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
select {
case err := <-l.errChan:
return nil, err
default:
}
return l, nil
}
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
log.Logf("[mws] %s -> %s", r.RemoteAddr, l.addr)
if Debug {
dump, _ := httputil.DumpRequest(r, false)
log.Log(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Logf("[mws] %s - %s : %s", r.RemoteAddr, l.addr, err)
return
}
l.mux(websocketServerConn(conn))
}
func (l *mwsListener) mux(conn net.Conn) {
smuxConfig := smux.DefaultConfig()
mux, err := smux.Server(conn, smuxConfig)
if err != nil {
log.Logf("[mws] %s - %s : %s", conn.RemoteAddr(), l.Addr(), err)
return
}
defer mux.Close()
log.Logf("[mws] %s <-> %s", conn.RemoteAddr(), l.Addr())
defer log.Logf("[mws] %s >-< %s", conn.RemoteAddr(), l.Addr())
for {
stream, err := mux.AcceptStream()
if err != nil {
log.Log("[mws] accept stream:", err)
return
}
cc := &muxStreamConn{Conn: conn, stream: stream}
select {
case l.connChan <- cc:
default:
cc.Close()
log.Logf("[mws] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
}
}
}
func (l *mwsListener) Accept() (conn net.Conn, err error) {
select {
case conn = <-l.connChan:
case err = <-l.errChan:
}
return
}
func (l *mwsListener) Close() error {
return l.srv.Close()
}
func (l *mwsListener) Addr() net.Addr {
return l.addr
}
type wssListener struct {
*wsListener
}
@ -308,6 +648,66 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen
return l, nil
}
type mwssListener struct {
*mwsListener
}
// MWSSListener creates a Listener for multiplex-websocket secure proxy server.
func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listener, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
if options == nil {
options = &WSOptions{}
}
l := &mwssListener{
mwsListener: &mwsListener{
addr: tcpAddr,
upgrader: &websocket.Upgrader{
ReadBufferSize: options.ReadBufferSize,
WriteBufferSize: options.WriteBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: options.EnableCompression,
},
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1),
},
}
if tlsConfig == nil {
tlsConfig = DefaultTLSConfig
}
mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: addr,
TLSConfig: tlsConfig,
Handler: mux,
}
ln, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return nil, err
}
go func() {
err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig))
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
select {
case err := <-l.errChan:
return nil, err
default:
}
return l, nil
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {