fix udp forward: clear closed client info

This commit is contained in:
ginuerzh 2019-12-09 16:44:10 +08:00
parent 4798f7d76f
commit f30de10a5d
5 changed files with 154 additions and 124 deletions

3
.gitignore vendored
View File

@ -22,6 +22,9 @@ _cgo_export.*
_testmain.go _testmain.go
*.swp
*.swo
*.exe *.exe
*.test *.test

View File

@ -19,6 +19,7 @@ import (
func init() { func init() {
SetLogger(&NopLogger{}) SetLogger(&NopLogger{})
// SetLogger(&LogLogger{})
Debug = true Debug = true
DialTimeout = 1000 * time.Millisecond DialTimeout = 1000 * time.Millisecond
HandshakeTimeout = 1000 * time.Millisecond HandshakeTimeout = 1000 * time.Millisecond

View File

@ -5,6 +5,7 @@ import (
"net" "net"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"fmt" "fmt"
@ -338,12 +339,45 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr) log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr)
} }
type udpConnMap struct {
m sync.Map
size int64
}
func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) {
v, ok := m.m.Load(key)
if ok {
conn, ok = v.(*udpServerConn)
}
return
}
func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) {
m.m.Store(key, conn)
atomic.AddInt64(&m.size, 1)
}
func (m *udpConnMap) Delete(key interface{}) {
m.m.Delete(key)
atomic.AddInt64(&m.size, -1)
}
func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) {
m.m.Range(func(k, v interface{}) bool {
return f(k, v.(*udpServerConn))
})
}
func (m *udpConnMap) Size() int64 {
return atomic.LoadInt64(&m.size)
}
type udpDirectForwardListener struct { type udpDirectForwardListener struct {
ln net.PacketConn ln net.PacketConn
conns map[string]*udpServerConn
connChan chan net.Conn connChan chan net.Conn
errChan chan error errChan chan error
ttl time.Duration ttl time.Duration
connMap udpConnMap
} }
// UDPDirectForwardListener creates a Listener for UDP port forwarding server. // UDPDirectForwardListener creates a Listener for UDP port forwarding server.
@ -358,8 +392,7 @@ func UDPDirectForwardListener(addr string, ttl time.Duration) (Listener, error)
} }
l := &udpDirectForwardListener{ l := &udpDirectForwardListener{
ln: ln, ln: ln,
conns: make(map[string]*udpServerConn), connChan: make(chan net.Conn, 128),
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1), errChan: make(chan error, 1),
ttl: ttl, ttl: ttl,
} }
@ -378,16 +411,19 @@ func (l *udpDirectForwardListener) listenLoop() {
close(l.errChan) close(l.errChan)
return return
} }
if Debug {
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) conn, ok := l.connMap.Get(raddr)
} if !ok {
conn, ok := l.conns[raddr.String()]
if !ok || conn.Closed() {
conn = newUDPServerConn(l.ln, raddr, l.ttl) conn = newUDPServerConn(l.ln, raddr, l.ttl)
l.conns[raddr.String()] = conn conn.onClose = func() {
l.connMap.Delete(raddr)
log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size())
}
select { select {
case l.connChan <- conn: case l.connChan <- conn:
l.connMap.Set(raddr, conn)
log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
default: default:
conn.Close() conn.Close()
log.Logf("[udp] %s - %s: connection queue is full", raddr, l.Addr()) log.Logf("[udp] %s - %s: connection queue is full", raddr, l.Addr())
@ -396,6 +432,9 @@ func (l *udpDirectForwardListener) listenLoop() {
select { select {
case conn.rChan <- b[:n]: case conn.rChan <- b[:n]:
if Debug {
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
}
default: default:
log.Logf("[udp] %s -> %s : read queue is full", raddr, l.Addr()) log.Logf("[udp] %s -> %s : read queue is full", raddr, l.Addr())
} }
@ -419,32 +458,35 @@ func (l *udpDirectForwardListener) Addr() net.Addr {
} }
func (l *udpDirectForwardListener) Close() error { func (l *udpDirectForwardListener) Close() error {
return l.ln.Close() err := l.ln.Close()
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
v.Close()
return true
})
return err
} }
type udpServerConn struct { type udpServerConn struct {
conn net.PacketConn conn net.PacketConn
raddr net.Addr raddr net.Addr
rChan, wChan chan []byte rChan chan []byte
closed chan struct{} closed chan struct{}
brokenChan chan struct{} closeMutex sync.Mutex
closeMutex sync.Mutex ttl time.Duration
ttl time.Duration nopChan chan int
nopChan chan int onClose func()
} }
func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration) *udpServerConn { func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration) *udpServerConn {
c := &udpServerConn{ c := &udpServerConn{
conn: conn, conn: conn,
raddr: raddr, raddr: raddr,
rChan: make(chan []byte, 128), rChan: make(chan []byte, 128),
wChan: make(chan []byte, 128), closed: make(chan struct{}),
closed: make(chan struct{}), nopChan: make(chan int),
brokenChan: make(chan struct{}), ttl: ttl,
nopChan: make(chan int),
ttl: ttl,
} }
go c.writeLoop()
go c.ttlWait() go c.ttlWait()
return c return c
} }
@ -453,12 +495,6 @@ func (c *udpServerConn) Read(b []byte) (n int, err error) {
select { select {
case bb := <-c.rChan: case bb := <-c.rChan:
n = copy(b, bb) n = copy(b, bb)
if n != len(bb) {
err = errors.New("read partial data")
return
}
case <-c.brokenChan:
err = errors.New("Broken pipe")
case <-c.closed: case <-c.closed:
err = errors.New("read from closed connection") err = errors.New("read from closed connection")
return return
@ -468,26 +504,22 @@ func (c *udpServerConn) Read(b []byte) (n int, err error) {
case c.nopChan <- n: case c.nopChan <- n:
default: default:
} }
return return
} }
func (c *udpServerConn) Write(b []byte) (n int, err error) { func (c *udpServerConn) Write(b []byte) (n int, err error) {
if len(b) == 0 { n, err = c.conn.WriteTo(b, c.raddr)
return 0, nil
}
select {
case c.wChan <- b:
n = len(b)
case <-c.brokenChan:
err = errors.New("Broken pipe")
case <-c.closed:
err = errors.New("write to closed connection")
return
}
select { if n > 0 {
case c.nopChan <- n: if Debug {
default: log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n)
}
select {
case c.nopChan <- n:
default:
}
} }
return return
@ -501,43 +533,14 @@ func (c *udpServerConn) Close() error {
case <-c.closed: case <-c.closed:
return errors.New("connection is closed") return errors.New("connection is closed")
default: default:
if c.onClose != nil {
c.onClose()
}
close(c.closed) close(c.closed)
} }
return nil return nil
} }
func (c *udpServerConn) Closed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
func (c *udpServerConn) writeLoop() {
for {
select {
case b, ok := <-c.wChan:
if !ok {
return
}
n, err := c.conn.WriteTo(b, c.raddr)
if err != nil {
log.Logf("[udp] %s - %s : %s", c.RemoteAddr(), c.LocalAddr(), err)
return
}
if Debug {
log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n)
}
case <-c.brokenChan:
return
case <-c.closed:
return
}
}
}
func (c *udpServerConn) ttlWait() { func (c *udpServerConn) ttlWait() {
ttl := c.ttl ttl := c.ttl
if ttl == 0 { if ttl == 0 {
@ -554,7 +557,7 @@ func (c *udpServerConn) ttlWait() {
} }
timer.Reset(ttl) timer.Reset(ttl)
case <-timer.C: case <-timer.C:
close(c.brokenChan) c.Close()
return return
case <-c.closed: case <-c.closed:
return return
@ -842,7 +845,7 @@ func (l *tcpRemoteForwardListener) Close() error {
type udpRemoteForwardListener struct { type udpRemoteForwardListener struct {
addr net.Addr addr net.Addr
chain *Chain chain *Chain
conns map[string]*udpServerConn connMap udpConnMap
connChan chan net.Conn connChan chan net.Conn
ln *net.UDPConn ln *net.UDPConn
errChan chan error errChan chan error
@ -862,8 +865,7 @@ func UDPRemoteForwardListener(addr string, chain *Chain, ttl time.Duration) (Lis
ln := &udpRemoteForwardListener{ ln := &udpRemoteForwardListener{
addr: laddr, addr: laddr,
chain: chain, chain: chain,
conns: make(map[string]*udpServerConn), connChan: make(chan net.Conn, 128),
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1), errChan: make(chan error, 1),
ttl: ttl, ttl: ttl,
closed: make(chan struct{}), closed: make(chan struct{}),
@ -888,39 +890,47 @@ func (l *udpRemoteForwardListener) listenLoop() {
log.Logf("[rudp] %s : %s", l.Addr(), err) log.Logf("[rudp] %s : %s", l.Addr(), err)
return return
} }
defer conn.Close()
for { func() {
b := make([]byte, mediumBufferSize) defer conn.Close()
n, raddr, err := conn.ReadFrom(b)
if err != nil { for {
log.Logf("[rudp] %s : %s", l.Addr(), err) b := make([]byte, mediumBufferSize)
break n, raddr, err := conn.ReadFrom(b)
} if err != nil {
if Debug { log.Logf("[rudp] %s : %s", l.Addr(), err)
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) break
} }
uc, ok := l.conns[raddr.String()]
if !ok || uc.Closed() { uc, ok := l.connMap.Get(raddr)
uc = newUDPServerConn(conn, raddr, l.ttl) if !ok {
l.conns[raddr.String()] = uc uc = newUDPServerConn(conn, raddr, l.ttl)
uc.onClose = func() {
l.connMap.Delete(raddr)
log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size())
}
select {
case l.connChan <- uc:
l.connMap.Set(raddr, uc)
log.Logf("[rudp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
default:
uc.Close()
log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr())
}
}
select { select {
case l.connChan <- uc: case uc.rChan <- b[:n]:
if Debug {
log.Logf("[rudp] %s >>> %s : length %d", raddr, l.Addr(), n)
}
default: default:
uc.Close() log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr())
log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr())
} }
} }
}()
select {
case uc.rChan <- b[:n]:
default:
log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr())
}
}
} }
} }
func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) {
@ -994,6 +1004,10 @@ func (l *udpRemoteForwardListener) Close() error {
case <-l.closed: case <-l.closed:
return nil return nil
default: default:
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
v.Close()
return true
})
close(l.closed) close(l.closed)
} }

