From 31f6d0af35b594510d6e8f49283906a930a3e1be Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Tue, 31 Oct 2017 17:30:51 +0800 Subject: [PATCH] add multiplex TLS tunnel --- .gitignore | 3 +- cmd/gost/main.go | 4 + kcp.go | 74 ++--------------- mux.go | 49 ++++++++++++ node.go | 2 +- sni.go | 3 + tls.go | 202 +++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 267 insertions(+), 70 deletions(-) create mode 100644 mux.go diff --git a/.gitignore b/.gitignore index 57726ad..2f239f7 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,5 @@ _testmain.go *.bak -cmd/gost \ No newline at end of file +cmd/gost +snap diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 532df03..e99a450 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -173,6 +173,8 @@ func initChain() (*gost.Chain, error) { tr = gost.Obfs4Transporter() case "ohttp": tr = gost.ObfsHTTPTransporter() + case "mtls": + tr = gost.MTLSTransporter() default: tr = gost.TCPTransporter() } @@ -317,6 +319,8 @@ func serve(chain *gost.Chain) error { ln, err = gost.Obfs4Listener(node.Addr) case "ohttp": ln, err = gost.ObfsHTTPListener(node.Addr) + case "mtls": + ln, err = gost.MTLSListener(node.Addr, tlsCfg) default: ln, err = gost.TCPListener(node.Addr) } diff --git a/kcp.go b/kcp.go index 857c697..b0ae961 100644 --- a/kcp.go +++ b/kcp.go @@ -90,70 +90,8 @@ var ( } ) -type kcpConn struct { - conn net.Conn - stream *smux.Stream -} - -func (c *kcpConn) Read(b []byte) (n int, err error) { - return c.stream.Read(b) -} - -func (c *kcpConn) Write(b []byte) (n int, err error) { - return c.stream.Write(b) -} - -func (c *kcpConn) Close() error { - return c.stream.Close() -} - -func (c *kcpConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *kcpConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *kcpConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *kcpConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *kcpConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -type kcpSession struct { - conn net.Conn - session *smux.Session -} - -func (session *kcpSession) GetConn() (*kcpConn, error) { - stream, err := session.session.OpenStream() - if err != nil { - return nil, err - } - return &kcpConn{conn: session.conn, stream: stream}, nil -} - -func (session *kcpSession) Close() error { - return session.session.Close() -} - -func (session *kcpSession) IsClosed() bool { - return session.session.IsClosed() -} - -func (session *kcpSession) NumStreams() int { - return session.session.NumStreams() -} - type kcpTransporter struct { - sessions map[string]*kcpSession + sessions map[string]*muxSession sessionMutex sync.Mutex config *KCPConfig } @@ -172,7 +110,7 @@ func KCPTransporter(config *KCPConfig) Transporter { return &kcpTransporter{ config: config, - sessions: make(map[string]*kcpSession), + sessions: make(map[string]*muxSession), } } @@ -191,7 +129,7 @@ func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Con if err != nil { return } - session = &kcpSession{conn: conn} + session = &muxSession{conn: conn} tr.sessions[addr] = session } return session.conn, nil @@ -234,7 +172,7 @@ func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( return cc, nil } -func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPConfig) (*kcpSession, error) { +func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPConfig) (*muxSession, error) { udpConn, ok := conn.(*net.UDPConn) if !ok { return nil, errors.New("kcp: wrong connection type") @@ -276,7 +214,7 @@ func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPCon if err != nil { return nil, err } - return &kcpSession{conn: conn, session: session}, nil + return &muxSession{conn: conn, session: session}, nil } func (tr *kcpTransporter) Multiplex() bool { @@ -374,7 +312,7 @@ func (l *kcpListener) mux(conn net.Conn) { return } - cc := &kcpConn{conn: conn, stream: stream} + cc := &muxStreamConn{Conn: conn, stream: stream} select { case l.connChan <- cc: default: diff --git a/mux.go b/mux.go new file mode 100644 index 0000000..c84463d --- /dev/null +++ b/mux.go @@ -0,0 +1,49 @@ +package gost + +import ( + "net" + + smux "gopkg.in/xtaci/smux.v1" +) + +type muxStreamConn struct { + net.Conn + stream *smux.Stream +} + +func (c *muxStreamConn) Read(b []byte) (n int, err error) { + return c.stream.Read(b) +} + +func (c *muxStreamConn) Write(b []byte) (n int, err error) { + return c.stream.Write(b) +} + +func (c *muxStreamConn) Close() error { + return c.stream.Close() +} + +type muxSession struct { + conn net.Conn + session *smux.Session +} + +func (session *muxSession) GetConn() (net.Conn, error) { + stream, err := session.session.OpenStream() + if err != nil { + return nil, err + } + return &muxStreamConn{Conn: session.conn, stream: stream}, nil +} + +func (session *muxSession) Close() error { + return session.session.Close() +} + +func (session *muxSession) IsClosed() bool { + return session.session.IsClosed() +} + +func (session *muxSession) NumStreams() int { + return session.session.NumStreams() +} diff --git a/node.go b/node.go index 324dd12..a1a8b89 100644 --- a/node.go +++ b/node.go @@ -52,7 +52,7 @@ func ParseNode(s string) (node Node, err error) { } switch node.Transport { - case "tls", "ws", "wss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "obfs4": + case "tls", "ws", "wss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "obfs4", "mtls": case "https": node.Protocol = "http" node.Transport = "tls" diff --git a/sni.go b/sni.go index 8e21722..d4ec608 100644 --- a/sni.go +++ b/sni.go @@ -213,6 +213,9 @@ func readClientHelloRecord(r io.Reader, host string, isClient bool) ([]byte, str for _, ext := range clientHello.Extensions { if ext.Type() == dissector.ExtServerName { snExtension := ext.(*dissector.ServerNameExtension) + if host == "" { + host = snExtension.Name + } if isClient { clientHello.Extensions = append(clientHello.Extensions, dissector.NewExtension(0xFFFE, []byte(encodeServerName(snExtension.Name)))) diff --git a/tls.go b/tls.go index 3c617be..6b8005c 100644 --- a/tls.go +++ b/tls.go @@ -3,8 +3,15 @@ package gost import ( "crypto/tls" "crypto/x509" + "errors" "net" + "sync" + "sync/atomic" "time" + + "github.com/go-log/log" + + smux "gopkg.in/xtaci/smux.v1" ) type tlsTransporter struct { @@ -27,6 +34,113 @@ func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( return wrapTLSClient(conn, opts.TLSConfig) } +type mtlsTransporter struct { + tcpTransporter + sessions map[string]*muxSession + sessionMutex sync.Mutex +} + +// MTLSTransporter creates a Transporter that is used by multiplex-TLS proxy client. +func MTLSTransporter() Transporter { + return &mtlsTransporter{ + sessions: make(map[string]*muxSession), + } +} + +func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + if len(opts.IPs) > 0 { + count := atomic.AddUint64(&tr.count, 1) + _, sport, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + n := uint64(len(opts.IPs)) + addr = opts.IPs[int(count%n)] + ":" + sport + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] // TODO: the addr may be changed. + if !ok { + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, opts.Timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &muxSession{conn: conn} + tr.sessions[addr] = session + } + return session.conn, nil +} + +func (tr *mtlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("mtls: unrecognized connection") + } + if !ok || session.session == nil { + s, err := tr.initSession(opts.Addr, conn, opts) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + return cc, nil +} + +func (tr *mtlsTransporter) initSession(addr string, conn net.Conn, opts *HandshakeOptions) (*muxSession, error) { + if opts == nil { + opts = &HandshakeOptions{} + } + if opts.TLSConfig == nil { + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + conn, err := wrapTLSClient(conn, opts.TLSConfig) + if err != nil { + return nil, err + } + + // stream multiplex + smuxConfig := smux.DefaultConfig() + session, err := smux.Client(conn, smuxConfig) + if err != nil { + return nil, err + } + return &muxSession{conn: conn, session: session}, nil +} + +func (tr *mtlsTransporter) Multiplex() bool { + return true +} + type tlsListener struct { net.Listener } @@ -43,6 +157,94 @@ func TLSListener(addr string, config *tls.Config) (Listener, error) { return &tlsListener{ln}, nil } +type mtlsListener struct { + ln net.Listener + connChan chan net.Conn + errChan chan error +} + +// MTLSListener creates a Listener for multiplex-TLS proxy server. +func MTLSListener(addr string, config *tls.Config) (Listener, error) { + if config == nil { + config = DefaultTLSConfig + } + ln, err := tls.Listen("tcp", addr, config) + if err != nil { + return nil, err + } + + l := &mtlsListener{ + ln: ln, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + go l.listenLoop() + + return l, nil +} + +func (l *mtlsListener) listenLoop() { + for { + conn, err := l.ln.Accept() + if err != nil { + log.Log("[mtls] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.mux(conn) + } +} + +func (l *mtlsListener) mux(conn net.Conn) { + log.Logf("[mtls] %s - %s", conn.RemoteAddr(), l.Addr()) + smuxConfig := smux.DefaultConfig() + mux, err := smux.Server(conn, smuxConfig) + if err != nil { + log.Logf("[mtls] %s - %s : %s", conn.RemoteAddr(), l.Addr(), err) + return + } + defer mux.Close() + + log.Logf("[mtls] %s <-> %s", conn.RemoteAddr(), l.Addr()) + defer log.Logf("[mtls] %s >-< %s", conn.RemoteAddr(), l.Addr()) + + for { + stream, err := mux.AcceptStream() + if err != nil { + log.Log("[mtls] accept stream:", err) + return + } + + cc := &muxStreamConn{Conn: conn, stream: stream} + select { + case l.connChan <- cc: + default: + cc.Close() + log.Logf("[mtls] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + } + } +} + +func (l *mtlsListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} +func (l *mtlsListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *mtlsListener) Close() error { + return l.ln.Close() +} + // Wrap a net.Conn into a client tls connection, performing any // additional verification as needed. //