fix udp forward: clear closed client info
This commit is contained in:
parent
4798f7d76f
commit
f30de10a5d
3
.gitignore
vendored
3
.gitignore
vendored
@ -22,6 +22,9 @@ _cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
*.exe
|
||||
*.test
|
||||
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
|
||||
func init() {
|
||||
SetLogger(&NopLogger{})
|
||||
// SetLogger(&LogLogger{})
|
||||
Debug = true
|
||||
DialTimeout = 1000 * time.Millisecond
|
||||
HandshakeTimeout = 1000 * time.Millisecond
|
||||
|
234
forward.go
234
forward.go
@ -5,6 +5,7 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
@ -338,12 +339,45 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
|
||||
log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr)
|
||||
}
|
||||
|
||||
type udpConnMap struct {
|
||||
m sync.Map
|
||||
size int64
|
||||
}
|
||||
|
||||
func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) {
|
||||
v, ok := m.m.Load(key)
|
||||
if ok {
|
||||
conn, ok = v.(*udpServerConn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) {
|
||||
m.m.Store(key, conn)
|
||||
atomic.AddInt64(&m.size, 1)
|
||||
}
|
||||
|
||||
func (m *udpConnMap) Delete(key interface{}) {
|
||||
m.m.Delete(key)
|
||||
atomic.AddInt64(&m.size, -1)
|
||||
}
|
||||
|
||||
func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) {
|
||||
m.m.Range(func(k, v interface{}) bool {
|
||||
return f(k, v.(*udpServerConn))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *udpConnMap) Size() int64 {
|
||||
return atomic.LoadInt64(&m.size)
|
||||
}
|
||||
|
||||
type udpDirectForwardListener struct {
|
||||
ln net.PacketConn
|
||||
conns map[string]*udpServerConn
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
ttl time.Duration
|
||||
connMap udpConnMap
|
||||
}
|
||||
|
||||
// UDPDirectForwardListener creates a Listener for UDP port forwarding server.
|
||||
@ -358,8 +392,7 @@ func UDPDirectForwardListener(addr string, ttl time.Duration) (Listener, error)
|
||||
}
|
||||
l := &udpDirectForwardListener{
|
||||
ln: ln,
|
||||
conns: make(map[string]*udpServerConn),
|
||||
connChan: make(chan net.Conn, 1024),
|
||||
connChan: make(chan net.Conn, 128),
|
||||
errChan: make(chan error, 1),
|
||||
ttl: ttl,
|
||||
}
|
||||
@ -378,16 +411,19 @@ func (l *udpDirectForwardListener) listenLoop() {
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
if Debug {
|
||||
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
|
||||
}
|
||||
conn, ok := l.conns[raddr.String()]
|
||||
if !ok || conn.Closed() {
|
||||
|
||||
conn, ok := l.connMap.Get(raddr)
|
||||
if !ok {
|
||||
conn = newUDPServerConn(l.ln, raddr, l.ttl)
|
||||
l.conns[raddr.String()] = conn
|
||||
conn.onClose = func() {
|
||||
l.connMap.Delete(raddr)
|
||||
log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size())
|
||||
}
|
||||
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
l.connMap.Set(raddr, conn)
|
||||
log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
|
||||
default:
|
||||
conn.Close()
|
||||
log.Logf("[udp] %s - %s: connection queue is full", raddr, l.Addr())
|
||||
@ -396,6 +432,9 @@ func (l *udpDirectForwardListener) listenLoop() {
|
||||
|
||||
select {
|
||||
case conn.rChan <- b[:n]:
|
||||
if Debug {
|
||||
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
|
||||
}
|
||||
default:
|
||||
log.Logf("[udp] %s -> %s : read queue is full", raddr, l.Addr())
|
||||
}
|
||||
@ -419,32 +458,35 @@ func (l *udpDirectForwardListener) Addr() net.Addr {
|
||||
}
|
||||
|
||||
func (l *udpDirectForwardListener) Close() error {
|
||||
return l.ln.Close()
|
||||
err := l.ln.Close()
|
||||
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
|
||||
v.Close()
|
||||
return true
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type udpServerConn struct {
|
||||
conn net.PacketConn
|
||||
raddr net.Addr
|
||||
rChan, wChan chan []byte
|
||||
closed chan struct{}
|
||||
brokenChan chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
ttl time.Duration
|
||||
nopChan chan int
|
||||
conn net.PacketConn
|
||||
raddr net.Addr
|
||||
rChan chan []byte
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
ttl time.Duration
|
||||
nopChan chan int
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration) *udpServerConn {
|
||||
c := &udpServerConn{
|
||||
conn: conn,
|
||||
raddr: raddr,
|
||||
rChan: make(chan []byte, 128),
|
||||
wChan: make(chan []byte, 128),
|
||||
closed: make(chan struct{}),
|
||||
brokenChan: make(chan struct{}),
|
||||
nopChan: make(chan int),
|
||||
ttl: ttl,
|
||||
conn: conn,
|
||||
raddr: raddr,
|
||||
rChan: make(chan []byte, 128),
|
||||
closed: make(chan struct{}),
|
||||
nopChan: make(chan int),
|
||||
ttl: ttl,
|
||||
}
|
||||
go c.writeLoop()
|
||||
go c.ttlWait()
|
||||
return c
|
||||
}
|
||||
@ -453,12 +495,6 @@ func (c *udpServerConn) Read(b []byte) (n int, err error) {
|
||||
select {
|
||||
case bb := <-c.rChan:
|
||||
n = copy(b, bb)
|
||||
if n != len(bb) {
|
||||
err = errors.New("read partial data")
|
||||
return
|
||||
}
|
||||
case <-c.brokenChan:
|
||||
err = errors.New("Broken pipe")
|
||||
case <-c.closed:
|
||||
err = errors.New("read from closed connection")
|
||||
return
|
||||
@ -468,26 +504,22 @@ func (c *udpServerConn) Read(b []byte) (n int, err error) {
|
||||
case c.nopChan <- n:
|
||||
default:
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *udpServerConn) Write(b []byte) (n int, err error) {
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
select {
|
||||
case c.wChan <- b:
|
||||
n = len(b)
|
||||
case <-c.brokenChan:
|
||||
err = errors.New("Broken pipe")
|
||||
case <-c.closed:
|
||||
err = errors.New("write to closed connection")
|
||||
return
|
||||
}
|
||||
n, err = c.conn.WriteTo(b, c.raddr)
|
||||
|
||||
select {
|
||||
case c.nopChan <- n:
|
||||
default:
|
||||
if n > 0 {
|
||||
if Debug {
|
||||
log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n)
|
||||
}
|
||||
|
||||
select {
|
||||
case c.nopChan <- n:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@ -501,43 +533,14 @@ func (c *udpServerConn) Close() error {
|
||||
case <-c.closed:
|
||||
return errors.New("connection is closed")
|
||||
default:
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpServerConn) Closed() bool {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpServerConn) writeLoop() {
|
||||
for {
|
||||
select {
|
||||
case b, ok := <-c.wChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
n, err := c.conn.WriteTo(b, c.raddr)
|
||||
if err != nil {
|
||||
log.Logf("[udp] %s - %s : %s", c.RemoteAddr(), c.LocalAddr(), err)
|
||||
return
|
||||
}
|
||||
if Debug {
|
||||
log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n)
|
||||
}
|
||||
case <-c.brokenChan:
|
||||
return
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpServerConn) ttlWait() {
|
||||
ttl := c.ttl
|
||||
if ttl == 0 {
|
||||
@ -554,7 +557,7 @@ func (c *udpServerConn) ttlWait() {
|
||||
}
|
||||
timer.Reset(ttl)
|
||||
case <-timer.C:
|
||||
close(c.brokenChan)
|
||||
c.Close()
|
||||
return
|
||||
case <-c.closed:
|
||||
return
|
||||
@ -842,7 +845,7 @@ func (l *tcpRemoteForwardListener) Close() error {
|
||||
type udpRemoteForwardListener struct {
|
||||
addr net.Addr
|
||||
chain *Chain
|
||||
conns map[string]*udpServerConn
|
||||
connMap udpConnMap
|
||||
connChan chan net.Conn
|
||||
ln *net.UDPConn
|
||||
errChan chan error
|
||||
@ -862,8 +865,7 @@ func UDPRemoteForwardListener(addr string, chain *Chain, ttl time.Duration) (Lis
|
||||
ln := &udpRemoteForwardListener{
|
||||
addr: laddr,
|
||||
chain: chain,
|
||||
conns: make(map[string]*udpServerConn),
|
||||
connChan: make(chan net.Conn, 1024),
|
||||
connChan: make(chan net.Conn, 128),
|
||||
errChan: make(chan error, 1),
|
||||
ttl: ttl,
|
||||
closed: make(chan struct{}),
|
||||
@ -888,39 +890,47 @@ func (l *udpRemoteForwardListener) listenLoop() {
|
||||
log.Logf("[rudp] %s : %s", l.Addr(), err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
b := make([]byte, mediumBufferSize)
|
||||
n, raddr, err := conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
log.Logf("[rudp] %s : %s", l.Addr(), err)
|
||||
break
|
||||
}
|
||||
if Debug {
|
||||
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
|
||||
}
|
||||
uc, ok := l.conns[raddr.String()]
|
||||
if !ok || uc.Closed() {
|
||||
uc = newUDPServerConn(conn, raddr, l.ttl)
|
||||
l.conns[raddr.String()] = uc
|
||||
func() {
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
b := make([]byte, mediumBufferSize)
|
||||
n, raddr, err := conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
log.Logf("[rudp] %s : %s", l.Addr(), err)
|
||||
break
|
||||
}
|
||||
|
||||
uc, ok := l.connMap.Get(raddr)
|
||||
if !ok {
|
||||
uc = newUDPServerConn(conn, raddr, l.ttl)
|
||||
uc.onClose = func() {
|
||||
l.connMap.Delete(raddr)
|
||||
log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size())
|
||||
}
|
||||
|
||||
select {
|
||||
case l.connChan <- uc:
|
||||
l.connMap.Set(raddr, uc)
|
||||
log.Logf("[rudp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
|
||||
default:
|
||||
uc.Close()
|
||||
log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case l.connChan <- uc:
|
||||
case uc.rChan <- b[:n]:
|
||||
if Debug {
|
||||
log.Logf("[rudp] %s >>> %s : length %d", raddr, l.Addr(), n)
|
||||
}
|
||||
default:
|
||||
uc.Close()
|
||||
log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr())
|
||||
log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case uc.rChan <- b[:n]:
|
||||
default:
|
||||
log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr())
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) {
|
||||
@ -994,6 +1004,10 @@ func (l *udpRemoteForwardListener) Close() error {
|
||||
case <-l.closed:
|
||||
return nil
|
||||
default:
|
||||
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
|
||||
v.Close()
|
||||
return true
|
||||
})
|
||||
close(l.closed)
|
||||
}
|
||||
|
||||
|
35
ss.go
35
ss.go
@ -298,10 +298,10 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn
|
||||
|
||||
type shadowUDPListener struct {
|
||||
ln net.PacketConn
|
||||
conns map[string]*udpServerConn
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
ttl time.Duration
|
||||
connMap udpConnMap
|
||||
}
|
||||
|
||||
// ShadowUDPListener creates a Listener for shadowsocks UDP relay server.
|
||||
@ -325,10 +325,10 @@ func ShadowUDPListener(addr string, cipher *url.Userinfo, ttl time.Duration) (Li
|
||||
ln.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l := &shadowUDPListener{
|
||||
ln: ss.NewSecurePacketConn(ln, cp, false),
|
||||
conns: make(map[string]*udpServerConn),
|
||||
connChan: make(chan net.Conn, 1024),
|
||||
connChan: make(chan net.Conn, 128),
|
||||
errChan: make(chan error, 1),
|
||||
ttl: ttl,
|
||||
}
|
||||
@ -347,17 +347,19 @@ func (l *shadowUDPListener) listenLoop() {
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
if Debug {
|
||||
log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n)
|
||||
}
|
||||
|
||||
conn, ok := l.conns[raddr.String()]
|
||||
if !ok || conn.Closed() {
|
||||
conn, ok := l.connMap.Get(raddr)
|
||||
if !ok {
|
||||
conn = newUDPServerConn(l.ln, raddr, l.ttl)
|
||||
l.conns[raddr.String()] = conn
|
||||
conn.onClose = func() {
|
||||
l.connMap.Delete(raddr)
|
||||
log.Logf("[ssu] %s closed (%d)", raddr, l.connMap.Size())
|
||||
}
|
||||
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
l.connMap.Set(raddr, conn)
|
||||
log.Logf("[ssu] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
|
||||
default:
|
||||
conn.Close()
|
||||
log.Logf("[ssu] %s - %s: connection queue is full", raddr, l.Addr())
|
||||
@ -366,6 +368,9 @@ func (l *shadowUDPListener) listenLoop() {
|
||||
|
||||
select {
|
||||
case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination.
|
||||
if Debug {
|
||||
log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n)
|
||||
}
|
||||
default:
|
||||
log.Logf("[ssu] %s -> %s : read queue is full", raddr, l.Addr())
|
||||
}
|
||||
@ -389,7 +394,13 @@ func (l *shadowUDPListener) Addr() net.Addr {
|
||||
}
|
||||
|
||||
func (l *shadowUDPListener) Close() error {
|
||||
return l.ln.Close()
|
||||
err := l.ln.Close()
|
||||
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
|
||||
v.Close()
|
||||
return true
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type shadowUDPdHandler struct {
|
||||
@ -474,7 +485,7 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
|
||||
return
|
||||
}
|
||||
if h.options.Bypass.Contains(addr.String()) {
|
||||
log.Log("[ssu] [bypass] write to", addr)
|
||||
log.Log("[ssu] bypass", addr)
|
||||
continue // bypass
|
||||
}
|
||||
if _, err := cc.WriteTo(dgram.Data, addr); err != nil {
|
||||
@ -498,7 +509,7 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
|
||||
log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n)
|
||||
}
|
||||
if h.options.Bypass.Contains(addr.String()) {
|
||||
log.Log("[ssu] [bypass] read from", addr)
|
||||
log.Log("[ssu] bypass", addr)
|
||||
continue // bypass
|
||||
}
|
||||
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n])
|
||||
|
@ -391,8 +391,7 @@ func BenchmarkShadowUDP(b *testing.B) {
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.SetDeadline(time.Now().Add(1 * time.Second))
|
||||
defer conn.SetDeadline(time.Time{})
|
||||
conn.SetDeadline(time.Now().Add(3 * time.Second))
|
||||
|
||||
if _, err = conn.Write(sendData); err != nil {
|
||||
b.Error(err)
|
||||
@ -403,6 +402,8 @@ func BenchmarkShadowUDP(b *testing.B) {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
conn.SetDeadline(time.Time{})
|
||||
|
||||
if !bytes.Equal(sendData, recv) {
|
||||
b.Error("data not equal")
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user