diff --git a/gost.go b/gost.go index 90c47f3..328955d 100644 --- a/gost.go +++ b/gost.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" //"sync/atomic" + //"fmt" "time" ) @@ -56,56 +57,193 @@ func (g *Gost) handle(conn net.Conn) { g.cli(conn) return } - // as server g.srv(conn) } func (g *Gost) cli(conn net.Conn) { + lg := NewLog() + sconn, err := g.connect(g.Saddr) if err != nil { + lg.Logln(err) return } defer sconn.Close() + laddr := sconn.(*net.TCPConn).LocalAddr().String() + lg.Logln(laddr) + + if _, err := sconn.Write([]byte{5, 1, 0}); err != nil { + lg.Logln(err) + return + } + lg.Logln(">>>|", []byte{5, 1, 0}) + + b := make([]byte, 8192) + + n, err := io.ReadFull(sconn, b[:2]) + if err != nil { + lg.Logln(err) + return + } + lg.Logln("<<<|", b[:n]) + + n, err = conn.Read(b) + if err != nil { + lg.Logln(err) + return + } + + if b[0] == 5 { // socks5,NO AUTHENTICATION + lg.Logln("|>>>", b[:n]) + + if _, err := conn.Write([]byte{5, 0}); err != nil { + lg.Logln(err) + return + } + lg.Logln("|<<<", []byte{5, 0}) + + cmd, err := ReadCmd(conn) + if err != nil { + lg.Logln(err) + return + } + lg.Logln("|>>>", cmd) + + if err = cmd.Write(sconn); err != nil { + lg.Logln(err) + return + } + lg.Logln(">>>|", cmd) + + cmd, err = ReadCmd(sconn) + if err != nil { + lg.Logln(err) + return + } + lg.Logln("<<<|", cmd) + + if err = cmd.Write(conn); err != nil { + lg.Logln(err) + return + } + lg.Logln("|<<<", cmd) + + lg.Logln() + lg.Flush() + + g.transport(conn, sconn) + return + } + + //log.Println(string(b[:n])) + req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b[:n]))) + if err != nil { + lg.Logln(err) + return + } + lg.Logln(req.Method, req.RequestURI) + + var addr string + var port uint16 + + host := strings.Split(req.Host, ":") + if len(host) == 1 { + addr = host[0] + port = 80 + } + if len(host) == 2 { + addr = host[0] + n, _ := strconv.ParseUint(host[1], 10, 16) + port = uint16(n) + } + + cmd := NewCmd(CmdConnect, AddrDomain, addr, port) + if err = cmd.Write(sconn); err != nil { + lg.Logln(err) + return + } + lg.Logln(">>>|", cmd) + + if cmd, err = ReadCmd(sconn); err != nil { + lg.Logln(err) + return + } + lg.Logln("<<<|", cmd) + + if cmd.Cmd != Succeeded { + conn.Write([]byte("HTTP/1.1 503 Service unavailable\r\n" + + "Proxy-Agent: gost/1.0\r\n\r\n")) + return + } + + if req.Method == "CONNECT" { + if _, err = conn.Write( + []byte("HTTP/1.1 200 Connection established\r\n" + + "Proxy-Agent: gost/2.0\r\n\r\n")); err != nil { + lg.Logln(err) + return + } + } else { + if err = req.Write(sconn); err != nil { + lg.Logln(err) + return + } + } + + lg.Logln() + lg.Flush() g.transport(conn, sconn) - return } func (g *Gost) srv(conn net.Conn) { b := make([]byte, 8192) + lg := NewLog() n, err := conn.Read(b) if err != nil { - log.Println(err) + lg.Logln(err) return } - if bytes.Equal(b[:n], []byte{5, 1, 0}) { // socks5,NO AUTHENTICATION - log.Println("read cmd:", b[:n]) + if b[0] == 5 { // socks5,NO AUTHENTICATION + lg.Logln("|>>>", b[:n]) if _, err := conn.Write([]byte{5, 0}); err != nil { - log.Println(err) + lg.Logln(err) return } + lg.Logln("|<<<", []byte{5, 0}) cmd, err := ReadCmd(conn) if err != nil { + lg.Logln(err) return } + lg.Logln("|>>>", cmd) + host := cmd.Addr + ":" + strconv.Itoa(int(cmd.Port)) - log.Println("connect", host) + lg.Logln("connect", host) + tconn, err := g.connect(host) if err != nil { - log.Println(err) - NewCmd(ConnRefused, 0, "", 0).Write(conn) + lg.Logln(err) + cmd = NewCmd(ConnRefused, 0, "", 0) + cmd.Write(conn) + lg.Logln("|<<<", cmd) return } defer tconn.Close() - if err = NewCmd(Succeeded, AddrIPv4, "0.0.0.0", 0).Write(conn); err != nil { - log.Println(err) + cmd = NewCmd(Succeeded, AddrIPv4, "0.0.0.0", 0) + if err = cmd.Write(conn); err != nil { + lg.Logln(err) return } + lg.Logln("|<<<", cmd) + + lg.Logln() + lg.Flush() g.transport(conn, tconn) return @@ -114,18 +252,18 @@ func (g *Gost) srv(conn net.Conn) { //log.Println(string(b[:n])) req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b[:n]))) if err != nil { - log.Println(err) + lg.Logln(err) return } - log.Println(req.Method, req.RequestURI) + lg.Logln(req.Method, req.RequestURI) host := req.Host if !strings.Contains(host, ":") { host = host + ":80" } tconn, err := g.connect(host) if err != nil { - log.Println(err) + lg.Logln(err) conn.Write([]byte("HTTP/1.1 503 Service unavailable\r\n" + "Proxy-Agent: gost/1.0\r\n\r\n")) return @@ -136,14 +274,19 @@ func (g *Gost) srv(conn net.Conn) { if _, err = conn.Write( []byte("HTTP/1.1 200 Connection established\r\n" + "Proxy-Agent: gost/1.0\r\n\r\n")); err != nil { + lg.Logln(err) return } } else { if err := req.Write(tconn); err != nil { + lg.Logln(err) return } } + lg.Logln() + lg.Flush() + g.transport(conn, tconn) } @@ -151,6 +294,7 @@ func (g *Gost) connect(addr string) (net.Conn, error) { if len(g.Proxy) == 0 { taddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { + log.Println(err) return nil, err } return net.DialTCP("tcp", nil, taddr) diff --git a/log.go b/log.go new file mode 100644 index 0000000..ec727b5 --- /dev/null +++ b/log.go @@ -0,0 +1,50 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "os" +) + +type BufferedLog struct { + buffer *bytes.Buffer + w io.WriteCloser +} + +func NewLog() *BufferedLog { + return &BufferedLog{ + buffer: &bytes.Buffer{}, + w: os.Stdout, + } +} + +func NewFileLog(file *os.File) *BufferedLog { + return &BufferedLog{ + buffer: &bytes.Buffer{}, + w: file, + } +} + +func (log *BufferedLog) Log(a ...interface{}) (int, error) { + return fmt.Fprint(log.buffer, a...) +} + +func (log *BufferedLog) Logln(a ...interface{}) (int, error) { + return fmt.Fprintln(log.buffer, a...) +} + +func (log *BufferedLog) Logf(format string, a ...interface{}) (int, error) { + return fmt.Fprintf(log.buffer, format, a...) +} + +func (log *BufferedLog) Flush() error { + defer func() { + if log.w != os.Stdout { + log.w.Close() + } + }() + + _, err := log.buffer.WriteTo(log.w) + return err +} diff --git a/socks5.go b/socks5.go index c8795b3..892eb5b 100644 --- a/socks5.go +++ b/socks5.go @@ -3,8 +3,9 @@ package main import ( "encoding/binary" "errors" + "fmt" "io" - "log" + //"log" "net" ) @@ -92,7 +93,6 @@ func ReadCmd(r io.Reader) (*Cmd, error) { if err != nil { return nil, err } - log.Println("read cmd:", b[:n]) if n < 10 { return nil, ErrBadFormat } @@ -159,12 +159,12 @@ func (cmd *Cmd) Write(w io.Writer) (err error) { binary.BigEndian.PutUint16(b[pos:], cmd.Port) pos += 2 - log.Println("write cmd:", b[:pos]) _, err = w.Write(b[:pos]) return } -func (cmd *Cmd) GetError() error { - return cmdErrMap[cmd.Cmd] +func (cmd *Cmd) String() string { + return fmt.Sprintf("5 %d 0 %d %s %d", + cmd.Cmd, cmd.AddrType, cmd.Addr, cmd.Port) }