From 8cbd2722f6146c6968e5a0e110e5805edfbbf2ec Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Sun, 13 Aug 2017 09:30:18 +0800 Subject: [PATCH] add test files --- .travis.yml | 2 +- cmd/gost/main.go | 11 ++-- gost.go | 2 +- http_test.go | 127 +++++++++++++++++++++++++++++++++++++++++++++++ node.go | 28 ++--------- node_test.go | 68 +++++++++++++++++++++++++ permissions.go | 22 ++++++++ server.go | 30 ++++++++--- 8 files changed, 249 insertions(+), 41 deletions(-) create mode 100644 http_test.go create mode 100644 node_test.go diff --git a/.travis.yml b/.travis.yml index 608d171..7aea3ca 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,5 +5,5 @@ go: install: true script: - - go test -v + - go test -race -v - cd cmd/gost && go build \ No newline at end of file diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 0b193ee..a39248f 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -56,6 +56,7 @@ func init() { os.Exit(0) } + gost.SetLogger(&gost.LogLogger{}) gost.Debug = options.DebugMode } @@ -313,18 +314,11 @@ func serve(chain *gost.Chain) error { if whitelist, err = gost.ParsePermissions(node.Values.Get("whitelist")); err != nil { return err } - } else { - // By default allow for everything - whitelist, _ = gost.ParsePermissions("*:*:*") } - if node.Values.Get("blacklist") != "" { if blacklist, err = gost.ParsePermissions(node.Values.Get("blacklist")); err != nil { return err } - } else { - // By default block nothing - blacklist, _ = gost.ParsePermissions("") } var handlerOptions []gost.HandlerOption @@ -366,7 +360,8 @@ func serve(chain *gost.Chain) error { default: handler = gost.AutoHandler(handlerOptions...) } - go new(gost.Server).Serve(ln, handler) + srv := &gost.Server{Listener: ln} + go srv.Serve(handler) } return nil diff --git a/gost.go b/gost.go index 0fa0b54..3be76b6 100644 --- a/gost.go +++ b/gost.go @@ -64,7 +64,7 @@ func init() { Certificates: []tls.Certificate{cert}, } - log.DefaultLogger = &LogLogger{} + // log.DefaultLogger = &LogLogger{} } // SetLogger sets a new logger for internal log system diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..0292ab9 --- /dev/null +++ b/http_test.go @@ -0,0 +1,127 @@ +package gost + +import ( + "bufio" + "bytes" + "crypto/rand" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +var httpTestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(w, r.Body) +}) + +func httpProxyRoundtrip(urlStr string, cliUser *url.Userinfo, srvUsers []*url.Userinfo, body io.Reader) (statusCode int, recv []byte, err error) { + ln, err := TCPListener("") + if err != nil { + return + } + h := HTTPHandler(UsersHandlerOption(srvUsers...)) + server := &Server{Listener: ln} + go server.Serve(h) + + exitChan := make(chan struct{}) + defer close(exitChan) + go func() { + defer server.Close() + <-exitChan + }() + + client := &Client{ + Connector: HTTPConnector(cliUser), + Transporter: TCPTransporter(), + } + conn, err := client.Dial(ln.Addr().String()) + if err != nil { + return + } + defer conn.Close() + conn, err = client.Handshake(conn) + if err != nil { + return + } + url, err := url.Parse(urlStr) + if err != nil { + return + } + conn, err = client.Connect(conn, url.Host) + if err != nil { + return + } + req, err := http.NewRequest(http.MethodGet, urlStr, body) + if err != nil { + return + } + if err = req.Write(conn); err != nil { + return + } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return + } + defer resp.Body.Close() + statusCode = resp.StatusCode + recv, err = ioutil.ReadAll(resp.Body) + return +} + +var httpProxyTests = []struct { + url string + cliUser *url.Userinfo + srvUsers []*url.Userinfo + errStr string +}{ + {"", nil, nil, ""}, + {"", nil, []*url.Userinfo{url.User("admin")}, "407 Proxy Authentication Required"}, + {"", nil, []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"}, + {"", url.User("admin"), []*url.Userinfo{url.User("test")}, "407 Proxy Authentication Required"}, + {"", url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, "407 Proxy Authentication Required"}, + {"", url.User("admin"), []*url.Userinfo{url.User("admin")}, ""}, + {"", url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""}, + {"", url.UserPassword("admin", "123456"), nil, ""}, + {"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""}, + {"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""}, + {"", url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""}, + {"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""}, + {"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""}, + {"http://:0", nil, nil, "503 Service unavailable"}, +} + +func TestHTTPProxy(t *testing.T) { + Debug = true + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + for _, test := range httpProxyTests { + send := make([]byte, 16) + rand.Read(send) + urlStr := test.url + if urlStr == "" { + urlStr = httpSrv.URL + } + _, recv, err := httpProxyRoundtrip(urlStr, test.cliUser, test.srvUsers, bytes.NewReader(send)) + if err == nil { + if test.errStr != "" { + t.Errorf("HTTP proxy response should failed with error %s", test.errStr) + continue + } + } else { + if test.errStr == "" { + t.Errorf("HTTP proxy got error %v", err) + } + if err.Error() != test.errStr { + t.Errorf("HTTP proxy got error %v, want %v", err, test.errStr) + } + continue + } + if !bytes.Equal(send, recv) { + t.Errorf("got %v, want %v", recv, send) + continue + } + } +} diff --git a/node.go b/node.go index 32211b8..e9da9ec 100644 --- a/node.go +++ b/node.go @@ -1,9 +1,7 @@ package gost import ( - "net" "net/url" - "strconv" "strings" ) @@ -24,6 +22,10 @@ type Node struct { // The proxy node string pattern is [scheme://][user:pass@host]:port. // Scheme can be divided into two parts by character '+', such as: http+tls. func ParseNode(s string) (node Node, err error) { + if s == "" { + return Node{}, nil + } + if !strings.Contains(s, "://") { s = "auto://" + s } @@ -58,7 +60,7 @@ func ParseNode(s string) (node Node, err error) { case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding node.Remote = strings.Trim(u.EscapedPath(), "/") default: - node.Transport = "" + node.Transport = "tcp" } switch node.Protocol { @@ -74,23 +76,3 @@ func ParseNode(s string) (node Node, err error) { return } - -// Can tests whether the given action and address is allowed by the whitelist and blacklist. -func Can(action string, addr string, whitelist, blacklist *Permissions) bool { - if !strings.Contains(addr, ":") { - addr = addr + ":80" - } - host, strport, err := net.SplitHostPort(addr) - - if err != nil { - return false - } - - port, err := strconv.Atoi(strport) - - if err != nil { - return false - } - - return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port) -} diff --git a/node_test.go b/node_test.go new file mode 100644 index 0000000..fa83b98 --- /dev/null +++ b/node_test.go @@ -0,0 +1,68 @@ +package gost + +import "testing" +import "net/url" + +var nodeTests = []struct { + in string + out Node + hasError bool +}{ + {"", Node{}, false}, + {"://", Node{}, true}, + {"localhost", Node{Addr: "localhost", Transport: "tcp"}, false}, + {":", Node{Addr: ":", Transport: "tcp"}, false}, + {":8080", Node{Addr: ":8080", Transport: "tcp"}, false}, + {"http://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp"}, false}, + {"http://localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp"}, false}, + {"http://admin:123456@:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("admin", "123456")}, false}, + {"http://admin@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("admin")}, false}, + {"http://:123456@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "123456")}, false}, + {"http://@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("")}, false}, + {"http://:@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "")}, false}, + {"https://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tls"}, false}, + {"socks+tls://:8080", Node{Addr: ":8080", Protocol: "socks5", Transport: "tls"}, false}, + {"tls://:8080", Node{Addr: ":8080", Transport: "tls"}, false}, + {"tcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "tcp", Transport: "tcp"}, false}, + {"udp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "udp", Transport: "udp"}, false}, + {"rtcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rtcp", Transport: "rtcp"}, false}, + {"rudp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rudp", Transport: "rudp"}, false}, + {"redirect://:8080", Node{Addr: ":8080", Protocol: "redirect", Transport: "tcp"}, false}, +} + +func TestParseNode(t *testing.T) { + for _, test := range nodeTests { + actual, err := ParseNode(test.in) + if err != nil { + if test.hasError { + t.Logf("ParseNode(%q) got expected error: %v", test.in, err) + continue + } + t.Errorf("ParseNode(%q) got error: %v", test.in, err) + } else { + if test.hasError { + t.Errorf("ParseNode(%q) got %v, but should return error", test.in, actual) + continue + } + if actual.Addr != test.out.Addr || actual.Protocol != test.out.Protocol || + actual.Transport != test.out.Transport || actual.Remote != test.out.Remote { + t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out) + } + if actual.User == nil { + if test.out.User != nil { + t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out) + } + continue + } + if actual.User != nil { + if test.out.User == nil { + t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out) + continue + } + if *actual.User != *test.out.User { + t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out) + } + } + } + } +} diff --git a/permissions.go b/permissions.go index c76f9f7..a2943fa 100644 --- a/permissions.go +++ b/permissions.go @@ -3,6 +3,7 @@ package gost import ( "errors" "fmt" + "net" "strconv" "strings" @@ -199,3 +200,24 @@ func maxint(x, y int) int { } return y } + +// Can tests whether the given action and address is allowed by the whitelist and blacklist. +func Can(action string, addr string, whitelist, blacklist *Permissions) bool { + if !strings.Contains(addr, ":") { + addr = addr + ":80" + } + host, strport, err := net.SplitHostPort(addr) + + if err != nil { + return false + } + + port, err := strconv.Atoi(strport) + + if err != nil { + return false + } + + return (whitelist == nil || whitelist.Can(action, host, port)) && + (blacklist == nil || !blacklist.Can(action, host, port)) +} diff --git a/server.go b/server.go index fb34958..174fe60 100644 --- a/server.go +++ b/server.go @@ -10,23 +10,33 @@ import ( // Server is a proxy server. type Server struct { + Listener Listener +} + +// Addr returns the address of the server +func (s *Server) Addr() net.Addr { + return s.Listener.Addr() +} + +// Close closes the server +func (s *Server) Close() error { + return s.Listener.Close() } // Serve serves as a proxy server. -func (s *Server) Serve(l net.Listener, h Handler) error { - defer l.Close() - - if l == nil { - ln, err := TCPListener(":8080") +func (s *Server) Serve(h Handler) error { + if s.Listener == nil { + ln, err := TCPListener("") if err != nil { return err } - l = ln + s.Listener = ln } if h == nil { h = HTTPHandler() } + l := s.Listener var tempDelay time.Duration for { conn, e := l.Accept() @@ -63,11 +73,15 @@ type tcpListener struct { // TCPListener creates a Listener for TCP proxy server. func TCPListener(addr string) (Listener, error) { - ln, err := net.Listen("tcp", addr) + laddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return nil, err } - return &tcpListener{Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}}, nil + ln, err := net.ListenTCP("tcp", laddr) + if err != nil { + return nil, err + } + return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil } type tcpKeepAliveListener struct {