add more test cases

This commit is contained in:
ginuerzh 2019-01-05 13:21:33 +08:00
parent 0f824edc1b
commit 47220e0687
16 changed files with 580 additions and 174 deletions

View File

@ -5,6 +5,8 @@ import (
"net"
"net/url"
"time"
"github.com/ginuerzh/gosocks5"
)
// Client is a proxy client.
@ -238,6 +240,8 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption {
type ConnectOptions struct {
Addr string
Timeout time.Duration
User *url.Userinfo
Selector gosocks5.Selector
}
// ConnectOption allows a common way to set ConnectOptions.
@ -256,3 +260,17 @@ func TimeoutConnectOption(timeout time.Duration) ConnectOption {
opts.Timeout = timeout
}
}
// UserConnectOption specifies the user info for authentication.
func UserConnectOption(user *url.Userinfo) ConnectOption {
return func(opts *ConnectOptions) {
opts.User = user
}
}
// SelectorConnectOption specifies the SOCKS5 client selector.
func SelectorConnectOption(s gosocks5.Selector) ConnectOption {
return func(opts *ConnectOptions) {
opts.Selector = s
}
}

View File

@ -13,6 +13,8 @@ import (
"net/url"
"sync"
"time"
"github.com/go-log/log"
)
func init() {
@ -95,7 +97,7 @@ func httpRoundtrip(conn net.Conn, targetURL string, data []byte) (err error) {
return
}
func udpRoundtrip(client *Client, server *Server, host string, data []byte) (err error) {
func udpRoundtrip(logger log.Logger, client *Client, server *Server, host string, data []byte) (err error) {
conn, err := proxyConn(client, server)
if err != nil {
return
@ -107,15 +109,17 @@ func udpRoundtrip(client *Client, server *Server, host string, data []byte) (err
return
}
conn.SetDeadline(time.Now().Add(3 * time.Second))
conn.SetDeadline(time.Now().Add(1 * time.Second))
defer conn.SetDeadline(time.Time{})
if _, err = conn.Write(data); err != nil {
logger.Logf("write to %s via %s: %s", host, server.Addr(), err)
return
}
recv := make([]byte, len(data))
if _, err = conn.Read(recv); err != nil {
logger.Logf("read from %s via %s: %s", host, server.Addr(), err)
return
}
@ -143,7 +147,7 @@ func proxyRoundtrip(client *Client, server *Server, targetURL string, data []byt
return
}
conn.SetDeadline(time.Now().Add(500 * time.Millisecond))
conn.SetDeadline(time.Now().Add(1000 * time.Millisecond))
defer conn.SetDeadline(time.Time{})
return httpRoundtrip(conn, targetURL, data)
@ -172,6 +176,7 @@ type udpTestServer struct {
wg sync.WaitGroup
mu sync.Mutex // guards closed and conns
closed bool
startChan chan struct{}
exitChan chan struct{}
}
@ -181,23 +186,30 @@ func newUDPTestServer(handler udpHandlerFunc) *udpTestServer {
if err != nil {
panic(fmt.Sprintf("udptest: failed to listen on a port: %v", err))
}
ln.SetReadBuffer(1024 * 1024)
ln.SetWriteBuffer(1024 * 1024)
return &udpTestServer{
ln: ln,
handler: handler,
startChan: make(chan struct{}),
exitChan: make(chan struct{}),
}
}
func (s *udpTestServer) Start() {
go s.serve()
<-s.startChan
}
func (s *udpTestServer) serve() {
select {
case <-s.startChan:
return
default:
close(s.startChan)
}
for {
data := make([]byte, 1024)
data := make([]byte, 32*1024)
n, raddr, err := s.ln.ReadFrom(data)
if err != nil {
break

View File

@ -787,7 +787,7 @@ func (l *tcpRemoteForwardListener) getSession() (s *muxSession, err error) {
conn.SetDeadline(time.Now().Add(HandshakeTimeout))
defer conn.SetDeadline(time.Time{})
conn, err = socks5Handshake(conn, l.chain.LastNode().User)
conn, err = socks5Handshake(conn, nil, l.chain.LastNode().User)
if err != nil {
return nil, err
}
@ -822,7 +822,7 @@ func (l *tcpRemoteForwardListener) getSession() (s *muxSession, err error) {
}
func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) {
conn, err := socks5Handshake(conn, l.chain.LastNode().User)
conn, err := socks5Handshake(conn, nil, l.chain.LastNode().User)
if err != nil {
return nil, err
}

View File

@ -119,7 +119,7 @@ func BenchmarkTCPDirectForwardParallel(b *testing.B) {
})
}
func udpDirectForwardRoundtrip(host string, data []byte) error {
func udpDirectForwardRoundtrip(t *testing.T, host string, data []byte) error {
ln, err := UDPDirectForwardListener("localhost:0", 0)
if err != nil {
return err
@ -138,7 +138,7 @@ func udpDirectForwardRoundtrip(host string, data []byte) error {
go server.Run()
defer server.Close()
return udpRoundtrip(client, server, host, data)
return udpRoundtrip(t, client, server, host, data)
}
func TestUDPDirectForward(t *testing.T) {
@ -148,7 +148,7 @@ func TestUDPDirectForward(t *testing.T) {
sendData := make([]byte, 128)
rand.Read(sendData)
err := udpDirectForwardRoundtrip(udpSrv.Addr(), sendData)
err := udpDirectForwardRoundtrip(t, udpSrv.Addr(), sendData)
if err != nil {
t.Error(err)
}
@ -181,7 +181,7 @@ func BenchmarkUDPDirectForward(b *testing.B) {
defer server.Close()
for i := 0; i < b.N; i++ {
if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil {
if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil {
b.Error(err)
}
}
@ -215,7 +215,7 @@ func BenchmarkUDPDirectForwardParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil {
if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil {
b.Error(err)
}
}
@ -281,7 +281,7 @@ func udpRemoteForwardRoundtrip(t *testing.T, host string, data []byte) error {
go server.Run()
defer server.Close()
return udpRoundtrip(client, server, host, data)
return udpRoundtrip(t, client, server, host, data)
}
func TestUDPRemoteForward(t *testing.T) {

11
http.go
View File

@ -52,9 +52,14 @@ func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOp
req.Header.Set("User-Agent", DefaultUserAgent)
req.Header.Set("Proxy-Connection", "keep-alive")
if c.User != nil {
u := c.User.Username()
p, _ := c.User.Password()
user := opts.User
if user == nil {
user = c.User
}
if user != nil {
u := user.Username()
p, _ := user.Password()
req.Header.Set("Proxy-Authorization",
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
}

View File

@ -58,9 +58,15 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO
}
// TODO: use the standard CONNECT method.
req.Header.Set("Gost-Target", addr)
if c.User != nil {
u := c.User.Username()
p, _ := c.User.Password()
user := opts.User
if user == nil {
user = c.User
}
if user != nil {
u := user.Username()
p, _ := user.Password()
req.Header.Set("Proxy-Authorization",
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
}

View File

@ -6,6 +6,7 @@ import (
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
@ -36,7 +37,7 @@ func http2ProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo
return proxyRoundtrip(client, server, targetURL, data)
}
func TestHTTP2Proxy(t *testing.T) {
func TestHTTP2ProxyAuth(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
@ -1108,3 +1109,42 @@ func TestHTTP2ProxyWithFileProbeResist(t *testing.T) {
t.Error("data not equal")
}
}
func TestHTTP2ProxyWithBypass(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
u, err := url.Parse(httpSrv.URL)
if err != nil {
t.Error(err)
}
ln, err := HTTP2Listener("", nil)
if err != nil {
t.Error(err)
}
client := &Client{
Connector: HTTP2Connector(nil),
Transporter: HTTP2Transporter(nil),
}
host := u.Host
if h, _, _ := net.SplitHostPort(u.Host); h != "" {
host = h
}
server := &Server{
Listener: ln,
Handler: HTTP2Handler(
BypassHandlerOption(NewBypassPatterns(false, host)),
),
}
go server.Run()
defer server.Close()
if err = proxyRoundtrip(client, server, httpSrv.URL, sendData); err == nil {
t.Error("should failed")
}
}

View File

@ -5,6 +5,7 @@ import (
"crypto/rand"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
@ -55,7 +56,7 @@ func httpProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo,
return proxyRoundtrip(client, server, targetURL, data)
}
func TestHTTPProxy(t *testing.T) {
func TestHTTPProxyAuth(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
@ -82,6 +83,40 @@ func TestHTTPProxy(t *testing.T) {
}
}
func TestHTTPProxyWithInvalidRequest(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
ln, err := TCPListener("")
if err != nil {
t.Error(err)
}
server := &Server{
Listener: ln,
Handler: HTTPHandler(),
}
go server.Run()
defer server.Close()
r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), bytes.NewReader(sendData))
if err != nil {
t.Error(err)
}
resp, err := http.DefaultClient.Do(r)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Error("got status:", resp.Status)
}
}
func BenchmarkHTTPProxy(b *testing.B) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
@ -302,3 +337,42 @@ func TestHTTPProxyWithFileProbeResist(t *testing.T) {
t.Error("data not equal, got:", string(recv))
}
}
func TestHTTPProxyWithBypass(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
u, err := url.Parse(httpSrv.URL)
if err != nil {
t.Error(err)
}
ln, err := TCPListener("")
if err != nil {
t.Error(err)
}
client := &Client{
Connector: HTTPConnector(nil),
Transporter: TCPTransporter(),
}
host := u.Host
if h, _, _ := net.SplitHostPort(u.Host); h != "" {
host = h
}
server := &Server{
Listener: ln,
Handler: HTTPHandler(
BypassHandlerOption(NewBypassPatterns(false, host)),
),
}
go server.Run()
defer server.Close()
if err = proxyRoundtrip(client, server, httpSrv.URL, sendData); err == nil {
t.Error("should failed")
}
}

View File

@ -369,3 +369,56 @@ func TestSNIOverObfsHTTP(t *testing.T) {
})
}
}
func httpOverObfs4Roundtrip(targetURL string, data []byte,
clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error {
ln, err := Obfs4Listener("")
if err != nil {
return err
}
client := &Client{
Connector: HTTPConnector(clientInfo),
Transporter: Obfs4Transporter(),
}
server := &Server{
Listener: ln,
Handler: HTTPHandler(
UsersHandlerOption(serverInfo...),
),
}
go server.Run()
defer server.Close()
return proxyRoundtrip(client, server, targetURL, data)
}
func _TestHTTPOverObfs4(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
for i, tc := range httpProxyTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
err := httpOverObfs4Roundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers)
if err == nil {
if tc.errStr != "" {
t.Errorf("#%d should failed with error %s", i, tc.errStr)
}
} else {
if tc.errStr == "" {
t.Errorf("#%d got error %v", i, err)
}
if err.Error() != tc.errStr {
t.Errorf("#%d got error %v, want %v", i, err, tc.errStr)
}
}
})
}
}

View File

@ -238,7 +238,7 @@ type resolverCacheItem struct {
}
func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
if ttl < 0 {
if name == "" || ttl < 0 {
return nil
}
@ -248,7 +248,8 @@ func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
ttl = item.ttl
}
if item == nil || time.Since(time.Unix(item.ts, 0)) > ttl {
if time.Since(time.Unix(item.ts, 0)) > ttl {
r.mCache.Delete(name)
return nil
}
return item.IPs
@ -258,7 +259,7 @@ func (r *resolver) loadCache(name string, ttl time.Duration) []net.IP {
}
func (r *resolver) storeCache(name string, ips []net.IP, ttl time.Duration) {
if name == "" || len(ips) == 0 {
if name == "" || len(ips) == 0 || ttl < 0 {
return
}
r.mCache.Store(name, &resolverCacheItem{

View File

@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"io"
"net"
"testing"
"time"
)
@ -13,6 +14,8 @@ var dnsTests = []struct {
host string
pass bool
}{
{NameServer{Addr: "1.1.1.1"}, "192.168.1.1", true},
{NameServer{Addr: "1.1.1.1"}, "github", true},
{NameServer{Addr: "1.1.1.1"}, "github.com", true},
{NameServer{Addr: "1.1.1.1:53"}, "github.com", true},
{NameServer{Addr: "1.1.1.1:53", Protocol: "tcp"}, "github.com", true},
@ -47,6 +50,8 @@ func TestDNSResolver(t *testing.T) {
}
t.Log(ns)
r := NewResolver(0, ns)
resolv := r.(*resolver)
resolv.domain = "com"
err := dnsResolverRoundtrip(t, r, tc.host)
if err != nil {
if tc.pass {
@ -61,6 +66,56 @@ func TestDNSResolver(t *testing.T) {
}
}
var resolverCacheTests = []struct {
name string
ips []net.IP
ttl time.Duration
result []net.IP
}{
{"", nil, 0, nil},
{"", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil},
{"", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second, nil},
{"example.com", nil, 10 * time.Second, nil},
{"example.com", []net.IP{}, 10 * time.Second, nil},
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil},
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, -1, nil},
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second,
[]net.IP{net.IPv4(192, 168, 1, 1)}},
{"example.com", []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}, 10 * time.Second,
[]net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}},
}
func TestResolverCache(t *testing.T) {
isEqual := func(a, b []net.IP) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil || len(a) != len(b) {
return false
}
for i := range a {
if !a[i].Equal(b[i]) {
return false
}
}
return true
}
for i, tc := range resolverCacheTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
r := newResolver(tc.ttl)
r.storeCache(tc.name, tc.ips, tc.ttl)
ips := r.loadCache(tc.name, tc.ttl)
if !isEqual(tc.result, ips) {
t.Error("unexpected cache value:", tc.name, ips, tc.ttl)
}
})
}
}
var resolverReloadTests = []struct {
r io.Reader

150
selector_test.go Normal file
View File

@ -0,0 +1,150 @@
package gost
import (
"testing"
"time"
)
func TestRoundStrategy(t *testing.T) {
nodes := []Node{
Node{ID: 1},
Node{ID: 2},
Node{ID: 3},
}
s := NewStrategy("round")
t.Log(s.String())
if node := s.Apply(nil); node.ID > 0 {
t.Error("unexpected node", node.String())
}
for i := 0; i <= len(nodes); i++ {
node := s.Apply(nodes)
if node.ID != nodes[i%len(nodes)].ID {
t.Error("unexpected node", node.String())
}
}
}
func TestRandomStrategy(t *testing.T) {
nodes := []Node{
Node{ID: 1},
Node{ID: 2},
Node{ID: 3},
}
s := NewStrategy("random")
t.Log(s.String())
if node := s.Apply(nil); node.ID > 0 {
t.Error("unexpected node", node.String())
}
for i := 0; i <= len(nodes); i++ {
node := s.Apply(nodes)
if node.ID == 0 {
t.Error("unexpected node", node.String())
}
}
}
func TestFIFOStrategy(t *testing.T) {
nodes := []Node{
Node{ID: 1},
Node{ID: 2},
Node{ID: 3},
}
s := NewStrategy("fifo")
t.Log(s.String())
if node := s.Apply(nil); node.ID > 0 {
t.Error("unexpected node", node.String())
}
for i := 0; i <= len(nodes); i++ {
node := s.Apply(nodes)
if node.ID != 1 {
t.Error("unexpected node", node.String())
}
}
}
func TestFailFilter(t *testing.T) {
nodes := []Node{
Node{ID: 1, marker: &failMarker{}},
Node{ID: 2, marker: &failMarker{}},
Node{ID: 3, marker: &failMarker{}},
}
filter := &FailFilter{}
t.Log(filter.String())
isEqual := func(a, b []Node) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil || len(a) != len(b) {
return false
}
for i := range a {
if a[i].ID != b[i].ID {
return false
}
}
return true
}
if v := filter.Filter(nil); v != nil {
t.Error("unexpected node", v)
}
if v := filter.Filter(nodes); !isEqual(v, nodes) {
t.Error("unexpected node", v)
}
filter.MaxFails = 1
if v := filter.Filter(nodes); !isEqual(v, nodes) {
t.Error("unexpected node", v)
}
nodes[0].MarkDead()
if v := filter.Filter(nodes); !isEqual(v, nodes) {
t.Error("unexpected node", v)
}
filter.FailTimeout = 5 * time.Second
if v := filter.Filter(nodes); isEqual(v, nodes) {
t.Error("unexpected node", v)
}
nodes[1].MarkDead()
nodes[2].MarkDead()
if v := filter.Filter(nodes); len(v) > 0 {
t.Error("unexpected node", v)
}
for i := range nodes {
nodes[i].ResetDead()
}
if v := filter.Filter(nodes); !isEqual(v, nodes) {
t.Error("unexpected node", v)
}
}
func TestSelector(t *testing.T) {
nodes := []Node{
Node{ID: 1, marker: &failMarker{}},
Node{ID: 2, marker: &failMarker{}},
Node{ID: 3, marker: &failMarker{}},
}
selector := &defaultSelector{}
if _, err := selector.Select(nil); err != ErrNoneAvailable {
t.Error("got unexpected error:", err)
}
if node, _ := selector.Select(nodes); node.ID != 1 {
t.Error("unexpected node:", node)
}
if node, _ := selector.Select(nodes,
WithStrategy(NewStrategy("")),
WithFilter(&FailFilter{MaxFails: 1, FailTimeout: 3 * time.Second}),
); node.ID != 1 {
t.Error("unexpected node:", node)
}
}

151
socks.go
View File

@ -218,18 +218,12 @@ func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...Connect
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
selector := &clientSelector{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
User: c.User,
user := opts.User
if user == nil {
user = c.User
}
selector.AddMethod(
gosocks5.MethodNoAuth,
gosocks5.MethodUserPass,
MethodTLS,
)
cc := gosocks5.ClientConn(conn, selector)
if err := cc.Handleshake(); err != nil {
cc, err := socks5Handshake(conn, opts.Selector, user)
if err != nil {
return nil, err
}
conn = cc
@ -292,7 +286,11 @@ func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...Con
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
cc, err := socks5Handshake(conn, c.User)
user := opts.User
if user == nil {
user = c.User
}
cc, err := socks5Handshake(conn, opts.Selector, user)
if err != nil {
return nil, err
}
@ -442,7 +440,7 @@ func (tr *socks5MuxBindTransporter) initSession(conn net.Conn, addr string, opts
opts = &HandshakeOptions{}
}
cc, err := socks5Handshake(conn, opts.User)
cc, err := socks5Handshake(conn, nil, opts.User)
if err != nil {
return nil, err
}
@ -522,7 +520,11 @@ func (c *socks5UDPConnector) Connect(conn net.Conn, addr string, options ...Conn
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
cc, err := socks5Handshake(conn, c.User)
user := opts.User
if user == nil {
user = c.User
}
cc, err := socks5Handshake(conn, opts.Selector, user)
if err != nil {
return nil, err
}
@ -597,7 +599,11 @@ func (c *socks5UDPTunConnector) Connect(conn net.Conn, addr string, options ...C
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
cc, err := socks5Handshake(conn, c.User)
user := opts.User
if user == nil {
user = c.User
}
cc, err := socks5Handshake(conn, opts.Selector, user)
if err != nil {
return nil, err
}
@ -642,68 +648,6 @@ func (c *socks5UDPTunConnector) Connect(conn net.Conn, addr string, options ...C
return &udpTunnelConn{Conn: conn, raddr: taddr.String()}, nil
}
func (c *socks5UDPTunConnector) tunnelClientUDP(pc net.PacketConn, cc net.Conn) (err error) {
errc := make(chan error, 2)
go func() {
b := mPool.Get().([]byte)
defer mPool.Put(b)
for {
n, addr, err := pc.ReadFrom(b)
if err != nil {
log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err)
errc <- err
return
}
// pipe from peer to tunnel
dgram := gosocks5.NewUDPDatagram(
gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n])
if err := dgram.Write(cc); err != nil {
log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err)
errc <- err
return
}
if Debug {
log.Logf("[udp-tun] %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data))
}
}
}()
go func() {
for {
dgram, err := gosocks5.ReadUDPDatagram(cc)
if err != nil {
log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err)
errc <- err
return
}
// pipe from tunnel to peer
addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String())
if err != nil {
continue // drop silently
}
if _, err := pc.WriteTo(dgram.Data, addr); err != nil {
log.Logf("[udp-tun] %s -> %s : %s", cc.RemoteAddr(), addr, err)
errc <- err
return
}
if Debug {
log.Logf("[udp-tun] %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data))
}
}
}()
select {
case err = <-errc:
}
return
}
type socks4Connector struct{}
// SOCKS4Connector creates a Connector for SOCKS4 proxy client.
@ -1186,7 +1130,7 @@ func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) {
}
defer cc.Close()
cc, err = socks5Handshake(cc, h.options.Chain.LastNode().User)
cc, err = socks5Handshake(cc, nil, h.options.Chain.LastNode().User)
if err != nil {
log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, err)
return
@ -1450,7 +1394,7 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) {
}
defer cc.Close()
cc, err = socks5Handshake(cc, h.options.Chain.LastNode().User)
cc, err = socks5Handshake(cc, nil, h.options.Chain.LastNode().User)
if err != nil {
log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err)
return
@ -1844,7 +1788,7 @@ func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) {
conn.SetDeadline(time.Now().Add(HandshakeTimeout))
defer conn.SetDeadline(time.Time{})
cc, err := socks5Handshake(conn, chain.LastNode().User)
cc, err := socks5Handshake(conn, nil, chain.LastNode().User)
if err != nil {
conn.Close()
return nil, err
@ -1877,16 +1821,20 @@ func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) {
return conn, nil
}
func socks5Handshake(conn net.Conn, user *url.Userinfo) (net.Conn, error) {
selector := &clientSelector{
func socks5Handshake(conn net.Conn, selector gosocks5.Selector, user *url.Userinfo) (net.Conn, error) {
if selector == nil {
cs := &clientSelector{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
User: user,
}
selector.AddMethod(
cs.AddMethod(
gosocks5.MethodNoAuth,
gosocks5.MethodUserPass,
MethodTLS,
)
selector = cs
}
cc := gosocks5.ClientConn(conn, selector)
if err := cc.Handleshake(); err != nil {
return nil, err
@ -1996,24 +1944,20 @@ type socks5UDPConn struct {
taddr net.Addr
}
func (c *socks5UDPConn) Read(b []byte) (int, error) {
data := mPool.Get().([]byte)
defer mPool.Put(data)
n, err := c.UDPConn.Read(data)
if err != nil {
return 0, err
}
dg, err := gosocks5.ReadUDPDatagram(bytes.NewReader(data[:n]))
if err != nil {
return 0, err
}
return copy(b, dg.Data), nil
func (c *socks5UDPConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *socks5UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
dg, err := gosocks5.ReadUDPDatagram(c.UDPConn)
data := mPool.Get().([]byte)
defer mPool.Put(data)
n, err = c.UDPConn.Read(data)
if err != nil {
return
}
dg, err := gosocks5.ReadUDPDatagram(bytes.NewReader(data[:n]))
if err != nil {
return
}
@ -2025,16 +1969,7 @@ func (c *socks5UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
}
func (c *socks5UDPConn) Write(b []byte) (int, error) {
addr, err := gosocks5.NewAddr(c.taddr.String())
if err != nil {
return 0, err
}
h := gosocks5.NewUDPHeader(0, 0, addr)
dg := gosocks5.NewUDPDatagram(h, b)
if err = dg.Write(c.UDPConn); err != nil {
return 0, err
}
return len(b), nil
return c.WriteTo(b, c.taddr)
}
func (c *socks5UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {

View File

@ -550,7 +550,7 @@ func socks5UDPRoundtrip(t *testing.T, host string, data []byte) (err error) {
go server.Run()
defer server.Close()
return udpRoundtrip(client, server, host, data)
return udpRoundtrip(t, client, server, host, data)
}
func TestSOCKS5UDP(t *testing.T) {
@ -593,7 +593,7 @@ func BenchmarkSOCKS5UDP(b *testing.B) {
defer server.Close()
for i := 0; i < b.N; i++ {
if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil {
if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil {
b.Error(err)
}
}
@ -679,7 +679,7 @@ func socks5UDPTunRoundtrip(t *testing.T, host string, data []byte) (err error) {
go server.Run()
defer server.Close()
return udpRoundtrip(client, server, host, data)
return udpRoundtrip(t, client, server, host, data)
}
func TestSOCKS5UDPTun(t *testing.T) {
@ -721,7 +721,7 @@ func BenchmarkSOCKS5UDPTun(b *testing.B) {
defer server.Close()
for i := 0; i < b.N; i++ {
if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil {
if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil {
b.Error(err)
}
}

28
ss.go
View File

@ -47,9 +47,13 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...Connect
}
var method, password string
if c.Cipher != nil {
method = c.Cipher.Username()
password, _ = c.Cipher.Password()
cp := opts.User
if cp == nil {
cp = c.Cipher
}
if cp != nil {
method = cp.Username()
password, _ = cp.Password()
}
cipher, err := ss.NewCipher(method, password)
@ -446,6 +450,10 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
b := mPool.Get().([]byte)
defer mPool.Put(b)
b[0] = 0
b[1] = 0
b[2] = 0
n, err := sc.Read(b[3:]) // add rsv and frag fields to make it the standard SOCKS5 UDP datagram
if err != nil {
// log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err)
@ -535,10 +543,14 @@ type shadowUDPConn struct {
func (c *shadowUDPConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
if len(c.header) > 0 {
b = append(c.header, b...)
buf := bytes.Buffer{}
if _, err = buf.Write(c.header); err != nil {
return
}
_, err = c.PacketConn.WriteTo(b, c.raddr)
if _, err = buf.Write(b); err != nil {
return
}
_, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr)
return
}
@ -546,6 +558,10 @@ func (c *shadowUDPConn) Read(b []byte) (n int, err error) {
buf := mPool.Get().([]byte)
defer mPool.Put(buf)
buf[0] = 0
buf[1] = 0
buf[2] = 0
n, _, err = c.PacketConn.ReadFrom(buf[3:])
if err != nil {
return

View File

@ -1,6 +1,7 @@
package gost
import (
"bytes"
"crypto/rand"
"fmt"
"net/http/httptest"
@ -299,14 +300,15 @@ func BenchmarkSSProxyParallel(b *testing.B) {
})
}
func shadowUDPRoundtrip(t *testing.T, host string, data []byte) error {
ln, err := ShadowUDPListener("localhost:0", url.UserPassword("chacha20-ietf", "123456"), 0)
func shadowUDPRoundtrip(t *testing.T, host string, data []byte,
clientInfo *url.Userinfo, serverInfo *url.Userinfo) error {
ln, err := ShadowUDPListener("localhost:0", serverInfo, 0)
if err != nil {
return err
}
client := &Client{
Connector: ShadowUDPConnector(url.UserPassword("chacha20-ietf", "123456")),
Connector: ShadowUDPConnector(clientInfo),
Transporter: UDPTransporter(),
}
@ -318,19 +320,35 @@ func shadowUDPRoundtrip(t *testing.T, host string, data []byte) error {
go server.Run()
defer server.Close()
return udpRoundtrip(client, server, host, data)
return udpRoundtrip(t, client, server, host, data)
}
func _TestShadowUDP(t *testing.T) {
func TestShadowUDP(t *testing.T) {
sendData := make([]byte, 128)
rand.Read(sendData)
for i, tc := range ssTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
udpSrv := newUDPTestServer(udpTestHandler)
udpSrv.Start()
defer udpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
err := shadowUDPRoundtrip(t, udpSrv.Addr(), sendData)
if err != nil {
t.Error(err)
err := shadowUDPRoundtrip(t, udpSrv.Addr(), sendData,
tc.clientCipher,
tc.serverCipher,
)
if err == nil {
if !tc.pass {
t.Errorf("#%d should failed", i)
}
} else {
// t.Logf("#%d %v", i, err)
if tc.pass {
t.Errorf("#%d got error: %v", i, err)
}
}
})
}
}
@ -343,7 +361,7 @@ func BenchmarkShadowUDP(b *testing.B) {
sendData := make([]byte, 128)
rand.Read(sendData)
ln, err := ShadowUDPListener("localhost:0", url.UserPassword("chacha20-ietf", "123456"), 1000*time.Millisecond)
ln, err := ShadowUDPListener("localhost:0", url.UserPassword("chacha20-ietf", "123456"), 0)
if err != nil {
b.Error(err)
}
@ -361,9 +379,32 @@ func BenchmarkShadowUDP(b *testing.B) {
go server.Run()
defer server.Close()
for i := 0; i < b.N; i++ {
if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil {
conn, err := proxyConn(client, server)
if err != nil {
b.Error(err)
}
defer conn.Close()
conn, err = client.Connect(conn, udpSrv.Addr())
if err != nil {
return
}
for i := 0; i < b.N; i++ {
conn.SetDeadline(time.Now().Add(1 * time.Second))
defer conn.SetDeadline(time.Time{})
if _, err = conn.Write(sendData); err != nil {
b.Error(err)
}
recv := make([]byte, len(sendData))
if _, err = conn.Read(recv); err != nil {
b.Error(err)
}
if !bytes.Equal(sendData, recv) {
b.Error("data not equal")
}
}
}