diff --git a/gost/cmd/gost/main.go b/gost/cmd/gost/main.go index 87c3e40..8d8f192 100644 --- a/gost/cmd/gost/main.go +++ b/gost/cmd/gost/main.go @@ -46,6 +46,8 @@ func init() { fmt.Fprintf(os.Stderr, "gost %s (%s)\n", gost.Version, runtime.Version()) os.Exit(0) } + + gost.Debug = options.debugMode } func main() { @@ -54,8 +56,8 @@ func main() { func buildChain() (*gost.Chain, error) { chain := gost.NewChain() - for _, cn := range options.chainNodes { - node, err := parseNode(cn) + for _, ns := range options.chainNodes { + node, err := gost.ParseNode(ns) if err != nil { return nil, err } @@ -66,9 +68,60 @@ func buildChain() (*gost.Chain, error) { tr = gost.TLSTransporter() case "ws": tr = gost.WSTransporter(nil) + case "wss": + tr = gost.WSSTransporter(nil) + case "kcp": + if !chain.IsEmpty() { + log.Log("KCP must be the first node in the proxy chain") + return nil, err + } + tr = gost.KCPTransporter(nil) + case "ssh": + if node.Protocol == "direct" || node.Protocol == "remote" { + tr = gost.SSHForwardTransporter() + } else { + tr = gost.SSHTunnelTransporter() + } + case "quic": + if !chain.IsEmpty() { + log.Log("QUIC must be the first node in the proxy chain") + return nil, err + } + tr = gost.QUICTransporter(nil) + case "http2": + tr = gost.HTTP2Transporter(nil) + case "h2": + tr = gost.H2Transporter(nil) + case "h2c": + tr = gost.H2CTransporter() + default: + tr = gost.TCPTransporter() } var connector gost.Connector + switch node.Protocol { + case "http2": + connector = gost.HTTP2Connector(nil) + case "socks", "socks5": + connector = gost.SOCKS5Connector(nil) + case "socks4": + connector = gost.SOCKS4Connector() + case "socks4a": + connector = gost.SOCKS4AConnector() + case "ss": + connector = gost.ShadowConnector(nil) + case "http": + fallthrough + default: + node.Protocol = "http" // default protocol is HTTP + connector = gost.HTTPConnector(nil) + } + + node.Client = &gost.Client{ + Connector: connector, + Transporter: tr, + } + chain.AddNode(node) } return chain, nil diff --git a/gost/cmd/gost/parser.go b/gost/cmd/gost/parser.go deleted file mode 100644 index 25f53fb..0000000 --- a/gost/cmd/gost/parser.go +++ /dev/null @@ -1,139 +0,0 @@ -package main - -import ( - "bufio" - "net" - "net/url" - "os" - "strings" - - "github.com/ginuerzh/gost/gost" - "github.com/go-log/log" -) - -type node struct { - Addr string - Protocol string // protocol: http/socks5/ss - Transport string // transport: ws/wss/tls/http2/tcp/udp/rtcp/rudp - Remote string // remote address, used by tcp/udp port forwarding - Users []*url.Userinfo // authentication for proxy - Whitelist *gost.Permissions - Blacklist *gost.Permissions - values url.Values - serverName string -} - -func parseNode(s string) (n node, err error) { - if !strings.Contains(s, "://") { - s = "gost://" + s - } - u, err := url.Parse(s) - if err != nil { - return - } - - query := u.Query() - - n = node{ - Addr: u.Host, - } - - if query.Get("whitelist") != "" { - if n.Whitelist, err = gost.ParsePermissions(query.Get("whitelist")); err != nil { - return - } - } else { - // By default allow for everyting - n.Whitelist, _ = gost.ParsePermissions("*:*:*") - } - - if query.Get("blacklist") != "" { - if n.Blacklist, err = gost.ParsePermissions(query.Get("blacklist")); err != nil { - return - } - } else { - // By default block nothing - n.Blacklist, _ = gost.ParsePermissions("") - } - - if u.User != nil { - n.Users = append(n.Users, u.User) - } - - users, er := parseUsers(n.values.Get("secrets")) - if users != nil { - n.Users = append(n.Users, users...) - } - if er != nil { - log.Log("load secrets:", er) - } - - if strings.Contains(u.Host, ":") { - n.serverName, _, _ = net.SplitHostPort(u.Host) - if n.serverName == "" { - n.serverName = "localhost" // default server name - } - } - - schemes := strings.Split(u.Scheme, "+") - if len(schemes) == 1 { - n.Protocol = schemes[0] - n.Transport = schemes[0] - } - if len(schemes) == 2 { - n.Protocol = schemes[0] - n.Transport = schemes[1] - } - - switch n.Transport { - case "ws", "wss", "tls", "h2", "h2c", "quic", "kcp", "redirect", "ssu", "ssh": - case "https": - n.Protocol = "http" - n.Transport = "tls" - case "http2": // http2 -> http2+tls, h2c mode is http2+tcp - n.Protocol = "http2" - n.Transport = "tls" - case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding - n.Remote = strings.Trim(u.EscapedPath(), "/") - case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding - n.Remote = strings.Trim(u.EscapedPath(), "/") - default: - n.Transport = "" - } - - switch n.Protocol { - case "http", "http2", "socks", "socks4", "socks4a", "socks5", "ss": - default: - n.Protocol = "" - } - - return -} - -func parseUsers(s string) (users []*url.Userinfo, err error) { - if s == "" { - return - } - - f, err := os.Open(s) - if err != nil { - return - } - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - s := strings.SplitN(line, " ", 2) - if len(s) == 1 { - users = append(users, url.User(strings.TrimSpace(s[0]))) - } else if len(s) == 2 { - users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) - } - } - - err = scanner.Err() - return -} diff --git a/gost/examples/bench/cli.go b/gost/examples/bench/cli.go index f942c4c..bcd9644 100644 --- a/gost/examples/bench/cli.go +++ b/gost/examples/bench/cli.go @@ -188,14 +188,14 @@ func request(chain *gost.Chain, start <-chan struct{}) { swg.Done() <-start - conn, err := chain.Dial("localhost:10000") + conn, err := chain.Dial("localhost:18888") if err != nil { log.Println(err) return } defer conn.Close() //conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) - req, err := http.NewRequest(http.MethodGet, "http://localhost:10000/pkg", nil) + req, err := http.NewRequest(http.MethodGet, "http://localhost:18888", nil) if err != nil { log.Println(err) return diff --git a/gost/examples/bench/srv.go b/gost/examples/bench/srv.go index d8f9205..86379b7 100644 --- a/gost/examples/bench/srv.go +++ b/gost/examples/bench/srv.go @@ -3,7 +3,9 @@ package main import ( "crypto/tls" "flag" + "fmt" "log" + "net/http" "net/url" "time" @@ -45,6 +47,7 @@ func main() { go http2TunnelServer() go quicServer() go shadowUDPServer() + go testServer() select {} } @@ -344,3 +347,13 @@ func tlsConfig() *tls.Config { PreferServerCipherSuites: true, } } + +func testServer() { + s := &http.Server{ + Addr: ":18888", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "abcdefghijklmnopqrstuvwxyz") + }), + } + log.Fatal(s.ListenAndServe()) +} diff --git a/gost/gost.go b/gost/gost.go index c9c16dc..e479024 100644 --- a/gost/gost.go +++ b/gost/gost.go @@ -1,7 +1,12 @@ package gost import ( - "errors" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" "time" "github.com/go-log/log" @@ -38,13 +43,66 @@ var ( ) var ( - ErrSessionDead = errors.New("session is dead") + defaultRawCert []byte + defaultRawKey []byte ) func init() { + rawCert, rawKey, err := generateKeyPair() + if err != nil { + panic(err) + } + defaultRawCert, defaultRawKey = rawCert, rawKey + log.DefaultLogger = &LogLogger{} } func SetLogger(logger log.Logger) { log.DefaultLogger = logger } + +func generateKeyPair() (rawCert, rawKey []byte, err error) { + if defaultRawCert != nil && defaultRawKey != nil { + return defaultRawCert, defaultRawKey, nil + } + + // Create private key and self-signed certificate + // Adapted from https://golang.org/src/crypto/tls/generate_cert.go + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + validFor := time.Hour * 24 * 365 * 10 // ten years + notBefore := time.Now() + notAfter := notBefore.Add(validFor) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"gost"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return + } + + rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return +} + +// SetDefaultCertificate replaces the default certificate by your own +func SetDefaultCertificate(rawCert, rawKey []byte) { + defaultRawCert = rawCert + defaultRawKey = rawKey +} diff --git a/gost/handler.go b/gost/handler.go index 1188787..6a6c7de 100644 --- a/gost/handler.go +++ b/gost/handler.go @@ -17,6 +17,8 @@ type HandlerOptions struct { Chain *Chain Users []*url.Userinfo TLSConfig *tls.Config + Whitelist *Permissions + Blacklist *Permissions } // HandlerOption allows a common way to set handler options. @@ -49,3 +51,17 @@ func TLSConfigHandlerOption(config *tls.Config) HandlerOption { opts.TLSConfig = config } } + +// WhitelistHandlerOption sets the Whitelist option of HandlerOptions. +func WhitelistHandlerOption(whitelist *Permissions) HandlerOption { + return func(opts *HandlerOptions) { + opts.Whitelist = whitelist + } +} + +// BlacklistHandlerOption sets the Blacklist option of HandlerOptions. +func BlacklistHandlerOption(blacklist *Permissions) HandlerOption { + return func(opts *HandlerOptions) { + opts.Blacklist = blacklist + } +} diff --git a/gost/http.go b/gost/http.go index 24eb067..dfc5dc0 100644 --- a/gost/http.go +++ b/gost/http.go @@ -121,6 +121,17 @@ func (h *httpHandler) Handle(conn net.Conn) { req.Header.Del("Proxy-Authorization") req.Header.Del("Proxy-Connection") + if !Can("tcp", req.Host, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[http] Unauthorized to tcp connect to %s", req.Host) + b := []byte("HTTP/1.1 403 Forbidden\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n") + conn.Write(b) + if Debug { + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b)) + } + return + } + // forward http request lastNode := h.options.Chain.LastNode() if req.Method != http.MethodConnect && lastNode.Protocol == "http" { @@ -128,11 +139,6 @@ func (h *httpHandler) Handle(conn net.Conn) { return } - // if !s.Base.Node.Can("tcp", req.Host) { - // glog.Errorf("Unauthorized to tcp connect to %s", req.Host) - // return - // } - host := req.Host if !strings.Contains(req.Host, ":") { host += ":80" diff --git a/gost/http2.go b/gost/http2.go index 066271a..853ae95 100644 --- a/gost/http2.go +++ b/gost/http2.go @@ -287,10 +287,11 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { w.Header().Set("Proxy-Agent", "gost/"+Version) - //! if !s.Base.Node.Can("tcp", target) { - //! glog.Errorf("Unauthorized to tcp connect to %s", target) - //! return - //! } + if !Can("tcp", target, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[http2] Unauthorized to tcp connect to %s", target) + w.WriteHeader(http.StatusForbidden) + return + } u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) if !authenticate(u, p, h.options.Users...) { diff --git a/gost/log.go b/gost/log.go index 4d5ed73..0c22259 100644 --- a/gost/log.go +++ b/gost/log.go @@ -21,7 +21,7 @@ func (l *LogLogger) Logf(format string, v ...interface{}) { log.Output(3, fmt.Sprintf(format, v...)) } -// NopLogger is a null logger that discards the log outputs +// NopLogger is a dummy logger that discards the log outputs type NopLogger struct { } diff --git a/gost/node.go b/gost/node.go index ce72707..31b3b2f 100644 --- a/gost/node.go +++ b/gost/node.go @@ -1,15 +1,165 @@ package gost import ( + "bufio" + "net" "net/url" + "os" + "strconv" + "strings" + + "github.com/go-log/log" ) // Node is a proxy node, mainly used to construct a proxy chain. type Node struct { - Addr string - Protocol string - Transport string - User *url.Userinfo - Client *Client - Server *Server + Addr string + Protocol string + Transport string + Remote string // remote address, used by tcp/udp port forwarding + User *url.Userinfo + users []*url.Userinfo // authentication or cipher for proxy + Whitelist *Permissions + Blacklist *Permissions + values url.Values + serverName string + Client *Client + Server *Server +} + +func ParseNode(s string) (node Node, err error) { + if !strings.Contains(s, "://") { + s = "auto://" + s + } + u, err := url.Parse(s) + if err != nil { + return + } + + query := u.Query() + node = Node{ + Addr: u.Host, + values: query, + serverName: u.Host, + } + + if query.Get("whitelist") != "" { + if node.Whitelist, err = ParsePermissions(query.Get("whitelist")); err != nil { + return + } + } else { + // By default allow for everyting + node.Whitelist, _ = ParsePermissions("*:*:*") + } + + if query.Get("blacklist") != "" { + if node.Blacklist, err = ParsePermissions(query.Get("blacklist")); err != nil { + return + } + } else { + // By default block nothing + node.Blacklist, _ = ParsePermissions("") + } + + if u.User != nil { + node.User = u.User + node.users = append(node.users, u.User) + } + + users, er := parseUsers(node.values.Get("secrets")) + if users != nil { + node.users = append(node.users, users...) + } + if er != nil { + log.Log("load secrets:", er) + } + + if strings.Contains(u.Host, ":") { + node.serverName, _, _ = net.SplitHostPort(u.Host) + if node.serverName == "" { + node.serverName = "localhost" // default server name + } + } + + schemes := strings.Split(u.Scheme, "+") + if len(schemes) == 1 { + node.Protocol = schemes[0] + node.Transport = schemes[0] + } + if len(schemes) == 2 { + node.Protocol = schemes[0] + node.Transport = schemes[1] + } + + switch node.Transport { + case "tls", "ws", "wss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "redirect": + case "https": + node.Protocol = "http" + node.Transport = "tls" + case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding + node.Remote = strings.Trim(u.EscapedPath(), "/") + case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding + node.Remote = strings.Trim(u.EscapedPath(), "/") + default: + node.Transport = "" + } + + switch node.Protocol { + case "http", "http2", "socks4", "socks4a", "socks", "socks5", "ss": + case "tcp", "udp", "rtcp", "rudp": // port forwarding + case "direct", "remote": // SSH port forwarding + default: + node.Protocol = "" + } + + return +} + +func parseUsers(authFile string) (users []*url.Userinfo, err error) { + if authFile == "" { + return + } + + file, err := os.Open(authFile) + if err != nil { + return + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + s := strings.SplitN(line, " ", 2) + if len(s) == 1 { + users = append(users, url.User(strings.TrimSpace(s[0]))) + } else if len(s) == 2 { + users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) + } + } + + err = scanner.Err() + return +} + +func Can(action string, addr string, whitelist, blacklist *Permissions) bool { + if !strings.Contains(addr, ":") { + addr = addr + ":80" + } + host, strport, err := net.SplitHostPort(addr) + + if err != nil { + return false + } + + port, err := strconv.Atoi(strport) + + if err != nil { + return false + } + + log.Logf("Can action: %s, host: %s, port %d", action, host, port) + + return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port) } diff --git a/gost/permissions.go b/gost/permissions.go index 3e079eb..8566c80 100644 --- a/gost/permissions.go +++ b/gost/permissions.go @@ -9,14 +9,6 @@ import ( glob "github.com/ryanuber/go-glob" ) -type PortRange struct { - Min, Max int -} - -type PortSet []PortRange - -type StringSet []string - type Permission struct { Actions StringSet Hosts StringSet @@ -39,6 +31,10 @@ func maxint(x, y int) int { return y } +type PortRange struct { + Min, Max int +} + func (ir *PortRange) Contains(value int) bool { return value >= ir.Min && value <= ir.Max } @@ -88,6 +84,8 @@ func (ps *PortSet) Contains(value int) bool { return false } +type PortSet []PortRange + func ParsePortSet(s string) (*PortSet, error) { ps := &PortSet{} @@ -120,6 +118,8 @@ func (ss *StringSet) Contains(subj string) bool { return false } +type StringSet []string + func ParseStringSet(s string) (*StringSet, error) { ss := &StringSet{} if s == "" { diff --git a/gost/socks.go b/gost/socks.go index 9815c27..54ef37c 100644 --- a/gost/socks.go +++ b/gost/socks.go @@ -395,13 +395,15 @@ func (h *socks5Handler) Handle(conn net.Conn) { func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { addr := req.Addr.String() - - //! if !s.Base.Node.Can("tcp", addr) { - //! glog.Errorf("Unauthorized to tcp connect to %s", addr) - //! rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) - //! rep.Write(s.conn) - //! return - //! } + if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5-connect] Unauthorized to tcp connect to %s", addr) + rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } + return + } cc, err := h.options.Chain.Dial(addr) if err != nil { @@ -430,13 +432,12 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { func (h *socks5Handler) handleBind(conn net.Conn, req *gosocks5.Request) { if h.options.Chain.IsEmpty() { - - //! if !s.Base.Node.Can("rtcp", addr) { - //! glog.Errorf("Unauthorized to tcp bind to %s", addr) - //! return - //! } - - h.bindOn(conn, req.Addr.String()) + addr := req.Addr.String() + if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("Unauthorized to tcp bind to %s", addr) + return + } + h.bindOn(conn, addr) return } @@ -554,14 +555,16 @@ func (h *socks5Handler) bindOn(conn net.Conn, addr string) { } func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) { - //! addr := req.Addr.String() - //! - //! if !s.Base.Node.Can("udp", addr) { - //! glog.Errorf("Unauthorized to udp connect to %s", addr) - //! rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) - //! rep.Write(s.conn) - //! return - //! } + addr := req.Addr.String() + if !Can("udp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5-udp] Unauthorized to udp connect to %s", addr) + rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } + return + } relay, err := net.ListenUDP("udp", nil) if err != nil { @@ -817,10 +820,10 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { if h.options.Chain.IsEmpty() { addr := req.Addr.String() - //! if !s.Base.Node.Can("rudp", addr) { - //! glog.Errorf("Unauthorized to udp bind to %s", addr) - //! return - //! } + if !Can("rudp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5-udp] Unauthorized to udp bind to %s", addr) + return + } bindAddr, _ := net.ResolveUDPAddr("udp", addr) uc, err := net.ListenUDP("udp", bindAddr) @@ -992,12 +995,15 @@ func (h *socks4Handler) Handle(conn net.Conn) { func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { addr := req.Addr.String() - //! if !s.Base.Node.Can("tcp", addr) { - //! glog.Errorf("Unauthorized to tcp connect to %s", addr) - //! rep := gosocks5.NewReply(gosocks4.Rejected, nil) - //! rep.Write(s.conn) - //! return - //! } + if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks4-connect] Unauthorized to tcp connect to %s", addr) + rep := gosocks5.NewReply(gosocks4.Rejected, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks4-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } + return + } cc, err := h.options.Chain.Dial(addr) if err != nil { diff --git a/gost/ss.go b/gost/ss.go index 0a73f4c..079015d 100644 --- a/gost/ss.go +++ b/gost/ss.go @@ -110,7 +110,6 @@ func (h *shadowHandler) Handle(conn net.Conn) { defer conn.Close() var method, password string - users := h.options.Users if len(users) > 0 { method = users[0].Username() @@ -132,6 +131,11 @@ func (h *shadowHandler) Handle(conn net.Conn) { } log.Logf("[ss] %s -> %s", conn.RemoteAddr(), addr) + if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ss] Unauthorized to tcp connect to %s", addr) + return + } + cc, err := h.options.Chain.Dial(addr) if err != nil { log.Logf("[ss] %s -> %s : %s", conn.RemoteAddr(), addr, err) diff --git a/gost/ssh.go b/gost/ssh.go index d570006..8c69603 100644 --- a/gost/ssh.go +++ b/gost/ssh.go @@ -27,6 +27,10 @@ const ( GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel ) +var ( + errSessionDead = errors.New("session is dead") +) + type sshDirectForwardConnector struct { } @@ -188,7 +192,7 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp } if session.Closed() { delete(tr.sessions, opts.Addr) - return nil, ErrSessionDead + return nil, errSessionDead } return &sshNopConn{session: session}, nil @@ -288,7 +292,7 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt if session.Closed() { delete(tr.sessions, opts.Addr) - return nil, ErrSessionDead + return nil, errSessionDead } channel, reqs, err := session.client.OpenChannel(GostSSHTunnelRequest, nil) @@ -485,10 +489,10 @@ func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr log.Logf("[ssh-tcp] %s - %s", h.options.Addr, raddr) - //! if !s.Base.Node.Can("tcp", raddr) { - //! glog.Errorf("Unauthorized to tcp connect to %s", raddr) - //! return - //! } + if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr) + return + } conn, err := h.options.Chain.Dial(raddr) if err != nil { @@ -514,11 +518,11 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque addr := fmt.Sprintf("%s:%d", t.Host, t.Port) - //! if !s.Base.Node.Can("rtcp", addr) { - //! glog.Errorf("Unauthorized to tcp bind to %s", addr) - //! req.Reply(false, nil) - //! return - //! } + if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr) + req.Reply(false, nil) + return + } log.Log("[ssh-rtcp] listening on tcp", addr) ln, err := net.Listen("tcp", addr) //tie to the client connection diff --git a/gost/tls.go b/gost/tls.go index 613e8a4..c65f6d0 100644 --- a/gost/tls.go +++ b/gost/tls.go @@ -25,7 +25,7 @@ func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( if opts.TLSConfig == nil { opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} } - return tls.Client(conn, opts.TLSConfig), nil + return wrapTLSClient(conn, opts.TLSConfig) } type tlsListener struct {