add path support for h2/h2c

This commit is contained in:
ginuerzh 2020-01-09 11:11:51 +08:00
parent 085b0b30e7
commit 9a24c06f96
2 changed files with 32 additions and 13 deletions

View File

@ -163,10 +163,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
case "http2": case "http2":
tr = gost.HTTP2Transporter(tlsCfg) tr = gost.HTTP2Transporter(tlsCfg)
case "h2": case "h2":
tr = gost.H2Transporter(tlsCfg) tr = gost.H2Transporter(tlsCfg, node.Get("path"))
case "h2c": case "h2c":
tr = gost.H2CTransporter() tr = gost.H2CTransporter(node.Get("path"))
case "obfs4": case "obfs4":
tr = gost.Obfs4Transporter() tr = gost.Obfs4Transporter()
case "ohttp": case "ohttp":
@ -348,9 +347,9 @@ func (r *route) GenRouters() ([]router, error) {
case "http2": case "http2":
ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) ln, err = gost.HTTP2Listener(node.Addr, tlsCfg)
case "h2": case "h2":
ln, err = gost.H2Listener(node.Addr, tlsCfg) ln, err = gost.H2Listener(node.Addr, tlsCfg, node.Get("path"))
case "h2c": case "h2c":
ln, err = gost.H2CListener(node.Addr) ln, err = gost.H2CListener(node.Addr, node.Get("path"))
case "tcp": case "tcp":
// Directly use SSH port forwarding if the last chain node is forward+ssh // Directly use SSH port forwarding if the last chain node is forward+ssh
if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" {

View File

@ -180,27 +180,31 @@ func (tr *http2Transporter) Multiplex() bool {
return true return true
} }
// TODO: clean closed clients
type h2Transporter struct { type h2Transporter struct {
clients map[string]*http.Client clients map[string]*http.Client
clientMutex sync.Mutex clientMutex sync.Mutex
tlsConfig *tls.Config tlsConfig *tls.Config
path string
} }
// H2Transporter creates a Transporter that is used by HTTP2 h2 tunnel client. // H2Transporter creates a Transporter that is used by HTTP2 h2 tunnel client.
func H2Transporter(config *tls.Config) Transporter { func H2Transporter(config *tls.Config, path string) Transporter {
if config == nil { if config == nil {
config = &tls.Config{InsecureSkipVerify: true} config = &tls.Config{InsecureSkipVerify: true}
} }
return &h2Transporter{ return &h2Transporter{
clients: make(map[string]*http.Client), clients: make(map[string]*http.Client),
tlsConfig: config, tlsConfig: config,
path: path,
} }
} }
// H2CTransporter creates a Transporter that is used by HTTP2 h2c tunnel client. // H2CTransporter creates a Transporter that is used by HTTP2 h2c tunnel client.
func H2CTransporter() Transporter { func H2CTransporter(path string) Transporter {
return &h2Transporter{ return &h2Transporter{
clients: make(map[string]*http.Client), clients: make(map[string]*http.Client),
path: path,
} }
} }
@ -251,6 +255,11 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err
Host: addr, Host: addr,
ContentLength: -1, ContentLength: -1,
} }
if tr.path != "" {
req.Method = http.MethodGet
req.URL.Path = tr.path
}
if Debug { if Debug {
dump, _ := httputil.DumpRequest(req, false) dump, _ := httputil.DumpRequest(req, false)
log.Log("[http2]", string(dump)) log.Log("[http2]", string(dump))
@ -650,12 +659,13 @@ type h2Listener struct {
net.Listener net.Listener
server *http2.Server server *http2.Server
tlsConfig *tls.Config tlsConfig *tls.Config
path string
connChan chan net.Conn connChan chan net.Conn
errChan chan error errChan chan error
} }
// H2Listener creates a Listener for HTTP2 h2 tunnel server. // H2Listener creates a Listener for HTTP2 h2 tunnel server.
func H2Listener(addr string, config *tls.Config) (Listener, error) { func H2Listener(addr string, config *tls.Config, path string) (Listener, error) {
ln, err := net.Listen("tcp", addr) ln, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -672,6 +682,7 @@ func H2Listener(addr string, config *tls.Config) (Listener, error) {
IdleTimeout: 5 * time.Minute, IdleTimeout: 5 * time.Minute,
}, },
tlsConfig: config, tlsConfig: config,
path: path,
connChan: make(chan net.Conn, 1024), connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1), errChan: make(chan error, 1),
} }
@ -681,7 +692,7 @@ func H2Listener(addr string, config *tls.Config) (Listener, error) {
} }
// H2CListener creates a Listener for HTTP2 h2c tunnel server. // H2CListener creates a Listener for HTTP2 h2c tunnel server.
func H2CListener(addr string) (Listener, error) { func H2CListener(addr string, path string) (Listener, error) {
ln, err := net.Listen("tcp", addr) ln, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -691,6 +702,7 @@ func H2CListener(addr string) (Listener, error) {
server: &http2.Server{ server: &http2.Server{
// MaxConcurrentStreams: 1000, // MaxConcurrentStreams: 1000,
}, },
path: path,
connChan: make(chan net.Conn, 1024), connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1), errChan: make(chan error, 1),
} }
@ -733,7 +745,8 @@ func (l *h2Listener) handleLoop(conn net.Conn) {
} }
func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
log.Logf("[http2] %s %s - %s %s", r.Method, r.RemoteAddr, r.Host, r.Proto) log.Logf("[http2] %s -> %s %s %s %s",
r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto)
if Debug { if Debug {
dump, _ := httputil.DumpRequest(r, false) dump, _ := httputil.DumpRequest(r, false)
log.Log("[http2]", string(dump)) log.Log("[http2]", string(dump))
@ -741,7 +754,8 @@ func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Proxy-Agent", "gost/"+Version) w.Header().Set("Proxy-Agent", "gost/"+Version)
conn, err := l.upgrade(w, r) conn, err := l.upgrade(w, r)
if err != nil { if err != nil {
log.Logf("[http2] %s %s - %s %s", r.Method, r.RemoteAddr, r.Host, r.Proto) log.Logf("[http2] %s - %s %s %s %s: %s",
r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err)
return return
} }
select { select {
@ -755,10 +769,16 @@ func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
} }
func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*http2Conn, error) { func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*http2Conn, error) {
if r.Method != http.MethodConnect { if l.path == "" && r.Method != http.MethodConnect {
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
return nil, errors.New("Method not allowed") return nil, errors.New("method not allowed")
} }
if l.path != "" && r.RequestURI != l.path {
w.WriteHeader(http.StatusBadRequest)
return nil, errors.New("bad request")
}
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok { if fw, ok := w.(http.Flusher); ok {
fw.Flush() // write header to client fw.Flush() // write header to client