Implement white/blacklisting for socks and ssh

This commit is contained in:
Adam Stankiewicz 2017-03-12 15:22:21 +01:00
parent cf09f68827
commit 81b90010e8
No known key found for this signature in database
GPG Key ID: A62480DCEAC884DF
7 changed files with 511 additions and 25 deletions

View File

@ -65,6 +65,8 @@ func main() {
glog.Fatal(err)
}
glog.Info(serverNode)
wg.Add(1)
go func(node gost.ProxyNode) {
defer wg.Done()

48
node.go
View File

@ -19,6 +19,8 @@ type ProxyNode struct {
Transport string // transport: ws/wss/tls/http2/tcp/udp/rtcp/rudp
Remote string // remote address, used by tcp/udp port forwarding
Users []*url.Userinfo // authentication for proxy
Whitelist *Permissions
Blacklist *Permissions
values url.Values
serverName string
conn net.Conn
@ -36,12 +38,36 @@ func ParseProxyNode(s string) (node ProxyNode, err error) {
return
}
query := u.Query()
node = ProxyNode{
Addr: u.Host,
values: u.Query(),
values: query,
serverName: u.Host,
}
if query.Get("whitelist") != "" {
node.Whitelist, err = ParsePermissions(query.Get("whitelist"))
if err != nil {
glog.Fatal(err)
}
} else {
// By default allow for everyting
node.Whitelist, _ = ParsePermissions("*:*:*")
}
if query.Get("blacklist") != "" {
node.Blacklist, err = ParsePermissions(query.Get("blacklist"))
if err != nil {
glog.Fatal(err)
}
} else {
// By default block nothing
node.Blacklist, _ = ParsePermissions("")
}
if u.User != nil {
node.Users = append(node.Users, u.User)
}
@ -126,6 +152,24 @@ func (node *ProxyNode) Get(key string) string {
return node.values.Get(key)
}
func (node *ProxyNode) Can(action string, addr string) bool {
host, strport, err := net.SplitHostPort(addr)
if err != nil {
return false
}
port, err := strconv.Atoi(strport)
if err != nil {
return false
}
glog.V(LDEBUG).Infof("Can action: %s, host: %s, port %d", action, host, port)
return node.Whitelist.Can(action, host, port) && !node.Blacklist.Can(action, host, port)
}
func (node *ProxyNode) getBool(key string) bool {
s := node.Get(key)
if b, _ := strconv.ParseBool(s); b {
@ -162,5 +206,5 @@ func (node *ProxyNode) keyFile() string {
}
func (node ProxyNode) String() string {
return fmt.Sprintf("transport: %s, protocol: %s, addr: %s", node.Transport, node.Protocol, node.Addr)
return fmt.Sprintf("transport: %s, protocol: %s, addr: %s, whitelist: %v, blacklist: %v", node.Transport, node.Protocol, node.Addr, node.Whitelist, node.Blacklist)
}

43
node_test.go Normal file
View File

@ -0,0 +1,43 @@
package gost
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNodeDefaultWhitelist(t *testing.T) {
assert := assert.New(t)
node, _ := ParseProxyNode("http2://localhost:8000")
assert.True(node.Can("connect", "google.pl:80"))
assert.True(node.Can("connect", "google.pl:443"))
assert.True(node.Can("connect", "google.pl:22"))
assert.True(node.Can("bind", "google.pl:80"))
assert.True(node.Can("bind", "google.com:80"))
}
func TestNodeWhitelist(t *testing.T) {
assert := assert.New(t)
node, _ := ParseProxyNode("http2://localhost:8000?whitelist=connect:google.pl:80,443")
assert.True(node.Can("connect", "google.pl:80"))
assert.True(node.Can("connect", "google.pl:443"))
assert.False(node.Can("connect", "google.pl:22"))
assert.False(node.Can("bind", "google.pl:80"))
assert.False(node.Can("bind", "google.com:80"))
}
func TestNodeBlacklist(t *testing.T) {
assert := assert.New(t)
node, _ := ParseProxyNode("http2://localhost:8000?blacklist=connect:google.pl:80,443")
assert.False(node.Can("connect", "google.pl:80"))
assert.False(node.Can("connect", "google.pl:443"))
assert.True(node.Can("connect", "google.pl:22"))
assert.True(node.Can("bind", "google.pl:80"))
assert.True(node.Can("bind", "google.com:80"))
}

185
permissions.go Normal file
View File

@ -0,0 +1,185 @@
package gost
import (
"errors"
"fmt"
"strconv"
"strings"
glob "github.com/ryanuber/go-glob"
)
type PortRange struct {
Min, Max int
}
type PortSet []PortRange
type StringSet []string
type Permission struct {
Actions StringSet
Hosts StringSet
Ports PortSet
}
type Permissions []Permission
func minint(x, y int) int {
if x < y {
return x
}
return y
}
func maxint(x, y int) int {
if x > y {
return x
}
return y
}
func (ir *PortRange) Contains(value int) bool {
return value >= ir.Min && value <= ir.Max
}
func ParsePortRange(s string) (*PortRange, error) {
if s == "*" {
return &PortRange{Min: 0, Max: 65535}, nil
}
minmax := strings.Split(s, "-")
switch len(minmax) {
case 1:
port, err := strconv.Atoi(s)
if err != nil {
return nil, err
}
if port < 0 || port > 65535 {
return nil, fmt.Errorf("invalid port: %s", s)
}
return &PortRange{Min: port, Max: port}, nil
case 2:
min, err := strconv.Atoi(minmax[0])
if err != nil {
return nil, err
}
max, err := strconv.Atoi(minmax[1])
if err != nil {
return nil, err
}
realmin := maxint(0, minint(min, max))
realmax := minint(65535, maxint(min, max))
return &PortRange{Min: realmin, Max: realmax}, nil
default:
return nil, fmt.Errorf("invalid range: %s", s)
}
}
func (ps *PortSet) Contains(value int) bool {
for _, portRange := range *ps {
if portRange.Contains(value) {
return true
}
}
return false
}
func ParsePortSet(s string) (*PortSet, error) {
ps := &PortSet{}
if s == "" {
return nil, errors.New("must specify at least one port")
}
ranges := strings.Split(s, ",")
for _, r := range ranges {
portRange, err := ParsePortRange(r)
if err != nil {
return nil, err
}
*ps = append(*ps, *portRange)
}
return ps, nil
}
func (ss *StringSet) Contains(subj string) bool {
for _, s := range *ss {
if glob.Glob(s, subj) {
return true
}
}
return false
}
func ParseStringSet(s string) (*StringSet, error) {
ss := &StringSet{}
if s == "" {
return nil, errors.New("cannot be empty")
}
*ss = strings.Split(s, ",")
return ss, nil
}
func (ps *Permissions) Can(action string, host string, port int) bool {
for _, p := range *ps {
if p.Actions.Contains(action) && p.Hosts.Contains(host) && p.Ports.Contains(port) {
return true
}
}
return false
}
func ParsePermissions(s string) (*Permissions, error) {
ps := &Permissions{}
if s == "" {
return &Permissions{}, nil
}
perms := strings.Split(s, "+")
for _, perm := range perms {
parts := strings.Split(perm, ":")
switch len(parts) {
case 3:
actions, err := ParseStringSet(parts[0])
if err != nil {
return nil, fmt.Errorf("action list must look like connect,bind given: %s", parts[0])
}
hosts, err := ParseStringSet(parts[1])
if err != nil {
return nil, fmt.Errorf("hosts list must look like google.pl,*.google.com given: %s", parts[1])
}
ports, err := ParsePortSet(parts[2])
if err != nil {
return nil, fmt.Errorf("ports list must look like 80,8000-9000, given: %s", parts[2])
}
permission := Permission{Actions: *actions, Hosts: *hosts, Ports: *ports}
*ps = append(*ps, permission)
default:
return nil, fmt.Errorf("permission must have format [actions]:[hosts]:[ports] given: %s", perm)
}
}
return ps, nil
}

152
permissions_test.go Normal file
View File

@ -0,0 +1,152 @@
package gost
import (
"fmt"
"testing"
)
var portRangeTests = []struct {
in string
out *PortRange
}{
{"1", &PortRange{Min: 1, Max: 1}},
{"1-3", &PortRange{Min: 1, Max: 3}},
{"3-1", &PortRange{Min: 1, Max: 3}},
{"0-100000", &PortRange{Min: 0, Max: 65535}},
{"*", &PortRange{Min: 0, Max: 65535}},
}
var stringSetTests = []struct {
in string
out *StringSet
}{
{"*", &StringSet{"*"}},
{"google.pl,google.com", &StringSet{"google.pl", "google.com"}},
}
var portSetTests = []struct {
in string
out *PortSet
}{
{"1,3", &PortSet{PortRange{Min: 1, Max: 1}, PortRange{Min: 3, Max: 3}}},
{"1-3,7-5", &PortSet{PortRange{Min: 1, Max: 3}, PortRange{Min: 5, Max: 7}}},
{"0-100000", &PortSet{PortRange{Min: 0, Max: 65535}}},
{"*", &PortSet{PortRange{Min: 0, Max: 65535}}},
}
var permissionsTests = []struct {
in string
out *Permissions
}{
{"", &Permissions{}},
{"*:*:*", &Permissions{
Permission{
Actions: StringSet{"*"},
Hosts: StringSet{"*"},
Ports: PortSet{PortRange{Min: 0, Max: 65535}},
},
}},
{"bind:127.0.0.1,localhost:80,443,8000-8100+connect:*.google.pl:80,443", &Permissions{
Permission{
Actions: StringSet{"bind"},
Hosts: StringSet{"127.0.0.1", "localhost"},
Ports: PortSet{
PortRange{Min: 80, Max: 80},
PortRange{Min: 443, Max: 443},
PortRange{Min: 8000, Max: 8100},
},
},
Permission{
Actions: StringSet{"connect"},
Hosts: StringSet{"*.google.pl"},
Ports: PortSet{
PortRange{Min: 80, Max: 80},
PortRange{Min: 443, Max: 443},
},
},
}},
}
func TestPortRangeParse(t *testing.T) {
for _, test := range portRangeTests {
actual, err := ParsePortRange(test.in)
if err != nil {
t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err)
} else if *actual != *test.out {
t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out)
}
}
}
func TestPortRangeContains(t *testing.T) {
actual, _ := ParsePortRange("5-10")
if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) {
t.Errorf("5-10 should contain 5, 7 and 10")
}
if actual.Contains(4) || actual.Contains(11) {
t.Errorf("5-10 should not contain 4, 11")
}
}
func TestStringSetParse(t *testing.T) {
for _, test := range stringSetTests {
actual, err := ParseStringSet(test.in)
if err != nil {
t.Errorf("ParseStringSet(%q) returned error: %v", test.in, err)
} else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) {
t.Errorf("ParseStringSet(%q): got %v, want %v", test.in, actual, test.out)
}
}
}
func TestStringSetContains(t *testing.T) {
ss, _ := ParseStringSet("google.pl,*.google.com")
if !ss.Contains("google.pl") || !ss.Contains("www.google.com") {
t.Errorf("google.pl,*.google.com should contain google.pl and www.google.com")
}
if ss.Contains("www.google.pl") || ss.Contains("foobar.com") {
t.Errorf("google.pl,*.google.com shound not contain www.google.pl and foobar.com")
}
}
func TestPortSetParse(t *testing.T) {
for _, test := range portSetTests {
actual, err := ParsePortSet(test.in)
if err != nil {
t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err)
} else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) {
t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out)
}
}
}
func TestPortSetContains(t *testing.T) {
actual, _ := ParsePortSet("5-10,20-30")
if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) {
t.Errorf("5-10,20-30 should contain 5, 7 and 10")
}
if !actual.Contains(20) || !actual.Contains(27) || !actual.Contains(30) {
t.Errorf("5-10,20-30 should contain 20, 27 and 30")
}
if actual.Contains(4) || actual.Contains(11) || actual.Contains(31) {
t.Errorf("5-10,20-30 should not contain 4, 11, 31")
}
}
func TestPermissionsParse(t *testing.T) {
for _, test := range permissionsTests {
actual, err := ParsePermissions(test.in)
if err != nil {
t.Errorf("ParsePermissions(%q) returned error: %v", test.in, err)
} else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) {
t.Errorf("ParsePermissions(%q): got %v, want %v", test.in, actual, test.out)
}
}
}

