add test files

This commit is contained in:
rui.zheng 2017-08-13 09:30:18 +08:00
parent 4ce6f0e15f
commit 8cbd2722f6
8 changed files with 249 additions and 41 deletions

View File

@ -5,5 +5,5 @@ go:
install: true
script:
- go test -v
- go test -race -v
- cd cmd/gost && go build

View File

@ -56,6 +56,7 @@ func init() {
os.Exit(0)
}
gost.SetLogger(&gost.LogLogger{})
gost.Debug = options.DebugMode
}
@ -313,18 +314,11 @@ func serve(chain *gost.Chain) error {
if whitelist, err = gost.ParsePermissions(node.Values.Get("whitelist")); err != nil {
return err
}
} else {
// By default allow for everything
whitelist, _ = gost.ParsePermissions("*:*:*")
}
if node.Values.Get("blacklist") != "" {
if blacklist, err = gost.ParsePermissions(node.Values.Get("blacklist")); err != nil {
return err
}
} else {
// By default block nothing
blacklist, _ = gost.ParsePermissions("")
}
var handlerOptions []gost.HandlerOption
@ -366,7 +360,8 @@ func serve(chain *gost.Chain) error {
default:
handler = gost.AutoHandler(handlerOptions...)
}
go new(gost.Server).Serve(ln, handler)
srv := &gost.Server{Listener: ln}
go srv.Serve(handler)
}
return nil

View File

@ -64,7 +64,7 @@ func init() {
Certificates: []tls.Certificate{cert},
}
log.DefaultLogger = &LogLogger{}
// log.DefaultLogger = &LogLogger{}
}
// SetLogger sets a new logger for internal log system

127
http_test.go Normal file
View File

