From c7b3111e8970fd0964d309f360375d46fa6a5a8f Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 15 Dec 2018 20:30:51 +0800 Subject: [PATCH] add tests for reload --- bypass.go | 16 ++-- bypass_test.go | 180 +++++++++++++++++++++++++++++++++++--- handler_test.go | 223 ++++++++++++++++++++++++++++++++++++++++++++++++ hosts.go | 15 +++- hosts_test.go | 121 ++++++++++++++++++++++++++ http_test.go | 2 +- 6 files changed, 530 insertions(+), 27 deletions(-) create mode 100644 handler_test.go create mode 100644 hosts_test.go diff --git a/bypass.go b/bypass.go index 067dff6..85bf12b 100644 --- a/bypass.go +++ b/bypass.go @@ -95,7 +95,7 @@ func DomainMatcher(pattern string) Matcher { p := pattern if strings.HasPrefix(pattern, ".") { p = pattern[1:] // trim the prefix '.' - pattern = "*" + pattern + pattern = "*" + p } return &domainMatcher{ pattern: p, @@ -143,8 +143,8 @@ func NewBypass(reversed bool, matchers ...Matcher) *Bypass { func NewBypassPatterns(reversed bool, patterns ...string) *Bypass { var matchers []Matcher for _, pattern := range patterns { - if pattern != "" { - matchers = append(matchers, NewMatcher(pattern)) + if m := NewMatcher(pattern); m != nil { + matchers = append(matchers, m) } } return NewBypass(reversed, matchers...) @@ -152,15 +152,9 @@ func NewBypassPatterns(reversed bool, patterns ...string) *Bypass { // Contains reports whether the bypass includes addr. func (bp *Bypass) Contains(addr string) bool { - if bp == nil || addr == "" { + if bp == nil || len(bp.matchers) == 0 || addr == "" { return false } - // try to strip the port - if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { - if p, _ := strconv.Atoi(port); p > 0 { // port is valid - addr = host - } - } bp.mux.RLock() defer bp.mux.RUnlock() @@ -209,7 +203,7 @@ func (bp *Bypass) Reload(r io.Reader) error { var period time.Duration var reversed bool - if bp.Stopped() { + if r == nil || bp.Stopped() { return nil } diff --git a/bypass_test.go b/bypass_test.go index ae124c5..5cc0e81 100644 --- a/bypass_test.go +++ b/bypass_test.go @@ -1,34 +1,94 @@ package gost -import "testing" +import ( + "bytes" + "io" + "testing" + "time" +) -var bypassTests = []struct { +var bypassContainTests = []struct { patterns []string reversed bool addr string bypassed bool }{ + // empty pattern + {[]string{""}, false, "", false}, + {[]string{""}, false, "192.168.1.1", false}, + {[]string{""}, true, "", false}, + {[]string{""}, true, "192.168.1.1", false}, + // IP address {[]string{"192.168.1.1"}, false, "192.168.1.1", true}, + {[]string{"192.168.1.1"}, true, "192.168.1.1", false}, {[]string{"192.168.1.1"}, false, "192.168.1.2", false}, + {[]string{"192.168.1.1"}, true, "192.168.1.2", true}, {[]string{"0.0.0.0"}, false, "0.0.0.0", true}, + {[]string{"0.0.0.0"}, true, "0.0.0.0", false}, // CIDR address {[]string{"192.168.1.0/0"}, false, "1.2.3.4", true}, + {[]string{"192.168.1.0/0"}, true, "1.2.3.4", false}, {[]string{"192.168.1.0/8"}, false, "192.1.0.255", true}, + {[]string{"192.168.1.0/8"}, true, "192.1.0.255", false}, + {[]string{"192.168.1.0/8"}, false, "191.1.0.255", false}, + {[]string{"192.168.1.0/8"}, true, "191.1.0.255", true}, {[]string{"192.168.1.0/16"}, false, "192.168.0.255", true}, + {[]string{"192.168.1.0/16"}, true, "192.168.0.255", false}, + {[]string{"192.168.1.0/16"}, false, "192.0.1.255", false}, + {[]string{"192.168.1.0/16"}, true, "192.0.0.255", true}, {[]string{"192.168.1.0/24"}, false, "192.168.1.255", true}, + {[]string{"192.168.1.0/24"}, true, "192.168.1.255", false}, + {[]string{"192.168.1.0/24"}, false, "192.168.0.255", false}, + {[]string{"192.168.1.0/24"}, true, "192.168.0.255", true}, {[]string{"192.168.1.1/32"}, false, "192.168.1.1", true}, + {[]string{"192.168.1.1/32"}, true, "192.168.1.1", false}, {[]string{"192.168.1.1/32"}, false, "192.168.1.2", false}, + {[]string{"192.168.1.1/32"}, true, "192.168.1.2", true}, // plain domain {[]string{"www.example.com"}, false, "www.example.com", true}, + {[]string{"www.example.com"}, true, "www.example.com", false}, {[]string{"http://www.example.com"}, false, "http://www.example.com", true}, + {[]string{"http://www.example.com"}, true, "http://www.example.com", false}, {[]string{"http://www.example.com"}, false, "http://example.com", false}, + {[]string{"http://www.example.com"}, true, "http://example.com", true}, {[]string{"www.example.com"}, false, "example.com", false}, + {[]string{"www.example.com"}, true, "example.com", true}, + + // host:port + {[]string{"192.168.1.1"}, false, "192.168.1.1:80", false}, + {[]string{"192.168.1.1"}, true, "192.168.1.1:80", true}, + {[]string{"192.168.1.1:80"}, false, "192.168.1.1", false}, + {[]string{"192.168.1.1:80"}, true, "192.168.1.1", true}, + {[]string{"192.168.1.1:80"}, false, "192.168.1.1:80", true}, + {[]string{"192.168.1.1:80"}, true, "192.168.1.1:80", false}, + {[]string{"192.168.1.1:80"}, false, "192.168.1.1:8080", false}, + {[]string{"192.168.1.1:80"}, true, "192.168.1.1:8080", true}, + + {[]string{"example.com"}, false, "example.com:80", false}, + {[]string{"example.com"}, true, "example.com:80", true}, + {[]string{"example.com:80"}, false, "example.com", false}, + {[]string{"example.com:80"}, true, "example.com", true}, + {[]string{"example.com:80"}, false, "example.com:80", true}, + {[]string{"example.com:80"}, true, "example.com:80", false}, + {[]string{"example.com:80"}, false, "example.com:8080", false}, + {[]string{"example.com:80"}, true, "example.com:8080", true}, // domain wildcard + {[]string{"*"}, false, "", false}, + {[]string{"*"}, false, "192.168.1.1", true}, + {[]string{"*"}, false, "192.168.0.0/16", true}, + {[]string{"*"}, false, "http://example.com", true}, + {[]string{"*"}, false, "example.com:80", true}, + {[]string{"*"}, true, "", false}, + {[]string{"*"}, true, "192.168.1.1", false}, + {[]string{"*"}, true, "192.168.0.0/16", false}, + {[]string{"*"}, true, "http://example.com", false}, + {[]string{"*"}, true, "example.com:80", false}, + // sub-domain {[]string{"*.example.com"}, false, "example.com", false}, {[]string{"*.example.com"}, false, "http://example.com", false}, @@ -78,18 +138,114 @@ var bypassTests = []struct { {[]string{".example.com"}, false, "www.example.com", true}, {[]string{".example.com"}, false, "example.com", true}, {[]string{".example.com"}, false, "www.example.com.cn", false}, + + {[]string{"example.com:*"}, false, "example.com", false}, + {[]string{"example.com:*"}, false, "example.com:80", true}, + {[]string{"example.com:*"}, false, "example.com:8080", true}, + {[]string{"example.com:*"}, false, "example.com:http", true}, + {[]string{"example.com:*"}, false, "http://example.com:80", false}, + + {[]string{"*example.com:*"}, false, "example.com:80", true}, + + {[]string{".example.com:*"}, false, "www.example.com", false}, + {[]string{".example.com:*"}, false, "http://www.example.com", false}, + {[]string{".example.com:*"}, false, "example.com:80", true}, + {[]string{".example.com:*"}, false, "www.example.com:8080", true}, + {[]string{".example.com:*"}, false, "http://www.example.com:80", true}, } -func TestBypass(t *testing.T) { - for _, test := range bypassTests { - bp := NewBypassPatterns(test.reversed, test.patterns...) - if bp.Contains(test.addr) != test.bypassed { - t.Errorf("test failed: %v, %s", test.patterns, test.addr) - } - - rbp := NewBypassPatterns(!test.reversed, test.patterns...) - if rbp.Contains(test.addr) == test.bypassed { - t.Errorf("reverse test failed: %v, %s", test.patterns, test.addr) +func TestBypassContains(t *testing.T) { + for i, tc := range bypassContainTests { + bp := NewBypassPatterns(tc.reversed, tc.patterns...) + if bp.Contains(tc.addr) != tc.bypassed { + t.Errorf("#%d test failed: %v, %s", i, tc.patterns, tc.addr) + } + } +} + +var bypassReloadTests = []struct { + r io.Reader + + reversed bool + period time.Duration + + addr string + bypassed bool + stopped bool +}{ + { + r: nil, + reversed: false, + period: 0, + addr: "192.168.1.1", + bypassed: false, + stopped: false, + }, + { + r: bytes.NewBufferString(""), + reversed: false, + period: 0, + addr: "192.168.1.1", + bypassed: false, + stopped: false, + }, + { + r: bytes.NewBufferString("reverse true\nreload 10s"), + reversed: true, + period: 10 * time.Second, + addr: "192.168.1.1", + bypassed: false, + stopped: false, + }, + { + r: bytes.NewBufferString("reverse false\nreload 10s\n192.168.1.1"), + reversed: false, + period: 10 * time.Second, + addr: "192.168.1.1", + bypassed: true, + stopped: false, + }, + { + r: bytes.NewBufferString("reverse false\nreload 10s\n192.168.1.1\n#example.com"), + reversed: false, + period: 10 * time.Second, + addr: "example.com", + bypassed: false, + stopped: false, + }, + { + r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.1\n#example.com"), + reversed: false, + period: 0, + addr: "192.168.1.1", + bypassed: true, + stopped: true, + }, +} + +func TestByapssReload(t *testing.T) { + for i, tc := range bypassReloadTests { + bp := NewBypass(false) + if err := bp.Reload(tc.r); err != nil { + t.Error(err) + } + if bp.Reversed() != tc.reversed { + t.Errorf("#%d test failed: reversed value should be %v, got %v", + i, tc.reversed, bp.reversed) + } + if bp.Period() != tc.period { + t.Errorf("#%d test failed: period value should be %v, got %v", + i, tc.period, bp.Period()) + } + if bp.Contains(tc.addr) != tc.bypassed { + t.Errorf("#%d test failed: %v, %s", i, bp.reversed, tc.addr) + } + if tc.stopped { + bp.Stop() + } + if bp.Stopped() != tc.stopped { + t.Errorf("#%d test failed: stopped value should be %v, got %v", + i, tc.stopped, bp.Stopped()) } } } diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..e69bdb3 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,223 @@ +package gost + +import ( + "crypto/rand" + "crypto/tls" + "net/http/httptest" + "net/url" + "testing" +) + +func autoHTTPProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: AutoHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoHTTPProxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := autoHTTPProxyRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func autoSocks5ProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: AutoHandler(UsersHandlerOption(serverInfo...)), + Listener: ln, + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoSOCKS5Proxy(t *testing.T) { + cert, err := GenCertificate() + if err != nil { + panic(err) + } + DefaultTLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := autoSocks5ProxyRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func autoSOCKS4ProxyRoundtrip(targetURL string, data []byte) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: AutoHandler(), + } + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoSOCKS4Proxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := autoSOCKS4ProxyRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func autoSocks4aProxyRoundtrip(targetURL string, data []byte) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: AutoHandler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoSOCKS4AProxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := autoSocks4aProxyRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func autoSSProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo *url.Userinfo) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: AutoHandler(UsersHandlerOption(serverInfo)), + Listener: ln, + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoSSProxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssTests { + err := autoSSProxyRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + t.Errorf("#%d should failed", i) + } + } +} diff --git a/hosts.go b/hosts.go index 4c492fb..aa0ef99 100644 --- a/hosts.go +++ b/hosts.go @@ -18,6 +18,15 @@ type Host struct { Aliases []string } +// NewHost creates a Host. +func NewHost(ip net.IP, hostname string, aliases ...string) Host { + return Host{ + IP: ip, + Hostname: hostname, + Aliases: aliases, + } +} + // Hosts is a static table lookup for hostnames. // For each host a single line should be present with the following information: // IP_address canonical_hostname [aliases...] @@ -30,7 +39,7 @@ type Hosts struct { mux sync.RWMutex } -// NewHosts creates a Hosts with optional list of host +// NewHosts creates a Hosts with optional list of hosts. func NewHosts(hosts ...Host) *Hosts { return &Hosts{ hosts: hosts, @@ -48,7 +57,7 @@ func (h *Hosts) AddHost(host ...Host) { // Lookup searches the IP address corresponds to the given host from the host table. func (h *Hosts) Lookup(host string) (ip net.IP) { - if h == nil { + if h == nil || host == "" { return } @@ -78,7 +87,7 @@ func (h *Hosts) Reload(r io.Reader) error { var period time.Duration var hosts []Host - if h.Stopped() { + if r == nil || h.Stopped() { return nil } diff --git a/hosts_test.go b/hosts_test.go new file mode 100644 index 0000000..7b775a8 --- /dev/null +++ b/hosts_test.go @@ -0,0 +1,121 @@ +package gost + +import ( + "bytes" + "io" + "net" + "testing" + "time" +) + +var hostsLookupTests = []struct { + hosts []Host + host string + ip net.IP +}{ + {nil, "", nil}, + {nil, "example.com", nil}, + {[]Host{}, "", nil}, + {[]Host{}, "example.com", nil}, + {[]Host{NewHost(nil, "")}, "", nil}, + {[]Host{NewHost(nil, "example.com")}, "example.com", nil}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "")}, "", nil}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com")}, "example.com", net.IPv4(192, 168, 1, 1)}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com")}, "example", nil}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com", "example", "examples")}, "example", net.IPv4(192, 168, 1, 1)}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com", "example", "examples")}, "examples", net.IPv4(192, 168, 1, 1)}, +} + +func TestHostsLookup(t *testing.T) { + for i, tc := range hostsLookupTests { + hosts := NewHosts(tc.hosts...) + ip := hosts.Lookup(tc.host) + if !ip.Equal(tc.ip) { + t.Errorf("#%d test failed: lookup should be %s, got %s", i, tc.ip, ip) + } + } +} + +var HostsReloadTests = []struct { + r io.Reader + period time.Duration + host string + ip net.IP + stopped bool +}{ + { + r: nil, + period: 0, + host: "", + ip: nil, + }, + { + r: bytes.NewBufferString(""), + period: 0, + host: "example.com", + ip: nil, + }, + { + r: bytes.NewBufferString("reload 10s"), + period: 10 * time.Second, + host: "example.com", + ip: nil, + }, + { + r: bytes.NewBufferString("reload 10s\n192.168.1.1"), + period: 10 * time.Second, + host: "", + ip: nil, + }, + { + r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com"), + period: 0, + host: "example.com", + ip: net.IPv4(192, 168, 1, 1), + }, + { + r: bytes.NewBufferString("#reload 10s\n#192.168.1.1 example.com"), + period: 0, + host: "example.com", + ip: nil, + stopped: true, + }, + { + r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com example examples"), + period: 0, + host: "example", + ip: net.IPv4(192, 168, 1, 1), + stopped: true, + }, + { + r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com example examples"), + period: 0, + host: "examples", + ip: net.IPv4(192, 168, 1, 1), + stopped: true, + }, +} + +func TestHostsReload(t *testing.T) { + for i, tc := range HostsReloadTests { + hosts := NewHosts() + if err := hosts.Reload(tc.r); err != nil { + t.Error(err) + } + if hosts.Period() != tc.period { + t.Errorf("#%d test failed: period value should be %v, got %v", + i, tc.period, hosts.Period()) + } + ip := hosts.Lookup(tc.host) + if !ip.Equal(tc.ip) { + t.Errorf("#%d test failed: lookup should be %s, got %s", i, tc.ip, ip) + } + if tc.stopped { + hosts.Stop() + } + if hosts.Stopped() != tc.stopped { + t.Errorf("#%d test failed: stopped value should be %v, got %v", + i, tc.stopped, hosts.Stopped()) + } + } +} diff --git a/http_test.go b/http_test.go index e8ac188..b322672 100644 --- a/http_test.go +++ b/http_test.go @@ -77,7 +77,7 @@ func proxyRoundtrip(client *Client, server *Server, targetURL string, data []byt return } - conn.SetDeadline(time.Now().Add(1 * time.Second)) + conn.SetDeadline(time.Now().Add(500 * time.Millisecond)) defer conn.SetDeadline(time.Time{}) conn, err = client.Connect(conn, u.Host)