View File

@ -3,13 +3,14 @@ package gost
import (
"bytes"
"crypto/tls"
"github.com/ginuerzh/gosocks4"
"github.com/ginuerzh/gosocks5"
"github.com/golang/glog"
"net"
"net/url"
"strconv"
"time"
"github.com/ginuerzh/gosocks4"
"github.com/ginuerzh/gosocks5"
"github.com/golang/glog"
)
const (
@ -191,7 +192,7 @@ func (s *Socks5Server) HandleRequest(req *gosocks5.Request) {
s.handleUDPRelay(req)
case CmdUdpTun:
glog.V(LINFO).Infof("[socks5-udp] %s - %s", s.conn.RemoteAddr(), req.Addr)
glog.V(LINFO).Infof("[socks5-rudp] %s - %s", s.conn.RemoteAddr(), req.Addr)
s.handleUDPTunnel(req)
default:
@ -200,7 +201,16 @@ func (s *Socks5Server) HandleRequest(req *gosocks5.Request) {
}
func (s *Socks5Server) handleConnect(req *gosocks5.Request) {
cc, err := s.Base.Chain.Dial(req.Addr.String())
addr := req.Addr.String()
if !s.Base.Node.Can("tcp", addr) {
glog.Errorf("Unauthorized to tcp connect to %s", addr)
rep := gosocks5.NewReply(gosocks5.NotAllowed, nil)
rep.Write(s.conn)
return
}
cc, err := s.Base.Chain.Dial(addr)
if err != nil {
glog.V(LWARNING).Infof("[socks5-connect] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err)
rep := gosocks5.NewReply(gosocks5.HostUnreachable, nil)
@ -226,7 +236,7 @@ func (s *Socks5Server) handleConnect(req *gosocks5.Request) {
func (s *Socks5Server) handleBind(req *gosocks5.Request) {
cc, err := s.Base.Chain.GetConn()
// connection error
// connection error when forwarding bind
if err != nil && err != ErrEmptyChain {
glog.V(LWARNING).Infof("[socks5-bind] %s <- %s : %s", s.conn.RemoteAddr(), req.Addr, err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
@ -234,23 +244,42 @@ func (s *Socks5Server) handleBind(req *gosocks5.Request) {
glog.V(LDEBUG).Infof("[socks5-bind] %s <- %s\n%s", s.conn.RemoteAddr(), req.Addr, reply)
return
}
// serve socks5 bind
if err == ErrEmptyChain {
s.bindOn(req.Addr.String())
addr := req.Addr.String()
if !s.Base.Node.Can("rtcp", addr) {
glog.Errorf("Unauthorized to tcp bind to %s", addr)
return
}
s.bindOn(addr)
return
}
defer cc.Close()
// forward request
// note: this type of request forwarding is defined when starting server
// so we don't need to authenticate it, as it's as explicit as whitelisting
defer cc.Close()
req.Write(cc)
glog.V(LINFO).Infof("[socks5-bind] %s <-> %s", s.conn.RemoteAddr(), cc.RemoteAddr())
s.Base.transport(s.conn, cc)
glog.V(LINFO).Infof("[socks5-bind] %s >-< %s", s.conn.RemoteAddr(), cc.RemoteAddr())
}
func (s *Socks5Server) handleUDPRelay(req *gosocks5.Request) {
bindAddr, _ := net.ResolveUDPAddr("udp", req.Addr.String())
addr := req.Addr.String()
if !s.Base.Node.Can("udp", addr) {
glog.Errorf("Unauthorized to udp connect to %s", addr)
rep := gosocks5.NewReply(gosocks5.NotAllowed, nil)
rep.Write(s.conn)
return
}
bindAddr, _ := net.ResolveUDPAddr("udp", addr)
relay, err := net.ListenUDP("udp", bindAddr) // udp associate, strict mode: if the port already in use, it will return error
if err != nil {
glog.V(LWARNING).Infof("[socks5-udp] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err)
@ -338,19 +367,26 @@ func (s *Socks5Server) handleUDPTunnel(req *gosocks5.Request) {
// connection error
if err != nil && err != ErrEmptyChain {
glog.V(LWARNING).Infof("[socks5-udp] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err)
glog.V(LWARNING).Infof("[socks5-rudp] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(s.conn)
glog.V(LDEBUG).Infof("[socks5-udp] %s -> %s\n%s", s.conn.RemoteAddr(), req.Addr, reply)
glog.V(LDEBUG).Infof("[socks5-rudp] %s -> %s\n%s", s.conn.RemoteAddr(), req.Addr, reply)
return
}
// serve tunnel udp, tunnel <-> remote, handle tunnel udp request
if err == ErrEmptyChain {
bindAddr, _ := net.ResolveUDPAddr("udp", req.Addr.String())
addr := req.Addr.String()
if !s.Base.Node.Can("rudp", addr) {
glog.Errorf("Unauthorized to udp bind to %s", addr)
return
}
bindAddr, _ := net.ResolveUDPAddr("udp", addr)
uc, err := net.ListenUDP("udp", bindAddr)
if err != nil {
glog.V(LWARNING).Infof("[socks5-udp] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err)
glog.V(LWARNING).Infof("[socks5-rudp] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err)
return
}
defer uc.Close()
@ -359,25 +395,27 @@ func (s *Socks5Server) handleUDPTunnel(req *gosocks5.Request) {
socksAddr.Host, _, _ = net.SplitHostPort(s.conn.LocalAddr().String())
reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr)
if err := reply.Write(s.conn); err != nil {
glog.V(LWARNING).Infof("[socks5-udp] %s <- %s : %s", s.conn.RemoteAddr(), socksAddr, err)
glog.V(LWARNING).Infof("[socks5-rudp] %s <- %s : %s", s.conn.RemoteAddr(), socksAddr, err)
return
}
glog.V(LDEBUG).Infof("[socks5-udp] %s <- %s\n%s", s.conn.RemoteAddr(), socksAddr, reply)
glog.V(LDEBUG).Infof("[socks5-rudp] %s <- %s\n%s", s.conn.RemoteAddr(), socksAddr, reply)
glog.V(LINFO).Infof("[socks5-udp] %s <-> %s", s.conn.RemoteAddr(), socksAddr)
glog.V(LINFO).Infof("[socks5-rudp] %s <-> %s", s.conn.RemoteAddr(), socksAddr)
s.tunnelServerUDP(s.conn, uc)
glog.V(LINFO).Infof("[socks5-udp] %s >-< %s", s.conn.RemoteAddr(), socksAddr)
glog.V(LINFO).Infof("[socks5-rudp] %s >-< %s", s.conn.RemoteAddr(), socksAddr)
return
}
defer cc.Close()
// tunnel <-> tunnel, direct forwarding
// note: this type of request forwarding is defined when starting server
// so we don't need to authenticate it, as it's as explicit as whitelisting
req.Write(cc)
glog.V(LINFO).Infof("[socks5-udp] %s <-> %s [tun]", s.conn.RemoteAddr(), cc.RemoteAddr())
glog.V(LINFO).Infof("[socks5-rudp] %s <-> %s [tun]", s.conn.RemoteAddr(), cc.RemoteAddr())
s.Base.transport(s.conn, cc)
glog.V(LINFO).Infof("[socks5-udp] %s >-< %s [tun]", s.conn.RemoteAddr(), cc.RemoteAddr())
glog.V(LINFO).Infof("[socks5-rudp] %s >-< %s [tun]", s.conn.RemoteAddr(), cc.RemoteAddr())
}
func (s *Socks5Server) bindOn(addr string) {
@ -697,7 +735,16 @@ func (s *Socks4Server) HandleRequest(req *gosocks4.Request) {
}
func (s *Socks4Server) handleConnect(req *gosocks4.Request) {
cc, err := s.Base.Chain.Dial(req.Addr.String())
addr := req.Addr.String()
if !s.Base.Node.Can("tcp", addr) {
glog.Errorf("Unauthorized to tcp connect to %s", addr)
rep := gosocks5.NewReply(gosocks4.Rejected, nil)
rep.Write(s.conn)
return
}
cc, err := s.Base.Chain.Dial(addr)
if err != nil {
glog.V(LWARNING).Infof("[socks4-connect] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err)
rep := gosocks4.NewReply(gosocks4.Failed, nil)

17
ssh.go
View File

@ -5,11 +5,12 @@ package gost
import (
"encoding/binary"
"fmt"
"github.com/golang/glog"
"golang.org/x/crypto/ssh"
"net"
"net/url"
"strconv"
"github.com/golang/glog"
"golang.org/x/crypto/ssh"
)
// Applicaple SSH Request types for Port Forwarding - RFC 4254 7.X
@ -121,6 +122,11 @@ func (s *SSHServer) directPortForwardChannel(channel ssh.Channel, raddr string)
glog.V(LINFO).Infof("[ssh-tcp] %s - %s", s.Addr, raddr)
if !s.Base.Node.Can("tcp", raddr) {
glog.Errorf("Unauthorized to tcp connect to %s", raddr)
return
}
conn, err := s.Base.Chain.Dial(raddr)
if err != nil {
glog.V(LINFO).Infof("[ssh-tcp] %s - %s : %s", s.Addr, raddr, err)
@ -143,6 +149,13 @@ func (s *SSHServer) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit
t := tcpipForward{}
ssh.Unmarshal(req.Payload, &t)
addr := fmt.Sprintf("%s:%d", t.Host, t.Port)
if !s.Base.Node.Can("rtcp", addr) {
glog.Errorf("Unauthorized to tcp bind to %s", addr)
req.Reply(false, nil)
return
}
glog.V(LINFO).Infoln("[ssh-rtcp] listening tcp", addr)
ln, err := net.Listen("tcp", addr) //tie to the client connection
if err != nil {