35
ss.go
View File

@ -298,10 +298,10 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn
type shadowUDPListener struct { type shadowUDPListener struct {
ln net.PacketConn ln net.PacketConn
conns map[string]*udpServerConn
connChan chan net.Conn connChan chan net.Conn
errChan chan error errChan chan error
ttl time.Duration ttl time.Duration
connMap udpConnMap
} }
// ShadowUDPListener creates a Listener for shadowsocks UDP relay server. // ShadowUDPListener creates a Listener for shadowsocks UDP relay server.
@ -325,10 +325,10 @@ func ShadowUDPListener(addr string, cipher *url.Userinfo, ttl time.Duration) (Li
ln.Close() ln.Close()
return nil, err return nil, err
} }
l := &shadowUDPListener{ l := &shadowUDPListener{
ln: ss.NewSecurePacketConn(ln, cp, false), ln: ss.NewSecurePacketConn(ln, cp, false),
conns: make(map[string]*udpServerConn), connChan: make(chan net.Conn, 128),
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1), errChan: make(chan error, 1),
ttl: ttl, ttl: ttl,
} }
@ -347,17 +347,19 @@ func (l *shadowUDPListener) listenLoop() {
close(l.errChan) close(l.errChan)
return return
} }
if Debug {
log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n)
}
conn, ok := l.conns[raddr.String()] conn, ok := l.connMap.Get(raddr)
if !ok || conn.Closed() { if !ok {
conn = newUDPServerConn(l.ln, raddr, l.ttl) conn = newUDPServerConn(l.ln, raddr, l.ttl)
l.conns[raddr.String()] = conn conn.onClose = func() {
l.connMap.Delete(raddr)
log.Logf("[ssu] %s closed (%d)", raddr, l.connMap.Size())
}
select { select {
case l.connChan <- conn: case l.connChan <- conn:
l.connMap.Set(raddr, conn)
log.Logf("[ssu] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
default: default:
conn.Close() conn.Close()
log.Logf("[ssu] %s - %s: connection queue is full", raddr, l.Addr()) log.Logf("[ssu] %s - %s: connection queue is full", raddr, l.Addr())
@ -366,6 +368,9 @@ func (l *shadowUDPListener) listenLoop() {
select { select {
case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination. case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination.
if Debug {
log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n)
}
default: default:
log.Logf("[ssu] %s -> %s : read queue is full", raddr, l.Addr()) log.Logf("[ssu] %s -> %s : read queue is full", raddr, l.Addr())
} }
@ -389,7 +394,13 @@ func (l *shadowUDPListener) Addr() net.Addr {
} }
func (l *shadowUDPListener) Close() error { func (l *shadowUDPListener) Close() error {
return l.ln.Close() err := l.ln.Close()
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
v.Close()
return true
})
return err
} }
type shadowUDPdHandler struct { type shadowUDPdHandler struct {
@ -474,7 +485,7 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
return return
} }
if h.options.Bypass.Contains(addr.String()) { if h.options.Bypass.Contains(addr.String()) {
log.Log("[ssu] [bypass] write to", addr) log.Log("[ssu] bypass", addr)
continue // bypass continue // bypass
} }
if _, err := cc.WriteTo(dgram.Data, addr); err != nil { if _, err := cc.WriteTo(dgram.Data, addr); err != nil {
@ -498,7 +509,7 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n) log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n)
} }
if h.options.Bypass.Contains(addr.String()) { if h.options.Bypass.Contains(addr.String()) {
log.Log("[ssu] [bypass] read from", addr) log.Log("[ssu] bypass", addr)
continue // bypass continue // bypass
} }
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n]) dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n])

View File

@ -391,8 +391,7 @@ func BenchmarkShadowUDP(b *testing.B) {
} }
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
conn.SetDeadline(time.Now().Add(1 * time.Second)) conn.SetDeadline(time.Now().Add(3 * time.Second))
defer conn.SetDeadline(time.Time{})
if _, err = conn.Write(sendData); err != nil { if _, err = conn.Write(sendData); err != nil {
b.Error(err) b.Error(err)
@ -403,6 +402,8 @@ func BenchmarkShadowUDP(b *testing.B) {
b.Error(err) b.Error(err)
} }
conn.SetDeadline(time.Time{})
if !bytes.Equal(sendData, recv) { if !bytes.Equal(sendData, recv) {
b.Error("data not equal") b.Error("data not equal")
} }