From 0dd8bbb935a49aa6cfd10ee5b46245a3d5ce863c Mon Sep 17 00:00:00 2001 From: tongsq Date: Mon, 18 Apr 2022 19:57:41 +0800 Subject: [PATCH 1/3] =?UTF-8?q?add=20http=E3=80=81http2=E3=80=81socks5?= =?UTF-8?q?=E3=80=81relay=20access=20speed=20limiter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + cmd/gost/cfg.go | 20 ++++ cmd/gost/route.go | 14 +++ handler.go | 8 ++ http.go | 16 +++ http2.go | 13 ++- limiter.go | 259 ++++++++++++++++++++++++++++++++++++++++++++++ limiter_test.go | 69 ++++++++++++ relay.go | 11 ++ socks.go | 11 +- 10 files changed, 420 insertions(+), 2 deletions(-) create mode 100644 limiter.go create mode 100644 limiter_test.go diff --git a/.gitignore b/.gitignore index 2016845..28e7880 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ _testmain.go .vscode/ cmd/gost/gost +.idea \ No newline at end of file diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index df0c53c..8b2912a 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -148,6 +148,26 @@ func parseAuthenticator(s string) (gost.Authenticator, error) { return au, nil } +func parseLimiter(s string) (gost.Limiter, error) { + if s == "" { + return nil, nil + } + f, err := os.Open(s) + if err != nil { + return nil, err + } + defer f.Close() + + l, _ := gost.NewLocalLimiter("", "") + err = l.Reload(f) + if err != nil { + return nil, err + } + go gost.PeriodReload(l, s) + + return l, nil +} + func parseIP(s string, port string) (ips []string) { if s == "" { return diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 360bc2d..9470041 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -386,6 +386,19 @@ func (r *route) GenRouters() ([]router, error) { node.User = users[0] } } + + //init rate limiter + limiterHandler, err := parseLimiter(node.Get("secrets")) + if err != nil { + return nil, err + } + if limiterHandler == nil && strings.TrimSpace(node.Get("limiter")) != "" && node.User != nil { + limiterHandler, err = gost.NewLocalLimiter(node.User.Username(), strings.TrimSpace(node.Get("limiter"))) + if err != nil { + return nil, err + } + } + certFile, keyFile := node.Get("cert"), node.Get("key") tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca")) if err != nil && certFile != "" && keyFile != "" { @@ -671,6 +684,7 @@ func (r *route) GenRouters() ([]router, error) { gost.IPRoutesHandlerOption(tunRoutes...), gost.ProxyAgentHandlerOption(node.Get("proxyAgent")), gost.HTTPTunnelHandlerOption(node.GetBool("httpTunnel")), + gost.LimiterHandlerOption(limiterHandler), ) rt := router{ diff --git a/handler.go b/handler.go index ee82cea..105d030 100644 --- a/handler.go +++ b/handler.go @@ -44,6 +44,7 @@ type HandlerOptions struct { IPRoutes []IPRoute ProxyAgent string HTTPTunnel bool + Limiter Limiter } // HandlerOption allows a common way to set handler options. @@ -87,6 +88,13 @@ func AuthenticatorHandlerOption(au Authenticator) HandlerOption { } } +// LimiterHandlerOption sets the Rate limiter option of HandlerOptions +func LimiterHandlerOption(l Limiter) HandlerOption { + return func(opts *HandlerOptions) { + opts.Limiter = l + } +} + // TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions. func TLSConfigHandlerOption(config *tls.Config) HandlerOption { return func(opts *HandlerOptions) { diff --git a/http.go b/http.go index 8f9e3fd..846f7ce 100644 --- a/http.go +++ b/http.go @@ -212,7 +212,23 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { if !h.authenticate(conn, req, resp) { return } + user, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) + if h.options.Limiter != nil { + done, ok := h.options.Limiter.CheckRate(user, true) + if !ok { + resp.StatusCode = http.StatusTooManyRequests + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http] %s <- %s rate limiter \n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + + resp.Write(conn) + return + } else { + defer done() + } + } if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") { resp.StatusCode = http.StatusBadRequest diff --git a/http2.go b/http2.go index de152ea..8ba40e0 100644 --- a/http2.go +++ b/http2.go @@ -394,7 +394,18 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { if !h.authenticate(w, r, resp) { return } - + user, _, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) + if h.options.Limiter != nil { + done, ok := h.options.Limiter.CheckRate(user, true) + if !ok { + log.Logf("[http2] %s - %s rate limiter %s, user is %s", + r.RemoteAddr, laddr, host, user) + w.WriteHeader(http.StatusTooManyRequests) + return + } else { + defer done() + } + } // delete the proxy related headers. r.Header.Del("Proxy-Authorization") r.Header.Del("Proxy-Connection") diff --git a/limiter.go b/limiter.go new file mode 100644 index 0000000..754aa05 --- /dev/null +++ b/limiter.go @@ -0,0 +1,259 @@ +package gost + +import ( + "bufio" + "errors" + "io" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +type Limiter interface { + CheckRate(key string, checkConcurrent bool) (func(), bool) +} + +func NewLocalLimiter(user string, cfg string) (*LocalLimiter, error) { + limiter := LocalLimiter{ + buckets: map[string]*limiterBucket{}, + concurrent: map[string]chan bool{}, + stopped: make(chan struct{}), + } + if cfg == "" || user == "" { + return &limiter, nil + } + if err := limiter.AddRule(user, cfg); err != nil { + return nil, err + } + return &limiter, nil +} + +// Token Bucket +type limiterBucket struct { + max int64 + cur int64 + duration int64 + batch int64 +} + +type LocalLimiter struct { + buckets map[string]*limiterBucket + concurrent map[string]chan bool + mux sync.RWMutex + stopped chan struct{} + period time.Duration +} + +func (l *LocalLimiter) CheckRate(key string, checkConcurrent bool) (func(), bool) { + if checkConcurrent { + done, ok := l.checkConcurrent(key) + if !ok { + return nil, false + } + if t := l.getToken(key); !t { + done() + return nil, false + } + return done, true + } else { + if t := l.getToken(key); !t { + return nil, false + } + return nil, true + } +} + +func (l *LocalLimiter) AddRule(user string, cfg string) error { + if user == "" { + return nil + } + if cfg == "" { + //reload need check old limit exists + if _, ok := l.buckets[user]; ok { + delete(l.buckets, user) + } + if _, ok := l.concurrent[user]; ok { + delete(l.concurrent, user) + } + return nil + } + args := strings.Split(cfg, ",") + if len(args) < 2 || len(args) > 3 { + return errors.New("parse limiter fail:" + cfg) + } + if len(args) == 2 { + args = append(args, "0") + } + + duration, e1 := strconv.ParseInt(strings.TrimSpace(args[0]), 10, 64) + count, e2 := strconv.ParseInt(strings.TrimSpace(args[1]), 10, 64) + cur, e3 := strconv.ParseInt(strings.TrimSpace(args[2]), 10, 64) + if e1 != nil || e2 != nil || e3 != nil { + return errors.New("parse limiter fail:" + cfg) + } + // 0 means not limit + if duration > 0 && count > 0 { + bu := &limiterBucket{ + cur: count * 10, + max: count * 10, + duration: duration * 100, + batch: count, + } + go func() { + for { + time.Sleep(time.Millisecond * time.Duration(bu.duration)) + if bu.cur+bu.batch > bu.max { + bu.cur = bu.max + } else { + atomic.AddInt64(&bu.cur, bu.batch) + } + } + }() + l.buckets[user] = bu + } else { + if _, ok := l.buckets[user]; ok { + delete(l.buckets, user) + } + } + // zero means not limit + if cur > 0 { + l.concurrent[user] = make(chan bool, cur) + } else { + if _, ok := l.concurrent[user]; ok { + delete(l.concurrent, user) + } + } + return nil +} + +// Reload parses config from r, then live reloads the LocalLimiter. +func (l *LocalLimiter) Reload(r io.Reader) error { + var period time.Duration + kvs := make(map[string]string) + + if r == nil || l.Stopped() { + return nil + } + + // splitLine splits a line text by white space. + // A line started with '#' will be ignored, otherwise it is valid. + split := func(line string) []string { + if line == "" { + return nil + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + + if strings.IndexByte(line, '#') == 0 { + return nil + } + + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + return ss + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + ss := split(line) + if len(ss) == 0 { + continue + } + + switch ss[0] { + case "reload": // reload option + if len(ss) > 1 { + period, _ = time.ParseDuration(ss[1]) + } + default: + var k, v string + k = ss[0] + if len(ss) > 2 { + v = ss[2] + } + kvs[k] = v + } + } + + if err := scanner.Err(); err != nil { + return err + } + + l.mux.Lock() + defer l.mux.Unlock() + + l.period = period + for user, args := range kvs { + err := l.AddRule(user, args) + if err != nil { + return err + } + } + + return nil +} + +// Period returns the reload period. +func (l *LocalLimiter) Period() time.Duration { + if l.Stopped() { + return -1 + } + + l.mux.RLock() + defer l.mux.RUnlock() + + return l.period +} + +// Stop stops reloading. +func (l *LocalLimiter) Stop() { + select { + case <-l.stopped: + default: + close(l.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (l *LocalLimiter) Stopped() bool { + select { + case <-l.stopped: + return true + default: + return false + } +} + +func (l *LocalLimiter) getToken(key string) bool { + b, ok := l.buckets[key] + if !ok || b == nil { + return true + } + if b.cur <= 0 { + return false + } + atomic.AddInt64(&b.cur, -10) + return true +} + +func (l *LocalLimiter) checkConcurrent(key string) (func(), bool) { + c, ok := l.concurrent[key] + if !ok || c == nil { + return func() {}, true + } + select { + case c <- true: + return func() { + <-c + }, true + default: + return nil, false + } +} diff --git a/limiter_test.go b/limiter_test.go new file mode 100644 index 0000000..d491352 --- /dev/null +++ b/limiter_test.go @@ -0,0 +1,69 @@ +package gost + +import ( + "fmt" + "testing" +) + +func TestNewLocalLimiter(t *testing.T) { + items := []struct { + user string + args string + success bool + }{ + {"admin", "10,1", true}, + {"admin", "", true}, + {"admin", "10,1,1", true}, + {"admin", "10", false}, + {"admin", "0,1", true}, + {"admin", "0,1,1", true}, + {"admin", "a,b", false}, + {"", "", true}, + {"", "1,2", true}, + } + for i, item := range items { + item := item + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + _, err := NewLocalLimiter(item.user, item.args) + if (err == nil) != item.success { + t.Error("test NewLocalLimiter fail", item.user, item.args) + } + }) + } +} + +func TestCheckRate(t *testing.T) { + items := []struct { + user string + args string + testUser string + checkCount int + shouldSuccessCount int + }{ + {"admin", "10,3", "admin", 10, 3}, + {"admin", "10,3,0", "admin", 10, 3}, + {"admin", "10,3,2", "admin", 10, 2}, + {"admin", "0,0", "admin", 10, 10}, + {"admin", "10,3,5", "admin", 10, 3}, + {"admin", "10,3,5", "admin22", 10, 10}, + {"admin", "0,0,5", "admin", 10, 5}, + } + for i, item := range items { + item := item + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + l, err := NewLocalLimiter(item.user, item.args) + if err != nil { + t.Error("test NewLocalLimiter fail", item.user, item.args) + } + successCount := 0 + for j := 0; j < item.checkCount; j++ { + if _, ok := l.CheckRate(item.testUser, true); ok { + successCount++ + } + } + if successCount != item.shouldSuccessCount { + t.Error("test localLimiter fail", item) + } + }) + } +} diff --git a/relay.go b/relay.go index 74423f4..103cac3 100644 --- a/relay.go +++ b/relay.go @@ -171,6 +171,17 @@ func (h *relayHandler) Handle(conn net.Conn) { log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user) return } + if h.options.Limiter != nil { + done, ok := h.options.Limiter.CheckRate(user, true) + if !ok { + resp.Status = relay.StatusForbidden + resp.WriteTo(conn) + log.Logf("[relay] %s -> %s : %s rate limiter", conn.RemoteAddr(), conn.LocalAddr(), user) + return + } else { + defer done() + } + } if raddr != "" { if len(h.group.Nodes()) > 0 { diff --git a/socks.go b/socks.go index d59dd89..5554a72 100644 --- a/socks.go +++ b/socks.go @@ -112,6 +112,7 @@ type serverSelector struct { // Users []*url.Userinfo Authenticator Authenticator TLSConfig *tls.Config + Limiter Limiter } func (selector *serverSelector) Methods() []uint8 { @@ -181,7 +182,14 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr()) return nil, gosocks5.ErrAuthFailure } - + if req.Username != "" && selector.Limiter != nil { + if _, ok := selector.Limiter.CheckRate(req.Username, false); !ok { + if Debug { + log.Logf("[http] %s <- %s rate limiter \n%s", conn.RemoteAddr(), conn.LocalAddr(), req.Username) + } + return nil, errors.New("rate limiter check fail") + } + } resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) if err := resp.Write(conn); err != nil { log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) @@ -836,6 +844,7 @@ func (h *socks5Handler) Init(options ...HandlerOption) { // Users: h.options.Users, Authenticator: h.options.Authenticator, TLSConfig: tlsConfig, + Limiter: h.options.Limiter, } // methods that socks5 server supported h.selector.AddMethod( From bece31e0bbf896ee1560df398fea75fd0a961827 Mon Sep 17 00:00:00 2001 From: tongsq Date: Thu, 17 Oct 2024 18:20:29 +0800 Subject: [PATCH 2/3] update --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index d633666..8a5a765 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/ginuerzh/gost +module github.com/tongsq/gost go 1.22 From 087b05ae3f722a06f17ed2693a02208077a9ec22 Mon Sep 17 00:00:00 2001 From: tongsq Date: Thu, 17 Oct 2024 18:26:19 +0800 Subject: [PATCH 3/3] update namespace --- cmd/gost/cfg.go | 2 +- cmd/gost/main.go | 2 +- cmd/gost/peer.go | 2 +- cmd/gost/route.go | 2 +- examples/bench/cli.go | 2 +- examples/bench/srv.go | 2 +- examples/forward/direct/client.go | 2 +- examples/forward/direct/server.go | 2 +- examples/forward/remote/client.go | 2 +- examples/forward/remote/server.go | 2 +- examples/forward/udp/direct.go | 30 +++++++++++++++--------------- examples/forward/udp/remote.go | 2 +- examples/http2/http2.go | 2 +- examples/quic/quicc.go | 2 +- examples/quic/quics.go | 2 +- examples/ssh/sshc.go | 2 +- examples/ssh/sshd.go | 2 +- 17 files changed, 31 insertions(+), 31 deletions(-) diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index 8b2912a..048852b 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -11,7 +11,7 @@ import ( "os" "strings" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) var ( diff --git a/cmd/gost/main.go b/cmd/gost/main.go index ce9184c..dde751b 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -11,8 +11,8 @@ import ( _ "net/http/pprof" - "github.com/ginuerzh/gost" "github.com/go-log/log" + "github.com/tongsq/gost" ) var ( diff --git a/cmd/gost/peer.go b/cmd/gost/peer.go index af062c7..a023ad2 100644 --- a/cmd/gost/peer.go +++ b/cmd/gost/peer.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) type peerConfig struct { diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 9470041..413d74a 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -12,8 +12,8 @@ import ( "strings" "time" - "github.com/ginuerzh/gost" "github.com/go-log/log" + "github.com/tongsq/gost" ) type stringList []string diff --git a/examples/bench/cli.go b/examples/bench/cli.go index 57c189c..bc4f133 100644 --- a/examples/bench/cli.go +++ b/examples/bench/cli.go @@ -10,7 +10,7 @@ import ( "sync" "time" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" "golang.org/x/net/http2" ) diff --git a/examples/bench/srv.go b/examples/bench/srv.go index 36da8e4..d9c6d47 100644 --- a/examples/bench/srv.go +++ b/examples/bench/srv.go @@ -9,7 +9,7 @@ import ( "net/url" "time" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" "golang.org/x/net/http2" ) diff --git a/examples/forward/direct/client.go b/examples/forward/direct/client.go index c824b54..9ddfef5 100644 --- a/examples/forward/direct/client.go +++ b/examples/forward/direct/client.go @@ -3,7 +3,7 @@ package main import ( "log" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) func main() { diff --git a/examples/forward/direct/server.go b/examples/forward/direct/server.go index 20aca6d..bcb4ac8 100644 --- a/examples/forward/direct/server.go +++ b/examples/forward/direct/server.go @@ -4,7 +4,7 @@ import ( "crypto/tls" "log" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) func main() { diff --git a/examples/forward/remote/client.go b/examples/forward/remote/client.go index 68f1737..2e05020 100644 --- a/examples/forward/remote/client.go +++ b/examples/forward/remote/client.go @@ -3,7 +3,7 @@ package main import ( "log" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) func main() { diff --git a/examples/forward/remote/server.go b/examples/forward/remote/server.go index cc83aa8..f5ed5b8 100644 --- a/examples/forward/remote/server.go +++ b/examples/forward/remote/server.go @@ -4,7 +4,7 @@ import ( "crypto/tls" "log" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) func main() { diff --git a/examples/forward/udp/direct.go b/examples/forward/udp/direct.go index 79f9a35..1a9a097 100644 --- a/examples/forward/udp/direct.go +++ b/examples/forward/udp/direct.go @@ -5,7 +5,7 @@ import ( "log" "time" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) var ( @@ -37,20 +37,20 @@ func udpDirectForwardServer() { } h := gost.UDPDirectForwardHandler( faddr, - /* - gost.ChainHandlerOption(gost.NewChain(gost.Node{ - Protocol: "socks5", - Transport: "tcp", - Addr: ":11080", - User: url.UserPassword("admin", "123456"), - Client: &gost.Client{ - Connector: gost.SOCKS5Connector( - url.UserPassword("admin", "123456"), - ), - Transporter: gost.TCPTransporter(), - }, - })), - */ + /* + gost.ChainHandlerOption(gost.NewChain(gost.Node{ + Protocol: "socks5", + Transport: "tcp", + Addr: ":11080", + User: url.UserPassword("admin", "123456"), + Client: &gost.Client{ + Connector: gost.SOCKS5Connector( + url.UserPassword("admin", "123456"), + ), + Transporter: gost.TCPTransporter(), + }, + })), + */ ) s := &gost.Server{ln} log.Fatal(s.Serve(h)) diff --git a/examples/forward/udp/remote.go b/examples/forward/udp/remote.go index b0c4d50..0f8c9c1 100644 --- a/examples/forward/udp/remote.go +++ b/examples/forward/udp/remote.go @@ -5,7 +5,7 @@ import ( "log" "time" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) var ( diff --git a/examples/http2/http2.go b/examples/http2/http2.go index 369fdfe..5de66de 100644 --- a/examples/http2/http2.go +++ b/examples/http2/http2.go @@ -8,7 +8,7 @@ import ( "golang.org/x/net/http2" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) var ( diff --git a/examples/quic/quicc.go b/examples/quic/quicc.go index 119e9e6..af866e5 100644 --- a/examples/quic/quicc.go +++ b/examples/quic/quicc.go @@ -6,7 +6,7 @@ import ( "log" "time" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) var ( diff --git a/examples/quic/quics.go b/examples/quic/quics.go index 246995a..da01952 100644 --- a/examples/quic/quics.go +++ b/examples/quic/quics.go @@ -5,7 +5,7 @@ import ( "flag" "log" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) var ( diff --git a/examples/ssh/sshc.go b/examples/ssh/sshc.go index a90a026..47c97bd 100644 --- a/examples/ssh/sshc.go +++ b/examples/ssh/sshc.go @@ -6,7 +6,7 @@ import ( "log" "time" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) var ( diff --git a/examples/ssh/sshd.go b/examples/ssh/sshd.go index 72aa2f7..07aabdc 100644 --- a/examples/ssh/sshd.go +++ b/examples/ssh/sshd.go @@ -5,7 +5,7 @@ import ( "flag" "log" - "github.com/ginuerzh/gost" + "github.com/tongsq/gost" ) var (