add mempool

This commit is contained in:
rui.zheng 2015-04-01 07:37:33 +08:00
parent df6abe8565
commit 149bc83706
5 changed files with 147 additions and 7 deletions

View File

@ -15,8 +15,11 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
//"sync/atomic"
) )
var sessionCount int64
func listenAndServe(addr string, handler func(net.Conn)) error { func listenAndServe(addr string, handler func(net.Conn)) error {
laddr, err := net.ResolveTCPAddr("tcp", addr) laddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil { if err != nil {
@ -45,7 +48,10 @@ func handshake(conn net.Conn, methods ...uint8) (method uint8, err error) {
if nm == 0 { if nm == 0 {
nm = 1 nm = 1
} }
b := make([]byte, 2+nm) b := spool.Take()
defer spool.put(b)
b = b[:2+nm]
b[0] = gosocks5.Ver5 b[0] = gosocks5.Ver5
b[1] = uint8(nm) b[1] = uint8(nm)
copy(b[2:], methods) copy(b[2:], methods)
@ -68,6 +74,12 @@ func handshake(conn net.Conn, methods ...uint8) (method uint8, err error) {
func cliHandle(conn net.Conn) { func cliHandle(conn net.Conn) {
defer conn.Close() defer conn.Close()
/*
fmt.Println("new session", atomic.AddInt64(&sessionCount, 1))
defer func() {
fmt.Println("session end", atomic.AddInt64(&sessionCount, -1))
}()
*/
sconn, err := Connect(Saddr, Proxy) sconn, err := Connect(Saddr, Proxy)
if err != nil { if err != nil {
@ -105,7 +117,9 @@ func cliHandle(conn net.Conn) {
return return
} }
b := make([]byte, 8192) b := mpool.Take()
defer mpool.put(b)
n, err := io.ReadAtLeast(conn, b, 2) n, err := io.ReadAtLeast(conn, b, 2)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
@ -197,7 +211,9 @@ func cliTunnelUDP(uconn *net.UDPConn, sconn net.Conn) {
var raddr *net.UDPAddr var raddr *net.UDPAddr
go func() { go func() {
b := make([]byte, 65535) b := lpool.Take()
defer lpool.put(b)
for { for {
n, addr, err := uconn.ReadFromUDP(b) n, addr, err := uconn.ReadFromUDP(b)
if err != nil { if err != nil {
@ -221,6 +237,9 @@ func cliTunnelUDP(uconn *net.UDPConn, sconn net.Conn) {
}() }()
for { for {
b := lpool.Take()
defer lpool.put(b)
udp, err := gosocks5.ReadUDPDatagram(sconn) udp, err := gosocks5.ReadUDPDatagram(sconn)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
@ -228,7 +247,7 @@ func cliTunnelUDP(uconn *net.UDPConn, sconn net.Conn) {
} }
//log.Println("w", udp.Header) //log.Println("w", udp.Header)
udp.Header.Rsv = 0 udp.Header.Rsv = 0
buf := &bytes.Buffer{} buf := bytes.NewBuffer(b[0:0])
udp.Write(buf) udp.Write(buf)
if _, err := uconn.WriteTo(buf.Bytes(), raddr); err != nil { if _, err := uconn.WriteTo(buf.Bytes(), raddr); err != nil {
log.Println(err) log.Println(err)
@ -331,7 +350,9 @@ func getShadowRequest(conn net.Conn) (addr *gosocks5.Addr, extra []byte, err err
// buf size should at least have the same size with the largest possible // buf size should at least have the same size with the largest possible
// request size (when addrType is 3, domain name has at most 256 bytes) // request size (when addrType is 3, domain name has at most 256 bytes)
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
buf := make([]byte, 260) buf := spool.Take()
defer spool.put(buf)
var n int var n int
// read till we get possible domain length field // read till we get possible domain length field
//shadowsocks.SetReadTimeout(conn) //shadowsocks.SetReadTimeout(conn)

View File

@ -5,6 +5,7 @@ import (
"flag" "flag"
"github.com/ginuerzh/gosocks5" "github.com/ginuerzh/gosocks5"
"log" "log"
"time"
) )
var ( var (
@ -31,6 +32,12 @@ func init() {
log.SetFlags(log.LstdFlags | log.Lshortfile) log.SetFlags(log.LstdFlags | log.Lshortfile)
} }
var (
spool = NewMemPool(1024, 120*time.Minute, 1024) // 1k size buffer pool
mpool = NewMemPool(16*1024, 60*time.Minute, 512) // 16k size buffer pool
lpool = NewMemPool(32*1024, 30*time.Minute, 256) // 32k size buffer pool
)
func main() { func main() {
//log.Fatal(gost.Run()) //log.Fatal(gost.Run())
if len(Saddr) == 0 { if len(Saddr) == 0 {

108
pool.go Normal file
View File

@ -0,0 +1,108 @@
// pool for buffer
package main
import (
"container/list"
//"log"
"time"
)
type poolItem struct {
when time.Time
item interface{}
}
type pool struct {
quque *list.List
takeChan, putChan chan interface{}
age time.Duration
max int
}
func (p *pool) run() {
for {
if p.size() == 0 {
select {
case b := <-p.putChan:
p.put(b)
}
continue
}
i := p.quque.Front()
timeout := time.NewTimer(p.age)
select {
case b := <-p.putChan:
timeout.Stop()
p.put(b)
case p.takeChan <- i.Value.(*poolItem).item:
timeout.Stop()
p.quque.Remove(i)
case <-timeout.C:
i = p.quque.Back()
for i != nil {
if time.Since(i.Value.(*poolItem).when) < p.age {
break
}
e := i.Prev()
p.quque.Remove(i)
i = e
}
}
}
}
func (p *pool) size() int {
return p.quque.Len()
}
func (p *pool) put(v interface{}) {
if p.size() < p.max {
p.quque.PushFront(&poolItem{when: time.Now(), item: v})
return
}
}
type MemPool struct {
pool
bs int
}
func NewMemPool(bs int, age time.Duration, max int) *MemPool {
if bs <= 0 {
bs = 8192
}
if age == 0 {
age = 1 * time.Minute
}
p := &MemPool{
pool: pool{
quque: list.New(),
takeChan: make(chan interface{}),
putChan: make(chan interface{}),
age: age,
max: max,
},
bs: bs,
}
go p.run()
return p
}
func (p *MemPool) Take() []byte {
select {
case v := <-p.takeChan:
return v.([]byte)
default:
return make([]byte, p.bs)
}
}
func (p *MemPool) Put(b []byte) {
p.putChan <- b
}

View File

@ -124,7 +124,9 @@ func srvHandle(conn net.Conn) {
func srvTunnelUDP(conn net.Conn, uconn *net.UDPConn) { func srvTunnelUDP(conn net.Conn, uconn *net.UDPConn) {
go func() { go func() {
b := make([]byte, 65535) b := lpool.Take()
defer lpool.put(b)
for { for {
n, addr, err := uconn.ReadFromUDP(b) n, addr, err := uconn.ReadFromUDP(b)
if err != nil { if err != nil {

View File

@ -100,7 +100,9 @@ func Connect(addr, proxy string) (net.Conn, error) {
// based on io.Copy // based on io.Copy
func Copy(dst io.Writer, src io.Reader) (written int64, err error) { func Copy(dst io.Writer, src io.Reader) (written int64, err error) {
buf := make([]byte, 32*1024) buf := lpool.Take()
defer lpool.put(buf)
for { for {
nr, er := src.Read(buf) nr, er := src.Read(buf)
//log.Println("cp r", nr, er) //log.Println("cp r", nr, er)