add test files

This commit is contained in:
rui.zheng 2017-08-13 09:30:18 +08:00
parent 4ce6f0e15f
commit 8cbd2722f6
8 changed files with 249 additions and 41 deletions

View File

@ -5,5 +5,5 @@ go:
install: true install: true
script: script:
- go test -v - go test -race -v
- cd cmd/gost && go build - cd cmd/gost && go build

View File

@ -56,6 +56,7 @@ func init() {
os.Exit(0) os.Exit(0)
} }
gost.SetLogger(&gost.LogLogger{})
gost.Debug = options.DebugMode gost.Debug = options.DebugMode
} }
@ -313,18 +314,11 @@ func serve(chain *gost.Chain) error {
if whitelist, err = gost.ParsePermissions(node.Values.Get("whitelist")); err != nil { if whitelist, err = gost.ParsePermissions(node.Values.Get("whitelist")); err != nil {
return err return err
} }
} else {
// By default allow for everything
whitelist, _ = gost.ParsePermissions("*:*:*")
} }
if node.Values.Get("blacklist") != "" { if node.Values.Get("blacklist") != "" {
if blacklist, err = gost.ParsePermissions(node.Values.Get("blacklist")); err != nil { if blacklist, err = gost.ParsePermissions(node.Values.Get("blacklist")); err != nil {
return err return err
} }
} else {
// By default block nothing
blacklist, _ = gost.ParsePermissions("")
} }
var handlerOptions []gost.HandlerOption var handlerOptions []gost.HandlerOption
@ -366,7 +360,8 @@ func serve(chain *gost.Chain) error {
default: default:
handler = gost.AutoHandler(handlerOptions...) handler = gost.AutoHandler(handlerOptions...)
} }
go new(gost.Server).Serve(ln, handler) srv := &gost.Server{Listener: ln}
go srv.Serve(handler)
} }
return nil return nil

View File

@ -64,7 +64,7 @@ func init() {
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
} }
log.DefaultLogger = &LogLogger{} // log.DefaultLogger = &LogLogger{}
} }
// SetLogger sets a new logger for internal log system // SetLogger sets a new logger for internal log system

127
http_test.go Normal file
View File

