add whitelist/blacklist support

This commit is contained in:
rui.zheng 2017-08-03 18:08:55 +08:00
parent c02bc32cb6
commit 97d2de15a3
15 changed files with 386 additions and 214 deletions

View File

@ -46,6 +46,8 @@ func init() {
fmt.Fprintf(os.Stderr, "gost %s (%s)\n", gost.Version, runtime.Version()) fmt.Fprintf(os.Stderr, "gost %s (%s)\n", gost.Version, runtime.Version())
os.Exit(0) os.Exit(0)
} }
gost.Debug = options.debugMode
} }
func main() { func main() {
@ -54,8 +56,8 @@ func main() {
func buildChain() (*gost.Chain, error) { func buildChain() (*gost.Chain, error) {
chain := gost.NewChain() chain := gost.NewChain()
for _, cn := range options.chainNodes { for _, ns := range options.chainNodes {
node, err := parseNode(cn) node, err := gost.ParseNode(ns)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -66,9 +68,60 @@ func buildChain() (*gost.Chain, error) {
tr = gost.TLSTransporter() tr = gost.TLSTransporter()
case "ws": case "ws":
tr = gost.WSTransporter(nil) tr = gost.WSTransporter(nil)
case "wss":
tr = gost.WSSTransporter(nil)
case "kcp":
if !chain.IsEmpty() {
log.Log("KCP must be the first node in the proxy chain")
return nil, err
}
tr = gost.KCPTransporter(nil)
case "ssh":
if node.Protocol == "direct" || node.Protocol == "remote" {
tr = gost.SSHForwardTransporter()
} else {
tr = gost.SSHTunnelTransporter()
}
case "quic":
if !chain.IsEmpty() {
log.Log("QUIC must be the first node in the proxy chain")
return nil, err
}
tr = gost.QUICTransporter(nil)
case "http2":
tr = gost.HTTP2Transporter(nil)
case "h2":
tr = gost.H2Transporter(nil)
case "h2c":
tr = gost.H2CTransporter()
default:
tr = gost.TCPTransporter()
} }
var connector gost.Connector var connector gost.Connector
switch node.Protocol {
case "http2":
connector = gost.HTTP2Connector(nil)
case "socks", "socks5":
connector = gost.SOCKS5Connector(nil)
case "socks4":
connector = gost.SOCKS4Connector()
case "socks4a":
connector = gost.SOCKS4AConnector()
case "ss":
connector = gost.ShadowConnector(nil)
case "http":
fallthrough
default:
node.Protocol = "http" // default protocol is HTTP
connector = gost.HTTPConnector(nil)
}
node.Client = &gost.Client{
Connector: connector,
Transporter: tr,
}
chain.AddNode(node)
} }
return chain, nil return chain, nil

View File

@ -1,139 +0,0 @@
package main
import (
"bufio"
"net"
"net/url"
"os"
"strings"
"github.com/ginuerzh/gost/gost"
"github.com/go-log/log"
)
type node struct {
Addr string
Protocol string // protocol: http/socks5/ss
Transport string // transport: ws/wss/tls/http2/tcp/udp/rtcp/rudp
Remote string // remote address, used by tcp/udp port forwarding
Users []*url.Userinfo // authentication for proxy
Whitelist *gost.Permissions
Blacklist *gost.Permissions
values url.Values
serverName string
}
func parseNode(s string) (n node, err error) {
if !strings.Contains(s, "://") {
s = "gost://" + s
}
u, err := url.Parse(s)
if err != nil {
return
}
query := u.Query()
n = node{
Addr: u.Host,
}
if query.Get("whitelist") != "" {
if n.Whitelist, err = gost.ParsePermissions(query.Get("whitelist")); err != nil {
return
}
} else {
// By default allow for everyting
n.Whitelist, _ = gost.ParsePermissions("*:*:*")
}
if query.Get("blacklist") != "" {
if n.Blacklist, err = gost.ParsePermissions(query.Get("blacklist")); err != nil {
return
}
} else {
// By default block nothing
n.Blacklist, _ = gost.ParsePermissions("")
}
if u.User != nil {
n.Users = append(n.Users, u.User)
}
users, er := parseUsers(n.values.Get("secrets"))
if users != nil {
n.Users = append(n.Users, users...)
}
if er != nil {
log.Log("load secrets:", er)
}
if strings.Contains(u.Host, ":") {
n.serverName, _, _ = net.SplitHostPort(u.Host)
if n.serverName == "" {
n.serverName = "localhost" // default server name
}
}
schemes := strings.Split(u.Scheme, "+")
if len(schemes) == 1 {
n.Protocol = schemes[0]
n.Transport = schemes[0]
}
if len(schemes) == 2 {
n.Protocol = schemes[0]
n.Transport = schemes[1]
}
switch n.Transport {
case "ws", "wss", "tls", "h2", "h2c", "quic", "kcp", "redirect", "ssu", "ssh":
case "https":
n.Protocol = "http"
n.Transport = "tls"
case "http2": // http2 -> http2+tls, h2c mode is http2+tcp
n.Protocol = "http2"
n.Transport = "tls"
case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding
n.Remote = strings.Trim(u.EscapedPath(), "/")
case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding
n.Remote = strings.Trim(u.EscapedPath(), "/")
default:
n.Transport = ""
}
switch n.Protocol {
case "http", "http2", "socks", "socks4", "socks4a", "socks5", "ss":
default:
n.Protocol = ""
}
return
}
func parseUsers(s string) (users []*url.Userinfo, err error) {
if s == "" {
return
}
f, err := os.Open(s)
if err != nil {
return
}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
s := strings.SplitN(line, " ", 2)
if len(s) == 1 {
users = append(users, url.User(strings.TrimSpace(s[0])))
} else if len(s) == 2 {
users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1])))
}
}
err = scanner.Err()
return
}

