diff --git a/gost/chain.go b/gost/chain.go index f976cf8..3cd0172 100644 --- a/gost/chain.go +++ b/gost/chain.go @@ -78,12 +78,12 @@ func (c *Chain) Conn() (net.Conn, error) { } nodes := c.nodes - conn, err := nodes[0].Client.Dial(nodes[0].Addr, TimeoutDialOption(DialTimeout)) + conn, err := nodes[0].Client.Dial(nodes[0].Addr, nodes[0].DialOptions...) if err != nil { return nil, err } - conn, err = nodes[0].Client.Handshake(conn, AddrHandshakeOption(nodes[0].Addr)) + conn, err = nodes[0].Client.Handshake(conn, nodes[0].HandshakeOptions...) if err != nil { return nil, err } @@ -99,7 +99,7 @@ func (c *Chain) Conn() (net.Conn, error) { conn.Close() return nil, err } - cc, err = next.Client.Handshake(cc, AddrHandshakeOption(next.Addr)) + cc, err = next.Client.Handshake(cc, next.HandshakeOptions...) if err != nil { conn.Close() return nil, err diff --git a/gost/cmd/gost/main.go b/gost/cmd/gost/main.go index 8d8f192..55a7d75 100644 --- a/gost/cmd/gost/main.go +++ b/gost/cmd/gost/main.go @@ -1,12 +1,20 @@ package main import ( + "bufio" + "crypto/tls" "encoding/json" + "errors" "flag" "fmt" "io/ioutil" + "net" + "net/url" "os" "runtime" + "strconv" + "strings" + "time" "github.com/ginuerzh/gost/gost" "github.com/go-log/log" @@ -51,10 +59,19 @@ func init() { } func main() { - + chain, err := initChain() + if err != nil { + log.Log(err) + os.Exit(1) + } + if err := serve(chain); err != nil { + log.Log(err) + os.Exit(1) + } + select {} } -func buildChain() (*gost.Chain, error) { +func initChain() (*gost.Chain, error) { chain := gost.NewChain() for _, ns := range options.chainNodes { node, err := gost.ParseNode(ns) @@ -62,19 +79,41 @@ func buildChain() (*gost.Chain, error) { return nil, err } + serverName, _, _ := net.SplitHostPort(node.Addr) + if serverName == "" { + serverName = "localhost" // default server name + } + + tlsCfg := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: !toBool(node.Values.Get("scure")), + } var tr gost.Transporter switch node.Transport { case "tls": tr = gost.TLSTransporter() case "ws": + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = toBool(node.Values.Get("compression")) + wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) + wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) + node.HandshakeOptions = append(node.HandshakeOptions, + gost.WSOptionsHandshakeOption(wsOpts), + ) tr = gost.WSTransporter(nil) case "wss": tr = gost.WSSTransporter(nil) case "kcp": if !chain.IsEmpty() { - log.Log("KCP must be the first node in the proxy chain") - return nil, err + return nil, errors.New("KCP must be the first node in the proxy chain") } + config, err := parseKCPConfig(node.Values.Get("c")) + if err != nil { + log.Log("[kcp]", err) + } + node.HandshakeOptions = append(node.HandshakeOptions, + gost.KCPConfigHandshakeOption(config), + ) tr = gost.KCPTransporter(nil) case "ssh": if node.Protocol == "direct" || node.Protocol == "remote" { @@ -82,14 +121,24 @@ func buildChain() (*gost.Chain, error) { } else { tr = gost.SSHTunnelTransporter() } + node.Chain = chain // cutoff the chain for multiplex + chain = gost.NewChain() case "quic": if !chain.IsEmpty() { - log.Log("QUIC must be the first node in the proxy chain") - return nil, err + return nil, errors.New("QUIC must be the first node in the proxy chain") } + config := &gost.QUICConfig{ + TLSConfig: tlsCfg, + KeepAlive: toBool(node.Values.Get("keepalive")), + } + node.HandshakeOptions = append(node.HandshakeOptions, + gost.QUICConfigHandshakeOption(config), + ) tr = gost.QUICTransporter(nil) case "http2": tr = gost.HTTP2Transporter(nil) + node.Chain = chain // cutoff the chain for multiplex + chain = gost.NewChain() case "h2": tr = gost.H2Transporter(nil) case "h2c": @@ -110,6 +159,10 @@ func buildChain() (*gost.Chain, error) { connector = gost.SOCKS4AConnector() case "ss": connector = gost.ShadowConnector(nil) + case "direct": + connector = gost.SSHDirectForwardConnector() + case "remote": + connector = gost.SSHRemoteForwardConnector() case "http": fallthrough default: @@ -117,6 +170,18 @@ func buildChain() (*gost.Chain, error) { connector = gost.HTTPConnector(nil) } + node.DialOptions = append(node.DialOptions, + gost.TimeoutDialOption(gost.DialTimeout), + gost.ChainDialOption(node.Chain), + ) + + interval, _ := strconv.Atoi(node.Values.Get("ping")) + node.HandshakeOptions = append(node.HandshakeOptions, + gost.AddrHandshakeOption(node.Addr), + gost.UserHandshakeOption(node.User), + gost.TLSConfigHandshakeOption(tlsCfg), + gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), + ) node.Client = &gost.Client{ Connector: connector, Transporter: tr, @@ -127,6 +192,166 @@ func buildChain() (*gost.Chain, error) { return chain, nil } +func serve(chain *gost.Chain) error { + for _, ns := range options.serveNodes { + node, err := gost.ParseNode(ns) + if err != nil { + return err + } + users, err := parseUsers(node.Values.Get("secrets")) + if err != nil { + return err + } + tlsCfg, err := tlsConfig(node.Values.Get("cert"), node.Values.Get("key")) + if err != nil { + return err + } + + var ln gost.Listener + switch node.Transport { + case "tls": + ln, err = gost.TLSListener(node.Addr, tlsCfg) + case "ws": + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = toBool(node.Values.Get("compression")) + wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) + wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) + ln, err = gost.WSListener(node.Addr, wsOpts) + case "wss": + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = toBool(node.Values.Get("compression")) + wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf")) + wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf")) + ln, err = gost.WSSListener(node.Addr, tlsCfg, wsOpts) + case "kcp": + config, err := parseKCPConfig(node.Values.Get("c")) + if err != nil { + log.Log("[kcp]", err) + } + ln, err = gost.KCPListener(node.Addr, config) + case "ssh": + config := &gost.SSHConfig{ + Users: users, + TLSConfig: tlsCfg, + } + if node.Protocol == "forward" { + ln, err = gost.TCPListener(node.Addr) + } else { + ln, err = gost.SSHTunnelListener(node.Addr, config) + } + case "quic": + config := &gost.QUICConfig{ + TLSConfig: tlsCfg, + KeepAlive: toBool(node.Values.Get("keepalive")), + } + timeout, _ := strconv.Atoi(node.Values.Get("timeout")) + config.Timeout = time.Duration(timeout) * time.Second + ln, err = gost.QUICListener(node.Addr, config) + case "http2": + ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) + case "h2": + ln, err = gost.H2Listener(node.Addr, tlsCfg) + case "h2c": + ln, err = gost.H2CListener(node.Addr) + case "tcp": + ln, err = gost.TCPListener(node.Addr) + case "rtcp": + ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) + case "udp": + ttl, _ := strconv.Atoi(node.Values.Get("ttl")) + ln, err = gost.UDPDirectForwardListener(node.Addr, time.Duration(ttl)*time.Second) + case "rudp": + ttl, _ := strconv.Atoi(node.Values.Get("ttl")) + ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, time.Duration(ttl)*time.Second) + case "redirect": + ln, err = gost.TCPListener(node.Addr) + case "ssu": + ttl, _ := strconv.Atoi(node.Values.Get("ttl")) + ln, err = gost.ShadowUDPListener(node.Addr, node.User, time.Duration(ttl)*time.Second) + default: + ln, err = gost.TCPListener(node.Addr) + } + if err != nil { + return err + } + + var whitelist, blacklist *gost.Permissions + if node.Values.Get("whitelist") != "" { + if whitelist, err = gost.ParsePermissions(node.Values.Get("whitelist")); err != nil { + return err + } + } else { + // By default allow for everyting + whitelist, _ = gost.ParsePermissions("*:*:*") + } + + if node.Values.Get("blacklist") != "" { + if blacklist, err = gost.ParsePermissions(node.Values.Get("blacklist")); err != nil { + return err + } + } else { + // By default block nothing + blacklist, _ = gost.ParsePermissions("") + } + + var handlerOptions []gost.HandlerOption + + handlerOptions = append(handlerOptions, + gost.AddrHandlerOption(node.Addr), + gost.ChainHandlerOption(chain), + gost.UsersHandlerOption(users...), + gost.TLSConfigHandlerOption(tlsCfg), + gost.WhitelistHandlerOption(whitelist), + gost.BlacklistHandlerOption(blacklist), + ) + var handler gost.Handler + switch node.Protocol { + case "http2": + handler = gost.HTTP2Handler(handlerOptions...) + case "socks", "socks5": + handler = gost.SOCKS5Handler(handlerOptions...) + case "socks4", "socks4a": + handler = gost.SOCKS4Handler(handlerOptions...) + case "ss": + handler = gost.ShadowHandler(handlerOptions...) + case "http": + handler = gost.HTTPHandler(handlerOptions...) + case "tcp": + handler = gost.TCPDirectForwardHandler(node.Remote, handlerOptions...) + case "rtcp": + handler = gost.TCPRemoteForwardHandler(node.Remote, handlerOptions...) + case "udp": + handler = gost.UDPDirectForwardHandler(node.Remote, handlerOptions...) + case "rudp": + handler = gost.UDPRemoteForwardHandler(node.Remote, handlerOptions...) + case "forward": + handler = gost.SSHForwardHandler(handlerOptions...) + case "redirect": + handler = gost.TCPRedirectHandler(handlerOptions...) + case "ssu": + handler = gost.ShadowUDPdHandler(handlerOptions...) + default: + // TODO: auto poroxy handler + handler = gost.HTTPHandler(handlerOptions...) + } + go new(gost.Server).Serve(ln, handler) + } + + return nil +} + +// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid. +func tlsConfig(certFile, keyFile string) (*tls.Config, error) { + if certFile == "" || keyFile == "" { + return nil, nil + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + return &tls.Config{Certificates: []tls.Certificate{cert}}, nil +} + func loadConfigureFile(configureFile string) error { if configureFile == "" { return nil @@ -150,3 +375,56 @@ func (l *stringList) Set(value string) error { *l = append(*l, value) return nil } + +func toBool(s string) bool { + if b, _ := strconv.ParseBool(s); b { + return b + } + n, _ := strconv.Atoi(s) + return n > 0 +} + +func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { + if configFile == "" { + return nil, nil + } + file, err := os.Open(configFile) + if err != nil { + return nil, err + } + defer file.Close() + + config := &gost.KCPConfig{} + if err = json.NewDecoder(file).Decode(config); err != nil { + return nil, err + } + return config, nil +} + +func parseUsers(authFile string) (users []*url.Userinfo, err error) { + if authFile == "" { + return + } + + file, err := os.Open(authFile) + if err != nil { + return + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + s := strings.SplitN(line, " ", 2) + if len(s) == 1 { + users = append(users, url.User(strings.TrimSpace(s[0]))) + } else if len(s) == 2 { + users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) + } + } + + err = scanner.Err() + return +} diff --git a/gost/forward.go b/gost/forward.go index 9763a88..bdbe361 100644 --- a/gost/forward.go +++ b/gost/forward.go @@ -367,12 +367,16 @@ func (c *udpServerConn) writeLoop() { } func (c *udpServerConn) ttlWait() { - timer := time.NewTimer(c.ttl) + ttl := c.ttl + if ttl == 0 { + ttl = defaultTTL + } + timer := time.NewTimer(ttl) for { select { case <-c.nopChan: - timer.Reset(c.ttl) + timer.Reset(ttl) case <-timer.C: close(c.brokenChan) return @@ -452,7 +456,7 @@ func (l *tcpRemoteForwardListener) Accept() (net.Conn, error) { func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { lastNode := l.chain.LastNode() - if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" { + if lastNode.Protocol == "remote" && lastNode.Transport == "ssh" { conn, err = l.chain.Dial(l.addr.String()) } else if lastNode.Protocol == "socks5" { cc, er := l.chain.Conn() diff --git a/gost/gost.go b/gost/gost.go index e479024..91768e2 100644 --- a/gost/gost.go +++ b/gost/gost.go @@ -13,7 +13,7 @@ import ( ) // Version is the gost version. -const Version = "2.4-dev20170722" +const Version = "2.4-dev20170803" // Debug is a flag that enables the debug log. var Debug bool @@ -39,7 +39,7 @@ var ( // PingRetries is the reties of ping. PingRetries = 3 // default udp node TTL in second for udp port forwarding. - defaultTTL = 60 + defaultTTL = 60 * time.Second ) var ( diff --git a/gost/http2.go b/gost/http2.go index 853ae95..2927afe 100644 --- a/gost/http2.go +++ b/gost/http2.go @@ -391,6 +391,15 @@ func HTTP2Listener(addr string, config *tls.Config) (Listener, error) { connChan: make(chan *http2ServerConn, 1024), errChan: make(chan error, 1), } + if config == nil { + cert, err := tls.X509KeyPair(defaultRawCert, defaultRawKey) + if err != nil { + return nil, err + } + config = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + } server := &http.Server{ Addr: addr, Handler: http.HandlerFunc(l.handleFunc), @@ -400,6 +409,7 @@ func HTTP2Listener(addr string, config *tls.Config) (Listener, error) { return nil, err } l.server = server + go server.ListenAndServeTLS("", "") return l, nil @@ -462,6 +472,16 @@ func H2Listener(addr string, config *tls.Config) (Listener, error) { if err != nil { return nil, err } + if config == nil { + cert, err := tls.X509KeyPair(defaultRawCert, defaultRawKey) + if err != nil { + return nil, err + } + config = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + } + l := &h2Listener{ Listener: ln, server: &http2.Server{ diff --git a/gost/node.go b/gost/node.go index 31b3b2f..b7d103e 100644 --- a/gost/node.go +++ b/gost/node.go @@ -1,10 +1,8 @@ package gost import ( - "bufio" "net" "net/url" - "os" "strconv" "strings" @@ -13,18 +11,16 @@ import ( // Node is a proxy node, mainly used to construct a proxy chain. type Node struct { - Addr string - Protocol string - Transport string - Remote string // remote address, used by tcp/udp port forwarding - User *url.Userinfo - users []*url.Userinfo // authentication or cipher for proxy - Whitelist *Permissions - Blacklist *Permissions - values url.Values - serverName string - Client *Client - Server *Server + Addr string + Protocol string + Transport string + Remote string // remote address, used by tcp/udp port forwarding + User *url.Userinfo + Chain *Chain + Values url.Values + Client *Client + DialOptions []DialOption + HandshakeOptions []HandshakeOption } func ParseNode(s string) (node Node, err error) { @@ -36,49 +32,10 @@ func ParseNode(s string) (node Node, err error) { return } - query := u.Query() node = Node{ - Addr: u.Host, - values: query, - serverName: u.Host, - } - - if query.Get("whitelist") != "" { - if node.Whitelist, err = ParsePermissions(query.Get("whitelist")); err != nil { - return - } - } else { - // By default allow for everyting - node.Whitelist, _ = ParsePermissions("*:*:*") - } - - if query.Get("blacklist") != "" { - if node.Blacklist, err = ParsePermissions(query.Get("blacklist")); err != nil { - return - } - } else { - // By default block nothing - node.Blacklist, _ = ParsePermissions("") - } - - if u.User != nil { - node.User = u.User - node.users = append(node.users, u.User) - } - - users, er := parseUsers(node.values.Get("secrets")) - if users != nil { - node.users = append(node.users, users...) - } - if er != nil { - log.Log("load secrets:", er) - } - - if strings.Contains(u.Host, ":") { - node.serverName, _, _ = net.SplitHostPort(u.Host) - if node.serverName == "" { - node.serverName = "localhost" // default server name - } + Addr: u.Host, + Values: u.Query(), + User: u.User, } schemes := strings.Split(u.Scheme, "+") @@ -105,9 +62,9 @@ func ParseNode(s string) (node Node, err error) { } switch node.Protocol { - case "http", "http2", "socks4", "socks4a", "socks", "socks5", "ss": + case "http", "http2", "socks4", "socks4a", "socks", "socks5", "ss", "ssu": case "tcp", "udp", "rtcp", "rudp": // port forwarding - case "direct", "remote": // SSH port forwarding + case "direct", "remote", "forward": // SSH port forwarding default: node.Protocol = "" } @@ -115,34 +72,6 @@ func ParseNode(s string) (node Node, err error) { return } -func parseUsers(authFile string) (users []*url.Userinfo, err error) { - if authFile == "" { - return - } - - file, err := os.Open(authFile) - if err != nil { - return - } - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - s := strings.SplitN(line, " ", 2) - if len(s) == 1 { - users = append(users, url.User(strings.TrimSpace(s[0]))) - } else if len(s) == 2 { - users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) - } - } - - err = scanner.Err() - return -} - func Can(action string, addr string, whitelist, blacklist *Permissions) bool { if !strings.Contains(addr, ":") { addr = addr + ":80" @@ -159,7 +88,8 @@ func Can(action string, addr string, whitelist, blacklist *Permissions) bool { return false } - log.Logf("Can action: %s, host: %s, port %d", action, host, port) - + if Debug { + log.Logf("Can action: %s, host: %s, port %d", action, host, port) + } return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port) } diff --git a/gost/ssh.go b/gost/ssh.go index 8c69603..a09dc9d 100644 --- a/gost/ssh.go +++ b/gost/ssh.go @@ -318,7 +318,6 @@ type sshSession struct { } func (s *sshSession) Ping(interval time.Duration, retries int) { - interval = 30 * time.Second if interval <= 0 { return } @@ -620,14 +619,25 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) { if len(config.Users) == 0 { sshConfig.NoClientAuth = true } - if config.TLSConfig != nil && len(config.TLSConfig.Certificates) > 0 { - signer, err := ssh.NewSignerFromKey(config.TLSConfig.Certificates[0].PrivateKey) + if config.TLSConfig == nil { + cert, err := tls.X509KeyPair(defaultRawCert, defaultRawKey) if err != nil { - log.Log("[sshf]", err) + ln.Close() + return nil, err + } + config.TLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, } - sshConfig.AddHostKey(signer) } + signer, err := ssh.NewSignerFromKey(config.TLSConfig.Certificates[0].PrivateKey) + if err != nil { + ln.Close() + return nil, err + + } + sshConfig.AddHostKey(signer) + l := &sshTunnelListener{ Listener: ln, config: sshConfig,