diff --git a/.gitignore b/.gitignore index 1d6095c..0cd5cca 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,9 @@ _cgo_export.* _testmain.go +*.swp +*.swo + *.exe *.test diff --git a/common_test.go b/common_test.go index a68d4a3..9e030d6 100644 --- a/common_test.go +++ b/common_test.go @@ -19,6 +19,7 @@ import ( func init() { SetLogger(&NopLogger{}) + // SetLogger(&LogLogger{}) Debug = true DialTimeout = 1000 * time.Millisecond HandshakeTimeout = 1000 * time.Millisecond diff --git a/forward.go b/forward.go index d403102..f2671df 100644 --- a/forward.go +++ b/forward.go @@ -5,6 +5,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "time" "fmt" @@ -338,12 +339,45 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr) } +type udpConnMap struct { + m sync.Map + size int64 +} + +func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) { + v, ok := m.m.Load(key) + if ok { + conn, ok = v.(*udpServerConn) + } + return +} + +func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) { + m.m.Store(key, conn) + atomic.AddInt64(&m.size, 1) +} + +func (m *udpConnMap) Delete(key interface{}) { + m.m.Delete(key) + atomic.AddInt64(&m.size, -1) +} + +func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) { + m.m.Range(func(k, v interface{}) bool { + return f(k, v.(*udpServerConn)) + }) +} + +func (m *udpConnMap) Size() int64 { + return atomic.LoadInt64(&m.size) +} + type udpDirectForwardListener struct { ln net.PacketConn - conns map[string]*udpServerConn connChan chan net.Conn errChan chan error ttl time.Duration + connMap udpConnMap } // UDPDirectForwardListener creates a Listener for UDP port forwarding server. @@ -358,8 +392,7 @@ func UDPDirectForwardListener(addr string, ttl time.Duration) (Listener, error) } l := &udpDirectForwardListener{ ln: ln, - conns: make(map[string]*udpServerConn), - connChan: make(chan net.Conn, 1024), + connChan: make(chan net.Conn, 128), errChan: make(chan error, 1), ttl: ttl, } @@ -378,16 +411,19 @@ func (l *udpDirectForwardListener) listenLoop() { close(l.errChan) return } - if Debug { - log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) - } - conn, ok := l.conns[raddr.String()] - if !ok || conn.Closed() { + + conn, ok := l.connMap.Get(raddr) + if !ok { conn = newUDPServerConn(l.ln, raddr, l.ttl) - l.conns[raddr.String()] = conn + conn.onClose = func() { + l.connMap.Delete(raddr) + log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size()) + } select { case l.connChan <- conn: + l.connMap.Set(raddr, conn) + log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) default: conn.Close() log.Logf("[udp] %s - %s: connection queue is full", raddr, l.Addr()) @@ -396,6 +432,9 @@ func (l *udpDirectForwardListener) listenLoop() { select { case conn.rChan <- b[:n]: + if Debug { + log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) + } default: log.Logf("[udp] %s -> %s : read queue is full", raddr, l.Addr()) } @@ -419,32 +458,35 @@ func (l *udpDirectForwardListener) Addr() net.Addr { } func (l *udpDirectForwardListener) Close() error { - return l.ln.Close() + err := l.ln.Close() + l.connMap.Range(func(k interface{}, v *udpServerConn) bool { + v.Close() + return true + }) + + return err } type udpServerConn struct { - conn net.PacketConn - raddr net.Addr - rChan, wChan chan []byte - closed chan struct{} - brokenChan chan struct{} - closeMutex sync.Mutex - ttl time.Duration - nopChan chan int + conn net.PacketConn + raddr net.Addr + rChan chan []byte + closed chan struct{} + closeMutex sync.Mutex + ttl time.Duration + nopChan chan int + onClose func() } func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration) *udpServerConn { c := &udpServerConn{ - conn: conn, - raddr: raddr, - rChan: make(chan []byte, 128), - wChan: make(chan []byte, 128), - closed: make(chan struct{}), - brokenChan: make(chan struct{}), - nopChan: make(chan int), - ttl: ttl, + conn: conn, + raddr: raddr, + rChan: make(chan []byte, 128), + closed: make(chan struct{}), + nopChan: make(chan int), + ttl: ttl, } - go c.writeLoop() go c.ttlWait() return c } @@ -453,12 +495,6 @@ func (c *udpServerConn) Read(b []byte) (n int, err error) { select { case bb := <-c.rChan: n = copy(b, bb) - if n != len(bb) { - err = errors.New("read partial data") - return - } - case <-c.brokenChan: - err = errors.New("Broken pipe") case <-c.closed: err = errors.New("read from closed connection") return @@ -468,26 +504,22 @@ func (c *udpServerConn) Read(b []byte) (n int, err error) { case c.nopChan <- n: default: } + return } func (c *udpServerConn) Write(b []byte) (n int, err error) { - if len(b) == 0 { - return 0, nil - } - select { - case c.wChan <- b: - n = len(b) - case <-c.brokenChan: - err = errors.New("Broken pipe") - case <-c.closed: - err = errors.New("write to closed connection") - return - } + n, err = c.conn.WriteTo(b, c.raddr) - select { - case c.nopChan <- n: - default: + if n > 0 { + if Debug { + log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n) + } + + select { + case c.nopChan <- n: + default: + } } return @@ -501,43 +533,14 @@ func (c *udpServerConn) Close() error { case <-c.closed: return errors.New("connection is closed") default: + if c.onClose != nil { + c.onClose() + } close(c.closed) } return nil } -func (c *udpServerConn) Closed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} - -func (c *udpServerConn) writeLoop() { - for { - select { - case b, ok := <-c.wChan: - if !ok { - return - } - n, err := c.conn.WriteTo(b, c.raddr) - if err != nil { - log.Logf("[udp] %s - %s : %s", c.RemoteAddr(), c.LocalAddr(), err) - return - } - if Debug { - log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n) - } - case <-c.brokenChan: - return - case <-c.closed: - return - } - } -} - func (c *udpServerConn) ttlWait() { ttl := c.ttl if ttl == 0 { @@ -554,7 +557,7 @@ func (c *udpServerConn) ttlWait() { } timer.Reset(ttl) case <-timer.C: - close(c.brokenChan) + c.Close() return case <-c.closed: return @@ -842,7 +845,7 @@ func (l *tcpRemoteForwardListener) Close() error { type udpRemoteForwardListener struct { addr net.Addr chain *Chain - conns map[string]*udpServerConn + connMap udpConnMap connChan chan net.Conn ln *net.UDPConn errChan chan error @@ -862,8 +865,7 @@ func UDPRemoteForwardListener(addr string, chain *Chain, ttl time.Duration) (Lis ln := &udpRemoteForwardListener{ addr: laddr, chain: chain, - conns: make(map[string]*udpServerConn), - connChan: make(chan net.Conn, 1024), + connChan: make(chan net.Conn, 128), errChan: make(chan error, 1), ttl: ttl, closed: make(chan struct{}), @@ -888,39 +890,47 @@ func (l *udpRemoteForwardListener) listenLoop() { log.Logf("[rudp] %s : %s", l.Addr(), err) return } - defer conn.Close() - for { - b := make([]byte, mediumBufferSize) - n, raddr, err := conn.ReadFrom(b) - if err != nil { - log.Logf("[rudp] %s : %s", l.Addr(), err) - break - } - if Debug { - log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) - } - uc, ok := l.conns[raddr.String()] - if !ok || uc.Closed() { - uc = newUDPServerConn(conn, raddr, l.ttl) - l.conns[raddr.String()] = uc + func() { + defer conn.Close() + + for { + b := make([]byte, mediumBufferSize) + n, raddr, err := conn.ReadFrom(b) + if err != nil { + log.Logf("[rudp] %s : %s", l.Addr(), err) + break + } + + uc, ok := l.connMap.Get(raddr) + if !ok { + uc = newUDPServerConn(conn, raddr, l.ttl) + uc.onClose = func() { + l.connMap.Delete(raddr) + log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size()) + } + + select { + case l.connChan <- uc: + l.connMap.Set(raddr, uc) + log.Logf("[rudp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) + default: + uc.Close() + log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr()) + } + } select { - case l.connChan <- uc: + case uc.rChan <- b[:n]: + if Debug { + log.Logf("[rudp] %s >>> %s : length %d", raddr, l.Addr(), n) + } default: - uc.Close() - log.Logf("[rudp] %s - %s: connection queue is full", raddr, l.Addr()) + log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr()) } } - - select { - case uc.rChan <- b[:n]: - default: - log.Logf("[rudp] %s -> %s : write queue is full", raddr, l.Addr()) - } - } + }() } - } func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { @@ -994,6 +1004,10 @@ func (l *udpRemoteForwardListener) Close() error { case <-l.closed: return nil default: + l.connMap.Range(func(k interface{}, v *udpServerConn) bool { + v.Close() + return true + }) close(l.closed) } diff --git a/ss.go b/ss.go index 3cfe0e2..8848ae7 100644 --- a/ss.go +++ b/ss.go @@ -298,10 +298,10 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn type shadowUDPListener struct { ln net.PacketConn - conns map[string]*udpServerConn connChan chan net.Conn errChan chan error ttl time.Duration + connMap udpConnMap } // ShadowUDPListener creates a Listener for shadowsocks UDP relay server. @@ -325,10 +325,10 @@ func ShadowUDPListener(addr string, cipher *url.Userinfo, ttl time.Duration) (Li ln.Close() return nil, err } + l := &shadowUDPListener{ ln: ss.NewSecurePacketConn(ln, cp, false), - conns: make(map[string]*udpServerConn), - connChan: make(chan net.Conn, 1024), + connChan: make(chan net.Conn, 128), errChan: make(chan error, 1), ttl: ttl, } @@ -347,17 +347,19 @@ func (l *shadowUDPListener) listenLoop() { close(l.errChan) return } - if Debug { - log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n) - } - conn, ok := l.conns[raddr.String()] - if !ok || conn.Closed() { + conn, ok := l.connMap.Get(raddr) + if !ok { conn = newUDPServerConn(l.ln, raddr, l.ttl) - l.conns[raddr.String()] = conn + conn.onClose = func() { + l.connMap.Delete(raddr) + log.Logf("[ssu] %s closed (%d)", raddr, l.connMap.Size()) + } select { case l.connChan <- conn: + l.connMap.Set(raddr, conn) + log.Logf("[ssu] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) default: conn.Close() log.Logf("[ssu] %s - %s: connection queue is full", raddr, l.Addr()) @@ -366,6 +368,9 @@ func (l *shadowUDPListener) listenLoop() { select { case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination. + if Debug { + log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n) + } default: log.Logf("[ssu] %s -> %s : read queue is full", raddr, l.Addr()) } @@ -389,7 +394,13 @@ func (l *shadowUDPListener) Addr() net.Addr { } func (l *shadowUDPListener) Close() error { - return l.ln.Close() + err := l.ln.Close() + l.connMap.Range(func(k interface{}, v *udpServerConn) bool { + v.Close() + return true + }) + + return err } type shadowUDPdHandler struct { @@ -474,7 +485,7 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error { return } if h.options.Bypass.Contains(addr.String()) { - log.Log("[ssu] [bypass] write to", addr) + log.Log("[ssu] bypass", addr) continue // bypass } if _, err := cc.WriteTo(dgram.Data, addr); err != nil { @@ -498,7 +509,7 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error { log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n) } if h.options.Bypass.Contains(addr.String()) { - log.Log("[ssu] [bypass] read from", addr) + log.Log("[ssu] bypass", addr) continue // bypass } dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n]) diff --git a/ss_test.go b/ss_test.go index 8390ca4..ff91273 100644 --- a/ss_test.go +++ b/ss_test.go @@ -391,8 +391,7 @@ func BenchmarkShadowUDP(b *testing.B) { } for i := 0; i < b.N; i++ { - conn.SetDeadline(time.Now().Add(1 * time.Second)) - defer conn.SetDeadline(time.Time{}) + conn.SetDeadline(time.Now().Add(3 * time.Second)) if _, err = conn.Write(sendData); err != nil { b.Error(err) @@ -403,6 +402,8 @@ func BenchmarkShadowUDP(b *testing.B) { b.Error(err) } + conn.SetDeadline(time.Time{}) + if !bytes.Equal(sendData, recv) { b.Error("data not equal") }