diff --git a/obfs.go b/obfs.go index 60b89d5..7f1e1e0 100644 --- a/obfs.go +++ b/obfs.go @@ -5,6 +5,7 @@ package gost import ( "bufio" "bytes" + "errors" "fmt" "io" "net" @@ -70,6 +71,7 @@ type obfsHTTPConn struct { rbuf bytes.Buffer wbuf bytes.Buffer isServer bool + headerDrained bool handshaked bool handshakeMutex sync.Mutex } @@ -103,7 +105,7 @@ func (c *obfsHTTPConn) serverHandshake() (err error) { } if Debug { dump, _ := httputil.DumpRequest(r, false) - log.Logf("[ohttp] %s -> %s\n%s", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), string(dump)) + log.Logf("[ohttp] %s -> %s\n%s", c.RemoteAddr(), c.LocalAddr(), string(dump)) } if r.ContentLength > 0 { @@ -121,26 +123,31 @@ func (c *obfsHTTPConn) serverHandshake() (err error) { } b := bytes.Buffer{} - if r.Header.Get("Upgrade") == "websocket" { - b.WriteString("HTTP/1.1 101 Switching Protocols\r\n") - b.WriteString("Server: nginx/1.10.0\r\n") + + if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" { + b.WriteString("HTTP/1.1 503 Service Unavailable\r\n") + b.WriteString("Content-Length: 0\r\n") b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") - b.WriteString("Connection: Upgrade\r\n") - b.WriteString("Upgrade: websocket\r\n") - b.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %s\r\n", computeAcceptKey(r.Header.Get("Sec-WebSocket-Key")))) - b.WriteString("\r\n") - } else { - b.WriteString("HTTP/1.1 200 OK\r\n") - b.WriteString("Server: nginx/1.10.0\r\n") - b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") - b.WriteString("Content-Type: application/octet-stream\r\n") - b.WriteString("Connection: keep-alive\r\n") - b.WriteString("Cache-Control: private, no-cache, no-store, proxy-revalidate, no-transform\r\n") - b.WriteString("Pragma: no-cache\r\n") b.WriteString("\r\n") + + if Debug { + log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String()) + } + + b.WriteTo(c.Conn) + return errors.New("bad request") } + + b.WriteString("HTTP/1.1 101 Switching Protocols\r\n") + b.WriteString("Server: nginx/1.10.0\r\n") + b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") + b.WriteString("Connection: Upgrade\r\n") + b.WriteString("Upgrade: websocket\r\n") + b.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %s\r\n", computeAcceptKey(r.Header.Get("Sec-WebSocket-Key")))) + b.WriteString("\r\n") + if Debug { - log.Logf("[ohttp] %s <- %s\n%s", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), b.String()) + log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String()) } if c.rbuf.Len() > 0 { @@ -160,29 +167,20 @@ func (c *obfsHTTPConn) clientHandshake() (err error) { URL: &url.URL{Scheme: "http", Host: c.host}, Header: make(http.Header), } - r.Header.Set("User-Agent", "curl/7.49.1") + r.Header.Set("User-Agent", DefaultUserAgent) r.Header.Set("Connection", "Upgrade") r.Header.Set("Upgrade", "websocket") key, _ := generateChallengeKey() r.Header.Set("Sec-WebSocket-Key", key) - if err = r.Write(c.Conn); err != nil { + // cache the request header + if err = r.Write(&c.wbuf); err != nil { return } - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Logf("[ohttp] %s -> %s\n%s", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), string(dump)) - } - var resp *http.Response - resp, err = http.ReadResponse(bufio.NewReader(c.Conn), r) - if err != nil { - return - } - defer resp.Body.Close() if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Logf("[ohttp] %s <- %s\n%s", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), string(dump)) + dump, _ := httputil.DumpRequest(r, false) + log.Logf("[ohttp] %s -> %s\n%s", c.LocalAddr(), c.RemoteAddr(), string(dump)) } return nil @@ -192,17 +190,56 @@ func (c *obfsHTTPConn) Read(b []byte) (n int, err error) { if err = c.Handshake(); err != nil { return } + + if !c.isServer { + if err = c.drainHeader(); err != nil { + return + } + } + if c.rbuf.Len() > 0 { return c.rbuf.Read(b) } return c.Conn.Read(b) } +func (c *obfsHTTPConn) drainHeader() (err error) { + if c.headerDrained { + return + } + c.headerDrained = true + + br := bufio.NewReader(c.Conn) + // drain and discard the response header + var line string + var buf bytes.Buffer + for { + line, err = br.ReadString('\n') + if err != nil { + return + } + buf.WriteString(line) + if line == "\r\n" { + break + } + } + + if Debug { + log.Logf("[ohttp] %s <- %s\n%s", c.LocalAddr(), c.RemoteAddr(), buf.String()) + } + // cache the extra data for next read. + var b []byte + b, err = br.Peek(br.Buffered()) + if len(b) > 0 { + _, err = c.rbuf.Write(b) + } + return +} + func (c *obfsHTTPConn) Write(b []byte) (n int, err error) { if err = c.Handshake(); err != nil { return } - if c.wbuf.Len() > 0 { c.wbuf.Write(b) // append the data to the cached header _, err = c.wbuf.WriteTo(c.Conn) diff --git a/ss.go b/ss.go index 753b202..44b6e9b 100644 --- a/ss.go +++ b/ss.go @@ -61,12 +61,12 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...Connect return nil, err } - sc := ss.NewConn(conn, cipher) - // sc, err := ss.DialWithRawAddrConn(rawaddr, conn, cipher) - if _, err := sc.Write(rawaddr); err != nil { - return nil, err + sc := &shadowConn{ + Conn: ss.NewConn(conn, cipher), } - return &shadowConn{sc}, nil + sc.wbuf.Write(rawaddr) // cache the header + + return sc, nil } type shadowHandler struct { @@ -106,7 +106,7 @@ func (h *shadowHandler) Handle(conn net.Conn) { conn.RemoteAddr(), conn.LocalAddr(), err) return } - conn = &shadowConn{ss.NewConn(conn, cipher)} + conn = &shadowConn{Conn: ss.NewConn(conn, cipher)} conn.SetReadDeadline(time.Now().Add(ReadTimeout)) host, err := h.getRequest(conn) @@ -526,11 +526,19 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error { // Due to in/out byte length is inconsistent of the shadowsocks.Conn.Write, // we wrap around it to make io.Copy happy. type shadowConn struct { + wbuf bytes.Buffer net.Conn } func (c *shadowConn) Write(b []byte) (n int, err error) { n = len(b) // force byte length consistent + + if c.wbuf.Len() > 0 { + c.wbuf.Write(b) // append the data to the cached header + _, err = c.Conn.Write(c.wbuf.Bytes()) + c.wbuf.Reset() + return + } _, err = c.Conn.Write(b) return }