fix udp forward: clear closed client info

This commit is contained in:
ginuerzh 2019-12-09 16:44:10 +08:00
parent 4798f7d76f
commit f30de10a5d
5 changed files with 154 additions and 124 deletions

3
.gitignore vendored
View File

@ -22,6 +22,9 @@ _cgo_export.*
_testmain.go
*.swp
*.swo
*.exe
*.test

View File

@ -19,6 +19,7 @@ import (
func init() {
SetLogger(&NopLogger{})
// SetLogger(&LogLogger{})
Debug = true
DialTimeout = 1000 * time.Millisecond
HandshakeTimeout = 1000 * time.Millisecond

View File

@ -5,6 +5,7 @@ import (
"net"
"strings"
"sync"
"sync/atomic"
"time"
"fmt"
@ -338,12 +339,45 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr)
}
type udpConnMap struct {
m sync.Map
size int64
}
func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) {
v, ok := m.m.Load(key)
if ok {
conn, ok = v.(*udpServerConn)
}
return
}
func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) {
m.m.Store(key, conn)
atomic.AddInt64(&m.size, 1)
}
func (m *udpConnMap) Delete(key interface{}) {
m.m.Delete(key)
atomic.AddInt64(&m.size, -1)
}
func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) {
m.m.Range(func(k, v interface{}) bool {
return f(k, v.(*udpServerConn))
})
}
func (m *udpConnMap) Size() int64 {
return atomic.LoadInt64(&m.size)
}
type udpDirectForwardListener struct {
ln net.PacketConn
conns map[string]*udpServerConn
connChan chan net.Conn
errChan chan error
ttl time.Duration
connMap udpConnMap
}
// UDPDirectForwardListener creates a Listener for UDP port forwarding server.
@ -358,8 +392,7 @@ func UDPDirectForwardListener(addr string, ttl time.Duration) (Listener, error)
}
l := &udpDirectForwardListener{
ln: ln,
conns: make(map[string]*udpServerConn),
connChan: make(chan net.Conn, 1024),
connChan: make(chan net.Conn, 128),
errChan: make(chan error, 1),
ttl: ttl,
}
@ -378,16 +411,19 @@ func (l *udpDirectForwardListener) listenLoop() {
close(l.errChan)
return
}
if Debug {
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
}
conn, ok := l.conns[raddr.String()]
if !ok || conn.Closed() {
conn, ok := l.connMap.Get(raddr)
if !ok {
conn = newUDPServerConn(l.ln, raddr, l.ttl)
l.conns[raddr.String()] = conn
conn.onClose = func() {
l.connMap.Delete(raddr)
log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size())
}
select {
case l.connChan <- conn:
l.connMap.Set(raddr, conn)
log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
default:
conn.Close()
log.Logf("[udp] %s - %s: connection queue is full", raddr, l.Addr())
@ -396,6 +432,9 @@ func (l *udpDirectForwardListener) listenLoop() {
select {
case conn.rChan <- b[:n]:
if Debug {
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
}
default:
log.Logf("[udp] %s -> %s : read queue is full", raddr, l.Addr())
}
@ -419,32 +458,35 @@ func (l *udpDirectForwardListener) Addr() net.Addr {
}
func (l *udpDirectForwardListener) Close() error {
return l.ln.Close()
err := l.ln.Close()
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
v.Close()
return true
})
return err
}
type udpServerConn struct {
conn net.PacketConn
raddr net.Addr
rChan, wChan chan []byte
closed chan struct{}
brokenChan chan struct{}
closeMutex sync.Mutex
ttl time.Duration
nopChan chan int
conn net.PacketConn
raddr net.Addr
rChan chan []byte
closed chan struct{}
closeMutex sync.Mutex
ttl time.Duration
nopChan chan int
onClose func()
}
func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration) *udpServerConn {
c := &udpServerConn{
conn: conn,
raddr: raddr,
rChan: make(chan []byte, 128),
wChan: make(chan []byte, 128),
closed: make(chan struct{}),
brokenChan: make(chan struct{}),
nopChan: make(chan int),
ttl: ttl,
conn: conn,
raddr: raddr,
rChan: make(chan []byte, 128),
closed: make(chan struct{}),
nopChan: make(chan int),
ttl: ttl,
}
go c.writeLoop()
go c.ttlWait()
return c
}
@ -453,12 +495,6 @@ func (c *udpServerConn) Read(b []byte) (n int, err error) {
select {
case bb := <-c.rChan:
n = copy(b, bb)
if n != len(bb) {
err = errors.New("read partial data")
return
}
case <-c.brokenChan:
err = errors.New("Broken pipe")
case <-c.closed:
err = errors.New("read from closed connection")
return
@ -468,26 +504,22 @@ func (c *udpServerConn) Read(b []byte) (n int, err error) {
case c.nopChan <- n:
default:
}
return
}
func (c *udpServerConn) Write(b []byte) (n int, err error) {
if len(b) == 0 {
return 0, nil
}
select {
case c.wChan <- b:
n = len(b)
case <-c.brokenChan:
err = errors.New("Broken pipe")
case <-c.closed:
err = errors.New("write to closed connection")
return
}
n, err = c.conn.WriteTo(b, c.raddr)
select {
case c.nopChan <- n:
default:
if n > 0 {
if Debug {
log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n)
}
select {
case c.nopChan <- n:
default:
}
}
return
@ -501,43 +533,14 @@ func (c *udpServerConn) Close() error {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.onClose != nil {
c.onClose()
}
close(c.closed)
}
return nil
}
func (c *udpServerConn) Closed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
func (c *udpServerConn) writeLoop() {
for {
select {
case b, ok := <-c.wChan:
if !ok {
return
}
n, err := c.conn.WriteTo(b, c.raddr)
if err != nil {
log.Logf("[udp] %s - %s : %s", c.RemoteAddr(), c.LocalAddr(), err)
return
}
if Debug {
log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n)
}
case <-c.brokenChan:
return
case <-c.closed:
return
}
}
}
func (c *udpServerConn) ttlWait() {
ttl := c.ttl
if ttl == 0 {
@ -554,7 +557,7 @@ func (c *udpServerConn) ttlWait() {
}
timer.Reset(ttl)
case <-timer.C:
close(c.brokenChan)
c.Close()
return
case <-c.closed:
return
@ -842,7 +845,7 @@ func (l *tcpRemoteForwardListener) Close() error {
type udpRemoteForwardListener struct {
addr net.Addr
chain *Chain
conns map[string]*udpServerConn
connMap udpConnMap
connChan chan net.Conn
ln *net.UDPConn
errChan chan error
@ -862,8 +865,7 @@ func UDPRemoteForwardListener(addr string, chain *Chain, ttl time.Duration) (Lis
ln := &udpRemoteForwardListener{
addr: laddr,
chain: chain,
conns: make(map[string]*udpServerConn),
connChan: make(chan net.Conn, 1024),
connChan: make(chan net.Conn, 128),
errChan: make(chan error, 1),
ttl: ttl,
closed: make(chan struct{}),
@ -888,39 +890,47 @@ func (l *udpRemoteForwardListener) listenLoop() {
log.Logf("[rudp] %s : %s", l.Addr(), err)
return
}
defer conn.Close()
for {
b := make([]byte, mediumBufferSize)
n, raddr, err := conn.ReadFrom(b)
if err != nil {
log.Logf("[rudp] %s : %s", l.Addr(), err)
break
}
if Debug {
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
}
uc, ok := l.conns[raddr.String()]
if !ok || uc.Closed() {
uc = newUDPServerConn(conn, raddr, l.ttl)
l.conns[raddr.String()] = uc
func() {
defer conn.Close()
for {
b := make([]byte, mediumBufferSize)
n, raddr, err := conn.ReadFrom(b)
if err != nil {
log.Logf("[rudp] %s : %s", l.Addr(), err)
break
}
uc, ok := l.connMap.Get(raddr)
if !ok {
uc = newUDPServerConn(conn, raddr, l.ttl)
uc.onClose = func() {
l.connMap.Delete(raddr)
log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size())
}
select {
case l.connChan <- uc:
l.connMap.Set(raddr, uc)
log.Logf("[rudp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
default:
uc.Close()
log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr())
}
}
select {
case l.connChan <- uc:
case uc.rChan <- b[:n]:
if Debug {
log.Logf("[rudp] %s >>> %s : length %d", raddr, l.Addr(), n)
}
default:
uc.Close()
log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr())
log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr())
}
}
select {
case uc.rChan <- b[:n]:
default:
log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr())
}
}
}()
}
}
func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) {
@ -994,6 +1004,10 @@ func (l *udpRemoteForwardListener) Close() error {
case <-l.closed:
return nil
default:
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
v.Close()
return true
})
close(l.closed)
}

