diff --git a/client.go b/client.go index 8730916..60fc327 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,8 @@ import ( "net" "net/url" "time" + + "github.com/ginuerzh/gosocks5" ) // Client is a proxy client. @@ -236,8 +238,10 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption { // ConnectOptions describes the options for Connector.Connect. type ConnectOptions struct { - Addr string - Timeout time.Duration + Addr string + Timeout time.Duration + User *url.Userinfo + Selector gosocks5.Selector } // ConnectOption allows a common way to set ConnectOptions. @@ -256,3 +260,17 @@ func TimeoutConnectOption(timeout time.Duration) ConnectOption { opts.Timeout = timeout } } + +// UserConnectOption specifies the user info for authentication. +func UserConnectOption(user *url.Userinfo) ConnectOption { + return func(opts *ConnectOptions) { + opts.User = user + } +} + +// SelectorConnectOption specifies the SOCKS5 client selector. +func SelectorConnectOption(s gosocks5.Selector) ConnectOption { + return func(opts *ConnectOptions) { + opts.Selector = s + } +} diff --git a/common_test.go b/common_test.go index 0e8c348..a68d4a3 100644 --- a/common_test.go +++ b/common_test.go @@ -13,6 +13,8 @@ import ( "net/url" "sync" "time" + + "github.com/go-log/log" ) func init() { @@ -95,7 +97,7 @@ func httpRoundtrip(conn net.Conn, targetURL string, data []byte) (err error) { return } -func udpRoundtrip(client *Client, server *Server, host string, data []byte) (err error) { +func udpRoundtrip(logger log.Logger, client *Client, server *Server, host string, data []byte) (err error) { conn, err := proxyConn(client, server) if err != nil { return @@ -107,15 +109,17 @@ func udpRoundtrip(client *Client, server *Server, host string, data []byte) (err return } - conn.SetDeadline(time.Now().Add(3 * time.Second)) + conn.SetDeadline(time.Now().Add(1 * time.Second)) defer conn.SetDeadline(time.Time{}) if _, err = conn.Write(data); err != nil { + logger.Logf("write to %s via %s: %s", host, server.Addr(), err) return } recv := make([]byte, len(data)) if _, err = conn.Read(recv); err != nil { + logger.Logf("read from %s via %s: %s", host, server.Addr(), err) return } @@ -143,7 +147,7 @@ func proxyRoundtrip(client *Client, server *Server, targetURL string, data []byt return } - conn.SetDeadline(time.Now().Add(500 * time.Millisecond)) + conn.SetDeadline(time.Now().Add(1000 * time.Millisecond)) defer conn.SetDeadline(time.Time{}) return httpRoundtrip(conn, targetURL, data) @@ -167,12 +171,13 @@ type udpHandlerFunc func(w io.Writer, r *udpRequest) // udpTestServer is a UDP server for test. type udpTestServer struct { - ln net.PacketConn - handler udpHandlerFunc - wg sync.WaitGroup - mu sync.Mutex // guards closed and conns - closed bool - exitChan chan struct{} + ln net.PacketConn + handler udpHandlerFunc + wg sync.WaitGroup + mu sync.Mutex // guards closed and conns + closed bool + startChan chan struct{} + exitChan chan struct{} } func newUDPTestServer(handler udpHandlerFunc) *udpTestServer { @@ -181,23 +186,30 @@ func newUDPTestServer(handler udpHandlerFunc) *udpTestServer { if err != nil { panic(fmt.Sprintf("udptest: failed to listen on a port: %v", err)) } - ln.SetReadBuffer(1024 * 1024) - ln.SetWriteBuffer(1024 * 1024) return &udpTestServer{ - ln: ln, - handler: handler, - exitChan: make(chan struct{}), + ln: ln, + handler: handler, + startChan: make(chan struct{}), + exitChan: make(chan struct{}), } } func (s *udpTestServer) Start() { go s.serve() + <-s.startChan } func (s *udpTestServer) serve() { + select { + case <-s.startChan: + return + default: + close(s.startChan) + } + for { - data := make([]byte, 1024) + data := make([]byte, 32*1024) n, raddr, err := s.ln.ReadFrom(data) if err != nil { break diff --git a/forward.go b/forward.go index 29b3e7b..1034dbb 100644 --- a/forward.go +++ b/forward.go @@ -787,7 +787,7 @@ func (l *tcpRemoteForwardListener) getSession() (s *muxSession, err error) { conn.SetDeadline(time.Now().Add(HandshakeTimeout)) defer conn.SetDeadline(time.Time{}) - conn, err = socks5Handshake(conn, l.chain.LastNode().User) + conn, err = socks5Handshake(conn, nil, l.chain.LastNode().User) if err != nil { return nil, err } @@ -822,7 +822,7 @@ func (l *tcpRemoteForwardListener) getSession() (s *muxSession, err error) { } func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) { - conn, err := socks5Handshake(conn, l.chain.LastNode().User) + conn, err := socks5Handshake(conn, nil, l.chain.LastNode().User) if err != nil { return nil, err } diff --git a/forward_test.go b/forward_test.go index 6838f61..fc58a3d 100644 --- a/forward_test.go +++ b/forward_test.go @@ -119,7 +119,7 @@ func BenchmarkTCPDirectForwardParallel(b *testing.B) { }) } -func udpDirectForwardRoundtrip(host string, data []byte) error { +func udpDirectForwardRoundtrip(t *testing.T, host string, data []byte) error { ln, err := UDPDirectForwardListener("localhost:0", 0) if err != nil { return err @@ -138,7 +138,7 @@ func udpDirectForwardRoundtrip(host string, data []byte) error { go server.Run() defer server.Close() - return udpRoundtrip(client, server, host, data) + return udpRoundtrip(t, client, server, host, data) } func TestUDPDirectForward(t *testing.T) { @@ -148,7 +148,7 @@ func TestUDPDirectForward(t *testing.T) { sendData := make([]byte, 128) rand.Read(sendData) - err := udpDirectForwardRoundtrip(udpSrv.Addr(), sendData) + err := udpDirectForwardRoundtrip(t, udpSrv.Addr(), sendData) if err != nil { t.Error(err) } @@ -181,7 +181,7 @@ func BenchmarkUDPDirectForward(b *testing.B) { defer server.Close() for i := 0; i < b.N; i++ { - if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil { + if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil { b.Error(err) } } @@ -215,7 +215,7 @@ func BenchmarkUDPDirectForwardParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil { + if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil { b.Error(err) } } @@ -281,7 +281,7 @@ func udpRemoteForwardRoundtrip(t *testing.T, host string, data []byte) error { go server.Run() defer server.Close() - return udpRoundtrip(client, server, host, data) + return udpRoundtrip(t, client, server, host, data) } func TestUDPRemoteForward(t *testing.T) { diff --git a/http.go b/http.go index 43d0f69..d815ebe 100644 --- a/http.go +++ b/http.go @@ -52,9 +52,14 @@ func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOp req.Header.Set("User-Agent", DefaultUserAgent) req.Header.Set("Proxy-Connection", "keep-alive") - if c.User != nil { - u := c.User.Username() - p, _ := c.User.Password() + user := opts.User + if user == nil { + user = c.User + } + + if user != nil { + u := user.Username() + p, _ := user.Password() req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) } diff --git a/http2.go b/http2.go index 4073a51..6902a62 100644 --- a/http2.go +++ b/http2.go @@ -58,9 +58,15 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO } // TODO: use the standard CONNECT method. req.Header.Set("Gost-Target", addr) - if c.User != nil { - u := c.User.Username() - p, _ := c.User.Password() + + user := opts.User + if user == nil { + user = c.User + } + + if user != nil { + u := user.Username() + p, _ := user.Password() req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) } diff --git a/http2_test.go b/http2_test.go index cdba55d..dcd6230 100644 --- a/http2_test.go +++ b/http2_test.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" @@ -36,7 +37,7 @@ func http2ProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo return proxyRoundtrip(client, server, targetURL, data) } -func TestHTTP2Proxy(t *testing.T) { +func TestHTTP2ProxyAuth(t *testing.T) { httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() @@ -1108,3 +1109,42 @@ func TestHTTP2ProxyWithFileProbeResist(t *testing.T) { t.Error("data not equal") } } + +func TestHTTP2ProxyWithBypass(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + u, err := url.Parse(httpSrv.URL) + if err != nil { + t.Error(err) + } + ln, err := HTTP2Listener("", nil) + if err != nil { + t.Error(err) + } + + client := &Client{ + Connector: HTTP2Connector(nil), + Transporter: HTTP2Transporter(nil), + } + + host := u.Host + if h, _, _ := net.SplitHostPort(u.Host); h != "" { + host = h + } + server := &Server{ + Listener: ln, + Handler: HTTP2Handler( + BypassHandlerOption(NewBypassPatterns(false, host)), + ), + } + go server.Run() + defer server.Close() + + if err = proxyRoundtrip(client, server, httpSrv.URL, sendData); err == nil { + t.Error("should failed") + } +} diff --git a/http_test.go b/http_test.go index 0f99d7b..901c44e 100644 --- a/http_test.go +++ b/http_test.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" @@ -55,7 +56,7 @@ func httpProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, return proxyRoundtrip(client, server, targetURL, data) } -func TestHTTPProxy(t *testing.T) { +func TestHTTPProxyAuth(t *testing.T) { httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() @@ -82,6 +83,40 @@ func TestHTTPProxy(t *testing.T) { } } +func TestHTTPProxyWithInvalidRequest(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + t.Error(err) + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler(), + } + go server.Run() + defer server.Close() + + r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), bytes.NewReader(sendData)) + if err != nil { + t.Error(err) + } + resp, err := http.DefaultClient.Do(r) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Error("got status:", resp.Status) + } +} + func BenchmarkHTTPProxy(b *testing.B) { httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() @@ -302,3 +337,42 @@ func TestHTTPProxyWithFileProbeResist(t *testing.T) { t.Error("data not equal, got:", string(recv)) } } + +func TestHTTPProxyWithBypass(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + u, err := url.Parse(httpSrv.URL) + if err != nil { + t.Error(err) + } + ln, err := TCPListener("") + if err != nil { + t.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(nil), + Transporter: TCPTransporter(), + } + + host := u.Host + if h, _, _ := net.SplitHostPort(u.Host); h != "" { + host = h + } + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + BypassHandlerOption(NewBypassPatterns(false, host)), + ), + } + go server.Run() + defer server.Close() + + if err = proxyRoundtrip(client, server, httpSrv.URL, sendData); err == nil { + t.Error("should failed") + } +} diff --git a/obfs_test.go b/obfs_test.go index 8999cda..cd33702 100644 --- a/obfs_test.go +++ b/obfs_test.go @@ -369,3 +369,56 @@ func TestSNIOverObfsHTTP(t *testing.T) { }) } } + +func httpOverObfs4Roundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := Obfs4Listener("") + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: Obfs4Transporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func _TestHTTPOverObfs4(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := httpOverObfs4Roundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + }) + } +} diff --git a/resolver.go b/resolver.go index 624928a..b0f8659 100644 --- a/resolver.go +++ b/resolver.go @@ -238,7 +238,7 @@ type resolverCacheItem struct { } func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP { - if ttl < 0 { + if name == "" || ttl < 0 { return nil } @@ -248,7 +248,8 @@ func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP { ttl = item.ttl } - if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl { + if time.Since(time.Unix(item.ts, 0)) > ttl { + r.mCache.Delete(name) return nil } return item.IPs @@ -258,7 +259,7 @@ func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP { } func (r *resolver) storeCache(name string, ips []net.IP, ttl time.Duration) { - if name == "" || len(ips) == 0 { + if name == "" || len(ips) == 0 || ttl < 0 { return } r.mCache.Store(name, &resolverCacheItem{ diff --git a/resolver_test.go b/resolver_test.go index 3847697..c7289bf 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "net" "testing" "time" ) @@ -13,6 +14,8 @@ var dnsTests = []struct { host string pass bool }{ + {NameServer{Addr: "1.1.1.1"}, "192.168.1.1", true}, + {NameServer{Addr: "1.1.1.1"}, "github", true}, {NameServer{Addr: "1.1.1.1"}, "github.com", true}, {NameServer{Addr: "1.1.1.1:53"}, "github.com", true}, {NameServer{Addr: "1.1.1.1:53", Protocol: "tcp"}, "github.com", true}, @@ -47,6 +50,8 @@ func TestDNSResolver(t *testing.T) { } t.Log(ns) r := NewResolver(0, ns) + resolv := r.(*resolver) + resolv.domain = "com" err := dnsResolverRoundtrip(t, r, tc.host) if err != nil { if tc.pass { @@ -61,6 +66,56 @@ func TestDNSResolver(t *testing.T) { } } +var resolverCacheTests = []struct { + name string + ips []net.IP + ttl time.Duration + result []net.IP +}{ + {"", nil, 0, nil}, + {"", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil}, + {"", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second, nil}, + {"example.com", nil, 10 * time.Second, nil}, + {"example.com", []net.IP{}, 10 * time.Second, nil}, + {"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil}, + {"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, -1, nil}, + {"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second, + []net.IP{net.IPv4(192, 168, 1, 1)}}, + {"example.com", []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}, 10 * time.Second, + []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}}, +} + +func TestResolverCache(t *testing.T) { + isEqual := func(a, b []net.IP) bool { + if a == nil && b == nil { + return true + } + + if a == nil || b == nil || len(a) != len(b) { + return false + } + + for i := range a { + if !a[i].Equal(b[i]) { + return false + } + } + return true + } + for i, tc := range resolverCacheTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + r := newResolver(tc.ttl) + r.storeCache(tc.name, tc.ips, tc.ttl) + ips := r.loadCache(tc.name, tc.ttl) + + if !isEqual(tc.result, ips) { + t.Error("unexpected cache value:", tc.name, ips, tc.ttl) + } + }) + } +} + var resolverReloadTests = []struct { r io.Reader diff --git a/selector_test.go b/selector_test.go new file mode 100644 index 0000000..41fe95f --- /dev/null +++ b/selector_test.go @@ -0,0 +1,150 @@ +package gost + +import ( + "testing" + "time" +) + +func TestRoundStrategy(t *testing.T) { + nodes := []Node{ + Node{ID: 1}, + Node{ID: 2}, + Node{ID: 3}, + } + s := NewStrategy("round") + t.Log(s.String()) + + if node := s.Apply(nil); node.ID > 0 { + t.Error("unexpected node", node.String()) + } + for i := 0; i <= len(nodes); i++ { + node := s.Apply(nodes) + if node.ID != nodes[i%len(nodes)].ID { + t.Error("unexpected node", node.String()) + } + } +} + +func TestRandomStrategy(t *testing.T) { + nodes := []Node{ + Node{ID: 1}, + Node{ID: 2}, + Node{ID: 3}, + } + s := NewStrategy("random") + t.Log(s.String()) + + if node := s.Apply(nil); node.ID > 0 { + t.Error("unexpected node", node.String()) + } + for i := 0; i <= len(nodes); i++ { + node := s.Apply(nodes) + if node.ID == 0 { + t.Error("unexpected node", node.String()) + } + } +} + +func TestFIFOStrategy(t *testing.T) { + nodes := []Node{ + Node{ID: 1}, + Node{ID: 2}, + Node{ID: 3}, + } + s := NewStrategy("fifo") + t.Log(s.String()) + + if node := s.Apply(nil); node.ID > 0 { + t.Error("unexpected node", node.String()) + } + for i := 0; i <= len(nodes); i++ { + node := s.Apply(nodes) + if node.ID != 1 { + t.Error("unexpected node", node.String()) + } + } +} + +func TestFailFilter(t *testing.T) { + nodes := []Node{ + Node{ID: 1, marker: &failMarker{}}, + Node{ID: 2, marker: &failMarker{}}, + Node{ID: 3, marker: &failMarker{}}, + } + filter := &FailFilter{} + t.Log(filter.String()) + + isEqual := func(a, b []Node) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil || len(a) != len(b) { + return false + } + + for i := range a { + if a[i].ID != b[i].ID { + return false + } + } + return true + } + if v := filter.Filter(nil); v != nil { + t.Error("unexpected node", v) + } + + if v := filter.Filter(nodes); !isEqual(v, nodes) { + t.Error("unexpected node", v) + } + + filter.MaxFails = 1 + if v := filter.Filter(nodes); !isEqual(v, nodes) { + t.Error("unexpected node", v) + } + + nodes[0].MarkDead() + if v := filter.Filter(nodes); !isEqual(v, nodes) { + t.Error("unexpected node", v) + } + + filter.FailTimeout = 5 * time.Second + if v := filter.Filter(nodes); isEqual(v, nodes) { + t.Error("unexpected node", v) + } + + nodes[1].MarkDead() + nodes[2].MarkDead() + if v := filter.Filter(nodes); len(v) > 0 { + t.Error("unexpected node", v) + } + + for i := range nodes { + nodes[i].ResetDead() + } + if v := filter.Filter(nodes); !isEqual(v, nodes) { + t.Error("unexpected node", v) + } +} + +func TestSelector(t *testing.T) { + nodes := []Node{ + Node{ID: 1, marker: &failMarker{}}, + Node{ID: 2, marker: &failMarker{}}, + Node{ID: 3, marker: &failMarker{}}, + } + selector := &defaultSelector{} + if _, err := selector.Select(nil); err != ErrNoneAvailable { + t.Error("got unexpected error:", err) + } + + if node, _ := selector.Select(nodes); node.ID != 1 { + t.Error("unexpected node:", node) + } + + if node, _ := selector.Select(nodes, + WithStrategy(NewStrategy("")), + WithFilter(&FailFilter{MaxFails: 1, FailTimeout: 3 * time.Second}), + ); node.ID != 1 { + t.Error("unexpected node:", node) + } +} diff --git a/socks.go b/socks.go index 66daaa1..c2cee6d 100644 --- a/socks.go +++ b/socks.go @@ -218,18 +218,12 @@ func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...Connect conn.SetDeadline(time.Now().Add(timeout)) defer conn.SetDeadline(time.Time{}) - selector := &clientSelector{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - User: c.User, + user := opts.User + if user == nil { + user = c.User } - selector.AddMethod( - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - MethodTLS, - ) - - cc := gosocks5.ClientConn(conn, selector) - if err := cc.Handleshake(); err != nil { + cc, err := socks5Handshake(conn, opts.Selector, user) + if err != nil { return nil, err } conn = cc @@ -292,7 +286,11 @@ func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...Con conn.SetDeadline(time.Now().Add(timeout)) defer conn.SetDeadline(time.Time{}) - cc, err := socks5Handshake(conn, c.User) + user := opts.User + if user == nil { + user = c.User + } + cc, err := socks5Handshake(conn, opts.Selector, user) if err != nil { return nil, err } @@ -442,7 +440,7 @@ func (tr *socks5MuxBindTransporter) initSession(conn net.Conn, addr string, opts opts = &HandshakeOptions{} } - cc, err := socks5Handshake(conn, opts.User) + cc, err := socks5Handshake(conn, nil, opts.User) if err != nil { return nil, err } @@ -522,7 +520,11 @@ func (c *socks5UDPConnector) Connect(conn net.Conn, addr string, options ...Conn conn.SetDeadline(time.Now().Add(timeout)) defer conn.SetDeadline(time.Time{}) - cc, err := socks5Handshake(conn, c.User) + user := opts.User + if user == nil { + user = c.User + } + cc, err := socks5Handshake(conn, opts.Selector, user) if err != nil { return nil, err } @@ -597,7 +599,11 @@ func (c *socks5UDPTunConnector) Connect(conn net.Conn, addr string, options ...C conn.SetDeadline(time.Now().Add(timeout)) defer conn.SetDeadline(time.Time{}) - cc, err := socks5Handshake(conn, c.User) + user := opts.User + if user == nil { + user = c.User + } + cc, err := socks5Handshake(conn, opts.Selector, user) if err != nil { return nil, err } @@ -642,68 +648,6 @@ func (c *socks5UDPTunConnector) Connect(conn net.Conn, addr string, options ...C return &udpTunnelConn{Conn: conn, raddr: taddr.String()}, nil } -func (c *socks5UDPTunConnector) tunnelClientUDP(pc net.PacketConn, cc net.Conn) (err error) { - errc := make(chan error, 2) - - go func() { - b := mPool.Get().([]byte) - defer mPool.Put(b) - - for { - n, addr, err := pc.ReadFrom(b) - if err != nil { - log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) - errc <- err - return - } - - // pipe from peer to tunnel - dgram := gosocks5.NewUDPDatagram( - gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n]) - if err := dgram.Write(cc); err != nil { - log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err) - errc <- err - return - } - if Debug { - log.Logf("[udp-tun] %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data)) - } - } - }() - - go func() { - for { - dgram, err := gosocks5.ReadUDPDatagram(cc) - if err != nil { - log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) - errc <- err - return - } - - // pipe from tunnel to peer - addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) - if err != nil { - continue // drop silently - } - - if _, err := pc.WriteTo(dgram.Data, addr); err != nil { - log.Logf("[udp-tun] %s -> %s : %s", cc.RemoteAddr(), addr, err) - errc <- err - return - } - if Debug { - log.Logf("[udp-tun] %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data)) - } - } - }() - - select { - case err = <-errc: - } - - return -} - type socks4Connector struct{} // SOCKS4Connector creates a Connector for SOCKS4 proxy client. @@ -1186,7 +1130,7 @@ func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) { } defer cc.Close() - cc, err = socks5Handshake(cc, h.options.Chain.LastNode().User) + cc, err = socks5Handshake(cc, nil, h.options.Chain.LastNode().User) if err != nil { log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, err) return @@ -1450,7 +1394,7 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { } defer cc.Close() - cc, err = socks5Handshake(cc, h.options.Chain.LastNode().User) + cc, err = socks5Handshake(cc, nil, h.options.Chain.LastNode().User) if err != nil { log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) return @@ -1844,7 +1788,7 @@ func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { conn.SetDeadline(time.Now().Add(HandshakeTimeout)) defer conn.SetDeadline(time.Time{}) - cc, err := socks5Handshake(conn, chain.LastNode().User) + cc, err := socks5Handshake(conn, nil, chain.LastNode().User) if err != nil { conn.Close() return nil, err @@ -1877,16 +1821,20 @@ func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { return conn, nil } -func socks5Handshake(conn net.Conn, user *url.Userinfo) (net.Conn, error) { - selector := &clientSelector{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - User: user, +func socks5Handshake(conn net.Conn, selector gosocks5.Selector, user *url.Userinfo) (net.Conn, error) { + if selector == nil { + cs := &clientSelector{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + User: user, + } + cs.AddMethod( + gosocks5.MethodNoAuth, + gosocks5.MethodUserPass, + MethodTLS, + ) + selector = cs } - selector.AddMethod( - gosocks5.MethodNoAuth, - gosocks5.MethodUserPass, - MethodTLS, - ) + cc := gosocks5.ClientConn(conn, selector) if err := cc.Handleshake(); err != nil { return nil, err @@ -1996,24 +1944,20 @@ type socks5UDPConn struct { taddr net.Addr } -func (c *socks5UDPConn) Read(b []byte) (int, error) { - data := mPool.Get().([]byte) - defer mPool.Put(data) - - n, err := c.UDPConn.Read(data) - if err != nil { - return 0, err - } - dg, err := gosocks5.ReadUDPDatagram(bytes.NewReader(data[:n])) - if err != nil { - return 0, err - } - - return copy(b, dg.Data), nil +func (c *socks5UDPConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return } func (c *socks5UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - dg, err := gosocks5.ReadUDPDatagram(c.UDPConn) + data := mPool.Get().([]byte) + defer mPool.Put(data) + + n, err = c.UDPConn.Read(data) + if err != nil { + return + } + dg, err := gosocks5.ReadUDPDatagram(bytes.NewReader(data[:n])) if err != nil { return } @@ -2025,16 +1969,7 @@ func (c *socks5UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { } func (c *socks5UDPConn) Write(b []byte) (int, error) { - addr, err := gosocks5.NewAddr(c.taddr.String()) - if err != nil { - return 0, err - } - h := gosocks5.NewUDPHeader(0, 0, addr) - dg := gosocks5.NewUDPDatagram(h, b) - if err = dg.Write(c.UDPConn); err != nil { - return 0, err - } - return len(b), nil + return c.WriteTo(b, c.taddr) } func (c *socks5UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { diff --git a/socks_test.go b/socks_test.go index 1d424d1..b2a94fa 100644 --- a/socks_test.go +++ b/socks_test.go @@ -550,7 +550,7 @@ func socks5UDPRoundtrip(t *testing.T, host string, data []byte) (err error) { go server.Run() defer server.Close() - return udpRoundtrip(client, server, host, data) + return udpRoundtrip(t, client, server, host, data) } func TestSOCKS5UDP(t *testing.T) { @@ -593,7 +593,7 @@ func BenchmarkSOCKS5UDP(b *testing.B) { defer server.Close() for i := 0; i < b.N; i++ { - if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil { + if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil { b.Error(err) } } @@ -679,7 +679,7 @@ func socks5UDPTunRoundtrip(t *testing.T, host string, data []byte) (err error) { go server.Run() defer server.Close() - return udpRoundtrip(client, server, host, data) + return udpRoundtrip(t, client, server, host, data) } func TestSOCKS5UDPTun(t *testing.T) { @@ -721,7 +721,7 @@ func BenchmarkSOCKS5UDPTun(b *testing.B) { defer server.Close() for i := 0; i < b.N; i++ { - if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil { + if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil { b.Error(err) } } diff --git a/ss.go b/ss.go index a5eb563..753b202 100644 --- a/ss.go +++ b/ss.go @@ -47,9 +47,13 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...Connect } var method, password string - if c.Cipher != nil { - method = c.Cipher.Username() - password, _ = c.Cipher.Password() + cp := opts.User + if cp == nil { + cp = c.Cipher + } + if cp != nil { + method = cp.Username() + password, _ = cp.Password() } cipher, err := ss.NewCipher(method, password) @@ -446,6 +450,10 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error { b := mPool.Get().([]byte) defer mPool.Put(b) + b[0] = 0 + b[1] = 0 + b[2] = 0 + n, err := sc.Read(b[3:]) // add rsv and frag fields to make it the standard SOCKS5 UDP datagram if err != nil { // log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err) @@ -535,10 +543,14 @@ type shadowUDPConn struct { func (c *shadowUDPConn) Write(b []byte) (n int, err error) { n = len(b) // force byte length consistent - if len(c.header) > 0 { - b = append(c.header, b...) + buf := bytes.Buffer{} + if _, err = buf.Write(c.header); err != nil { + return } - _, err = c.PacketConn.WriteTo(b, c.raddr) + if _, err = buf.Write(b); err != nil { + return + } + _, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr) return } @@ -546,6 +558,10 @@ func (c *shadowUDPConn) Read(b []byte) (n int, err error) { buf := mPool.Get().([]byte) defer mPool.Put(buf) + buf[0] = 0 + buf[1] = 0 + buf[2] = 0 + n, _, err = c.PacketConn.ReadFrom(buf[3:]) if err != nil { return diff --git a/ss_test.go b/ss_test.go index 24f14ec..8390ca4 100644 --- a/ss_test.go +++ b/ss_test.go @@ -1,6 +1,7 @@ package gost import ( + "bytes" "crypto/rand" "fmt" "net/http/httptest" @@ -299,14 +300,15 @@ func BenchmarkSSProxyParallel(b *testing.B) { }) } -func shadowUDPRoundtrip(t *testing.T, host string, data []byte) error { - ln, err := ShadowUDPListener("localhost:0", url.UserPassword("chacha20-ietf", "123456"), 0) +func shadowUDPRoundtrip(t *testing.T, host string, data []byte, + clientInfo *url.Userinfo, serverInfo *url.Userinfo) error { + ln, err := ShadowUDPListener("localhost:0", serverInfo, 0) if err != nil { return err } client := &Client{ - Connector: ShadowUDPConnector(url.UserPassword("chacha20-ietf", "123456")), + Connector: ShadowUDPConnector(clientInfo), Transporter: UDPTransporter(), } @@ -318,19 +320,35 @@ func shadowUDPRoundtrip(t *testing.T, host string, data []byte) error { go server.Run() defer server.Close() - return udpRoundtrip(client, server, host, data) + return udpRoundtrip(t, client, server, host, data) } -func _TestShadowUDP(t *testing.T) { - udpSrv := newUDPTestServer(udpTestHandler) - udpSrv.Start() - defer udpSrv.Close() - +func TestShadowUDP(t *testing.T) { sendData := make([]byte, 128) rand.Read(sendData) - err := shadowUDPRoundtrip(t, udpSrv.Addr(), sendData) - if err != nil { - t.Error(err) + + for i, tc := range ssTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + err := shadowUDPRoundtrip(t, udpSrv.Addr(), sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) } } @@ -343,7 +361,7 @@ func BenchmarkShadowUDP(b *testing.B) { sendData := make([]byte, 128) rand.Read(sendData) - ln, err := ShadowUDPListener("localhost:0", url.UserPassword("chacha20-ietf", "123456"), 1000*time.Millisecond) + ln, err := ShadowUDPListener("localhost:0", url.UserPassword("chacha20-ietf", "123456"), 0) if err != nil { b.Error(err) } @@ -361,9 +379,32 @@ func BenchmarkShadowUDP(b *testing.B) { go server.Run() defer server.Close() + conn, err := proxyConn(client, server) + if err != nil { + b.Error(err) + } + defer conn.Close() + + conn, err = client.Connect(conn, udpSrv.Addr()) + if err != nil { + return + } + for i := 0; i < b.N; i++ { - if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil { + conn.SetDeadline(time.Now().Add(1 * time.Second)) + defer conn.SetDeadline(time.Time{}) + + if _, err = conn.Write(sendData); err != nil { b.Error(err) } + + recv := make([]byte, len(sendData)) + if _, err = conn.Read(recv); err != nil { + b.Error(err) + } + + if !bytes.Equal(sendData, recv) { + b.Error("data not equal") + } } }