Merge 087b05ae3f
into 87d6a2fdc2
This commit is contained in:
commit
eddb63474c
1
.gitignore
vendored
1
.gitignore
vendored
@ -33,3 +33,4 @@ _testmain.go
|
|||||||
|
|
||||||
.vscode/
|
.vscode/
|
||||||
cmd/gost/gost
|
cmd/gost/gost
|
||||||
|
.idea
|
@ -11,7 +11,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -148,6 +148,26 @@ func parseAuthenticator(s string) (gost.Authenticator, error) {
|
|||||||
return au, nil
|
return au, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseLimiter(s string) (gost.Limiter, error) {
|
||||||
|
if s == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
f, err := os.Open(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
l, _ := gost.NewLocalLimiter("", "")
|
||||||
|
err = l.Reload(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
go gost.PeriodReload(l, s)
|
||||||
|
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
func parseIP(s string, port string) (ips []string) {
|
func parseIP(s string, port string) (ips []string) {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return
|
return
|
||||||
|
@ -11,8 +11,8 @@ import (
|
|||||||
|
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
|
||||||
"github.com/go-log/log"
|
"github.com/go-log/log"
|
||||||
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
type peerConfig struct {
|
type peerConfig struct {
|
||||||
|
@ -12,8 +12,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
|
||||||
"github.com/go-log/log"
|
"github.com/go-log/log"
|
||||||
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
type stringList []string
|
type stringList []string
|
||||||
@ -386,6 +386,19 @@ func (r *route) GenRouters() ([]router, error) {
|
|||||||
node.User = users[0]
|
node.User = users[0]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//init rate limiter
|
||||||
|
limiterHandler, err := parseLimiter(node.Get("secrets"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if limiterHandler == nil && strings.TrimSpace(node.Get("limiter")) != "" && node.User != nil {
|
||||||
|
limiterHandler, err = gost.NewLocalLimiter(node.User.Username(), strings.TrimSpace(node.Get("limiter")))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
certFile, keyFile := node.Get("cert"), node.Get("key")
|
certFile, keyFile := node.Get("cert"), node.Get("key")
|
||||||
tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca"))
|
tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca"))
|
||||||
if err != nil && certFile != "" && keyFile != "" {
|
if err != nil && certFile != "" && keyFile != "" {
|
||||||
@ -671,6 +684,7 @@ func (r *route) GenRouters() ([]router, error) {
|
|||||||
gost.IPRoutesHandlerOption(tunRoutes...),
|
gost.IPRoutesHandlerOption(tunRoutes...),
|
||||||
gost.ProxyAgentHandlerOption(node.Get("proxyAgent")),
|
gost.ProxyAgentHandlerOption(node.Get("proxyAgent")),
|
||||||
gost.HTTPTunnelHandlerOption(node.GetBool("httpTunnel")),
|
gost.HTTPTunnelHandlerOption(node.GetBool("httpTunnel")),
|
||||||
|
gost.LimiterHandlerOption(limiterHandler),
|
||||||
)
|
)
|
||||||
|
|
||||||
rt := router{
|
rt := router{
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -3,7 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/ginuerzh/gost"
|
"github.com/tongsq/gost"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
2
go.mod
2
go.mod
@ -1,4 +1,4 @@
|
|||||||
module github.com/ginuerzh/gost
|
module github.com/tongsq/gost
|
||||||
|
|
||||||
go 1.22
|
go 1.22
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ type HandlerOptions struct {
|
|||||||
IPRoutes []IPRoute
|
IPRoutes []IPRoute
|
||||||
ProxyAgent string
|
ProxyAgent string
|
||||||
HTTPTunnel bool
|
HTTPTunnel bool
|
||||||
|
Limiter Limiter
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlerOption allows a common way to set handler options.
|
// HandlerOption allows a common way to set handler options.
|
||||||
@ -87,6 +88,13 @@ func AuthenticatorHandlerOption(au Authenticator) HandlerOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LimiterHandlerOption sets the Rate limiter option of HandlerOptions
|
||||||
|
func LimiterHandlerOption(l Limiter) HandlerOption {
|
||||||
|
return func(opts *HandlerOptions) {
|
||||||
|
opts.Limiter = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions.
|
// TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions.
|
||||||
func TLSConfigHandlerOption(config *tls.Config) HandlerOption {
|
func TLSConfigHandlerOption(config *tls.Config) HandlerOption {
|
||||||
return func(opts *HandlerOptions) {
|
return func(opts *HandlerOptions) {
|
||||||
|
16
http.go
16
http.go
@ -212,7 +212,23 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
|
|||||||
if !h.authenticate(conn, req, resp) {
|
if !h.authenticate(conn, req, resp) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
user, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization"))
|
||||||
|
if h.options.Limiter != nil {
|
||||||
|
done, ok := h.options.Limiter.CheckRate(user, true)
|
||||||
|
if !ok {
|
||||||
|
resp.StatusCode = http.StatusTooManyRequests
|
||||||
|
|
||||||
|
if Debug {
|
||||||
|
dump, _ := httputil.DumpResponse(resp, false)
|
||||||
|
log.Logf("[http] %s <- %s rate limiter \n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Write(conn)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
defer done()
|
||||||
|
}
|
||||||
|
}
|
||||||
if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") {
|
if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") {
|
||||||
resp.StatusCode = http.StatusBadRequest
|
resp.StatusCode = http.StatusBadRequest
|
||||||
|
|
||||||
|
13
http2.go
13
http2.go
@ -394,7 +394,18 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
|
|||||||
if !h.authenticate(w, r, resp) {
|
if !h.authenticate(w, r, resp) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
user, _, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
|
||||||
|
if h.options.Limiter != nil {
|
||||||
|
done, ok := h.options.Limiter.CheckRate(user, true)
|
||||||
|
if !ok {
|
||||||
|
log.Logf("[http2] %s - %s rate limiter %s, user is %s",
|
||||||
|
r.RemoteAddr, laddr, host, user)
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
defer done()
|
||||||
|
}
|
||||||
|
}
|
||||||
// delete the proxy related headers.
|
// delete the proxy related headers.
|
||||||
r.Header.Del("Proxy-Authorization")
|
r.Header.Del("Proxy-Authorization")
|
||||||
r.Header.Del("Proxy-Connection")
|
r.Header.Del("Proxy-Connection")
|
||||||
|
259
limiter.go
Normal file
259
limiter.go
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
package gost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Limiter interface {
|
||||||
|
CheckRate(key string, checkConcurrent bool) (func(), bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLocalLimiter(user string, cfg string) (*LocalLimiter, error) {
|
||||||
|
limiter := LocalLimiter{
|
||||||
|
buckets: map[string]*limiterBucket{},
|
||||||
|
concurrent: map[string]chan bool{},
|
||||||
|
stopped: make(chan struct{}),
|
||||||
|
}
|
||||||
|
if cfg == "" || user == "" {
|
||||||
|
return &limiter, nil
|
||||||
|
}
|
||||||
|
if err := limiter.AddRule(user, cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &limiter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token Bucket
|
||||||
|
type limiterBucket struct {
|
||||||
|
max int64
|
||||||
|
cur int64
|
||||||
|
duration int64
|
||||||
|
batch int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type LocalLimiter struct {
|
||||||
|
buckets map[string]*limiterBucket
|
||||||
|
concurrent map[string]chan bool
|
||||||
|
mux sync.RWMutex
|
||||||
|
stopped chan struct{}
|
||||||
|
period time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalLimiter) CheckRate(key string, checkConcurrent bool) (func(), bool) {
|
||||||
|
if checkConcurrent {
|
||||||
|
done, ok := l.checkConcurrent(key)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if t := l.getToken(key); !t {
|
||||||
|
done()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return done, true
|
||||||
|
} else {
|
||||||
|
if t := l.getToken(key); !t {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalLimiter) AddRule(user string, cfg string) error {
|
||||||
|
if user == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cfg == "" {
|
||||||
|
//reload need check old limit exists
|
||||||
|
if _, ok := l.buckets[user]; ok {
|
||||||
|
delete(l.buckets, user)
|
||||||
|
}
|
||||||
|
if _, ok := l.concurrent[user]; ok {
|
||||||
|
delete(l.concurrent, user)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
args := strings.Split(cfg, ",")
|
||||||
|
if len(args) < 2 || len(args) > 3 {
|
||||||
|
return errors.New("parse limiter fail:" + cfg)
|
||||||
|
}
|
||||||
|
if len(args) == 2 {
|
||||||
|
args = append(args, "0")
|
||||||
|
}
|
||||||
|
|
||||||
|
duration, e1 := strconv.ParseInt(strings.TrimSpace(args[0]), 10, 64)
|
||||||
|
count, e2 := strconv.ParseInt(strings.TrimSpace(args[1]), 10, 64)
|
||||||
|
cur, e3 := strconv.ParseInt(strings.TrimSpace(args[2]), 10, 64)
|
||||||
|
if e1 != nil || e2 != nil || e3 != nil {
|
||||||
|
return errors.New("parse limiter fail:" + cfg)
|
||||||
|
}
|
||||||
|
// 0 means not limit
|
||||||
|
if duration > 0 && count > 0 {
|
||||||
|
bu := &limiterBucket{
|
||||||
|
cur: count * 10,
|
||||||
|
max: count * 10,
|
||||||
|
duration: duration * 100,
|
||||||
|
batch: count,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Millisecond * time.Duration(bu.duration))
|
||||||
|
if bu.cur+bu.batch > bu.max {
|
||||||
|
bu.cur = bu.max
|
||||||
|
} else {
|
||||||
|
atomic.AddInt64(&bu.cur, bu.batch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
l.buckets[user] = bu
|
||||||
|
} else {
|
||||||
|
if _, ok := l.buckets[user]; ok {
|
||||||
|
delete(l.buckets, user)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// zero means not limit
|
||||||
|
if cur > 0 {
|
||||||
|
l.concurrent[user] = make(chan bool, cur)
|
||||||
|
} else {
|
||||||
|
if _, ok := l.concurrent[user]; ok {
|
||||||
|
delete(l.concurrent, user)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload parses config from r, then live reloads the LocalLimiter.
|
||||||
|
func (l *LocalLimiter) Reload(r io.Reader) error {
|
||||||
|
var period time.Duration
|
||||||
|
kvs := make(map[string]string)
|
||||||
|
|
||||||
|
if r == nil || l.Stopped() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitLine splits a line text by white space.
|
||||||
|
// A line started with '#' will be ignored, otherwise it is valid.
|
||||||
|
split := func(line string) []string {
|
||||||
|
if line == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
line = strings.Replace(line, "\t", " ", -1)
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
if strings.IndexByte(line, '#') == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ss []string
|
||||||
|
for _, s := range strings.Split(line, " ") {
|
||||||
|
if s = strings.TrimSpace(s); s != "" {
|
||||||
|
ss = append(ss, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ss
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
ss := split(line)
|
||||||
|
if len(ss) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ss[0] {
|
||||||
|
case "reload": // reload option
|
||||||
|
if len(ss) > 1 {
|
||||||
|
period, _ = time.ParseDuration(ss[1])
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
var k, v string
|
||||||
|
k = ss[0]
|
||||||
|
if len(ss) > 2 {
|
||||||
|
v = ss[2]
|
||||||
|
}
|
||||||
|
kvs[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
l.mux.Lock()
|
||||||
|
defer l.mux.Unlock()
|
||||||
|
|
||||||
|
l.period = period
|
||||||
|
for user, args := range kvs {
|
||||||
|
err := l.AddRule(user, args)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Period returns the reload period.
|
||||||
|
func (l *LocalLimiter) Period() time.Duration {
|
||||||
|
if l.Stopped() {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
l.mux.RLock()
|
||||||
|
defer l.mux.RUnlock()
|
||||||
|
|
||||||
|
return l.period
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops reloading.
|
||||||
|
func (l *LocalLimiter) Stop() {
|
||||||
|
select {
|
||||||
|
case <-l.stopped:
|
||||||
|
default:
|
||||||
|
close(l.stopped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stopped checks whether the reloader is stopped.
|
||||||
|
func (l *LocalLimiter) Stopped() bool {
|
||||||
|
select {
|
||||||
|
case <-l.stopped:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalLimiter) getToken(key string) bool {
|
||||||
|
b, ok := l.buckets[key]
|
||||||
|
if !ok || b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if b.cur <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
atomic.AddInt64(&b.cur, -10)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalLimiter) checkConcurrent(key string) (func(), bool) {
|
||||||
|
c, ok := l.concurrent[key]
|
||||||
|
if !ok || c == nil {
|
||||||
|
return func() {}, true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case c <- true:
|
||||||
|
return func() {
|
||||||
|
<-c
|
||||||
|
}, true
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
69
limiter_test.go
Normal file
69
limiter_test.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package gost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewLocalLimiter(t *testing.T) {
|
||||||
|
items := []struct {
|
||||||
|
user string
|
||||||
|
args string
|
||||||
|
success bool
|
||||||
|
}{
|
||||||
|
{"admin", "10,1", true},
|
||||||
|
{"admin", "", true},
|
||||||
|
{"admin", "10,1,1", true},
|
||||||
|
{"admin", "10", false},
|
||||||
|
{"admin", "0,1", true},
|
||||||
|
{"admin", "0,1,1", true},
|
||||||
|
{"admin", "a,b", false},
|
||||||
|
{"", "", true},
|
||||||
|
{"", "1,2", true},
|
||||||
|
}
|
||||||
|
for i, item := range items {
|
||||||
|
item := item
|
||||||
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
||||||
|
_, err := NewLocalLimiter(item.user, item.args)
|
||||||
|
if (err == nil) != item.success {
|
||||||
|
t.Error("test NewLocalLimiter fail", item.user, item.args)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckRate(t *testing.T) {
|
||||||
|
items := []struct {
|
||||||
|
user string
|
||||||
|
args string
|
||||||
|
testUser string
|
||||||
|
checkCount int
|
||||||
|
shouldSuccessCount int
|
||||||
|
}{
|
||||||
|
{"admin", "10,3", "admin", 10, 3},
|
||||||
|
{"admin", "10,3,0", "admin", 10, 3},
|
||||||
|
{"admin", "10,3,2", "admin", 10, 2},
|
||||||
|
{"admin", "0,0", "admin", 10, 10},
|
||||||
|
{"admin", "10,3,5", "admin", 10, 3},
|
||||||
|
{"admin", "10,3,5", "admin22", 10, 10},
|
||||||
|
{"admin", "0,0,5", "admin", 10, 5},
|
||||||
|
}
|
||||||
|
for i, item := range items {
|
||||||
|
item := item
|
||||||
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
|
||||||
|
l, err := NewLocalLimiter(item.user, item.args)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("test NewLocalLimiter fail", item.user, item.args)
|
||||||
|
}
|
||||||
|
successCount := 0
|
||||||
|
for j := 0; j < item.checkCount; j++ {
|
||||||
|
if _, ok := l.CheckRate(item.testUser, true); ok {
|
||||||
|
successCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if successCount != item.shouldSuccessCount {
|
||||||
|
t.Error("test localLimiter fail", item)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
11
relay.go
11
relay.go
@ -171,6 +171,17 @@ func (h *relayHandler) Handle(conn net.Conn) {
|
|||||||
log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user)
|
log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.options.Limiter != nil {
|
||||||
|
done, ok := h.options.Limiter.CheckRate(user, true)
|
||||||
|
if !ok {
|
||||||
|
resp.Status = relay.StatusForbidden
|
||||||
|
resp.WriteTo(conn)
|
||||||
|
log.Logf("[relay] %s -> %s : %s rate limiter", conn.RemoteAddr(), conn.LocalAddr(), user)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
defer done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if raddr != "" {
|
if raddr != "" {
|
||||||
if len(h.group.Nodes()) > 0 {
|
if len(h.group.Nodes()) > 0 {
|
||||||
|
11
socks.go
11
socks.go
@ -112,6 +112,7 @@ type serverSelector struct {
|
|||||||
// Users []*url.Userinfo
|
// Users []*url.Userinfo
|
||||||
Authenticator Authenticator
|
Authenticator Authenticator
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
|
Limiter Limiter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (selector *serverSelector) Methods() []uint8 {
|
func (selector *serverSelector) Methods() []uint8 {
|
||||||
@ -181,7 +182,14 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
|
|||||||
log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr())
|
log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr())
|
||||||
return nil, gosocks5.ErrAuthFailure
|
return nil, gosocks5.ErrAuthFailure
|
||||||
}
|
}
|
||||||
|
if req.Username != "" && selector.Limiter != nil {
|
||||||
|
if _, ok := selector.Limiter.CheckRate(req.Username, false); !ok {
|
||||||
|
if Debug {
|
||||||
|
log.Logf("[http] %s <- %s rate limiter \n%s", conn.RemoteAddr(), conn.LocalAddr(), req.Username)
|
||||||
|
}
|
||||||
|
return nil, errors.New("rate limiter check fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded)
|
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded)
|
||||||
if err := resp.Write(conn); err != nil {
|
if err := resp.Write(conn); err != nil {
|
||||||
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||||
@ -836,6 +844,7 @@ func (h *socks5Handler) Init(options ...HandlerOption) {
|
|||||||
// Users: h.options.Users,
|
// Users: h.options.Users,
|
||||||
Authenticator: h.options.Authenticator,
|
Authenticator: h.options.Authenticator,
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
|
Limiter: h.options.Limiter,
|
||||||
}
|
}
|
||||||
// methods that socks5 server supported
|
// methods that socks5 server supported
|
||||||
h.selector.AddMethod(
|
h.selector.AddMethod(
|
||||||
|
Loading…
Reference in New Issue
Block a user