From 9ccddf3cf8ee48682b923f11f9bd5353cbeac13f Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Mon, 23 Mar 2015 18:17:04 +0800 Subject: [PATCH] add shadowsocks compatible --- gost.go | 178 +++++++++++++++++++++++++++++++++++++++++++++++--------- log.go | 44 +++++++++++--- main.go | 2 + 3 files changed, 188 insertions(+), 36 deletions(-) diff --git a/gost.go b/gost.go index 328955d..6831afa 100644 --- a/gost.go +++ b/gost.go @@ -13,7 +13,10 @@ import ( "strconv" "strings" //"sync/atomic" - //"fmt" + "encoding/binary" + "fmt" + "github.com/shadowsocks/shadowsocks-go/shadowsocks" + "net/url" "time" ) @@ -24,6 +27,7 @@ const ( type Gost struct { Laddr, Saddr, Proxy string + Shadows bool // shadowsocks compatible } func (g *Gost) Run() error { @@ -43,6 +47,7 @@ func (g *Gost) Run() error { log.Println("accept:", err) continue } + //log.Println("accept", conn.RemoteAddr().String()) go g.handle(conn) } @@ -62,7 +67,13 @@ func (g *Gost) handle(conn net.Conn) { } func (g *Gost) cli(conn net.Conn) { - lg := NewLog() + lg := NewLog(true) + defer func() { + lg.Logln() + lg.Flush() + }() + + lg.Logln("accept", conn.(*net.TCPConn).RemoteAddr().String()) sconn, err := g.connect(g.Saddr) if err != nil { @@ -80,6 +91,48 @@ func (g *Gost) cli(conn net.Conn) { } lg.Logln(">>>|", []byte{5, 1, 0}) + if g.Shadows { + lg.Logln("shadowsocks, aes-256-cfb") + cipher, _ := shadowsocks.NewCipher("aes-256-cfb", "123456") + conn = shadowsocks.NewConn(conn, cipher) + addr, port, extra, err := getRequest(conn) + if err != nil { + lg.Logln(err) + return + } + lg.Logln(addr, port) + + cmd := NewCmd(CmdConnect, AddrDomain, addr, port) + if err = cmd.Write(sconn); err != nil { + lg.Logln(err) + return + } + lg.Logln(">>>|", cmd) + + if cmd, err = ReadCmd(sconn); err != nil { + lg.Logln(err) + return + } + lg.Logln("<<<|", cmd) + + if cmd.Cmd != Succeeded { + conn.Write([]byte("HTTP/1.1 503 Service unavailable\r\n" + + "Proxy-Agent: gost/1.0\r\n\r\n")) + return + } + + if extra != nil { + if _, err := sconn.Write(extra); err != nil { + log.Println(err) + return + } + } + + g.transport(conn, sconn) + + return + } + b := make([]byte, 8192) n, err := io.ReadFull(sconn, b[:2]) @@ -94,7 +147,7 @@ func (g *Gost) cli(conn net.Conn) { lg.Logln(err) return } - + //log.Println(b[:n]) if b[0] == 5 { // socks5,NO AUTHENTICATION lg.Logln("|>>>", b[:n]) @@ -130,9 +183,6 @@ func (g *Gost) cli(conn net.Conn) { } lg.Logln("|<<<", cmd) - lg.Logln() - lg.Flush() - g.transport(conn, sconn) return } @@ -192,14 +242,17 @@ func (g *Gost) cli(conn net.Conn) { } } - lg.Logln() - lg.Flush() g.transport(conn, sconn) } func (g *Gost) srv(conn net.Conn) { b := make([]byte, 8192) - lg := NewLog() + lg := NewLog(true) + defer func() { + lg.Logln() + lg.Flush() + }() + lg.Logln("accept", conn.(*net.TCPConn).RemoteAddr().String()) n, err := conn.Read(b) if err != nil { @@ -306,36 +359,34 @@ func (g *Gost) connect(addr string) (net.Conn, error) { } pconn, err := net.DialTCP("tcp", nil, paddr) if err != nil { + log.Println(err) return nil, err } - b := make([]byte, 1500) - buffer := bytes.NewBuffer(b) - buffer.WriteString("CONNECT " + addr + " HTTP/1.1\r\n") - buffer.WriteString("Host: " + addr + "\r\n") - buffer.WriteString("Proxy-Connection: keep-alive\r\n\r\n") - if _, err = pconn.Write(buffer.Bytes()); err != nil { + header := http.Header{} + header.Set("Proxy-Connection", "keep-alive") + req := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Host: addr}, + Host: addr, + Header: header, + } + if err := req.Write(pconn); err != nil { + log.Println(err) pconn.Close() return nil, err } - r := "" - for !strings.HasSuffix(r, "\r\n\r\n") { - n := 0 - if n, err = pconn.Read(b); err != nil { - pconn.Close() - return nil, err - } - r += string(b[:n]) - } - - log.Println(r) - if !strings.Contains(r, "200") { - log.Println("connection failed:\n", r) - err = errors.New(r) + resp, err := http.ReadResponse(bufio.NewReader(pconn), req) + if err != nil { + log.Println(err) pconn.Close() return nil, err } + if resp.StatusCode != http.StatusOK { + pconn.Close() + return nil, errors.New(resp.Status) + } return pconn, nil } @@ -361,3 +412,72 @@ func (g *Gost) transport(conn, conn2 net.Conn) (err error) { return } + +func getRequest(conn net.Conn) (host string, port uint16, extra []byte, err error) { + const ( + idType = 0 // address type index + idIP0 = 1 // ip addres start index + idDmLen = 1 // domain address length index + idDm0 = 2 // domain address start index + + typeIPv4 = 1 // type is ipv4 address + typeDm = 3 // type is domain address + typeIPv6 = 4 // type is ipv6 address + + lenIPv4 = 1 + net.IPv4len + 2 // 1addrType + ipv4 + 2port + lenIPv6 = 1 + net.IPv6len + 2 // 1addrType + ipv6 + 2port + lenDmBase = 1 + 1 + 2 // 1addrType + 1addrLen + 2port, plus addrLen + ) + + // buf size should at least have the same size with the largest possible + // request size (when addrType is 3, domain name has at most 256 bytes) + // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + buf := make([]byte, 260) + var n int + // read till we get possible domain length field + //ss.SetReadTimeout(conn) + if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil { + log.Println(err) + return + } + log.Println(buf[:n]) + + reqLen := -1 + switch buf[idType] { + case typeIPv4: + reqLen = lenIPv4 + case typeIPv6: + reqLen = lenIPv6 + case typeDm: + reqLen = int(buf[idDmLen]) + lenDmBase + default: + err = fmt.Errorf("addr type %d not supported", buf[idType]) + return + } + + if n < reqLen { // rare case + //ss.SetReadTimeout(conn) + if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil { + log.Println(err) + return + } + } else if n > reqLen { + // it's possible to read more than just the request head + extra = buf[reqLen:n] + } + + // Return string for typeIP is not most efficient, but browsers (Chrome, + // Safari, Firefox) all seems using typeDm exclusively. So this is not a + // big problem. + switch buf[idType] { + case typeIPv4: + host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String() + case typeIPv6: + host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String() + case typeDm: + host = string(buf[idDm0 : idDm0+buf[idDmLen]]) + } + // parse port + port = binary.BigEndian.Uint16(buf[reqLen-2 : reqLen]) + return +} diff --git a/log.go b/log.go index ec727b5..a16944b 100644 --- a/log.go +++ b/log.go @@ -7,16 +7,24 @@ import ( "os" ) +var ( + Debug bool +) + type BufferedLog struct { buffer *bytes.Buffer w io.WriteCloser } -func NewLog() *BufferedLog { - return &BufferedLog{ - buffer: &bytes.Buffer{}, - w: os.Stdout, +func NewLog(buffered bool) *BufferedLog { + log := &BufferedLog{ + w: os.Stdout, } + if buffered { + log.buffer = &bytes.Buffer{} + } + + return log } func NewFileLog(file *os.File) *BufferedLog { @@ -27,15 +35,33 @@ func NewFileLog(file *os.File) *BufferedLog { } func (log *BufferedLog) Log(a ...interface{}) (int, error) { - return fmt.Fprint(log.buffer, a...) + if !Debug { + return 0, nil + } + if log.buffer != nil { + return fmt.Fprint(log.buffer, a...) + } + return fmt.Fprint(log.w, a...) } func (log *BufferedLog) Logln(a ...interface{}) (int, error) { - return fmt.Fprintln(log.buffer, a...) + if !Debug { + return 0, nil + } + if log.buffer != nil { + return fmt.Fprintln(log.buffer, a...) + } + return fmt.Fprintln(log.w, a...) } func (log *BufferedLog) Logf(format string, a ...interface{}) (int, error) { - return fmt.Fprintf(log.buffer, format, a...) + if !Debug { + return 0, nil + } + if log.buffer != nil { + return fmt.Fprintf(log.buffer, format, a...) + } + return fmt.Fprintf(log.w, format, a...) } func (log *BufferedLog) Flush() error { @@ -45,6 +71,10 @@ func (log *BufferedLog) Flush() error { } }() + if !Debug || log.buffer == nil { + return nil + } + _, err := log.buffer.WriteTo(log.w) return err } diff --git a/main.go b/main.go index a68e6c5..79be70a 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,8 @@ func init() { flag.StringVar(&gost.Proxy, "P", "", "proxy for forward") flag.StringVar(&gost.Saddr, "S", "", "the server that connecting to") flag.StringVar(&gost.Laddr, "L", ":8080", "listen address") + flag.BoolVar(&gost.Shadows, "ss", false, "shadowsocks compatible") + flag.BoolVar(&Debug, "d", false, "debug option") flag.Parse() log.SetFlags(log.LstdFlags | log.Lshortfile)