View File

@ -188,14 +188,14 @@ func request(chain *gost.Chain, start <-chan struct{}) {
swg.Done() swg.Done()
<-start <-start
conn, err := chain.Dial("localhost:10000") conn, err := chain.Dial("localhost:18888")
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
} }
defer conn.Close() defer conn.Close()
//conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) //conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
req, err := http.NewRequest(http.MethodGet, "http://localhost:10000/pkg", nil) req, err := http.NewRequest(http.MethodGet, "http://localhost:18888", nil)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return

View File

@ -3,7 +3,9 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"flag" "flag"
"fmt"
"log" "log"
"net/http"
"net/url" "net/url"
"time" "time"
@ -45,6 +47,7 @@ func main() {
go http2TunnelServer() go http2TunnelServer()
go quicServer() go quicServer()
go shadowUDPServer() go shadowUDPServer()
go testServer()
select {} select {}
} }
@ -344,3 +347,13 @@ func tlsConfig() *tls.Config {
PreferServerCipherSuites: true, PreferServerCipherSuites: true,
} }
} }
func testServer() {
s := &http.Server{
Addr: ":18888",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "abcdefghijklmnopqrstuvwxyz")
}),
}
log.Fatal(s.ListenAndServe())
}

View File

@ -1,7 +1,12 @@
package gost package gost
import ( import (
"errors" "crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"time" "time"
"github.com/go-log/log" "github.com/go-log/log"
@ -38,13 +43,66 @@ var (
) )
var ( var (
ErrSessionDead = errors.New("session is dead") defaultRawCert []byte
defaultRawKey []byte
) )
func init() { func init() {
rawCert, rawKey, err := generateKeyPair()
if err != nil {
panic(err)
}
defaultRawCert, defaultRawKey = rawCert, rawKey
log.DefaultLogger = &LogLogger{} log.DefaultLogger = &LogLogger{}
} }
func SetLogger(logger log.Logger) { func SetLogger(logger log.Logger) {
log.DefaultLogger = logger log.DefaultLogger = logger
} }
func generateKeyPair() (rawCert, rawKey []byte, err error) {
if defaultRawCert != nil && defaultRawKey != nil {
return defaultRawCert, defaultRawKey, nil
}
// Create private key and self-signed certificate
// Adapted from https://golang.org/src/crypto/tls/generate_cert.go
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return
}
validFor := time.Hour * 24 * 365 * 10 // ten years
notBefore := time.Now()
notAfter := notBefore.Add(validFor)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"gost"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return
}
rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
return
}
// SetDefaultCertificate replaces the default certificate by your own
func SetDefaultCertificate(rawCert, rawKey []byte) {
defaultRawCert = rawCert
defaultRawKey = rawKey
}

