diff --git a/obfs.go b/obfs.go index 6b789ed..6bdac5d 100644 --- a/obfs.go +++ b/obfs.go @@ -5,8 +5,8 @@ package gost import ( "bufio" "bytes" - "errors" "fmt" + "io/ioutil" "net" "net/http" "net/http/httputil" @@ -66,6 +66,8 @@ func (l *obfsHTTPListener) Accept() (net.Conn, error) { type obfsHTTPConn struct { net.Conn r *http.Request + rbuf []byte + wbuf []byte isServer bool handshaked bool handshakeMutex sync.Mutex @@ -80,7 +82,8 @@ func (c *obfsHTTPConn) Handshake() (err error) { } if c.isServer { - c.r, err = http.ReadRequest(bufio.NewReader(c.Conn)) + br := bufio.NewReader(c.Conn) + c.r, err = http.ReadRequest(br) if err != nil { return } @@ -88,22 +91,58 @@ func (c *obfsHTTPConn) Handshake() (err error) { dump, _ := httputil.DumpRequest(c.r, false) log.Logf("[ohttp] %s -> %s\n%s", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), string(dump)) } - b := bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\n\r\n") + + if br.Buffered() > 0 { + c.rbuf, err = br.Peek(br.Buffered()) + } else { + c.rbuf, err = ioutil.ReadAll(c.r.Body) + } + + if err != nil { + log.Logf("[ohttp] %s -> %s : %v", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), err) + return + } + + b := bytes.Buffer{} + if c.r.Header.Get("Connection") == "Upgrade" && + c.r.Header.Get("Upgrade") == "websocket" { + b.WriteString("HTTP/1.1 101 Switching Protocols\r\n") + b.WriteString("Server: nginx/1.10.0\r\n") + b.WriteString("Connection: Upgrade\r\n") + b.WriteString("Upgrade: websocket\r\n") + b.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %s\r\n", computeAcceptKey(c.r.Header.Get("Sec-WebSocket-Key")))) + b.WriteString("\r\n") + } else { + b.WriteString("HTTP/1.1 200 OK\r\n") + b.WriteString("Server: nginx/1.10.0\r\n") + b.WriteString("Content-Type: application/octet-stream\r\n") + b.WriteString("Connection: keep-alive\r\n") + b.WriteString("Cache-Control: private, no-cache, no-store, proxy-revalidate, no-transform\r\n") + b.WriteString("Pragma: no-cache\r\n") + b.WriteString("\r\n") + } if Debug { log.Logf("[ohttp] %s <- %s\n%s", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), b.String()) } if _, err = b.WriteTo(c.Conn); err != nil { return } - } else { r := c.r if r == nil { - r, err = http.NewRequest(http.MethodPost, "http://www.baidu.com/", nil) - if err != nil { - return + r = &http.Request{ + Method: http.MethodPost, + ProtoMajor: 1, + ProtoMinor: 1, + URL: &url.URL{Scheme: "http", Host: "www.baidu.com"}, + Header: make(http.Header), } + r.Header.Set("Connection", "keep-alive") r.Header.Set("User-Agent", DefaultUserAgent) + if len(c.wbuf) > 0 { + r.Body = ioutil.NopCloser(bytes.NewReader(c.wbuf)) + r.ContentLength = int64(len(c.wbuf)) + } } if err = r.Write(c.Conn); err != nil { return @@ -117,13 +156,12 @@ func (c *obfsHTTPConn) Handshake() (err error) { if err != nil { return } + defer resp.Body.Close() + if Debug { dump, _ := httputil.DumpResponse(resp, false) log.Logf("[ohttp] %s <- %s\n%s", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), string(dump)) } - if resp.StatusCode != http.StatusOK { - return errors.New(resp.Status) - } } c.handshaked = true return nil @@ -133,13 +171,24 @@ func (c *obfsHTTPConn) Read(b []byte) (n int, err error) { if err = c.Handshake(); err != nil { return } + if len(c.rbuf) > 0 { + n = copy(b, c.rbuf) + c.rbuf = c.rbuf[n:] + return + } return c.Conn.Read(b) } func (c *obfsHTTPConn) Write(b []byte) (n int, err error) { + handshaked := c.handshaked + c.wbuf = b if err = c.Handshake(); err != nil { return } + if !handshaked { + n = len(c.wbuf) + return + } return c.Conn.Write(b) } diff --git a/vendor/github.com/ginuerzh/gosocks5/socks5.go b/vendor/github.com/ginuerzh/gosocks5/socks5.go index 64709b2..2ee533a 100644 --- a/vendor/github.com/ginuerzh/gosocks5/socks5.go +++ b/vendor/github.com/ginuerzh/gosocks5/socks5.go @@ -516,8 +516,9 @@ func (r *Reply) Write(w io.Writer) (err error) { b[1] = r.Rep b[2] = 0 //rsv b[3] = AddrIPv4 // default - length := 10 + b[4], b[5], b[6], b[7], b[8], b[9] = 0, 0, 0, 0, 0, 0 // reset address field + if r.Addr != nil { n, _ := r.Addr.Encode(b[3:]) length = 3 + n diff --git a/vendor/vendor.json b/vendor/vendor.json index 3691d4f..df0fa4f 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -87,10 +87,10 @@ "revisionTime": "2017-02-09T14:09:51Z" }, { - "checksumSHA1": "4JEexBJToQeQm7fAo2PHVdCU3zM=", + "checksumSHA1": "Onmjh8hT6pjAixkuGJN4KKAaTT4=", "path": "github.com/ginuerzh/gosocks5", - "revision": "cb895c2f7a2cdceaf74ac6497f709b71a999168a", - "revisionTime": "2017-08-01T04:47:37Z" + "revision": "9e981f6c6b480d7e35e856a334c652a1ebe17fd1", + "revisionTime": "2017-09-11T08:28:29Z" }, { "checksumSHA1": "9e9tjPDTESeCEdUMElph250lurs=", diff --git a/ws.go b/ws.go index 69be7dc..69d41ae 100644 --- a/ws.go +++ b/ws.go @@ -1,7 +1,11 @@ package gost import ( + "crypto/rand" + "crypto/sha1" "crypto/tls" + "encoding/base64" + "io" "net" "net/http" "net/http/httputil" @@ -303,3 +307,20 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen return l, nil } + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +}