This commit is contained in:
tongsq 2024-10-17 10:27:47 +00:00 committed by GitHub
commit eddb63474c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 452 additions and 34 deletions

1
.gitignore vendored
View File

@ -33,3 +33,4 @@ _testmain.go
.vscode/ .vscode/
cmd/gost/gost cmd/gost/gost
.idea

View File

@ -11,7 +11,7 @@ import (
"os" "os"
"strings" "strings"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
var ( var (
@ -148,6 +148,26 @@ func parseAuthenticator(s string) (gost.Authenticator, error) {
return au, nil return au, nil
} }
func parseLimiter(s string) (gost.Limiter, error) {
if s == "" {
return nil, nil
}
f, err := os.Open(s)
if err != nil {
return nil, err
}
defer f.Close()
l, _ := gost.NewLocalLimiter("", "")
err = l.Reload(f)
if err != nil {
return nil, err
}
go gost.PeriodReload(l, s)
return l, nil
}
func parseIP(s string, port string) (ips []string) { func parseIP(s string, port string) (ips []string) {
if s == "" { if s == "" {
return return

View File

@ -11,8 +11,8 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"github.com/ginuerzh/gost"
"github.com/go-log/log" "github.com/go-log/log"
"github.com/tongsq/gost"
) )
var ( var (

View File

@ -9,7 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
type peerConfig struct { type peerConfig struct {

View File

@ -12,8 +12,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/ginuerzh/gost"
"github.com/go-log/log" "github.com/go-log/log"
"github.com/tongsq/gost"
) )
type stringList []string type stringList []string
@ -386,6 +386,19 @@ func (r *route) GenRouters() ([]router, error) {
node.User = users[0] node.User = users[0]
} }
} }
//init rate limiter
limiterHandler, err := parseLimiter(node.Get("secrets"))
if err != nil {
return nil, err
}
if limiterHandler == nil && strings.TrimSpace(node.Get("limiter")) != "" && node.User != nil {
limiterHandler, err = gost.NewLocalLimiter(node.User.Username(), strings.TrimSpace(node.Get("limiter")))
if err != nil {
return nil, err
}
}
certFile, keyFile := node.Get("cert"), node.Get("key") certFile, keyFile := node.Get("cert"), node.Get("key")
tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca")) tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca"))
if err != nil && certFile != "" && keyFile != "" { if err != nil && certFile != "" && keyFile != "" {
@ -671,6 +684,7 @@ func (r *route) GenRouters() ([]router, error) {
gost.IPRoutesHandlerOption(tunRoutes...), gost.IPRoutesHandlerOption(tunRoutes...),
gost.ProxyAgentHandlerOption(node.Get("proxyAgent")), gost.ProxyAgentHandlerOption(node.Get("proxyAgent")),
gost.HTTPTunnelHandlerOption(node.GetBool("httpTunnel")), gost.HTTPTunnelHandlerOption(node.GetBool("httpTunnel")),
gost.LimiterHandlerOption(limiterHandler),
) )
rt := router{ rt := router{

View File

@ -10,7 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )

View File

@ -9,7 +9,7 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )

View File

@ -3,7 +3,7 @@ package main
import ( import (
"log" "log"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
func main() { func main() {

View File

@ -4,7 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"log" "log"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
func main() { func main() {

View File

@ -3,7 +3,7 @@ package main
import ( import (
"log" "log"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
func main() { func main() {

View File

@ -4,7 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"log" "log"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
func main() { func main() {

View File

@ -5,7 +5,7 @@ import (
"log" "log"
"time" "time"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
var ( var (
@ -37,20 +37,20 @@ func udpDirectForwardServer() {
} }
h := gost.UDPDirectForwardHandler( h := gost.UDPDirectForwardHandler(
faddr, faddr,
/* /*
gost.ChainHandlerOption(gost.NewChain(gost.Node{ gost.ChainHandlerOption(gost.NewChain(gost.Node{
Protocol: "socks5", Protocol: "socks5",
Transport: "tcp", Transport: "tcp",
Addr: ":11080", Addr: ":11080",
User: url.UserPassword("admin", "123456"), User: url.UserPassword("admin", "123456"),
Client: &gost.Client{ Client: &gost.Client{
Connector: gost.SOCKS5Connector( Connector: gost.SOCKS5Connector(
url.UserPassword("admin", "123456"), url.UserPassword("admin", "123456"),
), ),
Transporter: gost.TCPTransporter(), Transporter: gost.TCPTransporter(),
}, },
})), })),
*/ */
) )
s := &gost.Server{ln} s := &gost.Server{ln}
log.Fatal(s.Serve(h)) log.Fatal(s.Serve(h))

View File

@ -5,7 +5,7 @@ import (
"log" "log"
"time" "time"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
var ( var (

View File

@ -8,7 +8,7 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
var ( var (

View File

@ -6,7 +6,7 @@ import (
"log" "log"
"time" "time"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
var ( var (

View File

@ -5,7 +5,7 @@ import (
"flag" "flag"
"log" "log"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
var ( var (

View File

@ -6,7 +6,7 @@ import (
"log" "log"
"time" "time"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
var ( var (

View File

@ -5,7 +5,7 @@ import (
"flag" "flag"
"log" "log"
"github.com/ginuerzh/gost" "github.com/tongsq/gost"
) )
var ( var (

2
go.mod
View File

@ -1,4 +1,4 @@
module github.com/ginuerzh/gost module github.com/tongsq/gost
go 1.22 go 1.22

View File

@ -44,6 +44,7 @@ type HandlerOptions struct {
IPRoutes []IPRoute IPRoutes []IPRoute
ProxyAgent string ProxyAgent string
HTTPTunnel bool HTTPTunnel bool
Limiter Limiter
} }
// HandlerOption allows a common way to set handler options. // HandlerOption allows a common way to set handler options.
@ -87,6 +88,13 @@ func AuthenticatorHandlerOption(au Authenticator) HandlerOption {
} }
} }
// LimiterHandlerOption sets the Rate limiter option of HandlerOptions
func LimiterHandlerOption(l Limiter) HandlerOption {
return func(opts *HandlerOptions) {
opts.Limiter = l
}
}
// TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions. // TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions.
func TLSConfigHandlerOption(config *tls.Config) HandlerOption { func TLSConfigHandlerOption(config *tls.Config) HandlerOption {
return func(opts *HandlerOptions) { return func(opts *HandlerOptions) {

16
http.go
View File

@ -212,7 +212,23 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
if !h.authenticate(conn, req, resp) { if !h.authenticate(conn, req, resp) {
return return
} }
user, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization"))
if h.options.Limiter != nil {
done, ok := h.options.Limiter.CheckRate(user, true)
if !ok {
resp.StatusCode = http.StatusTooManyRequests
if Debug {
dump, _ := httputil.DumpResponse(resp, false)
log.Logf("[http] %s <- %s rate limiter \n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
}
resp.Write(conn)
return
} else {
defer done()
}
}
if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") { if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") {
resp.StatusCode = http.StatusBadRequest resp.StatusCode = http.StatusBadRequest

View File

@ -394,7 +394,18 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
if !h.authenticate(w, r, resp) { if !h.authenticate(w, r, resp) {
return return
} }
user, _, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
if h.options.Limiter != nil {
done, ok := h.options.Limiter.CheckRate(user, true)
if !ok {
log.Logf("[http2] %s - %s rate limiter %s, user is %s",
r.RemoteAddr, laddr, host, user)
w.WriteHeader(http.StatusTooManyRequests)
return
} else {
defer done()
}
}
// delete the proxy related headers. // delete the proxy related headers.
r.Header.Del("Proxy-Authorization") r.Header.Del("Proxy-Authorization")
r.Header.Del("Proxy-Connection") r.Header.Del("Proxy-Connection")

259
limiter.go Normal file
View File

@ -0,0 +1,259 @@
package gost
import (
"bufio"
"errors"
"io"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
type Limiter interface {
CheckRate(key string, checkConcurrent bool) (func(), bool)
}
func NewLocalLimiter(user string, cfg string) (*LocalLimiter, error) {
limiter := LocalLimiter{
buckets: map[string]*limiterBucket{},
concurrent: map[string]chan bool{},
stopped: make(chan struct{}),
}
if cfg == "" || user == "" {
return &limiter, nil
}
if err := limiter.AddRule(user, cfg); err != nil {
return nil, err
}
return &limiter, nil
}
// Token Bucket
type limiterBucket struct {
max int64
cur int64
duration int64
batch int64
}
type LocalLimiter struct {
buckets map[string]*limiterBucket
concurrent map[string]chan bool
mux sync.RWMutex
stopped chan struct{}
period time.Duration
}
func (l *LocalLimiter) CheckRate(key string, checkConcurrent bool) (func(), bool) {
if checkConcurrent {
done, ok := l.checkConcurrent(key)
if !ok {
return nil, false
}
if t := l.getToken(key); !t {
done()
return nil, false
}
return done, true
} else {
if t := l.getToken(key); !t {
return nil, false
}
return nil, true
}
}
func (l *LocalLimiter) AddRule(user string, cfg string) error {
if user == "" {
return nil
}
if cfg == "" {
//reload need check old limit exists
if _, ok := l.buckets[user]; ok {
delete(l.buckets, user)
}
if _, ok := l.concurrent[user]; ok {
delete(l.concurrent, user)
}
return nil
}
args := strings.Split(cfg, ",")
if len(args) < 2 || len(args) > 3 {
return errors.New("parse limiter fail:" + cfg)
}
if len(args) == 2 {
args = append(args, "0")
}
duration, e1 := strconv.ParseInt(strings.TrimSpace(args[0]), 10, 64)
count, e2 := strconv.ParseInt(strings.TrimSpace(args[1]), 10, 64)
cur, e3 := strconv.ParseInt(strings.TrimSpace(args[2]), 10, 64)
if e1 != nil || e2 != nil || e3 != nil {
return errors.New("parse limiter fail:" + cfg)
}
// 0 means not limit
if duration > 0 && count > 0 {
bu := &limiterBucket{
cur: count * 10,
max: count * 10,
duration: duration * 100,
batch: count,
}
go func() {
for {
time.Sleep(time.Millisecond * time.Duration(bu.duration))
if bu.cur+bu.batch > bu.max {
bu.cur = bu.max
} else {
atomic.AddInt64(&bu.cur, bu.batch)
}
}
}()
l.buckets[user] = bu
} else {
if _, ok := l.buckets[user]; ok {
delete(l.buckets, user)
}
}
// zero means not limit
if cur > 0 {
l.concurrent[user] = make(chan bool, cur)
} else {
if _, ok := l.concurrent[user]; ok {
delete(l.concurrent, user)
}
}
return nil
}
// Reload parses config from r, then live reloads the LocalLimiter.
func (l *LocalLimiter) Reload(r io.Reader) error {
var period time.Duration
kvs := make(map[string]string)
if r == nil || l.Stopped() {
return nil
}
// splitLine splits a line text by white space.
// A line started with '#' will be ignored, otherwise it is valid.
split := func(line string) []string {
if line == "" {
return nil
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)
if strings.IndexByte(line, '#') == 0 {
return nil
}
var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
return ss
}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
ss := split(line)
if len(ss) == 0 {
continue
}
switch ss[0] {
case "reload": // reload option
if len(ss) > 1 {
period, _ = time.ParseDuration(ss[1])
}
default:
var k, v string
k = ss[0]
if len(ss) > 2 {
v = ss[2]
}
kvs[k] = v
}
}
if err := scanner.Err(); err != nil {
return err
}
l.mux.Lock()
defer l.mux.Unlock()
l.period = period
for user, args := range kvs {
err := l.AddRule(user, args)
if err != nil {
return err
}
}
return nil
}
// Period returns the reload period.
func (l *LocalLimiter) Period() time.Duration {
if l.Stopped() {
return -1
}
l.mux.RLock()
defer l.mux.RUnlock()
return l.period
}
// Stop stops reloading.
func (l *LocalLimiter) Stop() {
select {
case <-l.stopped:
default:
close(l.stopped)
}
}
// Stopped checks whether the reloader is stopped.
func (l *LocalLimiter) Stopped() bool {
select {
case <-l.stopped:
return true
default:
return false
}
}
func (l *LocalLimiter) getToken(key string) bool {
b, ok := l.buckets[key]
if !ok || b == nil {
return true
}
if b.cur <= 0 {
return false
}
atomic.AddInt64(&b.cur, -10)
return true
}
func (l *LocalLimiter) checkConcurrent(key string) (func(), bool) {
c, ok := l.concurrent[key]
if !ok || c == nil {
return func() {}, true
}
select {
case c <- true:
return func() {
<-c
}, true
default:
return nil, false
}
}

69
limiter_test.go Normal file
View File

@ -0,0 +1,69 @@
package gost
import (
"fmt"
"testing"
)
func TestNewLocalLimiter(t *testing.T) {
items := []struct {
user string
args string
success bool
}{
{"admin", "10,1", true},
{"admin", "", true},
{"admin", "10,1,1", true},
{"admin", "10", false},
{"admin", "0,1", true},
{"admin", "0,1,1", true},
{"admin", "a,b", false},
{"", "", true},
{"", "1,2", true},
}
for i, item := range items {
item := item
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
_, err := NewLocalLimiter(item.user, item.args)
if (err == nil) != item.success {
t.Error("test NewLocalLimiter fail", item.user, item.args)
}
})
}
}
func TestCheckRate(t *testing.T) {
items := []struct {
user string
args string
testUser string
checkCount int
shouldSuccessCount int
}{
{"admin", "10,3", "admin", 10, 3},
{"admin", "10,3,0", "admin", 10, 3},
{"admin", "10,3,2", "admin", 10, 2},
{"admin", "0,0", "admin", 10, 10},
{"admin", "10,3,5", "admin", 10, 3},
{"admin", "10,3,5", "admin22", 10, 10},
{"admin", "0,0,5", "admin", 10, 5},
}
for i, item := range items {
item := item
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
l, err := NewLocalLimiter(item.user, item.args)
if err != nil {
t.Error("test NewLocalLimiter fail", item.user, item.args)
}
successCount := 0
for j := 0; j < item.checkCount; j++ {
if _, ok := l.CheckRate(item.testUser, true); ok {
successCount++
}
}
if successCount != item.shouldSuccessCount {
t.Error("test localLimiter fail", item)
}
})
}
}

View File

@ -171,6 +171,17 @@ func (h *relayHandler) Handle(conn net.Conn) {
log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user) log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user)
return return
} }
if h.options.Limiter != nil {
done, ok := h.options.Limiter.CheckRate(user, true)
if !ok {
resp.Status = relay.StatusForbidden
resp.WriteTo(conn)
log.Logf("[relay] %s -> %s : %s rate limiter", conn.RemoteAddr(), conn.LocalAddr(), user)
return
} else {
defer done()
}
}
if raddr != "" { if raddr != "" {
if len(h.group.Nodes()) > 0 { if len(h.group.Nodes()) > 0 {

View File

@ -112,6 +112,7 @@ type serverSelector struct {
// Users []*url.Userinfo // Users []*url.Userinfo
Authenticator Authenticator Authenticator Authenticator
TLSConfig *tls.Config TLSConfig *tls.Config
Limiter Limiter
} }
func (selector *serverSelector) Methods() []uint8 { func (selector *serverSelector) Methods() []uint8 {
@ -181,7 +182,14 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr()) log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr())
return nil, gosocks5.ErrAuthFailure return nil, gosocks5.ErrAuthFailure
} }
if req.Username != "" && selector.Limiter != nil {
if _, ok := selector.Limiter.CheckRate(req.Username, false); !ok {
if Debug {
log.Logf("[http] %s <- %s rate limiter \n%s", conn.RemoteAddr(), conn.LocalAddr(), req.Username)
}
return nil, errors.New("rate limiter check fail")
}
}
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded)
if err := resp.Write(conn); err != nil { if err := resp.Write(conn); err != nil {
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
@ -836,6 +844,7 @@ func (h *socks5Handler) Init(options ...HandlerOption) {
// Users: h.options.Users, // Users: h.options.Users,
Authenticator: h.options.Authenticator, Authenticator: h.options.Authenticator,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
Limiter: h.options.Limiter,
} }
// methods that socks5 server supported // methods that socks5 server supported
h.selector.AddMethod( h.selector.AddMethod(