@ -0,0 +1,127 @@
package gost
import (
"bufio"
"bytes"
"crypto/rand"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
var httpTestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(w, r.Body)
})
func httpProxyRoundtrip(urlStr string, cliUser *url.Userinfo, srvUsers []*url.Userinfo, body io.Reader) (statusCode int, recv []byte, err error) {
ln, err := TCPListener("")
if err != nil {
return
}
h := HTTPHandler(UsersHandlerOption(srvUsers...))
server := &Server{Listener: ln}
go server.Serve(h)
exitChan := make(chan struct{})
defer close(exitChan)
go func() {
defer server.Close()
<-exitChan
}()
client := &Client{
Connector: HTTPConnector(cliUser),
Transporter: TCPTransporter(),
}
conn, err := client.Dial(ln.Addr().String())
if err != nil {
return
}
defer conn.Close()
conn, err = client.Handshake(conn)
if err != nil {
return
}
url, err := url.Parse(urlStr)
if err != nil {
return
}
conn, err = client.Connect(conn, url.Host)
if err != nil {
return
}
req, err := http.NewRequest(http.MethodGet, urlStr, body)
if err != nil {
return
}
if err = req.Write(conn); err != nil {
return
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return
}
defer resp.Body.Close()
statusCode = resp.StatusCode
recv, err = ioutil.ReadAll(resp.Body)
return
}
var httpProxyTests = []struct {
url string
cliUser *url.Userinfo
srvUsers []*url.Userinfo
errStr string
}{
{"", nil, nil, ""},
{"", nil, []*url.Userinfo{url.User("admin")}, "407 Proxy Authentication Required"},
{"", nil, []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"},
{"", url.User("admin"), []*url.Userinfo{url.User("test")}, "407 Proxy Authentication Required"},
{"", url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, "407 Proxy Authentication Required"},
{"", url.User("admin"), []*url.Userinfo{url.User("admin")}, ""},
{"", url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""},
{"", url.UserPassword("admin", "123456"), nil, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
{"", url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""},
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""},
{"http://:0", nil, nil, "503 Service unavailable"},
}
func TestHTTPProxy(t *testing.T) {
Debug = true
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
for _, test := range httpProxyTests {
send := make([]byte, 16)
rand.Read(send)
urlStr := test.url
if urlStr == "" {
urlStr = httpSrv.URL
}
_, recv, err := httpProxyRoundtrip(urlStr, test.cliUser, test.srvUsers, bytes.NewReader(send))
if err == nil {
if test.errStr != "" {
t.Errorf("HTTP proxy response should failed with error %s", test.errStr)
continue
}
} else {
if test.errStr == "" {
t.Errorf("HTTP proxy got error %v", err)
}
if err.Error() != test.errStr {
t.Errorf("HTTP proxy got error %v, want %v", err, test.errStr)
}
continue
}
if !bytes.Equal(send, recv) {
t.Errorf("got %v, want %v", recv, send)
continue
}
}
}

28
node.go
View File

@ -1,9 +1,7 @@
package gost package gost
import ( import (
"net"
"net/url" "net/url"
"strconv"
"strings" "strings"
) )
@ -24,6 +22,10 @@ type Node struct {
// The proxy node string pattern is [scheme://][user:pass@host]:port. // The proxy node string pattern is [scheme://][user:pass@host]:port.
// Scheme can be divided into two parts by character '+', such as: http+tls. // Scheme can be divided into two parts by character '+', such as: http+tls.
func ParseNode(s string) (node Node, err error) { func ParseNode(s string) (node Node, err error) {
if s == "" {
return Node{}, nil
}
if !strings.Contains(s, "://") { if !strings.Contains(s, "://") {
s = "auto://" + s s = "auto://" + s
} }
@ -58,7 +60,7 @@ func ParseNode(s string) (node Node, err error) {
case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding
node.Remote = strings.Trim(u.EscapedPath(), "/") node.Remote = strings.Trim(u.EscapedPath(), "/")
default: default:
node.Transport = "" node.Transport = "tcp"
} }
switch node.Protocol { switch node.Protocol {
@ -74,23 +76,3 @@ func ParseNode(s string) (node Node, err error) {
return return
} }
// Can tests whether the given action and address is allowed by the whitelist and blacklist.
func Can(action string, addr string, whitelist, blacklist *Permissions) bool {
if !strings.Contains(addr, ":") {
addr = addr + ":80"
}
host, strport, err := net.SplitHostPort(addr)
if err != nil {
return false
}
port, err := strconv.Atoi(strport)
if err != nil {
return false
}
return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port)
}

68
node_test.go Normal file
View File

@ -0,0 +1,68 @@
package gost
import "testing"
import "net/url"
var nodeTests = []struct {
in string
out Node
hasError bool
}{
{"", Node{}, false},
{"://", Node{}, true},
{"localhost", Node{Addr: "localhost", Transport: "tcp"}, false},
{":", Node{Addr: ":", Transport: "tcp"}, false},
{":8080", Node{Addr: ":8080", Transport: "tcp"}, false},
{"http://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp"}, false},
{"http://localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp"}, false},
{"http://admin:123456@:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("admin", "123456")}, false},
{"http://admin@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("admin")}, false},
{"http://:123456@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "123456")}, false},
{"http://@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("")}, false},
{"http://:@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "")}, false},
{"https://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tls"}, false},
{"socks+tls://:8080", Node{Addr: ":8080", Protocol: "socks5", Transport: "tls"}, false},
{"tls://:8080", Node{Addr: ":8080", Transport: "tls"}, false},
{"tcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "tcp", Transport: "tcp"}, false},
{"udp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "udp", Transport: "udp"}, false},
{"rtcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rtcp", Transport: "rtcp"}, false},
{"rudp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rudp", Transport: "rudp"}, false},
{"redirect://:8080", Node{Addr: ":8080", Protocol: "redirect", Transport: "tcp"}, false},
}
func TestParseNode(t *testing.T) {
for _, test := range nodeTests {
actual, err := ParseNode(test.in)
if err != nil {
if test.hasError {
t.Logf("ParseNode(%q) got expected error: %v", test.in, err)
continue
}
t.Errorf("ParseNode(%q) got error: %v", test.in, err)
} else {
if test.hasError {
t.Errorf("ParseNode(%q) got %v, but should return error", test.in, actual)
continue
}
if actual.Addr != test.out.Addr || actual.Protocol != test.out.Protocol ||
actual.Transport != test.out.Transport || actual.Remote != test.out.Remote {
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
}
if actual.User == nil {
if test.out.User != nil {
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
}
continue
}
if actual.User != nil {
if test.out.User == nil {
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
continue
}
if *actual.User != *test.out.User {
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
}
}
}
}
}

View File

@ -3,6 +3,7 @@ package gost
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"strconv" "strconv"
"strings" "strings"
@ -199,3 +200,24 @@ func maxint(x, y int) int {
} }
return y return y
} }
// Can tests whether the given action and address is allowed by the whitelist and blacklist.
func Can(action string, addr string, whitelist, blacklist *Permissions) bool {
if !strings.Contains(addr, ":") {
addr = addr + ":80"
}
host, strport, err := net.SplitHostPort(addr)
if err != nil {
return false
}
port, err := strconv.Atoi(strport)
if err != nil {
return false
}
return (whitelist == nil || whitelist.Can(action, host, port)) &&
(blacklist == nil || !blacklist.Can(action, host, port))
}

View File

@ -10,23 +10,33 @@ import (
// Server is a proxy server. // Server is a proxy server.
type Server struct { type Server struct {
Listener Listener
}
// Addr returns the address of the server
func (s *Server) Addr() net.Addr {
return s.Listener.Addr()
}
// Close closes the server
func (s *Server) Close() error {
return s.Listener.Close()
} }
// Serve serves as a proxy server. // Serve serves as a proxy server.
func (s *Server) Serve(l net.Listener, h Handler) error { func (s *Server) Serve(h Handler) error {
defer l.Close() if s.Listener == nil {
ln, err := TCPListener("")
if l == nil {
ln, err := TCPListener(":8080")
if err != nil { if err != nil {
return err return err
} }
l = ln s.Listener = ln
} }
if h == nil { if h == nil {
h = HTTPHandler() h = HTTPHandler()
} }
l := s.Listener
var tempDelay time.Duration var tempDelay time.Duration
for { for {
conn, e := l.Accept() conn, e := l.Accept()
@ -63,11 +73,15 @@ type tcpListener struct {
// TCPListener creates a Listener for TCP proxy server. // TCPListener creates a Listener for TCP proxy server.
func TCPListener(addr string) (Listener, error) { func TCPListener(addr string) (Listener, error) {
ln, err := net.Listen("tcp", addr) laddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tcpListener{Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}}, nil ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return nil, err
}
return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil
} }
type tcpKeepAliveListener struct { type tcpKeepAliveListener struct {