diff --git a/http.go b/http.go index d307689..4a03ebf 100644 --- a/http.go +++ b/http.go @@ -376,7 +376,17 @@ func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Ch if route.IsEmpty() { return nil } - lastNode := route.LastNode() + + host := req.Host + var userpass string + + if user := route.LastNode().User; user != nil { + s := user.String() + if _, set := user.Password(); !set { + s += ":" + } + userpass = base64.StdEncoding.EncodeToString([]byte(s)) + } cc, err := route.Conn() if err != nil { @@ -384,28 +394,47 @@ func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Ch } defer cc.Close() - if lastNode.User != nil { - s := lastNode.User.String() - if _, set := lastNode.User.Password(); !set { - s += ":" + errc := make(chan error, 1) + go func() { + errc <- copyBuffer(conn, cc) + }() + + go func() { + for { + if userpass != "" { + req.Header.Set("Proxy-Authorization", "Basic "+userpass) + } + + cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if !req.URL.IsAbs() { + req.URL.Scheme = "http" // make sure that the URL is absolute + } + err := req.WriteProxy(cc) + if err != nil { + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + errc <- err + return + } + cc.SetWriteDeadline(time.Time{}) + + req, err = http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + errc <- err + return + } + + if Debug { + dump, _ := httputil.DumpRequest(req, false) + log.Logf("[http] %s -> %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } } - req.Header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) - } + }() - cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if !req.URL.IsAbs() { - req.URL.Scheme = "http" // make sure that the URL is absolute - } - if err = req.WriteProxy(cc); err != nil { - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return nil - } - cc.SetWriteDeadline(time.Time{}) + log.Logf("[http] %s <-> %s", conn.RemoteAddr(), host) + <-errc + log.Logf("[http] %s >-< %s", conn.RemoteAddr(), host) - log.Logf("[http] %s <-> %s", conn.RemoteAddr(), req.Host) - transport(conn, cc) - log.Logf("[http] %s >-< %s", conn.RemoteAddr(), req.Host) return nil } diff --git a/server.go b/server.go index 86edf3e..17f8cc1 100644 --- a/server.go +++ b/server.go @@ -144,19 +144,11 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { func transport(rw1, rw2 io.ReadWriter) error { errc := make(chan error, 1) go func() { - buf := lPool.Get().([]byte) - defer lPool.Put(buf) - - _, err := io.CopyBuffer(rw1, rw2, buf) - errc <- err + errc <- copyBuffer(rw1, rw2) }() go func() { - buf := lPool.Get().([]byte) - defer lPool.Put(buf) - - _, err := io.CopyBuffer(rw2, rw1, buf) - errc <- err + errc <- copyBuffer(rw2, rw1) }() err := <-errc @@ -165,3 +157,11 @@ func transport(rw1, rw2 io.ReadWriter) error { } return err } + +func copyBuffer(dst io.Writer, src io.Reader) error { + buf := lPool.Get().([]byte) + defer lPool.Put(buf) + + _, err := io.CopyBuffer(dst, src, buf) + return err +}