View File

@ -17,6 +17,8 @@ type HandlerOptions struct {
Chain *Chain Chain *Chain
Users []*url.Userinfo Users []*url.Userinfo
TLSConfig *tls.Config TLSConfig *tls.Config
Whitelist *Permissions
Blacklist *Permissions
} }
// HandlerOption allows a common way to set handler options. // HandlerOption allows a common way to set handler options.
@ -49,3 +51,17 @@ func TLSConfigHandlerOption(config *tls.Config) HandlerOption {
opts.TLSConfig = config opts.TLSConfig = config
} }
} }
// WhitelistHandlerOption sets the Whitelist option of HandlerOptions.
func WhitelistHandlerOption(whitelist *Permissions) HandlerOption {
return func(opts *HandlerOptions) {
opts.Whitelist = whitelist
}
}
// BlacklistHandlerOption sets the Blacklist option of HandlerOptions.
func BlacklistHandlerOption(blacklist *Permissions) HandlerOption {
return func(opts *HandlerOptions) {
opts.Blacklist = blacklist
}
}

View File

@ -121,6 +121,17 @@ func (h *httpHandler) Handle(conn net.Conn) {
req.Header.Del("Proxy-Authorization") req.Header.Del("Proxy-Authorization")
req.Header.Del("Proxy-Connection") req.Header.Del("Proxy-Connection")
if !Can("tcp", req.Host, h.options.Whitelist, h.options.Blacklist) {
log.Logf("[http] Unauthorized to tcp connect to %s", req.Host)
b := []byte("HTTP/1.1 403 Forbidden\r\n" +
"Proxy-Agent: gost/" + Version + "\r\n\r\n")
conn.Write(b)
if Debug {
log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b))
}
return
}
// forward http request // forward http request
lastNode := h.options.Chain.LastNode() lastNode := h.options.Chain.LastNode()
if req.Method != http.MethodConnect && lastNode.Protocol == "http" { if req.Method != http.MethodConnect && lastNode.Protocol == "http" {
@ -128,11 +139,6 @@ func (h *httpHandler) Handle(conn net.Conn) {
return return
} }
// if !s.Base.Node.Can("tcp", req.Host) {
// glog.Errorf("Unauthorized to tcp connect to %s", req.Host)
// return
// }
host := req.Host host := req.Host
if !strings.Contains(req.Host, ":") { if !strings.Contains(req.Host, ":") {
host += ":80" host += ":80"

View File

@ -287,10 +287,11 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Proxy-Agent", "gost/"+Version) w.Header().Set("Proxy-Agent", "gost/"+Version)
//! if !s.Base.Node.Can("tcp", target) { if !Can("tcp", target, h.options.Whitelist, h.options.Blacklist) {
//! glog.Errorf("Unauthorized to tcp connect to %s", target) log.Logf("[http2] Unauthorized to tcp connect to %s", target)
//! return w.WriteHeader(http.StatusForbidden)
//! } return
}
u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
if !authenticate(u, p, h.options.Users...) { if !authenticate(u, p, h.options.Users...) {

View File

@ -21,7 +21,7 @@ func (l *LogLogger) Logf(format string, v ...interface{}) {
log.Output(3, fmt.Sprintf(format, v...)) log.Output(3, fmt.Sprintf(format, v...))
} }
// NopLogger is a null logger that discards the log outputs // NopLogger is a dummy logger that discards the log outputs
type NopLogger struct { type NopLogger struct {
} }

View File

@ -1,7 +1,14 @@
package gost package gost
import ( import (
"bufio"
"net"
"net/url" "net/url"
"os"
"strconv"
"strings"
"github.com/go-log/log"
) )
// Node is a proxy node, mainly used to construct a proxy chain. // Node is a proxy node, mainly used to construct a proxy chain.
@ -9,7 +16,150 @@ type Node struct {
Addr string Addr string
Protocol string Protocol string
Transport string Transport string
Remote string // remote address, used by tcp/udp port forwarding
User *url.Userinfo User *url.Userinfo
users []*url.Userinfo // authentication or cipher for proxy
Whitelist *Permissions
Blacklist *Permissions
values url.Values
serverName string
Client *Client Client *Client
Server *Server Server *Server
} }
func ParseNode(s string) (node Node, err error) {
if !strings.Contains(s, "://") {
s = "auto://" + s
}
u, err := url.Parse(s)
if err != nil {
return
}
query := u.Query()
node = Node{
Addr: u.Host,
values: query,
serverName: u.Host,
}
if query.Get("whitelist") != "" {
if node.Whitelist, err = ParsePermissions(query.Get("whitelist")); err != nil {
return
}
} else {
// By default allow for everyting
node.Whitelist, _ = ParsePermissions("*:*:*")
}
if query.Get("blacklist") != "" {
if node.Blacklist, err = ParsePermissions(query.Get("blacklist")); err != nil {
return
}
} else {
// By default block nothing
node.Blacklist, _ = ParsePermissions("")
}
if u.User != nil {
node.User = u.User
node.users = append(node.users, u.User)
}
users, er := parseUsers(node.values.Get("secrets"))
if users != nil {
node.users = append(node.users, users...)
}
if er != nil {
log.Log("load secrets:", er)
}
if strings.Contains(u.Host, ":") {
node.serverName, _, _ = net.SplitHostPort(u.Host)
if node.serverName == "" {
node.serverName = "localhost" // default server name
}
}
schemes := strings.Split(u.Scheme, "+")
if len(schemes) == 1 {
node.Protocol = schemes[0]
node.Transport = schemes[0]
}
if len(schemes) == 2 {
node.Protocol = schemes[0]
node.Transport = schemes[1]
}
switch node.Transport {
case "tls", "ws", "wss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "redirect":
case "https":
node.Protocol = "http"
node.Transport = "tls"
case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding
node.Remote = strings.Trim(u.EscapedPath(), "/")
case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding
node.Remote = strings.Trim(u.EscapedPath(), "/")
default:
node.Transport = ""
}
switch node.Protocol {
case "http", "http2", "socks4", "socks4a", "socks", "socks5", "ss":
case "tcp", "udp", "rtcp", "rudp": // port forwarding
case "direct", "remote": // SSH port forwarding
default:
node.Protocol = ""
}
return
}
func parseUsers(authFile string) (users []*url.Userinfo, err error) {
if authFile == "" {
return
}
file, err := os.Open(authFile)
if err != nil {
return
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
s := strings.SplitN(line, " ", 2)
if len(s) == 1 {
users = append(users, url.User(strings.TrimSpace(s[0])))
} else if len(s) == 2 {
users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1])))
}
}
err = scanner.Err()
return
}
func Can(action string, addr string, whitelist, blacklist *Permissions) bool {
if !strings.Contains(addr, ":") {
addr = addr + ":80"
}
host, strport, err := net.SplitHostPort(addr)
if err != nil {
return false
}
port, err := strconv.Atoi(strport)
if err != nil {
return false
}
log.Logf("Can action: %s, host: %s, port %d", action, host, port)
return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port)
}

