From 149bc83706269d23398f6ba4993352c119f073c9 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Wed, 1 Apr 2015 07:37:33 +0800 Subject: [PATCH] add mempool --- client.go | 31 +++++++++++++--- main.go | 7 ++++ pool.go | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ server.go | 4 +- util.go | 4 +- 5 files changed, 147 insertions(+), 7 deletions(-) create mode 100644 pool.go diff --git a/client.go b/client.go index 964ef3e..112c7a3 100644 --- a/client.go +++ b/client.go @@ -15,8 +15,11 @@ import ( "net/http" "strconv" "strings" + //"sync/atomic" ) +var sessionCount int64 + func listenAndServe(addr string, handler func(net.Conn)) error { laddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { @@ -45,7 +48,10 @@ func handshake(conn net.Conn, methods ...uint8) (method uint8, err error) { if nm == 0 { nm = 1 } - b := make([]byte, 2+nm) + b := spool.Take() + defer spool.put(b) + + b = b[:2+nm] b[0] = gosocks5.Ver5 b[1] = uint8(nm) copy(b[2:], methods) @@ -68,6 +74,12 @@ func handshake(conn net.Conn, methods ...uint8) (method uint8, err error) { func cliHandle(conn net.Conn) { defer conn.Close() + /* + fmt.Println("new session", atomic.AddInt64(&sessionCount, 1)) + defer func() { + fmt.Println("session end", atomic.AddInt64(&sessionCount, -1)) + }() + */ sconn, err := Connect(Saddr, Proxy) if err != nil { @@ -105,7 +117,9 @@ func cliHandle(conn net.Conn) { return } - b := make([]byte, 8192) + b := mpool.Take() + defer mpool.put(b) + n, err := io.ReadAtLeast(conn, b, 2) if err != nil { log.Println(err) @@ -197,7 +211,9 @@ func cliTunnelUDP(uconn *net.UDPConn, sconn net.Conn) { var raddr *net.UDPAddr go func() { - b := make([]byte, 65535) + b := lpool.Take() + defer lpool.put(b) + for { n, addr, err := uconn.ReadFromUDP(b) if err != nil { @@ -221,6 +237,9 @@ func cliTunnelUDP(uconn *net.UDPConn, sconn net.Conn) { }() for { + b := lpool.Take() + defer lpool.put(b) + udp, err := gosocks5.ReadUDPDatagram(sconn) if err != nil { log.Println(err) @@ -228,7 +247,7 @@ func cliTunnelUDP(uconn *net.UDPConn, sconn net.Conn) { } //log.Println("w", udp.Header) udp.Header.Rsv = 0 - buf := &bytes.Buffer{} + buf := bytes.NewBuffer(b[0:0]) udp.Write(buf) if _, err := uconn.WriteTo(buf.Bytes(), raddr); err != nil { log.Println(err) @@ -331,7 +350,9 @@ func getShadowRequest(conn net.Conn) (addr *gosocks5.Addr, extra []byte, err err // buf size should at least have the same size with the largest possible // request size (when addrType is 3, domain name has at most 256 bytes) // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) - buf := make([]byte, 260) + buf := spool.Take() + defer spool.put(buf) + var n int // read till we get possible domain length field //shadowsocks.SetReadTimeout(conn) diff --git a/main.go b/main.go index 5e0b83a..508f55a 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "flag" "github.com/ginuerzh/gosocks5" "log" + "time" ) var ( @@ -31,6 +32,12 @@ func init() { log.SetFlags(log.LstdFlags | log.Lshortfile) } +var ( + spool = NewMemPool(1024, 120*time.Minute, 1024) // 1k size buffer pool + mpool = NewMemPool(16*1024, 60*time.Minute, 512) // 16k size buffer pool + lpool = NewMemPool(32*1024, 30*time.Minute, 256) // 32k size buffer pool +) + func main() { //log.Fatal(gost.Run()) if len(Saddr) == 0 { diff --git a/pool.go b/pool.go new file mode 100644 index 0000000..58493ec --- /dev/null +++ b/pool.go @@ -0,0 +1,108 @@ +// pool for buffer +package main + +import ( + "container/list" + //"log" + "time" +) + +type poolItem struct { + when time.Time + item interface{} +} + +type pool struct { + quque *list.List + takeChan, putChan chan interface{} + age time.Duration + max int +} + +func (p *pool) run() { + for { + if p.size() == 0 { + select { + case b := <-p.putChan: + p.put(b) + } + continue + } + + i := p.quque.Front() + timeout := time.NewTimer(p.age) + + select { + case b := <-p.putChan: + timeout.Stop() + p.put(b) + case p.takeChan <- i.Value.(*poolItem).item: + timeout.Stop() + p.quque.Remove(i) + case <-timeout.C: + i = p.quque.Back() + for i != nil { + if time.Since(i.Value.(*poolItem).when) < p.age { + break + } + e := i.Prev() + p.quque.Remove(i) + i = e + } + } + } +} + +func (p *pool) size() int { + return p.quque.Len() +} + +func (p *pool) put(v interface{}) { + if p.size() < p.max { + p.quque.PushFront(&poolItem{when: time.Now(), item: v}) + return + } +} + +type MemPool struct { + pool + bs int +} + +func NewMemPool(bs int, age time.Duration, max int) *MemPool { + if bs <= 0 { + bs = 8192 + } + + if age == 0 { + age = 1 * time.Minute + } + + p := &MemPool{ + pool: pool{ + quque: list.New(), + takeChan: make(chan interface{}), + putChan: make(chan interface{}), + age: age, + max: max, + }, + bs: bs, + } + + go p.run() + + return p +} + +func (p *MemPool) Take() []byte { + select { + case v := <-p.takeChan: + return v.([]byte) + default: + return make([]byte, p.bs) + } +} + +func (p *MemPool) Put(b []byte) { + p.putChan <- b +} diff --git a/server.go b/server.go index 3c04461..8957f8c 100644 --- a/server.go +++ b/server.go @@ -124,7 +124,9 @@ func srvHandle(conn net.Conn) { func srvTunnelUDP(conn net.Conn, uconn *net.UDPConn) { go func() { - b := make([]byte, 65535) + b := lpool.Take() + defer lpool.put(b) + for { n, addr, err := uconn.ReadFromUDP(b) if err != nil { diff --git a/util.go b/util.go index 52f6555..43a8953 100644 --- a/util.go +++ b/util.go @@ -100,7 +100,9 @@ func Connect(addr, proxy string) (net.Conn, error) { // based on io.Copy func Copy(dst io.Writer, src io.Reader) (written int64, err error) { - buf := make([]byte, 32*1024) + buf := lpool.Take() + defer lpool.put(buf) + for { nr, er := src.Read(buf) //log.Println("cp r", nr, er)