From 89d2b7b46614e4d55109b65921fe4351614a4db1 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Tue, 25 Jul 2017 23:12:14 +0800 Subject: [PATCH] change Transporter interface --- gost/cli/cli.go | 8 +-- gost/client.go | 67 +++++++++++++---- gost/http2.go | 186 ++++++++++++++++++++++++++++++++---------------- gost/kcp.go | 10 ++- gost/srv/srv.go | 4 +- gost/tls.go | 18 +++-- gost/ws.go | 48 ++++++++----- 7 files changed, 232 insertions(+), 109 deletions(-) diff --git a/gost/cli/cli.go b/gost/cli/cli.go index 2100974..57ec9ed 100644 --- a/gost/cli/cli.go +++ b/gost/cli/cli.go @@ -89,15 +89,15 @@ func main() { }, */ - // http2 + // http2+tls, http2+tcp gost.Node{ Addr: "127.0.0.1:1443", Client: gost.NewClient( gost.HTTP2Connector(url.UserPassword("admin", "123456")), gost.HTTP2Transporter( nil, - &tls.Config{InsecureSkipVerify: true}, - time.Second*60, + &tls.Config{InsecureSkipVerify: true}, // or nil, will use h2c mode (http2+tcp). + time.Second*1, ), ), }, @@ -138,6 +138,6 @@ func main() { rb, _ = httputil.DumpResponse(resp, true) log.Println(string(rb)) - time.Sleep(100 * time.Millisecond) + time.Sleep(1000 * time.Millisecond) } } diff --git a/gost/client.go b/gost/client.go index 9f458ba..1de131a 100644 --- a/gost/client.go +++ b/gost/client.go @@ -1,6 +1,7 @@ package gost import ( + "crypto/tls" "net" ) @@ -22,13 +23,13 @@ func NewClient(c Connector, tr Transporter) *Client { } // Dial connects to the target address. -func (c *Client) Dial(addr string) (net.Conn, error) { - return c.Transporter.Dial(addr) +func (c *Client) Dial(addr string, options ...DialOption) (net.Conn, error) { + return c.Transporter.Dial(addr, options...) } // Handshake performs a handshake with the proxy over connection conn. -func (c *Client) Handshake(conn net.Conn) (net.Conn, error) { - return c.Transporter.Handshake(conn) +func (c *Client) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return c.Transporter.Handshake(conn, options...) } // Connect connects to the address addr via the proxy over connection conn. @@ -40,13 +41,13 @@ func (c *Client) Connect(conn net.Conn, addr string) (net.Conn, error) { var DefaultClient = NewClient(HTTPConnector(nil), TCPTransporter()) // Dial connects to the address addr via the DefaultClient. -func Dial(addr string) (net.Conn, error) { - return DefaultClient.Dial(addr) +func Dial(addr string, options ...DialOption) (net.Conn, error) { + return DefaultClient.Dial(addr, options...) } // Handshake performs a handshake via the DefaultClient. -func Handshake(conn net.Conn) (net.Conn, error) { - return DefaultClient.Handshake(conn) +func Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return DefaultClient.Handshake(conn, options...) } // Connect connects to the address addr via the DefaultClient. @@ -61,8 +62,8 @@ type Connector interface { // Transporter is responsible for handshaking with the proxy server. type Transporter interface { - Dial(addr string) (net.Conn, error) - Handshake(conn net.Conn) (net.Conn, error) + Dial(addr string, options ...DialOption) (net.Conn, error) + Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) // Indicate that the Transporter supports multiplex Multiplex() bool } @@ -75,14 +76,56 @@ func TCPTransporter() Transporter { return &tcpTransporter{} } -func (tr *tcpTransporter) Dial(addr string) (net.Conn, error) { +func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { return net.Dial("tcp", addr) } -func (tr *tcpTransporter) Handshake(conn net.Conn) (net.Conn, error) { +func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { return conn, nil } func (tr *tcpTransporter) Multiplex() bool { return false } + +// DialOptions describes the options for dialing. +type DialOptions struct { +} + +// DialOption allows a common way to set dial options. +type DialOption func(opts *DialOptions) + +// HandshakeOptions describes the options for handshake. +type HandshakeOptions struct { + Addr string + TLSConfig *tls.Config + WSOptions *WSOptions + KCPConfig *KCPConfig +} + +// HandshakeOption allows a common way to set handshake options. +type HandshakeOption func(opts *HandshakeOptions) + +func AddrHandshakeOption(addr string) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Addr = addr + } +} + +func TLSConfigHandshakeOption(config *tls.Config) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.TLSConfig = config + } +} + +func WSOptionsHandshakeOption(options *WSOptions) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.WSOptions = options + } +} + +func KCPConfigHandshakeOption(config *KCPConfig) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.KCPConfig = config + } +} diff --git a/gost/http2.go b/gost/http2.go index ad09d6e..365ef53 100644 --- a/gost/http2.go +++ b/gost/http2.go @@ -79,8 +79,8 @@ type http2Transporter struct { tlsConfig *tls.Config tr *http2.Transport chain *Chain - conns map[string]*http2.ClientConn - connMutex sync.Mutex + sessions map[string]*http2Session + sessionMutex sync.Mutex pingInterval time.Duration } @@ -97,93 +97,49 @@ func HTTP2Transporter(chain *Chain, config *tls.Config, ping time.Duration) Tran tr: new(http2.Transport), chain: chain, pingInterval: ping, - conns: make(map[string]*http2.ClientConn), + sessions: make(map[string]*http2Session), } } -func (tr *http2Transporter) Dial(addr string) (net.Conn, error) { - tr.connMutex.Lock() - conn, ok := tr.conns[addr] +func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn, error) { + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + session, ok := tr.sessions[addr] if !ok { - cc, err := tr.chain.Dial(addr) + conn, err := tr.chain.Dial(addr) if err != nil { - tr.connMutex.Unlock() return nil, err } if tr.tlsConfig != nil { - tc := tls.Client(cc, tr.tlsConfig) + tc := tls.Client(conn, tr.tlsConfig) if err := tc.Handshake(); err != nil { - tr.connMutex.Unlock() return nil, err } - cc = tc + conn = tc } - conn, err = tr.tr.NewClientConn(cc) + cc, err := tr.tr.NewClientConn(conn) if err != nil { - tr.connMutex.Unlock() return nil, err } - tr.conns[addr] = conn - go tr.ping(tr.pingInterval, addr, conn) + session = newHTTP2Session(conn, cc, tr.pingInterval) + tr.sessions[addr] = session } - tr.connMutex.Unlock() - if !conn.CanTakeNewRequest() { - tr.connMutex.Lock() - delete(tr.conns, addr) // TODO: we could re-connect to the addr automatically. - tr.connMutex.Unlock() + if !session.Healthy() { + session.Close() + delete(tr.sessions, addr) // TODO: we could re-connect to the addr automatically. return nil, errors.New("connection is dead") } return &http2DummyConn{ raddr: addr, - conn: conn, + conn: session.clientConn, }, nil } -func (tr *http2Transporter) ping(interval time.Duration, addr string, conn *http2.ClientConn) { - if interval <= 0 { - return - } - log.Log("[http2] ping is enabled, interval:", interval) - - baseCtx := context.Background() - t := time.NewTicker(interval) - retries := PingRetries - for { - select { - case <-t.C: - if !conn.CanTakeNewRequest() { - return - } - ctx, cancel := context.WithTimeout(baseCtx, PingTimeout) - if err := conn.Ping(ctx); err != nil { - log.Logf("[http2] ping: %s", err) - if retries > 0 { - retries-- - log.Log("[http2] retry ping") - cancel() - continue - } - - // connection is dead, remove it. - tr.connMutex.Lock() - delete(tr.conns, addr) - tr.connMutex.Unlock() - - cancel() - return - } - - cancel() - retries = PingRetries - } - } -} - -func (tr *http2Transporter) Handshake(conn net.Conn) (net.Conn, error) { +func (tr *http2Transporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { return conn, nil } @@ -333,6 +289,112 @@ func (h *http2Handler) handleFunc(w http.ResponseWriter, r *http.Request) { log.Logf("[http2] %s >-< %s", r.RemoteAddr, target) } +type http2Listener struct { + ln net.Listener +} + +// HTTP2Listener creates a Listener for server using HTTP2 as transport. +func HTTP2Listener(addr string, config *tls.Config) (Listener, error) { + var ln net.Listener + var err error + + if config != nil { + ln, err = tls.Listen("tcp", addr, config) + } else { + ln, err = net.Listen("tcp", addr) + } + if err != nil { + return nil, err + } + return ln, err + //return &http2Listener{ln: ln}, nil +} + +type http2Session struct { + conn net.Conn + clientConn *http2.ClientConn + closeChan chan struct{} + pingChan chan struct{} +} + +func newHTTP2Session(conn net.Conn, clientConn *http2.ClientConn, interval time.Duration) *http2Session { + session := &http2Session{ + conn: conn, + clientConn: clientConn, + closeChan: make(chan struct{}), + } + if interval > 0 { + session.pingChan = make(chan struct{}) + go session.Ping(interval) + } + return session +} + +func (s *http2Session) Ping(interval time.Duration) { + if interval <= 0 { + return + } + + defer close(s.pingChan) + log.Log("[http2] ping is enabled, interval:", interval) + + baseCtx := context.Background() + t := time.NewTicker(interval) + retries := PingRetries + for { + select { + case <-t.C: + if Debug { + log.Log("[http2] sending ping") + } + if !s.clientConn.CanTakeNewRequest() { + log.Logf("[http2] connection is dead") + return + } + ctx, cancel := context.WithTimeout(baseCtx, PingTimeout) + if err := s.clientConn.Ping(ctx); err != nil { + log.Logf("[http2] ping: %s", err) + if retries > 0 { + retries-- + log.Log("[http2] retry ping") + cancel() + continue + } + + cancel() + return + } + + if Debug { + log.Log("[http2] ping OK") + } + cancel() + retries = PingRetries + + case <-s.closeChan: + return + } + } +} + +func (s *http2Session) Healthy() bool { + select { + case <-s.pingChan: + return false + default: + } + return s.clientConn.CanTakeNewRequest() +} + +func (s *http2Session) Close() error { + select { + case <-s.closeChan: + default: + close(s.closeChan) + } + return nil +} + // HTTP2 connection, wrapped up just like a net.Conn type http2Conn struct { r io.Reader diff --git a/gost/kcp.go b/gost/kcp.go index 95d3ad4..694a97f 100644 --- a/gost/kcp.go +++ b/gost/kcp.go @@ -179,25 +179,23 @@ func KCPTransporter(config *KCPConfig) Transporter { } } -func (tr *kcpTransporter) Dial(addr string) (conn net.Conn, err error) { +func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + session, ok := tr.sessions[addr] if !ok { session, err = tr.dial(addr, tr.config) if err != nil { - tr.sessionMutex.Unlock() return } tr.sessions[addr] = session } - tr.sessionMutex.Unlock() conn, err = session.GetConn() if err != nil { - tr.sessionMutex.Lock() session.Close() delete(tr.sessions, addr) // TODO: we could obtain a new session automatically. - tr.sessionMutex.Unlock() } return } @@ -241,7 +239,7 @@ func (tr *kcpTransporter) dial(addr string, config *KCPConfig) (*kcpSession, err return &kcpSession{conn: conn, session: session}, nil } -func (tr *kcpTransporter) Handshake(conn net.Conn) (net.Conn, error) { +func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { return conn, nil } diff --git a/gost/srv/srv.go b/gost/srv/srv.go index 77185a1..78e8d64 100644 --- a/gost/srv/srv.go +++ b/gost/srv/srv.go @@ -189,8 +189,8 @@ func http2Server() { s.Handle(gost.HTTP2Handler( gost.UsersHandlerOption(url.UserPassword("admin", "123456")), )) - ln, err := gost.TLSListener(":1443", tlsConfig()) - // ln, err := gost.TCPListener(":1443") + ln, err := gost.TLSListener(":1443", tlsConfig()) // HTTP2 h2 mode + // ln, err := gost.TCPListener(":1443") // HTTP2 h2c mode if err != nil { log.Fatal(err) } diff --git a/gost/tls.go b/gost/tls.go index 9a9bc54..877f7ef 100644 --- a/gost/tls.go +++ b/gost/tls.go @@ -6,21 +6,27 @@ import ( ) type tlsTransporter struct { - tlsConfig *tls.Config } // TLSTransporter creates a Transporter that is used by TLS proxy client. // It accepts a TLS config for TLS handshake. -func TLSTransporter(cfg *tls.Config) Transporter { - return &tlsTransporter{tlsConfig: cfg} +func TLSTransporter() Transporter { + return &tlsTransporter{} } -func (tr *tlsTransporter) Dial(addr string) (net.Conn, error) { +func (tr *tlsTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { return net.Dial("tcp", addr) } -func (tr *tlsTransporter) Handshake(conn net.Conn) (net.Conn, error) { - return tls.Client(conn, tr.tlsConfig), nil +func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + if opts.TLSConfig == nil { + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + return tls.Client(conn, opts.TLSConfig), nil } func (tr *tlsTransporter) Multiplex() bool { diff --git a/gost/ws.go b/gost/ws.go index 9f9d615..778bc80 100644 --- a/gost/ws.go +++ b/gost/ws.go @@ -27,14 +27,14 @@ type websocketConn struct { rb []byte } -func websocketClientConn(url string, conn net.Conn, options *WSOptions) (net.Conn, error) { +func websocketClientConn(url string, conn net.Conn, tlsConfig *tls.Config, options *WSOptions) (net.Conn, error) { if options == nil { options = &WSOptions{} } dialer := websocket.Dialer{ ReadBufferSize: options.ReadBufferSize, WriteBufferSize: options.WriteBufferSize, - TLSClientConfig: options.TLSConfig, + TLSClientConfig: tlsConfig, HandshakeTimeout: options.HandshakeTimeout, EnableCompression: options.EnableCompression, NetDial: func(net, addr string) (net.Conn, error) { @@ -98,26 +98,31 @@ func (c *websocketConn) SetWriteDeadline(t time.Time) error { } type wsTransporter struct { - addr string options *WSOptions } // WSTransporter creates a Transporter that is used by websocket proxy client. -func WSTransporter(addr string, opts *WSOptions) Transporter { +func WSTransporter(opts *WSOptions) Transporter { return &wsTransporter{ - addr: addr, options: opts, } } -func (tr *wsTransporter) Dial(addr string) (net.Conn, error) { - tr.addr = addr // NOTE: the addr must match the initial tr.addr +func (tr *wsTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { return net.Dial("tcp", addr) } -func (tr *wsTransporter) Handshake(conn net.Conn) (net.Conn, error) { - url := url.URL{Scheme: "ws", Host: tr.addr, Path: "/ws"} - return websocketClientConn(url.String(), conn, tr.options) +func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + wsOptions := tr.options + if opts.WSOptions != nil { + wsOptions = opts.WSOptions + } + url := url.URL{Scheme: "ws", Host: opts.Addr, Path: "/ws"} + return websocketClientConn(url.String(), conn, nil, wsOptions) } func (tr *wsTransporter) Multiplex() bool { @@ -125,25 +130,34 @@ func (tr *wsTransporter) Multiplex() bool { } type wssTransporter struct { - addr string options *WSOptions } // WSSTransporter creates a Transporter that is used by websocket secure proxy client. -func WSSTransporter(addr string, opts *WSOptions) Transporter { +func WSSTransporter(opts *WSOptions) Transporter { return &wssTransporter{ - addr: addr, options: opts, } } -func (tr *wssTransporter) Dial(addr string) (net.Conn, error) { +func (tr *wssTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { return net.Dial("tcp", addr) } -func (tr *wssTransporter) Handshake(conn net.Conn) (net.Conn, error) { - url := url.URL{Scheme: "wss", Host: tr.addr, Path: "/ws"} - return websocketClientConn(url.String(), conn, tr.options) +func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + wsOptions := tr.options + if opts.WSOptions != nil { + wsOptions = opts.WSOptions + } + if opts.TLSConfig == nil { + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + url := url.URL{Scheme: "wss", Host: opts.Addr, Path: "/ws"} + return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) } func (tr *wssTransporter) Multiplex() bool {