@ -0,0 +1,127 @@
package gost
import (
"bufio"
"bytes"
"crypto/rand"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
var httpTestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(w, r.Body)
})
func httpProxyRoundtrip(urlStr string, cliUser *url.Userinfo, srvUsers []*url.Userinfo, body io.Reader) (statusCode int, recv []byte, err error) {
ln, err := TCPListener("")
if err != nil {
return
}
h := HTTPHandler(UsersHandlerOption(srvUsers...))
server := &Server{Listener: ln}
go server.Serve(h)
exitChan := make(chan struct{})
defer close(exitChan)
go func() {
defer server.Close()
<-exitChan
}()
client := &Client{
Connector: HTTPConnector(cliUser),
Transporter: TCPTransporter(),
}
conn, err := client.Dial(ln.Addr().String())
if err != nil {
return
}
defer conn.Close()
conn, err = client.Handshake(conn)
if err != nil {
return
}
url, err := url.Parse(urlStr)
if err != nil {
return
}
conn, err = client.Connect(conn, url.Host)
if err != nil {
return
}
req, err := http.NewRequest(http.MethodGet, urlStr, body)
if err != nil {
return
}
if err = req.Write(conn); err != nil {
return
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return
}
defer resp.Body.Close()
statusCode = resp.StatusCode
recv, err = ioutil.ReadAll(resp.Body)
return
}
var httpProxyTests = []struct {
url string
cliUser *url.Userinfo
srvUsers []*url.Userinfo
errStr string
}{
{"", nil, nil, ""},
{"", nil, []*url.Userinfo{url.User("admin")}, "407 Proxy Authentication Required"},
{"", nil, []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"},
{"", url.User("admin"), []*url.Userinfo{url.User("test")}, "407 Proxy Authentication Required"},
{"", url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, "407 Proxy Authentication Required"},
{"", url.User("admin"), []*url.Userinfo{url.User("admin")}, ""},
{"", url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""},
{"", url.UserPassword("admin", "123456"), nil, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
{"", url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""},
{"http://:0", nil, nil, "503 Service unavailable"},
}
func TestHTTPProxy(t *testing.T) {
Debug = true
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
for _, test := range httpProxyTests {
send := make([]byte, 16)
rand.Read(send)
urlStr := test.url
if urlStr == "" {
urlStr = httpSrv.URL
}
_, recv, err := httpProxyRoundtrip(urlStr, test.cliUser, test.srvUsers, bytes.NewReader(send))
if err == nil {
if test.errStr != "" {
t.Errorf("HTTP proxy response should failed with error %s", test.errStr)
continue
}
} else {
if test.errStr == "" {
t.Errorf("HTTP proxy got error %v", err)
}
if err.Error() != test.errStr {
t.Errorf("HTTP proxy got error %v, want %v", err, test.errStr)
}
continue
}
if !bytes.Equal(send, recv) {
t.Errorf("got %v, want %v", recv, send)
continue
}
}
}

28
node.go
View File

@ -1,9 +1,7 @@
package gost
import (
"net"
"net/url"
"strconv"
"strings"
)
@ -24,6 +22,10 @@ type Node struct {
// The proxy node string pattern is [scheme://][user:pass@host]:port.
// Scheme can be divided into two parts by character '+', such as: http+tls.
func ParseNode(s string) (node Node, err error) {
if s == "" {
return Node{}, nil
}
if !strings.Contains(s, "://") {
s = "auto://" + s
}
@ -58,7 +60,7 @@ func ParseNode(s string) (node Node, err error) {
case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding
node.Remote = strings.Trim(u.EscapedPath(), "/")
default:
node.Transport = ""
node.Transport = "tcp"
}
switch node.Protocol {
@ -74,23 +76,3 @@ func ParseNode(s string) (node Node, err error) {
return
}
// Can tests whether the given action and address is allowed by the whitelist and blacklist.
func Can(action string, addr string, whitelist, blacklist *Permissions) bool {
if !strings.Contains(addr, ":") {
addr = addr + ":80"
}
host, strport, err := net.SplitHostPort(addr)
if err != nil {
return false
}
port, err := strconv.Atoi(strport)
if err != nil {
return false
}
return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port)
}

68
node_test.go Normal file
View File

@ -0,0 +1,68 @@
package gost
import "testing"
import "net/url"
var nodeTests = []struct {
in string
out Node
hasError bool
}{
{"", Node{}, false},
{"://", Node{}, true},
{"localhost", Node{Addr: "localhost", Transport: "tcp"}, false},
{":", Node{Addr: ":", Transport: "tcp"}, false},
{":8080", Node{Addr: ":8080", Transport: "tcp"}, false},
{"http://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp"}, false},
{"http://localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp"}, false},
{"http://admin:123456@:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("admin", "123456")}, false},
{"http://admin@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("admin")}, false},
{"http://:123456@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "123456")}, false},
{"http://@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("")}, false},
{"http://:@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "")}, false},
{"https://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tls"}, false},
{"socks+tls://:8080", Node{Addr: ":8080", Protocol: "socks5", Transport: "tls"}, false},
{"tls://:8080", Node{Addr: ":8080", Transport: "tls"}, false},
{"tcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "tcp", Transport: "tcp"}, false},
{"udp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "udp", Transport: "udp"}, false},
{"rtcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rtcp", Transport: "rtcp"}, false},
{"rudp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rudp", Transport: "rudp"}, false},
{"redirect://:8080", Node{Addr: ":8080", Protocol: "redirect", Transport: "tcp"}, false},
}
func TestParseNode(t *testing.T) {
for _, test := range nodeTests {
actual, err := ParseNode(test.in)
if err != nil {
if test.hasError {
t.Logf("ParseNode(%q) got expected error: %v", test.in, err)
continue
}
t.Errorf("ParseNode(%q) got error: %v", test.in, err)
} else {
if test.hasError {
t.Errorf("ParseNode(%q) got %v, but should return error", test.in, actual)
continue
}
if actual.Addr != test.out.Addr || actual.Protocol != test.out.Protocol ||
actual.Transport != test.out.Transport || actual.Remote != test.out.Remote {
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
}
if actual.User == nil {
if test.out.User != nil {
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
}
continue
}
if actual.User != nil {
if test.out.User == nil {
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
continue
}
if *actual.User != *test.out.User {
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
}
}
}
}
}

View File

@ -3,6 +3,7 @@ package gost
import (
"errors"
"fmt"
"net"
"strconv"
"strings"
@ -199,3 +200,24 @@ func maxint(x, y int) int {
}
return y
}
// Can tests whether the given action and address is allowed by the whitelist and blacklist.
func Can(action string, addr string, whitelist, blacklist *Permissions) bool {
if !strings.Contains(addr, ":") {
addr = addr + ":80"
}
host, strport, err := net.SplitHostPort(addr)
if err != nil {
return false
}
port, err := strconv.Atoi(strport)
if err != nil {
return false
}
return (whitelist == nil || whitelist.Can(action, host, port)) &&
(blacklist == nil || !blacklist.Can(action, host, port))
}

View File

@ -10,23 +10,33 @@ import (
// Server is a proxy server.
type Server struct {
Listener Listener
}
// Addr returns the address of the server
func (s *Server) Addr() net.Addr {
return s.Listener.Addr()
}
// Close closes the server
func (s *Server) Close() error {
return s.Listener.Close()
}
// Serve serves as a proxy server.
func (s *Server) Serve(l net.Listener, h Handler) error {
defer l.Close()
if l == nil {
ln, err := TCPListener(":8080")
func (s *Server) Serve(h Handler) error {
if s.Listener == nil {
ln, err := TCPListener("")
if err != nil {
return err
}
l = ln
s.Listener = ln
}
if h == nil {
h = HTTPHandler()
}
l := s.Listener
var tempDelay time.Duration
for {
conn, e := l.Accept()
@ -63,11 +73,15 @@ type tcpListener struct {
// TCPListener creates a Listener for TCP proxy server.
func TCPListener(addr string) (Listener, error) {
ln, err := net.Listen("tcp", addr)
laddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
return &tcpListener{Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}}, nil
ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return nil, err
}
return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil
}
type tcpKeepAliveListener struct {