From 057a17be8757e9acfc7789726cdd4b4b509a53b9 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Wed, 28 Sep 2016 17:56:12 +0800 Subject: [PATCH] support http2 proxy --- conn.go | 32 ++++++++ http.go | 235 ++++++++++++++++++++++++++++++++++++++++++++++++------- http2.go | 116 --------------------------- main.go | 13 ++- util.go | 43 +++++++--- 5 files changed, 281 insertions(+), 158 deletions(-) delete mode 100644 http2.go diff --git a/conn.go b/conn.go index 7454af6..a9e9be0 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "github.com/golang/glog" "github.com/shadowsocks/shadowsocks-go/shadowsocks" "io" + "io/ioutil" "net" "net/http" "net/http/httputil" @@ -378,9 +379,40 @@ func (r *reqReader) Read(p []byte) (n int, err error) { } func Connect(addr string) (conn net.Conn, err error) { + if len(forwardArgs) > 0 { + last := forwardArgs[len(forwardArgs)-1] + if http2Client != nil && last.Protocol == "http2" { + return connectHttp2(http2Client, addr) + } + } return connectWithChain(addr, forwardArgs...) } +func connectHttp2(client *http.Client, host string) (net.Conn, error) { + pr, pw := io.Pipe() + u := url.URL{Scheme: "https", Host: host} + req, err := http.NewRequest(http.MethodConnect, u.String(), ioutil.NopCloser(pr)) + if err != nil { + return nil, err + } + req.ContentLength = -1 + if glog.V(LDEBUG) { + dump, _ := httputil.DumpRequest(req, false) + glog.Infoln(string(dump)) + } + resp, err := http2Client.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, errors.New(resp.Status) + } + conn := &Http2ClientConn{r: resp.Body, w: pw} + conn.remoteAddr, _ = net.ResolveTCPAddr("tcp", host) + return conn, nil +} + func connectWithChain(addr string, chain ...Args) (conn net.Conn, err error) { if !strings.Contains(addr, ":") { addr += ":80" diff --git a/http.go b/http.go index 1f5151e..14fa328 100644 --- a/http.go +++ b/http.go @@ -1,12 +1,21 @@ package main import ( + "bufio" + "crypto/tls" "encoding/base64" "github.com/golang/glog" + "golang.org/x/net/http2" + "io" "net" "net/http" "net/http/httputil" "strings" + "time" +) + +var ( + http2Client *http.Client ) func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { @@ -39,40 +48,14 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { return } - var c net.Conn - var err error - if len(forwardArgs) > 0 { last := forwardArgs[len(forwardArgs)-1] if last.Protocol == "http" || last.Protocol == "" { - c, _, err = forwardChain(forwardArgs...) - if err != nil { - glog.V(LWARNING).Infof("[http] %s -> %s : %s", conn.RemoteAddr(), last.Addr, err) - - b := []byte("HTTP/1.1 503 Service unavailable\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n") - glog.V(LDEBUG).Infof("[http] %s <- %s\n%s", conn.RemoteAddr(), last.Addr, string(b)) - conn.Write(b) - return - } - defer c.Close() - - if last.User != nil { - req.Header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(last.User.String()))) - } - - if err = req.Write(c); err != nil { - glog.V(LWARNING).Infof("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) - return - } - glog.V(LINFO).Infof("[http] %s <-> %s", conn.RemoteAddr(), req.Host) - Transport(conn, c) - glog.V(LINFO).Infof("[http] %s >-< %s", conn.RemoteAddr(), req.Host) + forwardHttpRequest(req, conn, arg) return } } - c, err = Connect(req.Host) + c, err := Connect(req.Host) if err != nil { glog.V(LWARNING).Infof("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) @@ -104,6 +87,186 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { glog.V(LINFO).Infof("[http] %s >-< %s", conn.RemoteAddr(), req.Host) } +func forwardHttpRequest(req *http.Request, conn net.Conn, arg Args) { + last := forwardArgs[len(forwardArgs)-1] + c, _, err := forwardChain(forwardArgs...) + if err != nil { + glog.V(LWARNING).Infof("[http] %s -> %s : %s", conn.RemoteAddr(), last.Addr, err) + + b := []byte("HTTP/1.1 503 Service unavailable\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n") + glog.V(LDEBUG).Infof("[http] %s <- %s\n%s", conn.RemoteAddr(), last.Addr, string(b)) + conn.Write(b) + return + } + defer c.Close() + + if last.User != nil { + req.Header.Set("Proxy-Authorization", + "Basic "+base64.StdEncoding.EncodeToString([]byte(last.User.String()))) + } + + if err = req.Write(c); err != nil { + glog.V(LWARNING).Infof("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) + return + } + glog.V(LINFO).Infof("[http] %s <-> %s", conn.RemoteAddr(), req.Host) + Transport(conn, c) + glog.V(LINFO).Infof("[http] %s >-< %s", conn.RemoteAddr(), req.Host) + return +} + +type Http2ClientConn struct { + r io.Reader + w io.Writer + localAddr net.Addr + remoteAddr net.Addr +} + +func (c *Http2ClientConn) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} + +func (c *Http2ClientConn) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *Http2ClientConn) Close() error { + if rc, ok := c.r.(io.ReadCloser); ok { + return rc.Close() + } + return nil +} + +func (c *Http2ClientConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *Http2ClientConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *Http2ClientConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *Http2ClientConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *Http2ClientConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// init http2 client with target http2 proxy server addr, and forward chain chain +func initHttp2Client(host string, chain ...Args) { + glog.V(LINFO).Infoln("init http2 client") + tr := http2.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + // replace the default dialer with our forward chain. + conn, err := connectWithChain(host, chain...) + if err != nil { + return conn, err + } + return tls.Client(conn, cfg), nil + }, + } + http2Client = &http.Client{Transport: &tr} +} + +func handlerHttp2Request(w http.ResponseWriter, req *http.Request) { + glog.V(LINFO).Infof("[http2] %s - %s", req.RemoteAddr, req.Host) + if glog.V(LDEBUG) { + dump, _ := httputil.DumpRequest(req, false) + glog.Infoln(string(dump)) + } + + c, err := Connect(req.Host) + if err != nil { + glog.V(LWARNING).Infof("[http2] %s -> %s : %s", req.RemoteAddr, req.Host, err) + w.Header().Set("Proxy-Agent", "gost/"+Version) + w.WriteHeader(http.StatusServiceUnavailable) + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + return + } + defer c.Close() + + glog.V(LINFO).Infof("[http2] %s <-> %s", req.RemoteAddr, req.Host) + errc := make(chan error, 2) + + if req.Method == http.MethodConnect { + w.Header().Set("Proxy-Agent", "gost/"+Version) + w.WriteHeader(http.StatusOK) + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + + // compatible with HTTP 1.x + if hj, ok := w.(http.Hijacker); ok && req.ProtoMajor == 1 { + // we take over the underly connection + conn, _, err := hj.Hijack() + if err != nil { + glog.V(LWARNING).Infof("[http2] %s -> %s : %s", req.RemoteAddr, req.Host, err) + return + } + defer conn.Close() + + go Pipe(conn, c, errc) + go Pipe(c, conn, errc) + } else { + go Pipe(req.Body, c, errc) + go Pipe(c, flushWriter{w}, errc) + } + + select { + case <-errc: + // glog.V(LWARNING).Infoln("exit", err) + } + } else { + req.Header.Set("Connection", "Keep-Alive") + if err = req.Write(c); err != nil { + glog.V(LWARNING).Infof("[http2] %s -> %s : %s", req.RemoteAddr, req.Host, err) + return + } + + resp, err := http.ReadResponse(bufio.NewReader(c), req) + if err != nil { + glog.V(LWARNING).Infoln(err) + return + } + defer resp.Body.Close() + + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + w.WriteHeader(resp.StatusCode) + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + + if _, err := io.Copy(flushWriter{w}, resp.Body); err != nil { + glog.V(LWARNING).Infof("[http2] %s <- %s : %s", req.RemoteAddr, req.Host, err) + } + } + + glog.V(LINFO).Infof("[http2] %s >-< %s", req.RemoteAddr, req.Host) +} + +func handleHttp2Transport(w http.ResponseWriter, req *http.Request) { + glog.V(LINFO).Infof("[http2] %s - %s", req.RemoteAddr, req.Host) + if glog.V(LDEBUG) { + dump, _ := httputil.DumpRequest(req, false) + glog.Infoln(string(dump)) + } +} + func basicAuth(authInfo string) (username, password string, ok bool) { if authInfo == "" { return @@ -124,3 +287,19 @@ func basicAuth(authInfo string) (username, password string, ok bool) { return cs[:s], cs[s+1:], true } + +type flushWriter struct { + w io.Writer +} + +func (fw flushWriter) Write(p []byte) (n int, err error) { + n, err = fw.w.Write(p) + if err != nil { + glog.V(LWARNING).Infoln("flush writer:", err) + return + } + if f, ok := fw.w.(http.Flusher); ok { + f.Flush() + } + return +} diff --git a/http2.go b/http2.go deleted file mode 100644 index 142dd2c..0000000 --- a/http2.go +++ /dev/null @@ -1,116 +0,0 @@ -package main - -import ( - "bufio" - "github.com/golang/glog" - "golang.org/x/net/http2" - "io" - //"net" - "net/http" - "net/http/httputil" -) - -func init() { - if glog.V(LDEBUG) { - http2.VerboseLogs = true - } -} - -func handlerHttp2Request(w http.ResponseWriter, req *http.Request) { - glog.V(LINFO).Infof("[http2] %s - %s", req.RemoteAddr, req.Host) - if glog.V(LDEBUG) { - dump, _ := httputil.DumpRequest(req, false) - glog.Infoln(string(dump)) - } - - c, err := Connect(req.Host) - if err != nil { - glog.V(LWARNING).Infof("[http2] %s -> %s : %s", req.RemoteAddr, req.Host, err) - w.Header().Set("Proxy-Agent", "gost/"+Version) - w.WriteHeader(http.StatusServiceUnavailable) - if fw, ok := w.(http.Flusher); ok { - fw.Flush() - } - return - } - defer c.Close() - - glog.V(LINFO).Infof("[http2] %s <-> %s", req.RemoteAddr, req.Host) - errc := make(chan error, 2) - - if req.Method == http.MethodConnect { - w.Header().Set("Proxy-Agent", "gost/"+Version) - w.WriteHeader(http.StatusOK) - if fw, ok := w.(http.Flusher); ok { - fw.Flush() - } - - // compatible with HTTP 1.x - if hj, ok := w.(http.Hijacker); ok && req.ProtoMajor == 1 { - // we take over the underly connection - conn, _, err := hj.Hijack() - if err != nil { - glog.V(LWARNING).Infof("[http2] %s -> %s : %s", req.RemoteAddr, req.Host, err) - return - } - defer conn.Close() - - go Pipe(conn, c, errc) - go Pipe(c, conn, errc) - } else { - go Pipe(req.Body, c, errc) - go Pipe(c, flushWriter{w}, errc) - } - - select { - case <-errc: - // glog.V(LWARNING).Infoln("exit", err) - } - } else { - req.Header.Set("Connection", "Keep-Alive") - if err = req.Write(c); err != nil { - glog.V(LWARNING).Infof("[http2] %s -> %s : %s", req.RemoteAddr, req.Host, err) - return - } - - resp, err := http.ReadResponse(bufio.NewReader(c), req) - if err != nil { - glog.V(LWARNING).Infoln(err) - return - } - defer resp.Body.Close() - - for k, v := range resp.Header { - for _, vv := range v { - w.Header().Add(k, vv) - } - } - w.WriteHeader(resp.StatusCode) - - if _, err := io.Copy(flushWriter{w}, resp.Body); err != nil { - glog.V(LWARNING).Infof("[http2] %s <- %s : %s", req.RemoteAddr, req.Host, err) - } - } - - glog.V(LINFO).Infof("[http2] %s >-< %s", req.RemoteAddr, req.Host) -} - -func handleHttp2Transport(w http.ResponseWriter, req *http.Request) { - glog.V(LINFO).Infof("[http2] %s - %s", req.RemoteAddr, req.Host) - if glog.V(LDEBUG) { - dump, _ := httputil.DumpRequest(req, false) - glog.Infoln(string(dump)) - } -} - -type flushWriter struct { - w io.Writer -} - -func (fw flushWriter) Write(p []byte) (n int, err error) { - n, err = fw.w.Write(p) - if f, ok := fw.w.(http.Flusher); ok { - f.Flush() - } - return -} diff --git a/main.go b/main.go index 70253b3..2b768c6 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "github.com/golang/glog" + "golang.org/x/net/http2" "os" "runtime" "sync" @@ -35,6 +36,10 @@ func init() { flag.Var(&forwardAddr, "F", "forward address, can make a forward chain") flag.BoolVar(&pv, "V", false, "print version") flag.Parse() + + if glog.V(LDEBUG) { + http2.VerboseLogs = true + } } func main() { @@ -50,12 +55,14 @@ func main() { } listenArgs = parseArgs(listenAddr) - forwardArgs = parseArgs(forwardAddr) - if len(listenArgs) == 0 { - glog.Exitln("no listen addr") + fmt.Fprintln(os.Stderr, "no listen address, please specify at least one -L parameter") + return } + forwardArgs = parseArgs(forwardAddr) + processForwardChain(forwardArgs...) + var wg sync.WaitGroup for _, args := range listenArgs { wg.Add(1) diff --git a/util.go b/util.go index c9448d1..343f768 100644 --- a/util.go +++ b/util.go @@ -29,8 +29,8 @@ func (ss *strSlice) Set(value string) error { // admin:123456@localhost:8080 type Args struct { Addr string // host:port - Protocol string // protocol: http/socks(5)/ss - Transport string // transport: ws(s)/tls/tcp/udp/rtcp/rudp + Protocol string // protocol: http/http2/socks5/ss + Transport string // transport: ws/wss/tls/tcp/udp/rtcp/rudp Remote string // remote address, used by tcp/udp port forwarding User *url.Userinfo // authentication for proxy Cert tls.Certificate // tls certificate @@ -73,15 +73,6 @@ func parseArgs(ss []string) (args []Args) { arg.Transport = schemes[1] } - switch arg.Protocol { - case "http", "http2", "socks", "socks5", "ss": - case "https": - arg.Protocol = "http" - arg.Transport = "tls" - default: - arg.Protocol = "" - } - switch arg.Transport { case "ws", "wss", "tls": case "https": @@ -97,12 +88,42 @@ func parseArgs(ss []string) (args []Args) { arg.Transport = "" } + switch arg.Protocol { + case "http", "socks", "socks5", "ss": + case "http2": + arg.Transport = "tls" // standard http2 proxy, only support http2 over tls + default: + arg.Protocol = "" + } + args = append(args, arg) } return } +func processForwardChain(chain ...Args) { + glog.V(LINFO).Infoln(chain) + if len(chain) == 0 { + return + } + length := len(chain) + c, last := chain[:length-1], chain[length-1] + + // http2 restrict: only last proxy can enable http2 + for i, _ := range c { + if c[i].Protocol == "http2" { + c[i].Protocol = "http" + } + if c[i].Transport == "http2" { + c[i].Transport = "" + } + } + if last.Protocol == "http2" || last.Transport == "http2" { + initHttp2Client(last.Addr, c...) + } +} + // Based on io.Copy, but the io.ErrShortWrite is ignored (mainly for websocket) func Copy(dst io.Writer, src io.Reader) (written int64, err error) { // b := make([]byte, 32*1024)