diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 14f5062..25de116 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -14,7 +14,6 @@ import ( "net/url" "os" "runtime" - "strconv" "strings" "time" @@ -117,7 +116,7 @@ func (r *route) initChain() (*gost.Chain, error) { ngroup.AddNode(nodes...) // parse peer nodes if exists - peerCfg, err := loadPeerConfig(nodes[0].Values.Get("peer")) + peerCfg, err := loadPeerConfig(nodes[0].Get("peer")) if err != nil { log.Log(err) } @@ -156,7 +155,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { return } - users, err := parseUsers(node.Values.Get("secrets")) + users, err := parseUsers(node.Get("secrets")) if err != nil { return } @@ -168,20 +167,20 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { serverName = "localhost" // default server name } - rootCAs, err := loadCA(node.Values.Get("ca")) + rootCAs, err := loadCA(node.Get("ca")) if err != nil { return } tlsCfg := &tls.Config{ ServerName: serverName, - InsecureSkipVerify: !toBool(node.Values.Get("secure")), + InsecureSkipVerify: !node.GetBool("secure"), RootCAs: rootCAs, } wsOpts := &gost.WSOptions{} - wsOpts.EnableCompression = toBool(node.Values.Get("compression")) - wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) - wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) - wsOpts.UserAgent = node.Values.Get("agent") + wsOpts.EnableCompression = node.GetBool("compression") + wsOpts.ReadBufferSize = node.GetInt("rbuf") + wsOpts.WriteBufferSize = node.GetInt("wbuf") + wsOpts.UserAgent = node.Get("agent") var tr gost.Transporter switch node.Transport { @@ -203,7 +202,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { return nil, errors.New("KCP must be the first node in the proxy chain") } */ - config, err := parseKCPConfig(node.Values.Get("c")) + config, err := parseKCPConfig(node.Get("c")) if err != nil { return nil, err } @@ -222,16 +221,13 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { */ config := &gost.QUICConfig{ TLSConfig: tlsCfg, - KeepAlive: toBool(node.Values.Get("keepalive")), + KeepAlive: node.GetBool("keepalive"), } - timeout, _ := strconv.Atoi(node.Values.Get("timeout")) - config.Timeout = time.Duration(timeout) * time.Second + config.Timeout = time.Duration(node.GetInt("timeout")) * time.Second + config.IdleTimeout = time.Duration(node.GetInt("idle")) * time.Second - idle, _ := strconv.Atoi(node.Values.Get("idle")) - config.IdleTimeout = time.Duration(idle) * time.Second - - if key := node.Values.Get("key"); key != "" { + if key := node.Get("key"); key != "" { sum := sha256.Sum256([]byte(key)) config.Key = sum[:] } @@ -274,7 +270,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { case "forward": connector = gost.ForwardConnector() case "sni": - connector = gost.SNIConnector(node.Values.Get("host")) + connector = gost.SNIConnector(node.Get("host")) case "http": fallthrough default: @@ -282,28 +278,26 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { connector = gost.HTTPConnector(node.User) } - timeout, _ := strconv.Atoi(node.Values.Get("timeout")) + timeout := node.GetInt("timeout") node.DialOptions = append(node.DialOptions, gost.TimeoutDialOption(time.Duration(timeout)*time.Second), ) - interval, _ := strconv.Atoi(node.Values.Get("ping")) - retry, _ := strconv.Atoi(node.Values.Get("retry")) handshakeOptions := []gost.HandshakeOption{ gost.AddrHandshakeOption(node.Addr), gost.HostHandshakeOption(node.Host), gost.UserHandshakeOption(node.User), gost.TLSConfigHandshakeOption(tlsCfg), - gost.IntervalHandshakeOption(time.Duration(interval) * time.Second), + gost.IntervalHandshakeOption(time.Duration(node.GetInt("ping")) * time.Second), gost.TimeoutHandshakeOption(time.Duration(timeout) * time.Second), - gost.RetryHandshakeOption(retry), + gost.RetryHandshakeOption(node.GetInt("retry")), } node.Client = &gost.Client{ Connector: connector, Transporter: tr, } - ips := parseIP(node.Values.Get("ip"), sport) + ips := parseIP(node.Get("ip"), sport) for _, ip := range ips { node.Addr = ip node.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip)) @@ -328,23 +322,23 @@ func (r *route) serve() error { if err != nil { return err } - users, err := parseUsers(node.Values.Get("secrets")) + users, err := parseUsers(node.Get("secrets")) if err != nil { return err } if node.User != nil { users = append(users, node.User) } - certFile, keyFile := node.Values.Get("cert"), node.Values.Get("key") + certFile, keyFile := node.Get("cert"), node.Get("key") tlsCfg, err := tlsConfig(certFile, keyFile) if err != nil && certFile != "" && keyFile != "" { return err } wsOpts := &gost.WSOptions{} - wsOpts.EnableCompression = toBool(node.Values.Get("compression")) - wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) - wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) + wsOpts.EnableCompression = node.GetBool("compression") + wsOpts.ReadBufferSize = node.GetInt("rbuf") + wsOpts.WriteBufferSize = node.GetInt("wbuf") var ln gost.Listener switch node.Transport { @@ -353,7 +347,7 @@ func (r *route) serve() error { case "mtls": ln, err = gost.MTLSListener(node.Addr, tlsCfg) case "ws": - wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) + wsOpts.WriteBufferSize = node.GetInt("wbuf") ln, err = gost.WSListener(node.Addr, wsOpts) case "mws": ln, err = gost.MWSListener(node.Addr, wsOpts) @@ -362,7 +356,7 @@ func (r *route) serve() error { case "mwss": ln, err = gost.MWSSListener(node.Addr, tlsCfg, wsOpts) case "kcp": - config, er := parseKCPConfig(node.Values.Get("c")) + config, er := parseKCPConfig(node.Get("c")) if er != nil { return er } @@ -380,15 +374,12 @@ func (r *route) serve() error { case "quic": config := &gost.QUICConfig{ TLSConfig: tlsCfg, - KeepAlive: toBool(node.Values.Get("keepalive")), + KeepAlive: node.GetBool("keepalive"), } - timeout, _ := strconv.Atoi(node.Values.Get("timeout")) - config.Timeout = time.Duration(timeout) * time.Second + config.Timeout = time.Duration(node.GetInt("timeout")) * time.Second + config.IdleTimeout = time.Duration(node.GetInt("idle")) * time.Second - idle, _ := strconv.Atoi(node.Values.Get("idle")) - config.IdleTimeout = time.Duration(idle) * time.Second - - if key := node.Values.Get("key"); key != "" { + if key := node.Get("key"); key != "" { sum := sha256.Sum256([]byte(key)) config.Key = sum[:] } @@ -415,14 +406,11 @@ func (r *route) serve() error { } ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) case "udp": - ttl, _ := strconv.Atoi(node.Values.Get("ttl")) - ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(ttl)*time.Second) + ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(node.GetInt("ttl"))*time.Second) case "rudp": - ttl, _ := strconv.Atoi(node.Values.Get("ttl")) - ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(ttl)*time.Second) + ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(node.GetInt("ttl"))*time.Second) case "ssu": - ttl, _ := strconv.Atoi(node.Values.Get("ttl")) - ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(ttl)*time.Second) + ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(node.GetInt("ttl"))*time.Second) case "obfs4": if err = gost.Obfs4Init(node, true); err != nil { return err @@ -572,14 +560,6 @@ func (l *stringList) Set(value string) error { return nil } -func toBool(s string) bool { - if b, _ := strconv.ParseBool(s); b { - return b - } - n, _ := strconv.Atoi(s) - return n > 0 -} - func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { if configFile == "" { return nil, nil diff --git a/forward.go b/forward.go index e200e6b..f5d3592 100644 --- a/forward.go +++ b/forward.go @@ -422,10 +422,13 @@ func (c *udpServerConn) SetWriteDeadline(t time.Time) error { } type tcpRemoteForwardListener struct { - addr net.Addr - chain *Chain - ln net.Listener - closed chan struct{} + addr net.Addr + chain *Chain + ln net.Listener + session *muxSession + once sync.Once + mutex sync.Mutex + closed chan struct{} } // TCPRemoteForwardListener creates a Listener for TCP remote port forwarding server. @@ -474,6 +477,10 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" { conn, err = l.chain.Dial(l.addr.String()) } else if lastNode.Protocol == "socks5" { + if lastNode.GetBool("mbind") { + return l.muxAccept() // multiplexing support for binding. + } + cc, er := l.chain.Conn() if er != nil { return nil, er @@ -494,6 +501,14 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { return } +func (l *tcpRemoteForwardListener) muxAccept() (conn net.Conn, err error) { + l.mutex.Lock() + defer l.mutex.Unlock() + + + return nil, nil +} + func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) { conn, err := socks5Handshake(conn, l.chain.LastNode().User) if err != nil { diff --git a/node.go b/node.go index 14d5b17..b63fac9 100644 --- a/node.go +++ b/node.go @@ -3,6 +3,7 @@ package gost import ( "fmt" "net/url" + "strconv" "strings" "sync/atomic" "time" @@ -141,6 +142,23 @@ func (node *Node) Clone() Node { } } +// Get returns node parameter specified by key. +func (node *Node) Get(key string) string { + return node.Values.Get(key) +} + +// GetBool likes Get, but convert parameter value to bool. +func (node *Node) GetBool(key string) bool { + b, _ := strconv.ParseBool(node.Values.Get(key)) + return b +} + +// GetInt likes Get, but convert parameter value to int. +func (node *Node) GetInt(key string) int { + n, _ := strconv.Atoi(node.Values.Get(key)) + return n +} + func (node *Node) String() string { return fmt.Sprintf("%d@%s", node.ID, node.Addr) } diff --git a/socks.go b/socks.go index 1bb4533..156cf04 100644 --- a/socks.go +++ b/socks.go @@ -5,16 +5,16 @@ import ( "crypto/tls" "errors" "fmt" + "io" "net" "net/url" "strconv" "time" - "io" - "github.com/ginuerzh/gosocks4" "github.com/ginuerzh/gosocks5" "github.com/go-log/log" + smux "gopkg.in/xtaci/smux.v1" ) const ( @@ -22,10 +22,15 @@ const ( MethodTLS uint8 = 0x80 // MethodTLSAuth is an extended SOCKS5 method for TLS+AUTH. MethodTLSAuth uint8 = 0x82 + // MethodMux is an extended SOCKS5 method for stream multiplexing. + MethodMux = 0x88 ) const ( - // CmdUDPTun is an extended SOCKS5 method for UDP over TCP. + // CMDMuxBind is an extended SOCKS5 request CMD for + // multiplexing transport with the binding server. + CMDMuxBind uint8 = 0xF2 + // CmdUDPTun is an extended SOCKS5 request CMD for UDP over TCP. CmdUDPTun uint8 = 0xF3 ) @@ -392,6 +397,9 @@ func (h *socks5Handler) Handle(conn net.Conn) { case gosocks5.CmdUdp: h.handleUDPRelay(conn, req) + case CMDMuxBind: + h.handleMuxBind(conn, req) + case CmdUDPTun: h.handleUDPTunnel(conn, req) @@ -942,6 +950,98 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, uc *net.UDPConn) (err error return } +func (h *socks5Handler) handleMuxBind(conn net.Conn, req *gosocks5.Request) { + if h.options.Chain.IsEmpty() { + addr := req.Addr.String() + if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("Unauthorized to tcp mbind to %s", addr) + return + } + h.muxBindOn(conn, addr) + return + } + + cc, err := h.options.Chain.Conn() + if err != nil { + log.Logf("[socks5-mbind] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks5-mbind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) + } + return + } + + // forward request + // note: this type of request forwarding is defined when starting server, + // so we don't need to authenticate it, as it's as explicit as whitelisting. + defer cc.Close() + req.Write(cc) + log.Logf("[socks5-mbind] %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) + transport(conn, cc) + log.Logf("[socks5-mbind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) +} + +func (h *socks5Handler) muxBindOn(conn net.Conn, addr string) { + bindAddr, _ := net.ResolveTCPAddr("tcp", addr) + ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error + if err != nil { + log.Logf("[socks5-mbind] %s -> %s : %s", conn.RemoteAddr(), addr, err) + gosocks5.NewReply(gosocks5.Failure, nil).Write(conn) + return + } + defer ln.Close() + + socksAddr := toSocksAddr(ln.Addr()) + // Issue: may not reachable when host has multi-interface. + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + if err := reply.Write(conn); err != nil { + log.Logf("[socks5-mbind] %s <- %s : %s", conn.RemoteAddr(), addr, err) + return + } + if Debug { + log.Logf("[socks5-mbind] %s <- %s\n%s", conn.RemoteAddr(), addr, reply) + } + log.Logf("[socks5-mbind] %s - %s BIND ON %s OK", conn.RemoteAddr(), addr, socksAddr) + + // Upgrade connection to multiplex stream. + s, err := smux.Client(conn, smux.DefaultConfig()) + if err != nil { + log.Logf("[socks5-mbind] %s - %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + + log.Logf("[socks5-mbind] %s <-> %s", conn.RemoteAddr(), socksAddr) + defer log.Logf("[socks5-mbind] %s >-< %s", conn.RemoteAddr(), socksAddr) + + session := &muxSession{ + conn: conn, + session: s, + } + + for { + cc, err := ln.Accept() + if err != nil { + log.Logf("[socks5-mbind] %s <- %s : %v", conn.RemoteAddr(), socksAddr, err) + return + } + log.Logf("[socks5-mbind %s <- %s : ACCEPT peer %s", + conn.RemoteAddr(), socksAddr, cc.RemoteAddr()) + + go func(c net.Conn) { + defer c.Close() + + sc, err := session.GetConn() + if err != nil { + log.Logf("[socks5-mbind %s <- %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + transport(sc, c) + }(cc) + } +} + func toSocksAddr(addr net.Addr) *gosocks5.Addr { host := "0.0.0.0" port := 0