add test files
This commit is contained in:
parent
4ce6f0e15f
commit
8cbd2722f6
@ -5,5 +5,5 @@ go:
|
||||
|
||||
install: true
|
||||
script:
|
||||
- go test -v
|
||||
- go test -race -v
|
||||
- cd cmd/gost && go build
|
@ -56,6 +56,7 @@ func init() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
gost.SetLogger(&gost.LogLogger{})
|
||||
gost.Debug = options.DebugMode
|
||||
}
|
||||
|
||||
@ -313,18 +314,11 @@ func serve(chain *gost.Chain) error {
|
||||
if whitelist, err = gost.ParsePermissions(node.Values.Get("whitelist")); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// By default allow for everything
|
||||
whitelist, _ = gost.ParsePermissions("*:*:*")
|
||||
}
|
||||
|
||||
if node.Values.Get("blacklist") != "" {
|
||||
if blacklist, err = gost.ParsePermissions(node.Values.Get("blacklist")); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// By default block nothing
|
||||
blacklist, _ = gost.ParsePermissions("")
|
||||
}
|
||||
|
||||
var handlerOptions []gost.HandlerOption
|
||||
@ -366,7 +360,8 @@ func serve(chain *gost.Chain) error {
|
||||
default:
|
||||
handler = gost.AutoHandler(handlerOptions...)
|
||||
}
|
||||
go new(gost.Server).Serve(ln, handler)
|
||||
srv := &gost.Server{Listener: ln}
|
||||
go srv.Serve(handler)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
2
gost.go
2
gost.go
@ -64,7 +64,7 @@ func init() {
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
log.DefaultLogger = &LogLogger{}
|
||||
// log.DefaultLogger = &LogLogger{}
|
||||
}
|
||||
|
||||
// SetLogger sets a new logger for internal log system
|
||||
|
127
http_test.go
Normal file
127
http_test.go
Normal file
@ -0,0 +1,127 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var httpTestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
io.Copy(w, r.Body)
|
||||
})
|
||||
|
||||
func httpProxyRoundtrip(urlStr string, cliUser *url.Userinfo, srvUsers []*url.Userinfo, body io.Reader) (statusCode int, recv []byte, err error) {
|
||||
ln, err := TCPListener("")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h := HTTPHandler(UsersHandlerOption(srvUsers...))
|
||||
server := &Server{Listener: ln}
|
||||
go server.Serve(h)
|
||||
|
||||
exitChan := make(chan struct{})
|
||||
defer close(exitChan)
|
||||
go func() {
|
||||
defer server.Close()
|
||||
<-exitChan
|
||||
}()
|
||||
|
||||
client := &Client{
|
||||
Connector: HTTPConnector(cliUser),
|
||||
Transporter: TCPTransporter(),
|
||||
}
|
||||
conn, err := client.Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
conn, err = client.Handshake(conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
url, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn, err = client.Connect(conn, url.Host)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodGet, urlStr, body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err = req.Write(conn); err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
statusCode = resp.StatusCode
|
||||
recv, err = ioutil.ReadAll(resp.Body)
|
||||
return
|
||||
}
|
||||
|
||||
var httpProxyTests = []struct {
|
||||
url string
|
||||
cliUser *url.Userinfo
|
||||
srvUsers []*url.Userinfo
|
||||
errStr string
|
||||
}{
|
||||
{"", nil, nil, ""},
|
||||
{"", nil, []*url.Userinfo{url.User("admin")}, "407 Proxy Authentication Required"},
|
||||
{"", nil, []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"},
|
||||
{"", url.User("admin"), []*url.Userinfo{url.User("test")}, "407 Proxy Authentication Required"},
|
||||
{"", url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, "407 Proxy Authentication Required"},
|
||||
{"", url.User("admin"), []*url.Userinfo{url.User("admin")}, ""},
|
||||
{"", url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""},
|
||||
{"", url.UserPassword("admin", "123456"), nil, ""},
|
||||
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""},
|
||||
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
|
||||
{"", url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""},
|
||||
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""},
|
||||
{"", url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""},
|
||||
{"http://:0", nil, nil, "503 Service unavailable"},
|
||||
}
|
||||
|
||||
func TestHTTPProxy(t *testing.T) {
|
||||
Debug = true
|
||||
httpSrv := httptest.NewServer(httpTestHandler)
|
||||
defer httpSrv.Close()
|
||||
|
||||
for _, test := range httpProxyTests {
|
||||
send := make([]byte, 16)
|
||||
rand.Read(send)
|
||||
urlStr := test.url
|
||||
if urlStr == "" {
|
||||
urlStr = httpSrv.URL
|
||||
}
|
||||
_, recv, err := httpProxyRoundtrip(urlStr, test.cliUser, test.srvUsers, bytes.NewReader(send))
|
||||
if err == nil {
|
||||
if test.errStr != "" {
|
||||
t.Errorf("HTTP proxy response should failed with error %s", test.errStr)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if test.errStr == "" {
|
||||
t.Errorf("HTTP proxy got error %v", err)
|
||||
}
|
||||
if err.Error() != test.errStr {
|
||||
t.Errorf("HTTP proxy got error %v, want %v", err, test.errStr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !bytes.Equal(send, recv) {
|
||||
t.Errorf("got %v, want %v", recv, send)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
28
node.go
28
node.go
@ -1,9 +1,7 @@
|
||||
package gost
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -24,6 +22,10 @@ type Node struct {
|
||||
// The proxy node string pattern is [scheme://][user:pass@host]:port.
|
||||
// Scheme can be divided into two parts by character '+', such as: http+tls.
|
||||
func ParseNode(s string) (node Node, err error) {
|
||||
if s == "" {
|
||||
return Node{}, nil
|
||||
}
|
||||
|
||||
if !strings.Contains(s, "://") {
|
||||
s = "auto://" + s
|
||||
}
|
||||
@ -58,7 +60,7 @@ func ParseNode(s string) (node Node, err error) {
|
||||
case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding
|
||||
node.Remote = strings.Trim(u.EscapedPath(), "/")
|
||||
default:
|
||||
node.Transport = ""
|
||||
node.Transport = "tcp"
|
||||
}
|
||||
|
||||
switch node.Protocol {
|
||||
@ -74,23 +76,3 @@ func ParseNode(s string) (node Node, err error) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Can tests whether the given action and address is allowed by the whitelist and blacklist.
|
||||
func Can(action string, addr string, whitelist, blacklist *Permissions) bool {
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr = addr + ":80"
|
||||
}
|
||||
host, strport, err := net.SplitHostPort(addr)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(strport)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port)
|
||||
}
|
||||
|
68
node_test.go
Normal file
68
node_test.go
Normal file
@ -0,0 +1,68 @@
|
||||
package gost
|
||||
|
||||
import "testing"
|
||||
import "net/url"
|
||||
|
||||
var nodeTests = []struct {
|
||||
in string
|
||||
out Node
|
||||
hasError bool
|
||||
}{
|
||||
{"", Node{}, false},
|
||||
{"://", Node{}, true},
|
||||
{"localhost", Node{Addr: "localhost", Transport: "tcp"}, false},
|
||||
{":", Node{Addr: ":", Transport: "tcp"}, false},
|
||||
{":8080", Node{Addr: ":8080", Transport: "tcp"}, false},
|
||||
{"http://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp"}, false},
|
||||
{"http://localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp"}, false},
|
||||
{"http://admin:123456@:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("admin", "123456")}, false},
|
||||
{"http://admin@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("admin")}, false},
|
||||
{"http://:123456@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "123456")}, false},
|
||||
{"http://@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("")}, false},
|
||||
{"http://:@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "")}, false},
|
||||
{"https://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tls"}, false},
|
||||
{"socks+tls://:8080", Node{Addr: ":8080", Protocol: "socks5", Transport: "tls"}, false},
|
||||
{"tls://:8080", Node{Addr: ":8080", Transport: "tls"}, false},
|
||||
{"tcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "tcp", Transport: "tcp"}, false},
|
||||
{"udp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "udp", Transport: "udp"}, false},
|
||||
{"rtcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rtcp", Transport: "rtcp"}, false},
|
||||
{"rudp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rudp", Transport: "rudp"}, false},
|
||||
{"redirect://:8080", Node{Addr: ":8080", Protocol: "redirect", Transport: "tcp"}, false},
|
||||
}
|
||||
|
||||
func TestParseNode(t *testing.T) {
|
||||
for _, test := range nodeTests {
|
||||
actual, err := ParseNode(test.in)
|
||||
if err != nil {
|
||||
if test.hasError {
|
||||
t.Logf("ParseNode(%q) got expected error: %v", test.in, err)
|
||||
continue
|
||||
}
|
||||
t.Errorf("ParseNode(%q) got error: %v", test.in, err)
|
||||
} else {
|
||||
if test.hasError {
|
||||
t.Errorf("ParseNode(%q) got %v, but should return error", test.in, actual)
|
||||
continue
|
||||
}
|
||||
if actual.Addr != test.out.Addr || actual.Protocol != test.out.Protocol ||
|
||||
actual.Transport != test.out.Transport || actual.Remote != test.out.Remote {
|
||||
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
|
||||
}
|
||||
if actual.User == nil {
|
||||
if test.out.User != nil {
|
||||
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if actual.User != nil {
|
||||
if test.out.User == nil {
|
||||
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
|
||||
continue
|
||||
}
|
||||
if *actual.User != *test.out.User {
|
||||
t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ package gost
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -199,3 +200,24 @@ func maxint(x, y int) int {
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
||||
// Can tests whether the given action and address is allowed by the whitelist and blacklist.
|
||||
func Can(action string, addr string, whitelist, blacklist *Permissions) bool {
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr = addr + ":80"
|
||||
}
|
||||
host, strport, err := net.SplitHostPort(addr)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(strport)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return (whitelist == nil || whitelist.Can(action, host, port)) &&
|
||||
(blacklist == nil || !blacklist.Can(action, host, port))
|
||||
}
|
||||
|
30
server.go
30
server.go
@ -10,23 +10,33 @@ import (
|
||||
|
||||
// Server is a proxy server.
|
||||
type Server struct {
|
||||
Listener Listener
|
||||
}
|
||||
|
||||
// Addr returns the address of the server
|
||||
func (s *Server) Addr() net.Addr {
|
||||
return s.Listener.Addr()
|
||||
}
|
||||
|
||||
// Close closes the server
|
||||
func (s *Server) Close() error {
|
||||
return s.Listener.Close()
|
||||
}
|
||||
|
||||
// Serve serves as a proxy server.
|
||||
func (s *Server) Serve(l net.Listener, h Handler) error {
|
||||
defer l.Close()
|
||||
|
||||
if l == nil {
|
||||
ln, err := TCPListener(":8080")
|
||||
func (s *Server) Serve(h Handler) error {
|
||||
if s.Listener == nil {
|
||||
ln, err := TCPListener("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l = ln
|
||||
s.Listener = ln
|
||||
}
|
||||
if h == nil {
|
||||
h = HTTPHandler()
|
||||
}
|
||||
|
||||
l := s.Listener
|
||||
var tempDelay time.Duration
|
||||
for {
|
||||
conn, e := l.Accept()
|
||||
@ -63,11 +73,15 @@ type tcpListener struct {
|
||||
|
||||
// TCPListener creates a Listener for TCP proxy server.
|
||||
func TCPListener(addr string) (Listener, error) {
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
laddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tcpListener{Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}}, nil
|
||||
ln, err := net.ListenTCP("tcp", laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil
|
||||
}
|
||||
|
||||
type tcpKeepAliveListener struct {
|
||||
|
Loading…
Reference in New Issue
Block a user