diff --git a/gost.go b/gost.go index bec6f23..6889460 100644 --- a/gost.go +++ b/gost.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "io" "math/big" "time" @@ -100,3 +101,16 @@ func generateKeyPair() (rawCert, rawKey []byte, err error) { return } + +type readWriter struct { + r io.Reader + w io.Writer +} + +func (rw *readWriter) Read(p []byte) (n int, err error) { + return rw.r.Read(p) +} + +func (rw *readWriter) Write(p []byte) (n int, err error) { + return rw.w.Write(p) +} diff --git a/http.go b/http.go index 5f349ce..8aed753 100644 --- a/http.go +++ b/http.go @@ -175,10 +175,12 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { case "host": cc, err := net.Dial("tcp", ss[1]) if err == nil { + defer cc.Close() + req.Write(cc) - log.Logf("[http] %s <-> %s", conn.LocalAddr(), ss[1]) + log.Logf("[http] %s <-> %s : forward to %s", conn.LocalAddr(), req.Host, ss[1]) transport(conn, cc) - log.Logf("[http] %s >-< %s", conn.LocalAddr(), ss[1]) + log.Logf("[http] %s >-< %s : forward to %s", conn.LocalAddr(), req.Host, ss[1]) return } case "file": diff --git a/http2.go b/http2.go index d0e8878..8d82896 100644 --- a/http2.go +++ b/http2.go @@ -2,14 +2,18 @@ package gost import ( "bufio" + "bytes" "crypto/tls" "encoding/base64" "errors" "io" + "io/ioutil" "net" "net/http" "net/http/httputil" "net/url" + "os" + "strconv" "strings" "sync" "time" @@ -308,14 +312,79 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { return } + resp := &http.Response{ + ProtoMajor: 2, + ProtoMinor: 0, + Header: http.Header{}, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + } + u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) if Debug && (u != "" || p != "") { log.Logf("[http] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p) } if !authenticate(u, p, h.options.Users...) { - log.Logf("[http2] %s <- %s : proxy authentication required", r.RemoteAddr, target) - w.Header().Set("Proxy-Authenticate", "Basic realm=\"gost\"") - w.WriteHeader(http.StatusProxyAuthRequired) + // probing resistance is enabled + if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 { + switch ss[0] { + case "error": + resp.StatusCode, _ = strconv.Atoi(ss[1]) + case "web": + url := ss[1] + if !strings.HasPrefix(url, "http") { + url = "http://" + url + } + if r, err := http.Get(url); err == nil { + resp = r + } + case "host": + cc, err := net.Dial("tcp", ss[1]) + if err == nil { + defer cc.Close() + log.Logf("[http2] %s <-> %s : forward to %s", r.RemoteAddr, target, ss[1]) + if err := h.forwardRequest(w, r, cc); err != nil { + log.Logf("[http2] %s - %s : %s", r.RemoteAddr, target, err) + } + log.Logf("[http2] %s >-< %s : forward to %s", r.RemoteAddr, target, ss[1]) + return + } + case "file": + f, _ := os.Open(ss[1]) + if f != nil { + resp.StatusCode = http.StatusOK + if finfo, _ := f.Stat(); finfo != nil { + resp.ContentLength = finfo.Size() + } + resp.Body = f + } + } + } + + if resp.StatusCode == 0 { + log.Logf("[http2] %s <- %s : proxy authentication required", r.RemoteAddr, target) + resp.StatusCode = http.StatusProxyAuthRequired + resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"") + } else { + w.Header().Del("Proxy-Agent") + resp.Header = http.Header{} + resp.Header.Set("Server", "nginx/1.14.1") + resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) + if resp.ContentLength > 0 { + resp.Header.Set("Content-Type", "text/html") + } + if resp.StatusCode == http.StatusOK { + resp.Header.Set("Connection", "keep-alive") + } + } + + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http2] %s <- %s\n%s", r.RemoteAddr, target, string(dump)) + } + + h.writeResponse(w, resp) + resp.Body.Close() + return } @@ -359,47 +428,41 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { } log.Logf("[http2] %s <-> %s", r.RemoteAddr, target) - errc := make(chan error, 2) - go func() { - _, err := io.Copy(cc, r.Body) - errc <- err - }() - go func() { - _, err := io.Copy(flushWriter{w}, cc) - errc <- err - }() - - select { - case <-errc: - // glog.V(LWARNING).Infoln("exit", err) - } + transport(&readWriter{r: r.Body, w: flushWriter{w}}, cc) log.Logf("[http2] %s >-< %s", r.RemoteAddr, target) return } log.Logf("[http2] %s <-> %s", r.RemoteAddr, target) - if err = r.Write(cc); err != nil { - log.Logf("[http2] %s -> %s : %s", r.RemoteAddr, target, err) + if err := h.forwardRequest(w, r, cc); err != nil { + log.Logf("[http2] %s - %s : %s", r.RemoteAddr, target, err) + } + log.Logf("[http2] %s >-< %s", r.RemoteAddr, target) +} + +func (h *http2Handler) forwardRequest(w http.ResponseWriter, r *http.Request, rw io.ReadWriter) (err error) { + if err = r.Write(rw); err != nil { return } - resp, err := http.ReadResponse(bufio.NewReader(cc), r) + resp, err := http.ReadResponse(bufio.NewReader(rw), r) if err != nil { - log.Logf("[http2] %s -> %s : %s", r.RemoteAddr, target, err) return } defer resp.Body.Close() + return h.writeResponse(w, resp) +} + +func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response) error { for k, v := range resp.Header { for _, vv := range v { w.Header().Add(k, vv) } } w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(flushWriter{w}, resp.Body); err != nil { - log.Logf("[http2] %s <- %s : %s", r.RemoteAddr, target, err) - } - log.Logf("[http2] %s >-< %s", r.RemoteAddr, target) + _, err := io.Copy(flushWriter{w}, resp.Body) + return err } type http2Listener struct {