35
ss.go
View File

@ -298,10 +298,10 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn
type shadowUDPListener struct {
ln net.PacketConn
conns map[string]*udpServerConn
connChan chan net.Conn
errChan chan error
ttl time.Duration
connMap udpConnMap
}
// ShadowUDPListener creates a Listener for shadowsocks UDP relay server.
@ -325,10 +325,10 @@ func ShadowUDPListener(addr string, cipher *url.Userinfo, ttl time.Duration) (Li
ln.Close()
return nil, err
}
l := &shadowUDPListener{
ln: ss.NewSecurePacketConn(ln, cp, false),
conns: make(map[string]*udpServerConn),
connChan: make(chan net.Conn, 1024),
connChan: make(chan net.Conn, 128),
errChan: make(chan error, 1),
ttl: ttl,
}
@ -347,17 +347,19 @@ func (l *shadowUDPListener) listenLoop() {
close(l.errChan)
return
}
if Debug {
log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n)
}
conn, ok := l.conns[raddr.String()]
if !ok || conn.Closed() {
conn, ok := l.connMap.Get(raddr)
if !ok {
conn = newUDPServerConn(l.ln, raddr, l.ttl)
l.conns[raddr.String()] = conn
conn.onClose = func() {
l.connMap.Delete(raddr)
log.Logf("[ssu] %s closed (%d)", raddr, l.connMap.Size())
}
select {
case l.connChan <- conn:
l.connMap.Set(raddr, conn)
log.Logf("[ssu] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
default:
conn.Close()
log.Logf("[ssu] %s - %s: connection queue is full", raddr, l.Addr())
@ -366,6 +368,9 @@ func (l *shadowUDPListener) listenLoop() {
select {
case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination.
if Debug {
log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n)
}
default:
log.Logf("[ssu] %s -> %s : read queue is full", raddr, l.Addr())
}
@ -389,7 +394,13 @@ func (l *shadowUDPListener) Addr() net.Addr {
}
func (l *shadowUDPListener) Close() error {
return l.ln.Close()
err := l.ln.Close()
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
v.Close()
return true
})
return err
}
type shadowUDPdHandler struct {
@ -474,7 +485,7 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
return
}
if h.options.Bypass.Contains(addr.String()) {
log.Log("[ssu] [bypass] write to", addr)
log.Log("[ssu] bypass", addr)
continue // bypass
}
if _, err := cc.WriteTo(dgram.Data, addr); err != nil {
@ -498,7 +509,7 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n)
}
if h.options.Bypass.Contains(addr.String()) {
log.Log("[ssu] [bypass] read from", addr)
log.Log("[ssu] bypass", addr)
continue // bypass
}
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n])

View File

@ -391,8 +391,7 @@ func BenchmarkShadowUDP(b *testing.B) {
}
for i := 0; i < b.N; i++ {
conn.SetDeadline(time.Now().Add(1 * time.Second))
defer conn.SetDeadline(time.Time{})
conn.SetDeadline(time.Now().Add(3 * time.Second))
if _, err = conn.Write(sendData); err != nil {
b.Error(err)
@ -403,6 +402,8 @@ func BenchmarkShadowUDP(b *testing.B) {
b.Error(err)
}
conn.SetDeadline(time.Time{})
if !bytes.Equal(sendData, recv) {
b.Error("data not equal")
}