View File

@ -9,14 +9,6 @@ import (
glob "github.com/ryanuber/go-glob" glob "github.com/ryanuber/go-glob"
) )
type PortRange struct {
Min, Max int
}
type PortSet []PortRange
type StringSet []string
type Permission struct { type Permission struct {
Actions StringSet Actions StringSet
Hosts StringSet Hosts StringSet
@ -39,6 +31,10 @@ func maxint(x, y int) int {
return y return y
} }
type PortRange struct {
Min, Max int
}
func (ir *PortRange) Contains(value int) bool { func (ir *PortRange) Contains(value int) bool {
return value >= ir.Min && value <= ir.Max return value >= ir.Min && value <= ir.Max
} }
@ -88,6 +84,8 @@ func (ps *PortSet) Contains(value int) bool {
return false return false
} }
type PortSet []PortRange
func ParsePortSet(s string) (*PortSet, error) { func ParsePortSet(s string) (*PortSet, error) {
ps := &PortSet{} ps := &PortSet{}
@ -120,6 +118,8 @@ func (ss *StringSet) Contains(subj string) bool {
return false return false
} }
type StringSet []string
func ParseStringSet(s string) (*StringSet, error) { func ParseStringSet(s string) (*StringSet, error) {
ss := &StringSet{} ss := &StringSet{}
if s == "" { if s == "" {

View File

@ -395,13 +395,15 @@ func (h *socks5Handler) Handle(conn net.Conn) {
func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) {
addr := req.Addr.String() addr := req.Addr.String()
if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) {
//! if !s.Base.Node.Can("tcp", addr) { log.Logf("[socks5-connect] Unauthorized to tcp connect to %s", addr)
//! glog.Errorf("Unauthorized to tcp connect to %s", addr) rep := gosocks5.NewReply(gosocks5.NotAllowed, nil)
//! rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) rep.Write(conn)
//! rep.Write(s.conn) if Debug {
//! return log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep)
//! } }
return
}
cc, err := h.options.Chain.Dial(addr) cc, err := h.options.Chain.Dial(addr)
if err != nil { if err != nil {
@ -430,13 +432,12 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) {
func (h *socks5Handler) handleBind(conn net.Conn, req *gosocks5.Request) { func (h *socks5Handler) handleBind(conn net.Conn, req *gosocks5.Request) {
if h.options.Chain.IsEmpty() { if h.options.Chain.IsEmpty() {
addr := req.Addr.String()
//! if !s.Base.Node.Can("rtcp", addr) { if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) {
//! glog.Errorf("Unauthorized to tcp bind to %s", addr) log.Logf("Unauthorized to tcp bind to %s", addr)
//! return return
//! } }
h.bindOn(conn, addr)
h.bindOn(conn, req.Addr.String())
return return
} }
@ -554,14 +555,16 @@ func (h *socks5Handler) bindOn(conn net.Conn, addr string) {
} }
func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) { func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) {
//! addr := req.Addr.String() addr := req.Addr.String()
//! if !Can("udp", addr, h.options.Whitelist, h.options.Blacklist) {
//! if !s.Base.Node.Can("udp", addr) { log.Logf("[socks5-udp] Unauthorized to udp connect to %s", addr)
//! glog.Errorf("Unauthorized to udp connect to %s", addr) rep := gosocks5.NewReply(gosocks5.NotAllowed, nil)
//! rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) rep.Write(conn)
//! rep.Write(s.conn) if Debug {
//! return log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep)
//! } }
return
}
relay, err := net.ListenUDP("udp", nil) relay, err := net.ListenUDP("udp", nil)
if err != nil { if err != nil {
@ -817,10 +820,10 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) {
if h.options.Chain.IsEmpty() { if h.options.Chain.IsEmpty() {
addr := req.Addr.String() addr := req.Addr.String()
//! if !s.Base.Node.Can("rudp", addr) { if !Can("rudp", addr, h.options.Whitelist, h.options.Blacklist) {
//! glog.Errorf("Unauthorized to udp bind to %s", addr) log.Logf("[socks5-udp] Unauthorized to udp bind to %s", addr)
//! return return
//! } }
bindAddr, _ := net.ResolveUDPAddr("udp", addr) bindAddr, _ := net.ResolveUDPAddr("udp", addr)
uc, err := net.ListenUDP("udp", bindAddr) uc, err := net.ListenUDP("udp", bindAddr)
@ -992,12 +995,15 @@ func (h *socks4Handler) Handle(conn net.Conn) {
func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) {
addr := req.Addr.String() addr := req.Addr.String()
//! if !s.Base.Node.Can("tcp", addr) { if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) {
//! glog.Errorf("Unauthorized to tcp connect to %s", addr) log.Logf("[socks4-connect] Unauthorized to tcp connect to %s", addr)
//! rep := gosocks5.NewReply(gosocks4.Rejected, nil) rep := gosocks5.NewReply(gosocks4.Rejected, nil)
//! rep.Write(s.conn) rep.Write(conn)
//! return if Debug {
//! } log.Logf("[socks4-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep)
}
return
}
cc, err := h.options.Chain.Dial(addr) cc, err := h.options.Chain.Dial(addr)
if err != nil { if err != nil {

View File

@ -110,7 +110,6 @@ func (h *shadowHandler) Handle(conn net.Conn) {
defer conn.Close() defer conn.Close()
var method, password string var method, password string
users := h.options.Users users := h.options.Users
if len(users) > 0 { if len(users) > 0 {
method = users[0].Username() method = users[0].Username()
@ -132,6 +131,11 @@ func (h *shadowHandler) Handle(conn net.Conn) {
} }
log.Logf("[ss] %s -> %s", conn.RemoteAddr(), addr) log.Logf("[ss] %s -> %s", conn.RemoteAddr(), addr)
if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) {
log.Logf("[ss] Unauthorized to tcp connect to %s", addr)
return
}
cc, err := h.options.Chain.Dial(addr) cc, err := h.options.Chain.Dial(addr)
if err != nil { if err != nil {
log.Logf("[ss] %s -> %s : %s", conn.RemoteAddr(), addr, err) log.Logf("[ss] %s -> %s : %s", conn.RemoteAddr(), addr, err)

View File

@ -27,6 +27,10 @@ const (
GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel
) )
var (
errSessionDead = errors.New("session is dead")
)
type sshDirectForwardConnector struct { type sshDirectForwardConnector struct {
} }
@ -188,7 +192,7 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
} }
if session.Closed() { if session.Closed() {
delete(tr.sessions, opts.Addr) delete(tr.sessions, opts.Addr)
return nil, ErrSessionDead return nil, errSessionDead
} }
return &sshNopConn{session: session}, nil return &sshNopConn{session: session}, nil
@ -288,7 +292,7 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt
if session.Closed() { if session.Closed() {
delete(tr.sessions, opts.Addr) delete(tr.sessions, opts.Addr)
return nil, ErrSessionDead return nil, errSessionDead
} }
channel, reqs, err := session.client.OpenChannel(GostSSHTunnelRequest, nil) channel, reqs, err := session.client.OpenChannel(GostSSHTunnelRequest, nil)
@ -485,10 +489,10 @@ func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr
log.Logf("[ssh-tcp] %s - %s", h.options.Addr, raddr) log.Logf("[ssh-tcp] %s - %s", h.options.Addr, raddr)
//! if !s.Base.Node.Can("tcp", raddr) { if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) {
//! glog.Errorf("Unauthorized to tcp connect to %s", raddr) log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr)
//! return return
//! } }
conn, err := h.options.Chain.Dial(raddr) conn, err := h.options.Chain.Dial(raddr)
if err != nil { if err != nil {
@ -514,11 +518,11 @@ func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Reque
addr := fmt.Sprintf("%s:%d", t.Host, t.Port) addr := fmt.Sprintf("%s:%d", t.Host, t.Port)
//! if !s.Base.Node.Can("rtcp", addr) { if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) {
//! glog.Errorf("Unauthorized to tcp bind to %s", addr) log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr)
//! req.Reply(false, nil) req.Reply(false, nil)
//! return return
//! } }
log.Log("[ssh-rtcp] listening on tcp", addr) log.Log("[ssh-rtcp] listening on tcp", addr)
ln, err := net.Listen("tcp", addr) //tie to the client connection ln, err := net.Listen("tcp", addr) //tie to the client connection

View File

@ -25,7 +25,7 @@ func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
if opts.TLSConfig == nil { if opts.TLSConfig == nil {
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
} }
return tls.Client(conn, opts.TLSConfig), nil return wrapTLSClient(conn, opts.TLSConfig)
} }
type tlsListener struct { type tlsListener struct {