260 lines
4.8 KiB
Go
260 lines
4.8 KiB
Go
package gost
|
|
|
|
import (
|
|
"bufio"
|
|
"errors"
|
|
"io"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
type Limiter interface {
|
|
CheckRate(key string, checkConcurrent bool) (func(), bool)
|
|
}
|
|
|
|
func NewLocalLimiter(user string, cfg string) (*LocalLimiter, error) {
|
|
limiter := LocalLimiter{
|
|
buckets: map[string]*limiterBucket{},
|
|
concurrent: map[string]chan bool{},
|
|
stopped: make(chan struct{}),
|
|
}
|
|
if cfg == "" || user == "" {
|
|
return &limiter, nil
|
|
}
|
|
if err := limiter.AddRule(user, cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
return &limiter, nil
|
|
}
|
|
|
|
// Token Bucket
|
|
type limiterBucket struct {
|
|
max int64
|
|
cur int64
|
|
duration int64
|
|
batch int64
|
|
}
|
|
|
|
type LocalLimiter struct {
|
|
buckets map[string]*limiterBucket
|
|
concurrent map[string]chan bool
|
|
mux sync.RWMutex
|
|
stopped chan struct{}
|
|
period time.Duration
|
|
}
|
|
|
|
func (l *LocalLimiter) CheckRate(key string, checkConcurrent bool) (func(), bool) {
|
|
if checkConcurrent {
|
|
done, ok := l.checkConcurrent(key)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
if t := l.getToken(key); !t {
|
|
done()
|
|
return nil, false
|
|
}
|
|
return done, true
|
|
} else {
|
|
if t := l.getToken(key); !t {
|
|
return nil, false
|
|
}
|
|
return nil, true
|
|
}
|
|
}
|
|
|
|
func (l *LocalLimiter) AddRule(user string, cfg string) error {
|
|
if user == "" {
|
|
return nil
|
|
}
|
|
if cfg == "" {
|
|
//reload need check old limit exists
|
|
if _, ok := l.buckets[user]; ok {
|
|
delete(l.buckets, user)
|
|
}
|
|
if _, ok := l.concurrent[user]; ok {
|
|
delete(l.concurrent, user)
|
|
}
|
|
return nil
|
|
}
|
|
args := strings.Split(cfg, ",")
|
|
if len(args) < 2 || len(args) > 3 {
|
|
return errors.New("parse limiter fail:" + cfg)
|
|
}
|
|
if len(args) == 2 {
|
|
args = append(args, "0")
|
|
}
|
|
|
|
duration, e1 := strconv.ParseInt(strings.TrimSpace(args[0]), 10, 64)
|
|
count, e2 := strconv.ParseInt(strings.TrimSpace(args[1]), 10, 64)
|
|
cur, e3 := strconv.ParseInt(strings.TrimSpace(args[2]), 10, 64)
|
|
if e1 != nil || e2 != nil || e3 != nil {
|
|
return errors.New("parse limiter fail:" + cfg)
|
|
}
|
|
// 0 means not limit
|
|
if duration > 0 && count > 0 {
|
|
bu := &limiterBucket{
|
|
cur: count * 10,
|
|
max: count * 10,
|
|
duration: duration * 100,
|
|
batch: count,
|
|
}
|
|
go func() {
|
|
for {
|
|
time.Sleep(time.Millisecond * time.Duration(bu.duration))
|
|
if bu.cur+bu.batch > bu.max {
|
|
bu.cur = bu.max
|
|
} else {
|
|
atomic.AddInt64(&bu.cur, bu.batch)
|
|
}
|
|
}
|
|
}()
|
|
l.buckets[user] = bu
|
|
} else {
|
|
if _, ok := l.buckets[user]; ok {
|
|
delete(l.buckets, user)
|
|
}
|
|
}
|
|
// zero means not limit
|
|
if cur > 0 {
|
|
l.concurrent[user] = make(chan bool, cur)
|
|
} else {
|
|
if _, ok := l.concurrent[user]; ok {
|
|
delete(l.concurrent, user)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Reload parses config from r, then live reloads the LocalLimiter.
|
|
func (l *LocalLimiter) Reload(r io.Reader) error {
|
|
var period time.Duration
|
|
kvs := make(map[string]string)
|
|
|
|
if r == nil || l.Stopped() {
|
|
return nil
|
|
}
|
|
|
|
// splitLine splits a line text by white space.
|
|
// A line started with '#' will be ignored, otherwise it is valid.
|
|
split := func(line string) []string {
|
|
if line == "" {
|
|
return nil
|
|
}
|
|
line = strings.Replace(line, "\t", " ", -1)
|
|
line = strings.TrimSpace(line)
|
|
|
|
if strings.IndexByte(line, '#') == 0 {
|
|
return nil
|
|
}
|
|
|
|
var ss []string
|
|
for _, s := range strings.Split(line, " ") {
|
|
if s = strings.TrimSpace(s); s != "" {
|
|
ss = append(ss, s)
|
|
}
|
|
}
|
|
return ss
|
|
}
|
|
|
|
scanner := bufio.NewScanner(r)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
ss := split(line)
|
|
if len(ss) == 0 {
|
|
continue
|
|
}
|
|
|
|
switch ss[0] {
|
|
case "reload": // reload option
|
|
if len(ss) > 1 {
|
|
period, _ = time.ParseDuration(ss[1])
|
|
}
|
|
default:
|
|
var k, v string
|
|
k = ss[0]
|
|
if len(ss) > 2 {
|
|
v = ss[2]
|
|
}
|
|
kvs[k] = v
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
l.mux.Lock()
|
|
defer l.mux.Unlock()
|
|
|
|
l.period = period
|
|
for user, args := range kvs {
|
|
err := l.AddRule(user, args)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Period returns the reload period.
|
|
func (l *LocalLimiter) Period() time.Duration {
|
|
if l.Stopped() {
|
|
return -1
|
|
}
|
|
|
|
l.mux.RLock()
|
|
defer l.mux.RUnlock()
|
|
|
|
return l.period
|
|
}
|
|
|
|
// Stop stops reloading.
|
|
func (l *LocalLimiter) Stop() {
|
|
select {
|
|
case <-l.stopped:
|
|
default:
|
|
close(l.stopped)
|
|
}
|
|
}
|
|
|
|
// Stopped checks whether the reloader is stopped.
|
|
func (l *LocalLimiter) Stopped() bool {
|
|
select {
|
|
case <-l.stopped:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (l *LocalLimiter) getToken(key string) bool {
|
|
b, ok := l.buckets[key]
|
|
if !ok || b == nil {
|
|
return true
|
|
}
|
|
if b.cur <= 0 {
|
|
return false
|
|
}
|
|
atomic.AddInt64(&b.cur, -10)
|
|
return true
|
|
}
|
|
|
|
func (l *LocalLimiter) checkConcurrent(key string) (func(), bool) {
|
|
c, ok := l.concurrent[key]
|
|
if !ok || c == nil {
|
|
return func() {}, true
|
|
}
|
|
select {
|
|
case c <- true:
|
|
return func() {
|
|
<-c
|
|
}, true
|
|
default:
|
|
return nil, false
|
|
}
|
|
}
|