From daf7c32cb85a0eda98e779fcbb5385023e69c30c Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Wed, 21 Sep 2016 17:36:02 +0800 Subject: [PATCH] forward http request directly when end of chain is http proxy --- conn.go | 20 +++++++++++--------- http.go | 41 +++++++++++++++++++++++++++++++++++------ util.go | 26 ++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 15 deletions(-) diff --git a/conn.go b/conn.go index cec2338..e6906eb 100644 --- a/conn.go +++ b/conn.go @@ -77,10 +77,7 @@ func listenAndServe(arg Args) error { continue } - if tc, ok := conn.(*net.TCPConn); ok { - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(time.Second * 180) - } + setKeepAlive(conn, keepAliveTime) go handleConn(conn, arg) } @@ -104,6 +101,8 @@ func listenAndServeTcpForward(arg Args) error { glog.V(LWARNING).Infoln(err) continue } + setKeepAlive(conn, keepAliveTime) + go handleTcpForward(conn, raddr) } } @@ -361,15 +360,19 @@ func (r *reqReader) Read(p []byte) (n int, err error) { } func Connect(addr string) (conn net.Conn, err error) { + return connectWithChain(addr, forwardArgs...) +} + +func connectWithChain(addr string, chain ...Args) (conn net.Conn, err error) { if !strings.Contains(addr, ":") { addr += ":80" } - if len(forwardArgs) == 0 { + if len(chain) == 0 { return net.DialTimeout("tcp", addr, time.Second*90) } var end Args - conn, end, err = forwardChain(forwardArgs...) + conn, end, err = forwardChain(chain...) if err != nil { return nil, err } @@ -394,9 +397,7 @@ func forwardChain(chain ...Args) (conn net.Conn, end Args, err error) { return } - tc := conn.(*net.TCPConn) - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(time.Second * 180) // 3min + setKeepAlive(conn, keepAliveTime) c, err := forward(conn, end) if err != nil { @@ -453,6 +454,7 @@ func forward(conn net.Conn, arg Args) (net.Conn, error) { case "tls": // tls connection tlsUsed = true conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) + // conn = tls.Client(conn, &tls.Config{ServerName: "ice139.com"}) case "tcp": fallthrough default: diff --git a/http.go b/http.go index 0d45960..36f4449 100644 --- a/http.go +++ b/http.go @@ -42,7 +42,40 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { return } - c, err := Connect(req.Host) + var c net.Conn + var err error + + if len(forwardArgs) > 0 { + last := forwardArgs[len(forwardArgs)-1] + if last.Protocol == "http" || last.Protocol == "" { + c, _, err = forwardChain(forwardArgs...) + if err != nil { + glog.V(LWARNING).Infof("[http] %s -> %s : %s", conn.RemoteAddr(), last.Addr, err) + + b := []byte("HTTP/1.1 503 Service unavailable\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n") + glog.V(LDEBUG).Infof("[http] %s <- %s\n%s", conn.RemoteAddr(), last.Addr, string(b)) + conn.Write(b) + return + } + defer c.Close() + + if last.User != nil { + req.Header.Set("Proxy-Authorization", + "Basic "+base64.StdEncoding.EncodeToString([]byte(last.User.String()))) + } + + if err = req.Write(c); err != nil { + glog.V(LWARNING).Infof("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) + return + } + glog.V(LINFO).Infof("[http] %s <-> %s", conn.RemoteAddr(), req.Host) + Transport(conn, c) + glog.V(LINFO).Infof("[http] %s >-< %s", conn.RemoteAddr(), req.Host) + return + } + } + c, err = Connect(req.Host) if err != nil { glog.V(LWARNING).Infof("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) @@ -58,11 +91,7 @@ func handleHttpRequest(req *http.Request, conn net.Conn, arg Args) { b := []byte("HTTP/1.1 200 Connection established\r\n" + "Proxy-Agent: gost/" + Version + "\r\n\r\n") glog.V(LDEBUG).Infof("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b)) - - if _, err := conn.Write(b); err != nil { - glog.V(LWARNING).Infof("[http] %s <- %s : %s", conn.RemoteAddr(), req.Host, err) - return - } + conn.Write(b) } else { req.Header.Del("Proxy-Connection") req.Header.Set("Connection", "Keep-Alive") diff --git a/util.go b/util.go index d348948..27f3d9f 100644 --- a/util.go +++ b/util.go @@ -2,12 +2,18 @@ package main import ( "crypto/tls" + "errors" "fmt" "github.com/golang/glog" "io" "net" "net/url" "strings" + "time" +) + +const ( + keepAliveTime = 180 * time.Second ) type strSlice []string @@ -69,12 +75,18 @@ func parseArgs(ss []string) (args []Args) { switch arg.Protocol { case "http", "socks", "socks5", "ss": + case "https": + arg.Protocol = "http" + arg.Transport = "tls" default: arg.Protocol = "" } switch arg.Transport { case "ws", "wss", "tls": + case "https": + arg.Protocol = "http" + arg.Transport = "tls" case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding arg.Remote = strings.Trim(u.EscapedPath(), "/") case "rtcp", "rudp": // started from v2.1, rtcp and rudp are for remote port forwarding @@ -147,3 +159,17 @@ func Transport(conn, conn2 net.Conn) (err error) { return } + +func setKeepAlive(conn net.Conn, d time.Duration) error { + c, ok := conn.(*net.TCPConn) + if !ok { + return errors.New("Not a TCP connection") + } + if err := c.SetKeepAlive(true); err != nil { + return err + } + if err := c.SetKeepAlivePeriod(d); err != nil { + return err + } + return nil +}