diff --git a/gost/examples/bench/cli.go b/gost/examples/bench/cli.go index 47948f1..911d226 100644 --- a/gost/examples/bench/cli.go +++ b/gost/examples/bench/cli.go @@ -152,8 +152,10 @@ func main() { gost.Node{ Addr: "localhost:8443", Client: &gost.Client{ - Connector: gost.HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: gost.H2Transporter(), + // Connector: gost.HTTPConnector(url.UserPassword("admin", "123456")), + Connector: gost.SOCKS5Connector(url.UserPassword("admin", "123456")), + // Transporter: gost.H2CTransporter(), // HTTP2 h2c mode + Transporter: gost.H2Transporter(), // HTTP2 h2 }, }, ) diff --git a/gost/examples/bench/srv.go b/gost/examples/bench/srv.go index 0306f77..2f0fe86 100644 --- a/gost/examples/bench/srv.go +++ b/gost/examples/bench/srv.go @@ -236,12 +236,16 @@ func http2Server() { func http2TunnelServer() { s := &gost.Server{} ln, err := gost.H2Listener(":8443", tlsConfig()) // HTTP2 h2 mode - // ln, err := gost.H2Listener(":8443", nil) // HTTP2 h2c mode + // ln, err := gost.H2CListener(":8443") // HTTP2 h2c mode if err != nil { log.Fatal(err) } - h := gost.HTTPHandler( + // h := gost.HTTPHandler( + // gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + // ) + h := gost.SOCKS5Handler( gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + gost.TLSConfigHandlerOption(tlsConfig()), ) log.Fatal(s.Serve(ln, h)) } diff --git a/gost/examples/http2/http2.go b/gost/examples/http2/http2.go index 951bf77..5e185e6 100644 --- a/gost/examples/http2/http2.go +++ b/gost/examples/http2/http2.go @@ -16,14 +16,12 @@ var ( keyFile, certFile string laddr string user, passwd string - tlsEnabled bool ) func init() { log.SetFlags(log.LstdFlags | log.Lshortfile) flag.StringVar(&laddr, "L", ":1443", "listen address") - flag.BoolVar(&tlsEnabled, "tls", true, "enable TLS (h2)") flag.StringVar(&user, "u", "", "username") flag.StringVar(&passwd, "p", "", "password") flag.BoolVar(&quiet, "q", false, "quiet mode") @@ -43,9 +41,17 @@ func main() { } func http2Server() { + cert, er := tls.LoadX509KeyPair(certFile, keyFile) + if er != nil { + log.Println(er) + cert, er = tls.X509KeyPair(rawCert, rawKey) + if er != nil { + panic(er) + } + } s := &gost.Server{} - ln, err := gost.TCPListener(laddr) + ln, err := gost.HTTP2Listener(laddr, &tls.Config{Certificates: []tls.Certificate{cert}}) if err != nil { log.Fatal(err) } @@ -55,22 +61,9 @@ func http2Server() { users = append(users, url.UserPassword(user, passwd)) } - var tlsConfig *tls.Config - if tlsEnabled { - cert, er := tls.LoadX509KeyPair(certFile, keyFile) - if er != nil { - log.Println(er) - cert, er = tls.X509KeyPair(rawCert, rawKey) - if er != nil { - panic(er) - } - } - tlsConfig = &tls.Config{Certificates: []tls.Certificate{cert}} - } h := gost.HTTP2Handler( gost.UsersHandlerOption(users...), gost.AddrHandlerOption(laddr), - gost.TLSConfigHandlerOption(tlsConfig), ) log.Fatal(s.Serve(ln, h)) } diff --git a/gost/http2.go b/gost/http2.go index 8cd6f02..f1b3f64 100644 --- a/gost/http2.go +++ b/gost/http2.go @@ -40,7 +40,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { u := &url.URL{ Host: addr, } - req, err := http.NewRequest("CONNECT", u.String(), ioutil.NopCloser(pr)) + req, err := http.NewRequest(http.MethodConnect, u.String(), ioutil.NopCloser(pr)) if err != nil { log.Logf("[http2] %s - %s : %s", cc.raddr, addr, err) return nil, err @@ -161,6 +161,12 @@ func H2Transporter() Transporter { } } +func H2CTransporter() Transporter { + return &h2Transporter{ + clients: make(map[string]*http.Client), + } +} + func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, error) { opts := &DialOptions{} for _, option := range options { @@ -177,13 +183,16 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err if err != nil { return nil, err } - if cfg == nil { + if tr.tlsConfig == nil { return conn, nil } - return wrapTLSClient(conn, cfg) + return wrapTLSClient(conn, tr.tlsConfig) }, } - client = &http.Client{Transport: &transport} + client = &http.Client{ + Transport: &transport, + Timeout: opts.Timeout, + } tr.clients[addr] = client } tr.clientMutex.Unlock() @@ -240,8 +249,6 @@ func (tr *h2Transporter) Multiplex() bool { } type http2Handler struct { - base *http.Server - server *http2.Server options *HandlerOptions } @@ -254,47 +261,27 @@ func HTTP2Handler(opts ...HandlerOption) Handler { opt(h.options) } - h.base = &http.Server{ - Addr: h.options.Addr, - TLSConfig: h.options.TLSConfig, - Handler: http.HandlerFunc(h.handleFunc), - } - h.server = new(http2.Server) - if err := http2.ConfigureServer(h.base, h.server); err != nil { - log.Log("[http2]", err) - } return h } func (h *http2Handler) Handle(conn net.Conn) { defer conn.Close() - if h.options.TLSConfig != nil { - conn = tls.Server(conn, h.options.TLSConfig) + h2c, ok := conn.(*http2ServerConn) + if !ok { + log.Log("[http2] wrong connection type") + return } - if tc, ok := conn.(*tls.Conn); ok { - // NOTE: HTTP2 server will check the TLS version, - // so we must ensure that the TLS connection is handshake completed. - if err := tc.Handshake(); err != nil { - log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - } - - opt := http2.ServeConnOpts{ - BaseConfig: h.base, - Handler: http.HandlerFunc(h.handleFunc), - } - h.server.ServeConn(conn, &opt) + h.roundTrip(h2c.w, h2c.r) } -func (h *http2Handler) handleFunc(w http.ResponseWriter, r *http.Request) { - // target := r.Header.Get("Gost-Target") // compitable with old version - // if target == "" { - // target = r.Host - // } - target := r.Host +func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { + target := r.Header.Get("Gost-Target") // compitable with old version + if target == "" { + target = r.Host + } + // target := r.Host if !strings.Contains(target, ":") { target += ":80" } @@ -398,6 +385,74 @@ func (h *http2Handler) handleFunc(w http.ResponseWriter, r *http.Request) { log.Logf("[http2] %s >-< %s", r.RemoteAddr, target) } +type http2Listener struct { + server *http.Server + connChan chan *http2ServerConn + errChan chan error +} + +func HTTP2Listener(addr string, config *tls.Config) (Listener, error) { + l := &http2Listener{ + connChan: make(chan *http2ServerConn, 1024), + errChan: make(chan error, 1), + } + server := &http.Server{ + Addr: addr, + Handler: http.HandlerFunc(l.handleFunc), + TLSConfig: config, + } + if err := http2.ConfigureServer(server, nil); err != nil { + return nil, err + } + l.server = server + go server.ListenAndServeTLS("", "") + + return l, nil +} + +func (l *http2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { + conn := &http2ServerConn{ + r: r, + w: w, + closed: make(chan struct{}), + } + select { + case l.connChan <- conn: + default: + log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr) + return + } + + <-conn.closed +} + +func (l *http2Listener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errChan: + if err == nil { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *http2Listener) Addr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", l.server.Addr) + return addr +} + +func (l *http2Listener) Close() (err error) { + select { + case <-l.errChan: + default: + err = l.server.Close() + l.errChan <- err + close(l.errChan) + } + return nil +} + type h2Listener struct { net.Listener server *http2.Server @@ -406,6 +461,7 @@ type h2Listener struct { errChan chan error } +// H2Listener creates a Listener for HTTP2 h2 tunnel server. func H2Listener(addr string, config *tls.Config) (Listener, error) { ln, err := net.Listen("tcp", addr) if err != nil { @@ -414,8 +470,9 @@ func H2Listener(addr string, config *tls.Config) (Listener, error) { l := &h2Listener{ Listener: ln, server: &http2.Server{ - MaxConcurrentStreams: 1000, + // MaxConcurrentStreams: 1000, PermitProhibitedCipherSuites: true, + IdleTimeout: 5 * time.Minute, }, tlsConfig: config, connChan: make(chan net.Conn, 1024), @@ -426,8 +483,23 @@ func H2Listener(addr string, config *tls.Config) (Listener, error) { return l, nil } +// H2CListener creates a Listener for HTTP2 h2c tunnel server. func H2CListener(addr string) (Listener, error) { - return H2Listener(addr, nil) + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + l := &h2Listener{ + Listener: ln, + server: &http2.Server{ + // MaxConcurrentStreams: 1000, + }, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + go l.listenLoop() + + return l, nil } func (l *h2Listener) listenLoop() { @@ -482,7 +554,7 @@ func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) } - <-conn.closed // wait for streaming + <-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed } func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*http2Conn, error) { @@ -662,6 +734,52 @@ func (c *http2Conn) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } +// a dummy HTTP2 server conn used by HTTP2 handler +type http2ServerConn struct { + r *http.Request + w http.ResponseWriter + closed chan struct{} +} + +func (c *http2ServerConn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *http2ServerConn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *http2ServerConn) Close() error { + select { + case <-c.closed: + default: + close(c.closed) + } + return nil +} + +func (c *http2ServerConn) LocalAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", c.r.Host) + return addr +} + +func (c *http2ServerConn) RemoteAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr) + return addr +} + +func (c *http2ServerConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2ServerConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2ServerConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + // Dummy HTTP2 connection. type http2DummyConn struct { raddr string