From 45022d241512806deccc0bb7aae2407d87140fee Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 22 Dec 2018 15:29:38 +0800 Subject: [PATCH] add test for SOCKS5 bind & UDP relay --- Dockerfile | 6 +- bypass.go | 13 ++- bypass_test.go | 28 +++--- gost.go | 25 ++++- server.go | 17 +--- socks.go | 250 ++++++++++++++++++++++++++++++++++++++++++++++++- socks_test.go | 125 ++++++++++++++++++++----- 7 files changed, 403 insertions(+), 61 deletions(-) diff --git a/Dockerfile b/Dockerfile index 42ef103..22b8ae8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,12 +4,14 @@ ADD . /data WORKDIR /data -RUN cd cmd/gost && go install +ENV GO111MODULE=on + +RUN cd cmd/gost && go build FROM alpine:latest WORKDIR /bin/ -COPY --from=builder /go/bin/gost . +COPY --from=builder /data/cmd/gost/gost . ENTRYPOINT ["/bin/gost"] \ No newline at end of file diff --git a/bypass.go b/bypass.go index 85bf12b..5422b5e 100644 --- a/bypass.go +++ b/bypass.go @@ -152,13 +152,24 @@ func NewBypassPatterns(reversed bool, patterns ...string) *Bypass { // Contains reports whether the bypass includes addr. func (bp *Bypass) Contains(addr string) bool { - if bp == nil || len(bp.matchers) == 0 || addr == "" { + if bp == nil || addr == "" { return false } + // try to strip the port + if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { + if p, _ := strconv.Atoi(port); p > 0 { // port is valid + addr = host + } + } + bp.mux.RLock() defer bp.mux.RUnlock() + if len(bp.matchers) == 0 { + return false + } + var matched bool for _, matcher := range bp.matchers { if matcher == nil { diff --git a/bypass_test.go b/bypass_test.go index 5cc0e81..5e295b0 100644 --- a/bypass_test.go +++ b/bypass_test.go @@ -58,21 +58,21 @@ var bypassContainTests = []struct { {[]string{"www.example.com"}, true, "example.com", true}, // host:port - {[]string{"192.168.1.1"}, false, "192.168.1.1:80", false}, - {[]string{"192.168.1.1"}, true, "192.168.1.1:80", true}, + {[]string{"192.168.1.1"}, false, "192.168.1.1:80", true}, + {[]string{"192.168.1.1"}, true, "192.168.1.1:80", false}, {[]string{"192.168.1.1:80"}, false, "192.168.1.1", false}, {[]string{"192.168.1.1:80"}, true, "192.168.1.1", true}, - {[]string{"192.168.1.1:80"}, false, "192.168.1.1:80", true}, - {[]string{"192.168.1.1:80"}, true, "192.168.1.1:80", false}, + {[]string{"192.168.1.1:80"}, false, "192.168.1.1:80", false}, + {[]string{"192.168.1.1:80"}, true, "192.168.1.1:80", true}, {[]string{"192.168.1.1:80"}, false, "192.168.1.1:8080", false}, {[]string{"192.168.1.1:80"}, true, "192.168.1.1:8080", true}, - {[]string{"example.com"}, false, "example.com:80", false}, - {[]string{"example.com"}, true, "example.com:80", true}, + {[]string{"example.com"}, false, "example.com:80", true}, + {[]string{"example.com"}, true, "example.com:80", false}, {[]string{"example.com:80"}, false, "example.com", false}, {[]string{"example.com:80"}, true, "example.com", true}, - {[]string{"example.com:80"}, false, "example.com:80", true}, - {[]string{"example.com:80"}, true, "example.com:80", false}, + {[]string{"example.com:80"}, false, "example.com:80", false}, + {[]string{"example.com:80"}, true, "example.com:80", true}, {[]string{"example.com:80"}, false, "example.com:8080", false}, {[]string{"example.com:80"}, true, "example.com:8080", true}, @@ -139,18 +139,20 @@ var bypassContainTests = []struct { {[]string{".example.com"}, false, "example.com", true}, {[]string{".example.com"}, false, "www.example.com.cn", false}, + {[]string{"example.com*"}, false, "example.com", true}, {[]string{"example.com:*"}, false, "example.com", false}, - {[]string{"example.com:*"}, false, "example.com:80", true}, - {[]string{"example.com:*"}, false, "example.com:8080", true}, + {[]string{"example.com:*"}, false, "example.com:80", false}, + {[]string{"example.com:*"}, false, "example.com:8080", false}, {[]string{"example.com:*"}, false, "example.com:http", true}, {[]string{"example.com:*"}, false, "http://example.com:80", false}, - {[]string{"*example.com:*"}, false, "example.com:80", true}, + {[]string{"*example.com*"}, false, "example.com:80", true}, + {[]string{"*example.com:*"}, false, "example.com:80", false}, {[]string{".example.com:*"}, false, "www.example.com", false}, {[]string{".example.com:*"}, false, "http://www.example.com", false}, - {[]string{".example.com:*"}, false, "example.com:80", true}, - {[]string{".example.com:*"}, false, "www.example.com:8080", true}, + {[]string{".example.com:*"}, false, "example.com:80", false}, + {[]string{".example.com:*"}, false, "www.example.com:8080", false}, {[]string{".example.com:*"}, false, "http://www.example.com:80", true}, } diff --git a/gost.go b/gost.go index 6889460..f3e9e8b 100644 --- a/gost.go +++ b/gost.go @@ -9,6 +9,7 @@ import ( "encoding/pem" "io" "math/big" + "sync" "time" "github.com/go-log/log" @@ -27,15 +28,33 @@ var ( largeBufferSize = 32 * 1024 // 32KB large buffer ) +var ( + sPool = sync.Pool{ + New: func() interface{} { + return make([]byte, smallBufferSize) + }, + } + mPool = sync.Pool{ + New: func() interface{} { + return make([]byte, mediumBufferSize) + }, + } + lPool = sync.Pool{ + New: func() interface{} { + return make([]byte, largeBufferSize) + }, + } +) + var ( // KeepAliveTime is the keep alive time period for TCP connection. KeepAliveTime = 180 * time.Second // DialTimeout is the timeout of dial. - DialTimeout = 30 * time.Second + DialTimeout = 5 * time.Second // ReadTimeout is the timeout for reading. - ReadTimeout = 30 * time.Second + ReadTimeout = 5 * time.Second // WriteTimeout is the timeout for writing. - WriteTimeout = 60 * time.Second + WriteTimeout = 5 * time.Second // PingTimeout is the timeout for pinging. PingTimeout = 30 * time.Second // PingRetries is the reties of ping. diff --git a/server.go b/server.go index 9a07032..5af2e08 100644 --- a/server.go +++ b/server.go @@ -3,7 +3,6 @@ package gost import ( "io" "net" - "sync" "time" "github.com/go-log/log" @@ -137,27 +136,19 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { return tc, nil } -var ( - trPool = sync.Pool{ - New: func() interface{} { - return make([]byte, 32*1024) - }, - } -) - func transport(rw1, rw2 io.ReadWriter) error { errc := make(chan error, 1) go func() { - buf := trPool.Get().([]byte) - defer trPool.Put(buf) + buf := lPool.Get().([]byte) + defer lPool.Put(buf) _, err := io.CopyBuffer(rw1, rw2, buf) errc <- err }() go func() { - buf := trPool.Get().([]byte) - defer trPool.Put(buf) + buf := lPool.Get().([]byte) + defer lPool.Put(buf) _, err := io.CopyBuffer(rw2, rw1, buf) errc <- err diff --git a/socks.go b/socks.go index 480623f..16ce22f 100644 --- a/socks.go +++ b/socks.go @@ -9,6 +9,7 @@ import ( "net" "net/url" "strconv" + "sync" "time" "github.com/ginuerzh/gosocks4" @@ -254,6 +255,131 @@ func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...Connect return conn, nil } +type socks5BindConnector struct { + User *url.Userinfo +} + +// SOCKS5BindConnector creates a connector for SOCKS5 bind. +// It accepts an optional auth info for SOCKS5 Username/Password Authentication. +func SOCKS5BindConnector(user *url.Userinfo) Connector { + return &socks5BindConnector{User: user} +} + +func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { + cc, err := socks5Handshake(conn, c.User) + if err != nil { + return nil, err + } + conn = cc + + laddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + log.Log(err) + return nil, err + } + + req := gosocks5.NewRequest(gosocks5.CmdBind, &gosocks5.Addr{ + Type: gosocks5.AddrIPv4, + Host: laddr.IP.String(), + Port: uint16(laddr.Port), + }) + + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] bind\n", req) + } + + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + conn.SetReadDeadline(time.Time{}) + + if Debug { + log.Log("[socks5] bind\n", reply) + } + + if reply.Rep != gosocks5.Succeeded { + log.Logf("[socks5] bind on %s failure", addr) + return nil, fmt.Errorf("SOCKS5 bind on %s failure", addr) + } + baddr, err := net.ResolveTCPAddr("tcp", reply.Addr.String()) + if err != nil { + return nil, err + } + log.Logf("[socks5] bind on %s OK", baddr) + + return &socks5BindConn{Conn: conn, laddr: baddr}, nil +} + +type socks5UDPConnector struct { + User *url.Userinfo +} + +// SOCKS5UDPConnector creates a connector for SOCKS5 UDP relay. +// It accepts an optional auth info for SOCKS5 Username/Password Authentication. +func SOCKS5UDPConnector(user *url.Userinfo) Connector { + return &socks5UDPConnector{User: user} +} + +func (c *socks5UDPConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { + cc, err := socks5Handshake(conn, c.User) + if err != nil { + return nil, err + } + conn = cc + + taddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + req := gosocks5.NewRequest(gosocks5.CmdUdp, &gosocks5.Addr{ + Type: gosocks5.AddrIPv4, + }) + + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] udp\n", req) + } + + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + conn.SetReadDeadline(time.Time{}) + + if Debug { + log.Log("[socks5] udp\n", reply) + } + + if reply.Rep != gosocks5.Succeeded { + log.Logf("[socks5] udp relay failure") + return nil, fmt.Errorf("SOCKS5 udp relay failure") + } + baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String()) + if err != nil { + return nil, err + } + log.Logf("[socks5] udp associate on %s OK", baddr) + + uc, err := net.DialUDP("udp", nil, baddr) + if err != nil { + return nil, err + } + log.Logf("udp laddr:%s, raddr:%s", uc.LocalAddr(), uc.RemoteAddr()) + + return &socks5UDPConn{UDPConn: uc, taddr: taddr}, nil +} + type socks4Connector struct{} // SOCKS4Connector creates a Connector for SOCKS4 proxy client. @@ -773,7 +899,8 @@ func (h *socks5Handler) transportUDP(relay, peer *net.UDPConn) (err error) { var clientAddr *net.UDPAddr go func() { - b := make([]byte, largeBufferSize) + b := mPool.Get().([]byte) + defer mPool.Put(b) for { n, laddr, err := relay.ReadFromUDP(b) @@ -809,7 +936,8 @@ func (h *socks5Handler) transportUDP(relay, peer *net.UDPConn) (err error) { }() go func() { - b := make([]byte, largeBufferSize) + b := mPool.Get().([]byte) + defer mPool.Put(b) for { n, raddr, err := peer.ReadFromUDP(b) @@ -851,7 +979,8 @@ func (h *socks5Handler) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error var clientAddr *net.UDPAddr go func() { - b := make([]byte, mediumBufferSize) + b := mPool.Get().([]byte) + defer mPool.Put(b) for { n, addr, err := uc.ReadFromUDP(b) @@ -990,7 +1119,8 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error errc := make(chan error, 2) go func() { - b := make([]byte, mediumBufferSize) + b := mPool.Get().([]byte) + defer mPool.Put(b) for { n, addr, err := uc.ReadFromUDP(b) @@ -1453,3 +1583,115 @@ func (c *udpTunnelConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { } return len(b), nil } + +// socks5BindConn is a connection for SOCKS5 bind request. +type socks5BindConn struct { + raddr net.Addr + laddr net.Addr + net.Conn + handshaked bool + handshakeMux sync.Mutex +} + +// Handshake waits for a peer to connect to the bind port. +func (c *socks5BindConn) Handshake() (err error) { + c.handshakeMux.Lock() + defer c.handshakeMux.Unlock() + + if c.handshaked { + return nil + } + + c.handshaked = true + + rep, err := gosocks5.ReadReply(c.Conn) + if err != nil { + return fmt.Errorf("bind: read reply %v", err) + } + if rep.Rep != gosocks5.Succeeded { + return fmt.Errorf("bind: peer connect failure") + } + c.raddr, err = net.ResolveTCPAddr("tcp", rep.Addr.String()) + return +} + +func (c *socks5BindConn) Read(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + return c.Conn.Read(b) +} + +func (c *socks5BindConn) Write(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + return c.Conn.Write(b) +} + +func (c *socks5BindConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *socks5BindConn) RemoteAddr() net.Addr { + return c.raddr +} + +type socks5UDPConn struct { + *net.UDPConn + taddr net.Addr +} + +func (c *socks5UDPConn) Read(b []byte) (int, error) { + data := mPool.Get().([]byte) + defer mPool.Put(data) + + n, err := c.UDPConn.Read(data) + if err != nil { + return 0, err + } + dg, err := gosocks5.ReadUDPDatagram(bytes.NewReader(data[:n])) + if err != nil { + return 0, err + } + + return copy(b, dg.Data), nil +} + +func (c *socks5UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + dg, err := gosocks5.ReadUDPDatagram(c.UDPConn) + if err != nil { + return + } + + n = copy(b, dg.Data) + addr, err = net.ResolveUDPAddr("udp", dg.Header.Addr.String()) + + return +} + +func (c *socks5UDPConn) Write(b []byte) (int, error) { + addr, err := gosocks5.NewAddr(c.taddr.String()) + if err != nil { + return 0, err + } + h := gosocks5.NewUDPHeader(0, 0, addr) + dg := gosocks5.NewUDPDatagram(h, b) + if err = dg.Write(c.UDPConn); err != nil { + return 0, err + } + return len(b), nil +} + +func (c *socks5UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { + adr, err := gosocks5.NewAddr(addr.String()) + if err != nil { + return 0, err + } + h := gosocks5.NewUDPHeader(0, 0, adr) + dg := gosocks5.NewUDPDatagram(h, b) + if err = dg.Write(c.UDPConn); err != nil { + return 0, err + } + return len(b), nil +} diff --git a/socks_test.go b/socks_test.go index 21be11c..b4137ec 100644 --- a/socks_test.go +++ b/socks_test.go @@ -2,7 +2,7 @@ package gost import ( "crypto/rand" - "crypto/tls" + "net" "net/http/httptest" "net/url" "testing" @@ -51,14 +51,6 @@ func socks5ProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinf } func TestSOCKS5Proxy(t *testing.T) { - cert, err := GenCertificate() - if err != nil { - panic(err) - } - DefaultTLSConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() @@ -84,14 +76,6 @@ func TestSOCKS5Proxy(t *testing.T) { } func BenchmarkSOCKS5Proxy(b *testing.B) { - cert, err := GenCertificate() - if err != nil { - panic(err) - } - DefaultTLSConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() @@ -124,14 +108,6 @@ func BenchmarkSOCKS5Proxy(b *testing.B) { } func BenchmarkSOCKS5ProxyParallel(b *testing.B) { - cert, err := GenCertificate() - if err != nil { - panic(err) - } - DefaultTLSConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() @@ -365,3 +341,102 @@ func BenchmarkSOCKS4AProxyParallel(b *testing.B) { } }) } + +func socks5BindRoundtrip(t *testing.T, targetURL string, data []byte) (err error) { + ln, err := TCPListener("") + if err != nil { + return + } + + client := &Client{ + Connector: SOCKS5BindConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + conn, err := proxyConn(client, server) + if err != nil { + return + } + defer conn.Close() + + conn, err = client.Connect(conn, "") + if err != nil { + return + } + + cc, err := net.Dial("tcp", conn.LocalAddr().String()) + if err != nil { + return + } + defer cc.Close() + + if err = conn.(*socks5BindConn).Handshake(); err != nil { + return + } + + u, err := url.Parse(targetURL) + if err != nil { + return + } + hc, err := net.Dial("tcp", u.Host) + if err != nil { + return + } + go transport(hc, conn) + + return httpRoundtrip(cc, targetURL, data) +} + +func TestSOCKS5Bind(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + if err := socks5BindRoundtrip(t, httpSrv.URL, sendData); err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks5UDPRoundtrip(t *testing.T, host string, data []byte) (err error) { + ln, err := TCPListener("") + if err != nil { + return + } + + client := &Client{ + Connector: SOCKS5UDPConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + go server.Run() + defer server.Close() + + return udpRoundtrip(client, server, host, data) +} + +func TestSOCKS5UDP(t *testing.T) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + if err := socks5UDPRoundtrip(t, udpSrv.Addr(), sendData); err != nil { + t.Errorf("got error: %v", err) + } +}