From dc0b4f54234c7e78b7fed088b96c9abb398be75d Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Wed, 26 Jul 2017 18:34:29 +0800 Subject: [PATCH] update examples --- gost/chain.go | 6 +- gost/cli/cli.go | 141 ++++++++++++++++++----------- gost/client.go | 31 ++++++- gost/cmd/gost/main.go | 99 +++++++++++++++++++++ gost/cmd/gost/parser.go | 139 +++++++++++++++++++++++++++++ gost/gost.go | 6 +- gost/http.go | 4 +- gost/kcp.go | 76 +++++++++++----- gost/log.go | 19 +++- gost/permissions.go | 185 +++++++++++++++++++++++++++++++++++++++ gost/permissions_test.go | 152 ++++++++++++++++++++++++++++++++ gost/srv/srv.go | 45 ++++++---- gost/tls.go | 5 +- gost/ws.go | 14 +-- 14 files changed, 803 insertions(+), 119 deletions(-) create mode 100644 gost/cmd/gost/main.go create mode 100644 gost/cmd/gost/parser.go create mode 100644 gost/permissions.go create mode 100644 gost/permissions_test.go diff --git a/gost/chain.go b/gost/chain.go index 8d6b6e2..f976cf8 100644 --- a/gost/chain.go +++ b/gost/chain.go @@ -78,12 +78,12 @@ func (c *Chain) Conn() (net.Conn, error) { } nodes := c.nodes - conn, err := nodes[0].Client.Dial(nodes[0].Addr) + conn, err := nodes[0].Client.Dial(nodes[0].Addr, TimeoutDialOption(DialTimeout)) if err != nil { return nil, err } - conn, err = nodes[0].Client.Handshake(conn) + conn, err = nodes[0].Client.Handshake(conn, AddrHandshakeOption(nodes[0].Addr)) if err != nil { return nil, err } @@ -99,7 +99,7 @@ func (c *Chain) Conn() (net.Conn, error) { conn.Close() return nil, err } - cc, err = next.Client.Handshake(cc) + cc, err = next.Client.Handshake(cc, AddrHandshakeOption(next.Addr)) if err != nil { conn.Close() return nil, err diff --git a/gost/cli/cli.go b/gost/cli/cli.go index 57ec9ed..0ceda4a 100644 --- a/gost/cli/cli.go +++ b/gost/cli/cli.go @@ -2,28 +2,43 @@ package main import ( "bufio" - "crypto/tls" + "flag" "log" "net/http" "net/http/httputil" - "net/url" - + "sync" "time" "github.com/ginuerzh/gost/gost" ) +var ( + requests, concurrency int + quiet bool + swg, ewg sync.WaitGroup +) + func init() { log.SetFlags(log.LstdFlags | log.Lshortfile) - gost.Debug = true + + flag.IntVar(&requests, "n", 1, "Number of requests to perform") + flag.IntVar(&concurrency, "c", 1, "Number of multiple requests to make at a time") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } } func main() { chain := gost.NewChain( + /* // http+tcp gost.Node{ - Addr: "127.0.0.1:8080", + Addr: "127.0.0.1:18080", Client: gost.NewClient( gost.HTTPConnector(url.UserPassword("admin", "123456")), gost.TCPTransporter(), @@ -34,7 +49,7 @@ func main() { /* // socks5+tcp gost.Node{ - Addr: "127.0.0.1:1080", + Addr: "127.0.0.1:11080", Client: gost.NewClient( gost.SOCKS5Connector(url.UserPassword("admin", "123456")), gost.TCPTransporter(), @@ -45,7 +60,7 @@ func main() { /* // ss+tcp gost.Node{ - Addr: "127.0.0.1:8338", + Addr: "127.0.0.1:18338", Client: gost.NewClient( gost.ShadowConnector(url.UserPassword("chacha20", "123456")), gost.TCPTransporter(), @@ -56,10 +71,10 @@ func main() { /* // http+ws gost.Node{ - Addr: "127.0.0.1:8000", + Addr: "127.0.0.1:18000", Client: gost.NewClient( gost.HTTPConnector(url.UserPassword("admin", "123456")), - gost.WSTransporter("127.0.0.1:8000", nil), + gost.WSTransporter(nil), ), }, */ @@ -67,13 +82,10 @@ func main() { /* // http+wss gost.Node{ - Addr: "127.0.0.1:8443", + Addr: "127.0.0.1:18443", Client: gost.NewClient( gost.HTTPConnector(url.UserPassword("admin", "123456")), - gost.WSSTransporter( - "127.0.0.1:8443", - &gost.WSOptions{TLSConfig: &tls.Config{InsecureSkipVerify: true}}, - ), + gost.WSSTransporter(nil), ), }, */ @@ -81,63 +93,88 @@ func main() { /* // http+tls gost.Node{ - Addr: "127.0.0.1:1443", + Addr: "127.0.0.1:11443", Client: gost.NewClient( gost.HTTPConnector(url.UserPassword("admin", "123456")), - gost.TLSTransporter(&tls.Config{InsecureSkipVerify: true}), + gost.TLSTransporter(), ), }, */ - // http2+tls, http2+tcp - gost.Node{ - Addr: "127.0.0.1:1443", - Client: gost.NewClient( - gost.HTTP2Connector(url.UserPassword("admin", "123456")), - gost.HTTP2Transporter( - nil, - &tls.Config{InsecureSkipVerify: true}, // or nil, will use h2c mode (http2+tcp). - time.Second*1, - ), - ), - }, - /* - // http+kcp + // http2+tls, http2+tcp gost.Node{ - Addr: "127.0.0.1:8388", + Addr: "127.0.0.1:1443", Client: gost.NewClient( - gost.HTTPConnector(nil), - gost.KCPTransporter(nil), + gost.HTTP2Connector(url.UserPassword("admin", "123456")), + gost.HTTP2Transporter( + nil, + &tls.Config{InsecureSkipVerify: true}, // or nil, will use h2c mode (http2+tcp). + time.Second*1, + ), ), }, */ + + // http+kcp + gost.Node{ + Addr: "127.0.0.1:18388", + Client: gost.NewClient( + gost.HTTPConnector(nil), + gost.KCPTransporter(nil), + ), + }, ) - for i := 0; i < 10; i++ { - conn, err := chain.Dial("localhost:10000") - if err != nil { - log.Fatal(err) + total := 0 + for total < requests { + startChan := make(chan struct{}) + for i := 0; i < concurrency; i++ { + swg.Add(1) + ewg.Add(1) + go request(chain, startChan) } - //conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) - req, err := http.NewRequest(http.MethodGet, "http://localhost:10000/pkg", nil) - if err != nil { - log.Fatal(err) - } - if err := req.Write(conn); err != nil { - log.Fatal(err) - } - resp, err := http.ReadResponse(bufio.NewReader(conn), req) - if err != nil { - log.Fatal(err) - } - defer resp.Body.Close() + start := time.Now() + swg.Wait() // wait for workers ready + close(startChan) // start signal + ewg.Wait() // wait for workers done + + duration := time.Since(start) + total += concurrency + log.Printf("%d/%d/%d requests done (%v/%v)", total, requests, concurrency, duration, duration/time.Duration(concurrency)) + } +} + +func request(chain *gost.Chain, start <-chan struct{}) { + defer ewg.Done() + + swg.Done() + <-start + + conn, err := chain.Dial("localhost:10000") + if err != nil { + log.Fatal(err) + } + defer conn.Close() + //conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) + req, err := http.NewRequest(http.MethodGet, "http://localhost:10000/pkg", nil) + if err != nil { + log.Fatal(err) + } + if err := req.Write(conn); err != nil { + log.Fatal(err) + } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + if !quiet { rb, _ := httputil.DumpRequest(req, true) log.Println(string(rb)) rb, _ = httputil.DumpResponse(resp, true) log.Println(string(rb)) - - time.Sleep(1000 * time.Millisecond) } } diff --git a/gost/client.go b/gost/client.go index 1de131a..9553488 100644 --- a/gost/client.go +++ b/gost/client.go @@ -3,6 +3,7 @@ package gost import ( "crypto/tls" "net" + "time" ) // Client is a proxy client. @@ -77,7 +78,14 @@ func TCPTransporter() Transporter { } func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - return net.Dial("tcp", addr) + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + if opts.Chain == nil { + return net.DialTimeout("tcp", addr, opts.Timeout) + } + return opts.Chain.Dial(addr) } func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { @@ -90,14 +98,29 @@ func (tr *tcpTransporter) Multiplex() bool { // DialOptions describes the options for dialing. type DialOptions struct { + Timeout time.Duration + Chain *Chain } // DialOption allows a common way to set dial options. type DialOption func(opts *DialOptions) +func TimeoutDialOption(timeout time.Duration) DialOption { + return func(opts *DialOptions) { + opts.Timeout = timeout + } +} + +func ChainDialOption(chain *Chain) DialOption { + return func(opts *DialOptions) { + opts.Chain = chain + } +} + // HandshakeOptions describes the options for handshake. type HandshakeOptions struct { Addr string + Timeout time.Duration TLSConfig *tls.Config WSOptions *WSOptions KCPConfig *KCPConfig @@ -112,6 +135,12 @@ func AddrHandshakeOption(addr string) HandshakeOption { } } +func TimeoutHandshakeOption(timeout time.Duration) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Timeout = timeout + } +} + func TLSConfigHandshakeOption(config *tls.Config) HandshakeOption { return func(opts *HandshakeOptions) { opts.TLSConfig = config diff --git a/gost/cmd/gost/main.go b/gost/cmd/gost/main.go new file mode 100644 index 0000000..87c3e40 --- /dev/null +++ b/gost/cmd/gost/main.go @@ -0,0 +1,99 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "os" + "runtime" + + "github.com/ginuerzh/gost/gost" + "github.com/go-log/log" +) + +var ( + options struct { + chainNodes, serveNodes stringList + debugMode bool + } +) + +func init() { + var ( + configureFile string + printVersion bool + ) + + flag.Var(&options.chainNodes, "F", "forward address, can make a forward chain") + flag.Var(&options.serveNodes, "L", "listen address, can listen on multiple ports") + flag.StringVar(&configureFile, "C", "", "configure file") + flag.BoolVar(&options.debugMode, "D", false, "enable debug log") + flag.BoolVar(&printVersion, "V", false, "print version") + flag.Parse() + + if err := loadConfigureFile(configureFile); err != nil { + log.Log(err) + os.Exit(1) + } + + if flag.NFlag() == 0 { + flag.PrintDefaults() + os.Exit(0) + } + + if printVersion { + fmt.Fprintf(os.Stderr, "gost %s (%s)\n", gost.Version, runtime.Version()) + os.Exit(0) + } +} + +func main() { + +} + +func buildChain() (*gost.Chain, error) { + chain := gost.NewChain() + for _, cn := range options.chainNodes { + node, err := parseNode(cn) + if err != nil { + return nil, err + } + + var tr gost.Transporter + switch node.Transport { + case "tls": + tr = gost.TLSTransporter() + case "ws": + tr = gost.WSTransporter(nil) + } + + var connector gost.Connector + } + + return chain, nil +} + +func loadConfigureFile(configureFile string) error { + if configureFile == "" { + return nil + } + content, err := ioutil.ReadFile(configureFile) + if err != nil { + return err + } + if err := json.Unmarshal(content, &options); err != nil { + return err + } + return nil +} + +type stringList []string + +func (l *stringList) String() string { + return fmt.Sprintf("%s", *l) +} +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} diff --git a/gost/cmd/gost/parser.go b/gost/cmd/gost/parser.go new file mode 100644 index 0000000..25f53fb --- /dev/null +++ b/gost/cmd/gost/parser.go @@ -0,0 +1,139 @@ +package main + +import ( + "bufio" + "net" + "net/url" + "os" + "strings" + + "github.com/ginuerzh/gost/gost" + "github.com/go-log/log" +) + +type node struct { + Addr string + Protocol string // protocol: http/socks5/ss + Transport string // transport: ws/wss/tls/http2/tcp/udp/rtcp/rudp + Remote string // remote address, used by tcp/udp port forwarding + Users []*url.Userinfo // authentication for proxy + Whitelist *gost.Permissions + Blacklist *gost.Permissions + values url.Values + serverName string +} + +func parseNode(s string) (n node, err error) { + if !strings.Contains(s, "://") { + s = "gost://" + s + } + u, err := url.Parse(s) + if err != nil { + return + } + + query := u.Query() + + n = node{ + Addr: u.Host, + } + + if query.Get("whitelist") != "" { + if n.Whitelist, err = gost.ParsePermissions(query.Get("whitelist")); err != nil { + return + } + } else { + // By default allow for everyting + n.Whitelist, _ = gost.ParsePermissions("*:*:*") + } + + if query.Get("blacklist") != "" { + if n.Blacklist, err = gost.ParsePermissions(query.Get("blacklist")); err != nil { + return + } + } else { + // By default block nothing + n.Blacklist, _ = gost.ParsePermissions("") + } + + if u.User != nil { + n.Users = append(n.Users, u.User) + } + + users, er := parseUsers(n.values.Get("secrets")) + if users != nil { + n.Users = append(n.Users, users...) + } + if er != nil { + log.Log("load secrets:", er) + } + + if strings.Contains(u.Host, ":") { + n.serverName, _, _ = net.SplitHostPort(u.Host) + if n.serverName == "" { + n.serverName = "localhost" // default server name + } + } + + schemes := strings.Split(u.Scheme, "+") + if len(schemes) == 1 { + n.Protocol = schemes[0] + n.Transport = schemes[0] + } + if len(schemes) == 2 { + n.Protocol = schemes[0] + n.Transport = schemes[1] + } + + switch n.Transport { + case "ws", "wss", "tls", "h2", "h2c", "quic", "kcp", "redirect", "ssu", "ssh": + case "https": + n.Protocol = "http" + n.Transport = "tls" + case "http2": // http2 -> http2+tls, h2c mode is http2+tcp + n.Protocol = "http2" + n.Transport = "tls" + case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding + n.Remote = strings.Trim(u.EscapedPath(), "/") + case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding + n.Remote = strings.Trim(u.EscapedPath(), "/") + default: + n.Transport = "" + } + + switch n.Protocol { + case "http", "http2", "socks", "socks4", "socks4a", "socks5", "ss": + default: + n.Protocol = "" + } + + return +} + +func parseUsers(s string) (users []*url.Userinfo, err error) { + if s == "" { + return + } + + f, err := os.Open(s) + if err != nil { + return + } + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + s := strings.SplitN(line, " ", 2) + if len(s) == 1 { + users = append(users, url.User(strings.TrimSpace(s[0]))) + } else if len(s) == 2 { + users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) + } + } + + err = scanner.Err() + return +} diff --git a/gost/gost.go b/gost/gost.go index 1a9b7da..751989e 100644 --- a/gost/gost.go +++ b/gost/gost.go @@ -37,5 +37,9 @@ var ( ) func init() { - log.DefaultLogger = &logger{} + log.DefaultLogger = &LogLogger{} +} + +func SetLogger(logger log.Logger) { + log.DefaultLogger = logger } diff --git a/gost/http.go b/gost/http.go index 959a302..a8e7034 100644 --- a/gost/http.go +++ b/gost/http.go @@ -76,9 +76,7 @@ type httpHandler struct { // HTTPHandler creates a server Handler for HTTP proxy server. func HTTPHandler(opts ...HandlerOption) Handler { h := &httpHandler{ - options: &HandlerOptions{ - Chain: new(Chain), - }, + options: &HandlerOptions{}, } for _, opt := range opts { opt(h.options) diff --git a/gost/kcp.go b/gost/kcp.go index 694a97f..a5ccfa2 100644 --- a/gost/kcp.go +++ b/gost/kcp.go @@ -3,6 +3,7 @@ package gost import ( "crypto/sha1" "encoding/csv" + "errors" "fmt" "net" "os" @@ -180,29 +181,61 @@ func KCPTransporter(config *KCPConfig) Transporter { } func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + uaddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return + } + tr.sessionMutex.Lock() defer tr.sessionMutex.Unlock() session, ok := tr.sessions[addr] if !ok { - session, err = tr.dial(addr, tr.config) - if err != nil { - return - } - tr.sessions[addr] = session + return net.DialUDP("udp", nil, uaddr) } - - conn, err = session.GetConn() - if err != nil { - session.Close() - delete(tr.sessions, addr) // TODO: we could obtain a new session automatically. - } - return + return session.conn, nil } -func (tr *kcpTransporter) dial(addr string, config *KCPConfig) (*kcpSession, error) { - kcpconn, err := kcp.DialWithOptions(addr, - blockCrypt(config.Key, config.Crypt, KCPSalt), config.DataShard, config.ParityShard) +func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + config := tr.config + if opts.KCPConfig != nil { + config = opts.KCPConfig + } + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[opts.Addr] + if !ok { + s, err := tr.initSession(opts.Addr, conn, config) + if err != nil { + conn.Close() + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + return cc, nil +} + +func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPConfig) (*kcpSession, error) { + pc, ok := conn.(net.PacketConn) + if !ok { + return nil, errors.New("wrong connection type") + } + + kcpconn, err := kcp.NewConn(addr, + blockCrypt(config.Key, config.Crypt, KCPSalt), config.DataShard, config.ParityShard, pc) if err != nil { return nil, err } @@ -227,22 +260,17 @@ func (tr *kcpTransporter) dial(addr string, config *KCPConfig) (*kcpSession, err // stream multiplex smuxConfig := smux.DefaultConfig() smuxConfig.MaxReceiveBuffer = config.SockBuf - var conn net.Conn = kcpconn + var cc net.Conn = kcpconn if !config.NoComp { - conn = newCompStreamConn(kcpconn) + cc = newCompStreamConn(kcpconn) } - session, err := smux.Client(conn, smuxConfig) + session, err := smux.Client(cc, smuxConfig) if err != nil { - conn.Close() return nil, err } return &kcpSession{conn: conn, session: session}, nil } -func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - return conn, nil -} - func (tr *kcpTransporter) Multiplex() bool { return true } @@ -284,7 +312,7 @@ func KCPListener(addr string, config *KCPConfig) (Listener, error) { l := &kcpListener{ config: config, ln: ln, - connChan: make(chan net.Conn, 128), + connChan: make(chan net.Conn, 1024), errChan: make(chan error), } go l.acceptLoop() diff --git a/gost/log.go b/gost/log.go index b55303e..ba2dc69 100644 --- a/gost/log.go +++ b/gost/log.go @@ -1,18 +1,29 @@ package gost -import "log" +import ( + "log" +) func init() { log.SetFlags(log.LstdFlags | log.Lshortfile) } -type logger struct { +type LogLogger struct { } -func (l *logger) Log(v ...interface{}) { +func (l *LogLogger) Log(v ...interface{}) { log.Println(v...) } -func (l *logger) Logf(format string, v ...interface{}) { +func (l *LogLogger) Logf(format string, v ...interface{}) { log.Printf(format, v...) } + +type NopLogger struct { +} + +func (l *NopLogger) Log(v ...interface{}) { +} + +func (l *NopLogger) Logf(format string, v ...interface{}) { +} diff --git a/gost/permissions.go b/gost/permissions.go new file mode 100644 index 0000000..3e079eb --- /dev/null +++ b/gost/permissions.go @@ -0,0 +1,185 @@ +package gost + +import ( + "errors" + "fmt" + "strconv" + "strings" + + glob "github.com/ryanuber/go-glob" +) + +type PortRange struct { + Min, Max int +} + +type PortSet []PortRange + +type StringSet []string + +type Permission struct { + Actions StringSet + Hosts StringSet + Ports PortSet +} + +type Permissions []Permission + +func minint(x, y int) int { + if x < y { + return x + } + return y +} + +func maxint(x, y int) int { + if x > y { + return x + } + return y +} + +func (ir *PortRange) Contains(value int) bool { + return value >= ir.Min && value <= ir.Max +} + +func ParsePortRange(s string) (*PortRange, error) { + if s == "*" { + return &PortRange{Min: 0, Max: 65535}, nil + } + + minmax := strings.Split(s, "-") + switch len(minmax) { + case 1: + port, err := strconv.Atoi(s) + if err != nil { + return nil, err + } + if port < 0 || port > 65535 { + return nil, fmt.Errorf("invalid port: %s", s) + } + return &PortRange{Min: port, Max: port}, nil + case 2: + min, err := strconv.Atoi(minmax[0]) + if err != nil { + return nil, err + } + max, err := strconv.Atoi(minmax[1]) + if err != nil { + return nil, err + } + + realmin := maxint(0, minint(min, max)) + realmax := minint(65535, maxint(min, max)) + + return &PortRange{Min: realmin, Max: realmax}, nil + default: + return nil, fmt.Errorf("invalid range: %s", s) + } +} + +func (ps *PortSet) Contains(value int) bool { + for _, portRange := range *ps { + if portRange.Contains(value) { + return true + } + } + + return false +} + +func ParsePortSet(s string) (*PortSet, error) { + ps := &PortSet{} + + if s == "" { + return nil, errors.New("must specify at least one port") + } + + ranges := strings.Split(s, ",") + + for _, r := range ranges { + portRange, err := ParsePortRange(r) + + if err != nil { + return nil, err + } + + *ps = append(*ps, *portRange) + } + + return ps, nil +} + +func (ss *StringSet) Contains(subj string) bool { + for _, s := range *ss { + if glob.Glob(s, subj) { + return true + } + } + + return false +} + +func ParseStringSet(s string) (*StringSet, error) { + ss := &StringSet{} + if s == "" { + return nil, errors.New("cannot be empty") + } + + *ss = strings.Split(s, ",") + + return ss, nil +} + +func (ps *Permissions) Can(action string, host string, port int) bool { + for _, p := range *ps { + if p.Actions.Contains(action) && p.Hosts.Contains(host) && p.Ports.Contains(port) { + return true + } + } + + return false +} + +func ParsePermissions(s string) (*Permissions, error) { + ps := &Permissions{} + + if s == "" { + return &Permissions{}, nil + } + + perms := strings.Split(s, " ") + + for _, perm := range perms { + parts := strings.Split(perm, ":") + + switch len(parts) { + case 3: + actions, err := ParseStringSet(parts[0]) + + if err != nil { + return nil, fmt.Errorf("action list must look like connect,bind given: %s", parts[0]) + } + + hosts, err := ParseStringSet(parts[1]) + + if err != nil { + return nil, fmt.Errorf("hosts list must look like google.pl,*.google.com given: %s", parts[1]) + } + + ports, err := ParsePortSet(parts[2]) + + if err != nil { + return nil, fmt.Errorf("ports list must look like 80,8000-9000, given: %s", parts[2]) + } + + permission := Permission{Actions: *actions, Hosts: *hosts, Ports: *ports} + + *ps = append(*ps, permission) + default: + return nil, fmt.Errorf("permission must have format [actions]:[hosts]:[ports] given: %s", perm) + } + } + + return ps, nil +} diff --git a/gost/permissions_test.go b/gost/permissions_test.go new file mode 100644 index 0000000..bc99824 --- /dev/null +++ b/gost/permissions_test.go @@ -0,0 +1,152 @@ +package gost + +import ( + "fmt" + "testing" +) + +var portRangeTests = []struct { + in string + out *PortRange +}{ + {"1", &PortRange{Min: 1, Max: 1}}, + {"1-3", &PortRange{Min: 1, Max: 3}}, + {"3-1", &PortRange{Min: 1, Max: 3}}, + {"0-100000", &PortRange{Min: 0, Max: 65535}}, + {"*", &PortRange{Min: 0, Max: 65535}}, +} + +var stringSetTests = []struct { + in string + out *StringSet +}{ + {"*", &StringSet{"*"}}, + {"google.pl,google.com", &StringSet{"google.pl", "google.com"}}, +} + +var portSetTests = []struct { + in string + out *PortSet +}{ + {"1,3", &PortSet{PortRange{Min: 1, Max: 1}, PortRange{Min: 3, Max: 3}}}, + {"1-3,7-5", &PortSet{PortRange{Min: 1, Max: 3}, PortRange{Min: 5, Max: 7}}}, + {"0-100000", &PortSet{PortRange{Min: 0, Max: 65535}}}, + {"*", &PortSet{PortRange{Min: 0, Max: 65535}}}, +} + +var permissionsTests = []struct { + in string + out *Permissions +}{ + {"", &Permissions{}}, + {"*:*:*", &Permissions{ + Permission{ + Actions: StringSet{"*"}, + Hosts: StringSet{"*"}, + Ports: PortSet{PortRange{Min: 0, Max: 65535}}, + }, + }}, + {"bind:127.0.0.1,localhost:80,443,8000-8100 connect:*.google.pl:80,443", &Permissions{ + Permission{ + Actions: StringSet{"bind"}, + Hosts: StringSet{"127.0.0.1", "localhost"}, + Ports: PortSet{ + PortRange{Min: 80, Max: 80}, + PortRange{Min: 443, Max: 443}, + PortRange{Min: 8000, Max: 8100}, + }, + }, + Permission{ + Actions: StringSet{"connect"}, + Hosts: StringSet{"*.google.pl"}, + Ports: PortSet{ + PortRange{Min: 80, Max: 80}, + PortRange{Min: 443, Max: 443}, + }, + }, + }}, +} + +func TestPortRangeParse(t *testing.T) { + for _, test := range portRangeTests { + actual, err := ParsePortRange(test.in) + if err != nil { + t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) + } else if *actual != *test.out { + t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) + } + } +} + +func TestPortRangeContains(t *testing.T) { + actual, _ := ParsePortRange("5-10") + + if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { + t.Errorf("5-10 should contain 5, 7 and 10") + } + + if actual.Contains(4) || actual.Contains(11) { + t.Errorf("5-10 should not contain 4, 11") + } +} + +func TestStringSetParse(t *testing.T) { + for _, test := range stringSetTests { + actual, err := ParseStringSet(test.in) + if err != nil { + t.Errorf("ParseStringSet(%q) returned error: %v", test.in, err) + } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { + t.Errorf("ParseStringSet(%q): got %v, want %v", test.in, actual, test.out) + } + } +} + +func TestStringSetContains(t *testing.T) { + ss, _ := ParseStringSet("google.pl,*.google.com") + + if !ss.Contains("google.pl") || !ss.Contains("www.google.com") { + t.Errorf("google.pl,*.google.com should contain google.pl and www.google.com") + } + + if ss.Contains("www.google.pl") || ss.Contains("foobar.com") { + t.Errorf("google.pl,*.google.com shound not contain www.google.pl and foobar.com") + } +} + +func TestPortSetParse(t *testing.T) { + for _, test := range portSetTests { + actual, err := ParsePortSet(test.in) + if err != nil { + t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) + } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { + t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) + } + } +} + +func TestPortSetContains(t *testing.T) { + actual, _ := ParsePortSet("5-10,20-30") + + if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { + t.Errorf("5-10,20-30 should contain 5, 7 and 10") + } + + if !actual.Contains(20) || !actual.Contains(27) || !actual.Contains(30) { + t.Errorf("5-10,20-30 should contain 20, 27 and 30") + } + + if actual.Contains(4) || actual.Contains(11) || actual.Contains(31) { + t.Errorf("5-10,20-30 should not contain 4, 11, 31") + } +} + +func TestPermissionsParse(t *testing.T) { + for _, test := range permissionsTests { + actual, err := ParsePermissions(test.in) + if err != nil { + t.Errorf("ParsePermissions(%q) returned error: %v", test.in, err) + } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { + t.Errorf("ParsePermissions(%q): got %v, want %v", test.in, actual, test.out) + } + } +} diff --git a/gost/srv/srv.go b/gost/srv/srv.go index 78e8d64..0aac7a7 100644 --- a/gost/srv/srv.go +++ b/gost/srv/srv.go @@ -2,31 +2,42 @@ package main import ( "crypto/tls" + "flag" "log" - "net/url" "github.com/ginuerzh/gost/gost" ) +var ( + quiet bool +) + func init() { log.SetFlags(log.LstdFlags | log.Lshortfile) - gost.Debug = true + + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } } func main() { - // go httpServer() - // go socks5Server() - // go tlsServer() - // go shadowServer() - // go wsServer() - // go wssServer() - // go kcpServer() + go httpServer() + go socks5Server() + go tlsServer() + go shadowServer() + go wsServer() + go wssServer() + go kcpServer() // go tcpForwardServer() // go rtcpForwardServer() // go rudpForwardServer() // go tcpRedirectServer() - go http2Server() + // go http2Server() select {} } @@ -36,7 +47,7 @@ func httpServer() { s.Handle(gost.HTTPHandler( gost.UsersHandlerOption(url.UserPassword("admin", "123456")), )) - ln, err := gost.TCPListener(":8080") + ln, err := gost.TCPListener(":18080") if err != nil { log.Fatal(err) } @@ -49,7 +60,7 @@ func socks5Server() { gost.UsersHandlerOption(url.UserPassword("admin", "123456")), gost.TLSConfigHandlerOption(tlsConfig()), )) - ln, err := gost.TCPListener(":1080") + ln, err := gost.TCPListener(":11080") if err != nil { log.Fatal(err) } @@ -61,7 +72,7 @@ func shadowServer() { s.Handle(gost.ShadowHandler( gost.UsersHandlerOption(url.UserPassword("chacha20", "123456")), )) - ln, err := gost.TCPListener(":8338") + ln, err := gost.TCPListener(":18338") if err != nil { log.Fatal(err) } @@ -73,7 +84,7 @@ func tlsServer() { s.Handle(gost.HTTPHandler( gost.UsersHandlerOption(url.UserPassword("admin", "123456")), )) - ln, err := gost.TLSListener(":1443", tlsConfig()) + ln, err := gost.TLSListener(":11443", tlsConfig()) if err != nil { log.Fatal(err) } @@ -85,7 +96,7 @@ func wsServer() { s.Handle(gost.HTTPHandler( gost.UsersHandlerOption(url.UserPassword("admin", "123456")), )) - ln, err := gost.WSListener(":8000", nil) + ln, err := gost.WSListener(":18000", nil) if err != nil { log.Fatal(err) } @@ -97,7 +108,7 @@ func wssServer() { s.Handle(gost.HTTPHandler( gost.UsersHandlerOption(url.UserPassword("admin", "123456")), )) - ln, err := gost.WSSListener(":8443", &gost.WSOptions{TLSConfig: tlsConfig()}) + ln, err := gost.WSSListener(":18443", &gost.WSOptions{TLSConfig: tlsConfig()}) if err != nil { log.Fatal(err) } @@ -107,7 +118,7 @@ func wssServer() { func kcpServer() { s := &gost.Server{} s.Handle(gost.HTTPHandler()) - ln, err := gost.KCPListener(":8388", nil) + ln, err := gost.KCPListener(":18388", nil) if err != nil { log.Fatal(err) } diff --git a/gost/tls.go b/gost/tls.go index 877f7ef..3b5952f 100644 --- a/gost/tls.go +++ b/gost/tls.go @@ -6,6 +6,7 @@ import ( ) type tlsTransporter struct { + *tcpTransporter } // TLSTransporter creates a Transporter that is used by TLS proxy client. @@ -14,10 +15,6 @@ func TLSTransporter() Transporter { return &tlsTransporter{} } -func (tr *tlsTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - return net.Dial("tcp", addr) -} - func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { opts := &HandshakeOptions{} for _, option := range options { diff --git a/gost/ws.go b/gost/ws.go index 778bc80..191478c 100644 --- a/gost/ws.go +++ b/gost/ws.go @@ -98,6 +98,7 @@ func (c *websocketConn) SetWriteDeadline(t time.Time) error { } type wsTransporter struct { + *tcpTransporter options *WSOptions } @@ -108,10 +109,6 @@ func WSTransporter(opts *WSOptions) Transporter { } } -func (tr *wsTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - return net.Dial("tcp", addr) -} - func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { opts := &HandshakeOptions{} for _, option := range options { @@ -130,6 +127,7 @@ func (tr *wsTransporter) Multiplex() bool { } type wssTransporter struct { + *tcpTransporter options *WSOptions } @@ -140,10 +138,6 @@ func WSSTransporter(opts *WSOptions) Transporter { } } -func (tr *wssTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - return net.Dial("tcp", addr) -} - func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { opts := &HandshakeOptions{} for _, option := range options { @@ -189,7 +183,7 @@ func WSListener(addr string, options *WSOptions) (Listener, error) { CheckOrigin: func(r *http.Request) bool { return true }, EnableCompression: options.EnableCompression, }, - connChan: make(chan net.Conn, 128), + connChan: make(chan net.Conn, 1024), errChan: make(chan error, 1), } @@ -274,7 +268,7 @@ func WSSListener(addr string, options *WSOptions) (Listener, error) { CheckOrigin: func(r *http.Request) bool { return true }, EnableCompression: options.EnableCompression, }, - connChan: make(chan net.Conn, 128), + connChan: make(chan net.Conn, 1024), errChan: make(chan error, 1), }, }