SOCKS5: add mux-bind request CMD

This commit is contained in:
zhengrui 2018-03-18 18:22:19 +08:00
parent 22a4d48cf9
commit 2b5655890c
4 changed files with 172 additions and 59 deletions

View File

@ -14,7 +14,6 @@ import (
"net/url" "net/url"
"os" "os"
"runtime" "runtime"
"strconv"
"strings" "strings"
"time" "time"
@ -117,7 +116,7 @@ func (r *route) initChain() (*gost.Chain, error) {
ngroup.AddNode(nodes...) ngroup.AddNode(nodes...)
// parse peer nodes if exists // parse peer nodes if exists
peerCfg, err := loadPeerConfig(nodes[0].Values.Get("peer")) peerCfg, err := loadPeerConfig(nodes[0].Get("peer"))
if err != nil { if err != nil {
log.Log(err) log.Log(err)
} }
@ -156,7 +155,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
return return
} }
users, err := parseUsers(node.Values.Get("secrets")) users, err := parseUsers(node.Get("secrets"))
if err != nil { if err != nil {
return return
} }
@ -168,20 +167,20 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
serverName = "localhost" // default server name serverName = "localhost" // default server name
} }
rootCAs, err := loadCA(node.Values.Get("ca")) rootCAs, err := loadCA(node.Get("ca"))
if err != nil { if err != nil {
return return
} }
tlsCfg := &tls.Config{ tlsCfg := &tls.Config{
ServerName: serverName, ServerName: serverName,
InsecureSkipVerify: !toBool(node.Values.Get("secure")), InsecureSkipVerify: !node.GetBool("secure"),
RootCAs: rootCAs, RootCAs: rootCAs,
} }
wsOpts := &gost.WSOptions{} wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression")) wsOpts.EnableCompression = node.GetBool("compression")
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) wsOpts.ReadBufferSize = node.GetInt("rbuf")
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) wsOpts.WriteBufferSize = node.GetInt("wbuf")
wsOpts.UserAgent = node.Values.Get("agent") wsOpts.UserAgent = node.Get("agent")
var tr gost.Transporter var tr gost.Transporter
switch node.Transport { switch node.Transport {
@ -203,7 +202,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
return nil, errors.New("KCP must be the first node in the proxy chain") return nil, errors.New("KCP must be the first node in the proxy chain")
} }
*/ */
config, err := parseKCPConfig(node.Values.Get("c")) config, err := parseKCPConfig(node.Get("c"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -222,16 +221,13 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
*/ */
config := &gost.QUICConfig{ config := &gost.QUICConfig{
TLSConfig: tlsCfg, TLSConfig: tlsCfg,
KeepAlive: toBool(node.Values.Get("keepalive")), KeepAlive: node.GetBool("keepalive"),
} }
timeout, _ := strconv.Atoi(node.Values.Get("timeout")) config.Timeout = time.Duration(node.GetInt("timeout")) * time.Second
config.Timeout = time.Duration(timeout) * time.Second config.IdleTimeout = time.Duration(node.GetInt("idle")) * time.Second
idle, _ := strconv.Atoi(node.Values.Get("idle")) if key := node.Get("key"); key != "" {
config.IdleTimeout = time.Duration(idle) * time.Second
if key := node.Values.Get("key"); key != "" {
sum := sha256.Sum256([]byte(key)) sum := sha256.Sum256([]byte(key))
config.Key = sum[:] config.Key = sum[:]
} }
@ -274,7 +270,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
case "forward": case "forward":
connector = gost.ForwardConnector() connector = gost.ForwardConnector()
case "sni": case "sni":
connector = gost.SNIConnector(node.Values.Get("host")) connector = gost.SNIConnector(node.Get("host"))
case "http": case "http":
fallthrough fallthrough
default: default:
@ -282,28 +278,26 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
connector = gost.HTTPConnector(node.User) connector = gost.HTTPConnector(node.User)
} }
timeout, _ := strconv.Atoi(node.Values.Get("timeout")) timeout := node.GetInt("timeout")
node.DialOptions = append(node.DialOptions, node.DialOptions = append(node.DialOptions,
gost.TimeoutDialOption(time.Duration(timeout)*time.Second), gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
) )
interval, _ := strconv.Atoi(node.Values.Get("ping"))
retry, _ := strconv.Atoi(node.Values.Get("retry"))
handshakeOptions := []gost.HandshakeOption{ handshakeOptions := []gost.HandshakeOption{
gost.AddrHandshakeOption(node.Addr), gost.AddrHandshakeOption(node.Addr),
gost.HostHandshakeOption(node.Host), gost.HostHandshakeOption(node.Host),
gost.UserHandshakeOption(node.User), gost.UserHandshakeOption(node.User),
gost.TLSConfigHandshakeOption(tlsCfg), gost.TLSConfigHandshakeOption(tlsCfg),
gost.IntervalHandshakeOption(time.Duration(interval) * time.Second), gost.IntervalHandshakeOption(time.Duration(node.GetInt("ping")) * time.Second),
gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second), gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second),
gost.RetryHandshakeOption(retry), gost.RetryHandshakeOption(node.GetInt("retry")),
} }
node.Client = &gost.Client{ node.Client = &gost.Client{
Connector: connector, Connector: connector,
Transporter: tr, Transporter: tr,
} }
ips := parseIP(node.Values.Get("ip"), sport) ips := parseIP(node.Get("ip"), sport)
for _, ip := range ips { for _, ip := range ips {
node.Addr = ip node.Addr = ip
node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip)) node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip))
@ -328,23 +322,23 @@ func (r *route) serve() error {
if err != nil { if err != nil {
return err return err
} }
users, err := parseUsers(node.Values.Get("secrets")) users, err := parseUsers(node.Get("secrets"))
if err != nil { if err != nil {
return err return err
} }
if node.User != nil { if node.User != nil {
users = append(users, node.User) users = append(users, node.User)
} }
certFile, keyFile := node.Values.Get("cert"), node.Values.Get("key") certFile, keyFile := node.Get("cert"), node.Get("key")
tlsCfg, err := tlsConfig(certFile, keyFile) tlsCfg, err := tlsConfig(certFile, keyFile)
if err != nil && certFile != "" && keyFile != "" { if err != nil && certFile != "" && keyFile != "" {
return err return err
} }
wsOpts := &gost.WSOptions{} wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression")) wsOpts.EnableCompression = node.GetBool("compression")
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) wsOpts.ReadBufferSize = node.GetInt("rbuf")
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) wsOpts.WriteBufferSize = node.GetInt("wbuf")
var ln gost.Listener var ln gost.Listener
switch node.Transport { switch node.Transport {
@ -353,7 +347,7 @@ func (r *route) serve() error {
case "mtls": case "mtls":
ln, err = gost.MTLSListener(node.Addr, tlsCfg) ln, err = gost.MTLSListener(node.Addr, tlsCfg)
case "ws": case "ws":
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) wsOpts.WriteBufferSize = node.GetInt("wbuf")
ln, err = gost.WSListener(node.Addr, wsOpts) ln, err = gost.WSListener(node.Addr, wsOpts)
case "mws": case "mws":
ln, err = gost.MWSListener(node.Addr, wsOpts) ln, err = gost.MWSListener(node.Addr, wsOpts)
@ -362,7 +356,7 @@ func (r *route) serve() error {
case "mwss": case "mwss":
ln, err = gost.MWSSListener(node.Addr, tlsCfg, wsOpts) ln, err = gost.MWSSListener(node.Addr, tlsCfg, wsOpts)
case "kcp": case "kcp":
config, er := parseKCPConfig(node.Values.Get("c")) config, er := parseKCPConfig(node.Get("c"))
if er != nil { if er != nil {
return er return er
} }
@ -380,15 +374,12 @@ func (r *route) serve() error {
case "quic": case "quic":
config := &gost.QUICConfig{ config := &gost.QUICConfig{
TLSConfig: tlsCfg, TLSConfig: tlsCfg,
KeepAlive: toBool(node.Values.Get("keepalive")), KeepAlive: node.GetBool("keepalive"),
} }
timeout, _ := strconv.Atoi(node.Values.Get("timeout")) config.Timeout = time.Duration(node.GetInt("timeout")) * time.Second
config.Timeout = time.Duration(timeout) * time.Second config.IdleTimeout = time.Duration(node.GetInt("idle")) * time.Second
idle, _ := strconv.Atoi(node.Values.Get("idle")) if key := node.Get("key"); key != "" {
config.IdleTimeout = time.Duration(idle) * time.Second
if key := node.Values.Get("key"); key != "" {
sum := sha256.Sum256([]byte(key)) sum := sha256.Sum256([]byte(key))
config.Key = sum[:] config.Key = sum[:]
} }
@ -415,14 +406,11 @@ func (r *route) serve() error {
} }
ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) ln, err = gost.TCPRemoteForwardListener(node.Addr, chain)
case "udp": case "udp":
ttl, _ := strconv.Atoi(node.Values.Get("ttl")) ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(node.GetInt("ttl"))*time.Second)
ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(ttl)*time.Second)
case "rudp": case "rudp":
ttl, _ := strconv.Atoi(node.Values.Get("ttl")) ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(node.GetInt("ttl"))*time.Second)
ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(ttl)*time.Second)
case "ssu": case "ssu":
ttl, _ := strconv.Atoi(node.Values.Get("ttl")) ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(node.GetInt("ttl"))*time.Second)
ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(ttl)*time.Second)
case "obfs4": case "obfs4":
if err = gost.Obfs4Init(node, true); err != nil { if err = gost.Obfs4Init(node, true); err != nil {
return err return err
@ -572,14 +560,6 @@ func (l *stringList) Set(value string) error {
return nil return nil
} }
func toBool(s string) bool {
if b, _ := strconv.ParseBool(s); b {
return b
}
n, _ := strconv.Atoi(s)
return n > 0
}
func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { func parseKCPConfig(configFile string) (*gost.KCPConfig, error) {
if configFile == "" { if configFile == "" {
return nil, nil return nil, nil

View File

@ -425,6 +425,9 @@ type tcpRemoteForwardListener struct {
addr net.Addr addr net.Addr
chain *Chain chain *Chain
ln net.Listener ln net.Listener
session *muxSession
once sync.Once
mutex sync.Mutex
closed chan struct{} closed chan struct{}
} }
@ -474,6 +477,10 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) {
if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" { if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" {
conn, err = l.chain.Dial(l.addr.String()) conn, err = l.chain.Dial(l.addr.String())
} else if lastNode.Protocol == "socks5" { } else if lastNode.Protocol == "socks5" {
if lastNode.GetBool("mbind") {
return l.muxAccept() // multiplexing support for binding.
}
cc, er := l.chain.Conn() cc, er := l.chain.Conn()
if er != nil { if er != nil {
return nil, er return nil, er
@ -494,6 +501,14 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) {
return return
} }
func (l *tcpRemoteForwardListener) muxAccept() (conn net.Conn, err error) {
l.mutex.Lock()
defer l.mutex.Unlock()
return nil, nil
}
func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) { func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) {
conn, err := socks5Handshake(conn, l.chain.LastNode().User) conn, err := socks5Handshake(conn, l.chain.LastNode().User)
if err != nil { if err != nil {

18
node.go
View File

@ -3,6 +3,7 @@ package gost
import ( import (
"fmt" "fmt"
"net/url" "net/url"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@ -141,6 +142,23 @@ func (node *Node) Clone() Node {
} }
} }
// Get returns node parameter specified by key.
func (node *Node) Get(key string) string {
return node.Values.Get(key)
}
// GetBool likes Get, but convert parameter value to bool.
func (node *Node) GetBool(key string) bool {
b, _ := strconv.ParseBool(node.Values.Get(key))
return b
}
// GetInt likes Get, but convert parameter value to int.
func (node *Node) GetInt(key string) int {
n, _ := strconv.Atoi(node.Values.Get(key))
return n
}
func (node *Node) String() string { func (node *Node) String() string {
return fmt.Sprintf("%d@%s", node.ID, node.Addr) return fmt.Sprintf("%d@%s", node.ID, node.Addr)
} }

106
socks.go
View File

@ -5,16 +5,16 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/url" "net/url"
"strconv" "strconv"
"time" "time"
"io"
"github.com/ginuerzh/gosocks4" "github.com/ginuerzh/gosocks4"
"github.com/ginuerzh/gosocks5" "github.com/ginuerzh/gosocks5"
"github.com/go-log/log" "github.com/go-log/log"
smux "gopkg.in/xtaci/smux.v1"
) )
const ( const (
@ -22,10 +22,15 @@ const (
MethodTLS uint8 = 0x80 MethodTLS uint8 = 0x80
// MethodTLSAuth is an extended SOCKS5 method for TLS+AUTH. // MethodTLSAuth is an extended SOCKS5 method for TLS+AUTH.
MethodTLSAuth uint8 = 0x82 MethodTLSAuth uint8 = 0x82
// MethodMux is an extended SOCKS5 method for stream multiplexing.
MethodMux = 0x88
) )
const ( const (
// CmdUDPTun is an extended SOCKS5 method for UDP over TCP. // CMDMuxBind is an extended SOCKS5 request CMD for
// multiplexing transport with the binding server.
CMDMuxBind uint8 = 0xF2
// CmdUDPTun is an extended SOCKS5 request CMD for UDP over TCP.
CmdUDPTun uint8 = 0xF3 CmdUDPTun uint8 = 0xF3
) )
@ -392,6 +397,9 @@ func (h *socks5Handler) Handle(conn net.Conn) {
case gosocks5.CmdUdp: case gosocks5.CmdUdp:
h.handleUDPRelay(conn, req) h.handleUDPRelay(conn, req)
case CMDMuxBind:
h.handleMuxBind(conn, req)
case CmdUDPTun: case CmdUDPTun:
h.handleUDPTunnel(conn, req) h.handleUDPTunnel(conn, req)
@ -942,6 +950,98 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error
return return
} }
func (h *socks5Handler) handleMuxBind(conn net.Conn, req *gosocks5.Request) {
if h.options.Chain.IsEmpty() {
addr := req.Addr.String()
if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) {
log.Logf("Unauthorized to tcp mbind to %s", addr)
return
}
h.muxBindOn(conn, addr)
return
}
cc, err := h.options.Chain.Conn()
if err != nil {
log.Logf("[socks5-mbind] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn)
if Debug {
log.Logf("[socks5-mbind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply)
}
return
}
// forward request
// note: this type of request forwarding is defined when starting server,
// so we don't need to authenticate it, as it's as explicit as whitelisting.
defer cc.Close()
req.Write(cc)
log.Logf("[socks5-mbind] %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr())
transport(conn, cc)
log.Logf("[socks5-mbind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr())
}
func (h *socks5Handler) muxBindOn(conn net.Conn, addr string) {
bindAddr, _ := net.ResolveTCPAddr("tcp", addr)
ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error
if err != nil {
log.Logf("[socks5-mbind] %s -> %s : %s", conn.RemoteAddr(), addr, err)
gosocks5.NewReply(gosocks5.Failure, nil).Write(conn)
return
}
defer ln.Close()
socksAddr := toSocksAddr(ln.Addr())
// Issue: may not reachable when host has multi-interface.
socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr)
if err := reply.Write(conn); err != nil {
log.Logf("[socks5-mbind] %s <- %s : %s", conn.RemoteAddr(), addr, err)
return
}
if Debug {
log.Logf("[socks5-mbind] %s <- %s\n%s", conn.RemoteAddr(), addr, reply)
}
log.Logf("[socks5-mbind] %s - %s BIND ON %s OK", conn.RemoteAddr(), addr, socksAddr)
// Upgrade connection to multiplex stream.
s, err := smux.Client(conn, smux.DefaultConfig())
if err != nil {
log.Logf("[socks5-mbind] %s - %s : %s", conn.RemoteAddr(), socksAddr, err)
return
}
log.Logf("[socks5-mbind] %s <-> %s", conn.RemoteAddr(), socksAddr)
defer log.Logf("[socks5-mbind] %s >-< %s", conn.RemoteAddr(), socksAddr)
session := &muxSession{
conn: conn,
session: s,
}
for {
cc, err := ln.Accept()
if err != nil {
log.Logf("[socks5-mbind] %s <- %s : %v", conn.RemoteAddr(), socksAddr, err)
return
}
log.Logf("[socks5-mbind %s <- %s : ACCEPT peer %s",
conn.RemoteAddr(), socksAddr, cc.RemoteAddr())
go func(c net.Conn) {
defer c.Close()
sc, err := session.GetConn()
if err != nil {
log.Logf("[socks5-mbind %s <- %s : %s", conn.RemoteAddr(), socksAddr, err)
return
}
transport(sc, c)
}(cc)
}
}
func toSocksAddr(addr net.Addr) *gosocks5.Addr { func toSocksAddr(addr net.Addr) *gosocks5.Addr {
host := "0.0.0.0" host := "0.0.0.0"
port := 0 port := 0