SOCKS5: add mux-bind request CMD

This commit is contained in:
zhengrui 2018-03-18 18:22:19 +08:00
parent 22a4d48cf9
commit 2b5655890c
4 changed files with 172 additions and 59 deletions

View File

@ -14,7 +14,6 @@ import (
"net/url"
"os"
"runtime"
"strconv"
"strings"
"time"
@ -117,7 +116,7 @@ func (r *route) initChain() (*gost.Chain, error) {
ngroup.AddNode(nodes...)
// parse peer nodes if exists
peerCfg, err := loadPeerConfig(nodes[0].Values.Get("peer"))
peerCfg, err := loadPeerConfig(nodes[0].Get("peer"))
if err != nil {
log.Log(err)
}
@ -156,7 +155,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
return
}
users, err := parseUsers(node.Values.Get("secrets"))
users, err := parseUsers(node.Get("secrets"))
if err != nil {
return
}
@ -168,20 +167,20 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
serverName = "localhost" // default server name
}
rootCAs, err := loadCA(node.Values.Get("ca"))
rootCAs, err := loadCA(node.Get("ca"))
if err != nil {
return
}
tlsCfg := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: !toBool(node.Values.Get("secure")),
InsecureSkipVerify: !node.GetBool("secure"),
RootCAs: rootCAs,
}
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
wsOpts.UserAgent = node.Values.Get("agent")
wsOpts.EnableCompression = node.GetBool("compression")
wsOpts.ReadBufferSize = node.GetInt("rbuf")
wsOpts.WriteBufferSize = node.GetInt("wbuf")
wsOpts.UserAgent = node.Get("agent")
var tr gost.Transporter
switch node.Transport {
@ -203,7 +202,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
return nil, errors.New("KCP must be the first node in the proxy chain")
}
*/
config, err := parseKCPConfig(node.Values.Get("c"))
config, err := parseKCPConfig(node.Get("c"))
if err != nil {
return nil, err
}
@ -222,16 +221,13 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
*/
config := &gost.QUICConfig{
TLSConfig: tlsCfg,
KeepAlive: toBool(node.Values.Get("keepalive")),
KeepAlive: node.GetBool("keepalive"),
}
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
config.Timeout = time.Duration(timeout) * time.Second
config.Timeout = time.Duration(node.GetInt("timeout")) * time.Second
config.IdleTimeout = time.Duration(node.GetInt("idle")) * time.Second
idle, _ := strconv.Atoi(node.Values.Get("idle"))
config.IdleTimeout = time.Duration(idle) * time.Second
if key := node.Values.Get("key"); key != "" {
if key := node.Get("key"); key != "" {
sum := sha256.Sum256([]byte(key))
config.Key = sum[:]
}
@ -274,7 +270,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
case "forward":
connector = gost.ForwardConnector()
case "sni":
connector = gost.SNIConnector(node.Values.Get("host"))
connector = gost.SNIConnector(node.Get("host"))
case "http":
fallthrough
default:
@ -282,28 +278,26 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
connector = gost.HTTPConnector(node.User)
}
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
timeout := node.GetInt("timeout")
node.DialOptions = append(node.DialOptions,
gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
)
interval, _ := strconv.Atoi(node.Values.Get("ping"))
retry, _ := strconv.Atoi(node.Values.Get("retry"))
handshakeOptions := []gost.HandshakeOption{
gost.AddrHandshakeOption(node.Addr),
gost.HostHandshakeOption(node.Host),
gost.UserHandshakeOption(node.User),
gost.TLSConfigHandshakeOption(tlsCfg),
gost.IntervalHandshakeOption(time.Duration(interval) * time.Second),
gost.IntervalHandshakeOption(time.Duration(node.GetInt("ping")) * time.Second),
gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second),
gost.RetryHandshakeOption(retry),
gost.RetryHandshakeOption(node.GetInt("retry")),
}
node.Client = &gost.Client{
Connector: connector,
Transporter: tr,
}
ips := parseIP(node.Values.Get("ip"), sport)
ips := parseIP(node.Get("ip"), sport)
for _, ip := range ips {
node.Addr = ip
node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip))
@ -328,23 +322,23 @@ func (r *route) serve() error {
if err != nil {
return err
}
users, err := parseUsers(node.Values.Get("secrets"))
users, err := parseUsers(node.Get("secrets"))
if err != nil {
return err
}
if node.User != nil {
users = append(users, node.User)
}
certFile, keyFile := node.Values.Get("cert"), node.Values.Get("key")
certFile, keyFile := node.Get("cert"), node.Get("key")
tlsCfg, err := tlsConfig(certFile, keyFile)
if err != nil && certFile != "" && keyFile != "" {
return err
}
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
wsOpts.EnableCompression = node.GetBool("compression")
wsOpts.ReadBufferSize = node.GetInt("rbuf")
wsOpts.WriteBufferSize = node.GetInt("wbuf")
var ln gost.Listener
switch node.Transport {
@ -353,7 +347,7 @@ func (r *route) serve() error {
case "mtls":
ln, err = gost.MTLSListener(node.Addr, tlsCfg)
case "ws":
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
wsOpts.WriteBufferSize = node.GetInt("wbuf")
ln, err = gost.WSListener(node.Addr, wsOpts)
case "mws":
ln, err = gost.MWSListener(node.Addr, wsOpts)
@ -362,7 +356,7 @@ func (r *route) serve() error {
case "mwss":
ln, err = gost.MWSSListener(node.Addr, tlsCfg, wsOpts)
case "kcp":
config, er := parseKCPConfig(node.Values.Get("c"))
config, er := parseKCPConfig(node.Get("c"))
if er != nil {
return er
}
@ -380,15 +374,12 @@ func (r *route) serve() error {
case "quic":
config := &gost.QUICConfig{
TLSConfig: tlsCfg,
KeepAlive: toBool(node.Values.Get("keepalive")),
KeepAlive: node.GetBool("keepalive"),
}
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
config.Timeout = time.Duration(timeout) * time.Second
config.Timeout = time.Duration(node.GetInt("timeout")) * time.Second
config.IdleTimeout = time.Duration(node.GetInt("idle")) * time.Second
idle, _ := strconv.Atoi(node.Values.Get("idle"))
config.IdleTimeout = time.Duration(idle) * time.Second
if key := node.Values.Get("key"); key != "" {
if key := node.Get("key"); key != "" {
sum := sha256.Sum256([]byte(key))
config.Key = sum[:]
}
@ -415,14 +406,11 @@ func (r *route) serve() error {
}
ln, err = gost.TCPRemoteForwardListener(node.Addr, chain)
case "udp":
ttl, _ := strconv.Atoi(node.Values.Get("ttl"))
ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(ttl)*time.Second)
ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(node.GetInt("ttl"))*time.Second)
case "rudp":
ttl, _ := strconv.Atoi(node.Values.Get("ttl"))
ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(ttl)*time.Second)
ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(node.GetInt("ttl"))*time.Second)
case "ssu":
ttl, _ := strconv.Atoi(node.Values.Get("ttl"))
ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(ttl)*time.Second)
ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(node.GetInt("ttl"))*time.Second)
case "obfs4":
if err = gost.Obfs4Init(node, true); err != nil {
return err
@ -572,14 +560,6 @@ func (l *stringList) Set(value string) error {
return nil
}
func toBool(s string) bool {
if b, _ := strconv.ParseBool(s); b {
return b
}
n, _ := strconv.Atoi(s)
return n > 0
}
func parseKCPConfig(configFile string) (*gost.KCPConfig, error) {
if configFile == "" {
return nil, nil

View File

@ -425,6 +425,9 @@ type tcpRemoteForwardListener struct {
addr net.Addr
chain *Chain
ln net.Listener
session *muxSession
once sync.Once
mutex sync.Mutex
closed chan struct{}
}
@ -474,6 +477,10 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) {
if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" {
conn, err = l.chain.Dial(l.addr.String())
} else if lastNode.Protocol == "socks5" {
if lastNode.GetBool("mbind") {
return l.muxAccept() // multiplexing support for binding.
}
cc, er := l.chain.Conn()
if er != nil {
return nil, er
@ -494,6 +501,14 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) {
return
}
func (l *tcpRemoteForwardListener) muxAccept() (conn net.Conn, err error) {
l.mutex.Lock()
defer l.mutex.Unlock()
return nil, nil
}
func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) {
conn, err := socks5Handshake(conn, l.chain.LastNode().User)
if err != nil {

18
node.go
View File

@ -3,6 +3,7 @@ package gost
import (
"fmt"
"net/url"
"strconv"
"strings"
"sync/atomic"
"time"
@ -141,6 +142,23 @@ func (node *Node) Clone() Node {
}
}
// Get returns node parameter specified by key.
func (node *Node) Get(key string) string {
return node.Values.Get(key)
}
// GetBool likes Get, but convert parameter value to bool.
func (node *Node) GetBool(key string) bool {
b, _ := strconv.ParseBool(node.Values.Get(key))
return b
}
// GetInt likes Get, but convert parameter value to int.
func (node *Node) GetInt(key string) int {
n, _ := strconv.Atoi(node.Values.Get(key))
return n
}
func (node *Node) String() string {
return fmt.Sprintf("%d@%s", node.ID, node.Addr)
}

106
socks.go
View File

@ -5,16 +5,16 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/url"
"strconv"
"time"
"io"
"github.com/ginuerzh/gosocks4"
"github.com/ginuerzh/gosocks5"
"github.com/go-log/log"
smux "gopkg.in/xtaci/smux.v1"
)
const (
@ -22,10 +22,15 @@ const (
MethodTLS uint8 = 0x80
// MethodTLSAuth is an extended SOCKS5 method for TLS+AUTH.
MethodTLSAuth uint8 = 0x82
// MethodMux is an extended SOCKS5 method for stream multiplexing.
MethodMux = 0x88
)
const (
// CmdUDPTun is an extended SOCKS5 method for UDP over TCP.
// CMDMuxBind is an extended SOCKS5 request CMD for
// multiplexing transport with the binding server.
CMDMuxBind uint8 = 0xF2
// CmdUDPTun is an extended SOCKS5 request CMD for UDP over TCP.
CmdUDPTun uint8 = 0xF3
)
@ -392,6 +397,9 @@ func (h *socks5Handler) Handle(conn net.Conn) {
case gosocks5.CmdUdp:
h.handleUDPRelay(conn, req)
case CMDMuxBind:
h.handleMuxBind(conn, req)
case CmdUDPTun:
h.handleUDPTunnel(conn, req)
@ -942,6 +950,98 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error
return
}
func (h *socks5Handler) handleMuxBind(conn net.Conn, req *gosocks5.Request) {
if h.options.Chain.IsEmpty() {
addr := req.Addr.String()
if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) {
log.Logf("Unauthorized to tcp mbind to %s", addr)
return
}
h.muxBindOn(conn, addr)
return
}
cc, err := h.options.Chain.Conn()
if err != nil {
log.Logf("[socks5-mbind] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn)
if Debug {
log.Logf("[socks5-mbind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply)
}
return
}
// forward request
// note: this type of request forwarding is defined when starting server,
// so we don't need to authenticate it, as it's as explicit as whitelisting.
defer cc.Close()
req.Write(cc)
log.Logf("[socks5-mbind] %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr())
transport(conn, cc)
log.Logf("[socks5-mbind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr())
}
func (h *socks5Handler) muxBindOn(conn net.Conn, addr string) {
bindAddr, _ := net.ResolveTCPAddr("tcp", addr)
ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error
if err != nil {
log.Logf("[socks5-mbind] %s -> %s : %s", conn.RemoteAddr(), addr, err)
gosocks5.NewReply(gosocks5.Failure, nil).Write(conn)
return
}
defer ln.Close()
socksAddr := toSocksAddr(ln.Addr())
// Issue: may not reachable when host has multi-interface.
socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr)
if err := reply.Write(conn); err != nil {
log.Logf("[socks5-mbind] %s <- %s : %s", conn.RemoteAddr(), addr, err)
return
}
if Debug {
log.Logf("[socks5-mbind] %s <- %s\n%s", conn.RemoteAddr(), addr, reply)
}
log.Logf("[socks5-mbind] %s - %s BIND ON %s OK", conn.RemoteAddr(), addr, socksAddr)
// Upgrade connection to multiplex stream.
s, err := smux.Client(conn, smux.DefaultConfig())
if err != nil {
log.Logf("[socks5-mbind] %s - %s : %s", conn.RemoteAddr(), socksAddr, err)
return
}
log.Logf("[socks5-mbind] %s <-> %s", conn.RemoteAddr(), socksAddr)
defer log.Logf("[socks5-mbind] %s >-< %s", conn.RemoteAddr(), socksAddr)
session := &muxSession{
conn: conn,
session: s,
}
for {
cc, err := ln.Accept()
if err != nil {
log.Logf("[socks5-mbind] %s <- %s : %v", conn.RemoteAddr(), socksAddr, err)
return
}
log.Logf("[socks5-mbind %s <- %s : ACCEPT peer %s",
conn.RemoteAddr(), socksAddr, cc.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
sc, err := session.GetConn()
if err != nil {
log.Logf("[socks5-mbind %s <- %s : %s", conn.RemoteAddr(), socksAddr, err)
return
}
transport(sc, c)
}(cc)
}
}
func toSocksAddr(addr net.Addr) *gosocks5.Addr {
host := "0.0.0.0"
port := 0