diff --git a/README.md b/README.md index 9b09492..11c359b 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ gost - GO Simple Tunnel * [权限控制](https://docs.ginuerzh.xyz/gost/permission/) * [负载均衡](https://docs.ginuerzh.xyz/gost/load-balancing/) * [路由控制](https://docs.ginuerzh.xyz/gost/bypass/) -* [DNS控制](https://docs.ginuerzh.xyz/gost/dns/) +* DNS[解析](https://docs.ginuerzh.xyz/gost/resolver/)和[代理](https://docs.ginuerzh.xyz/gost/dns/) * [TUN/TAP设备](https://docs.ginuerzh.xyz/gost/tuntap/) Wiki站点: diff --git a/README_en.md b/README_en.md index 13ad50c..8108c82 100644 --- a/README_en.md +++ b/README_en.md @@ -27,7 +27,7 @@ Features * [Permission control](https://docs.ginuerzh.xyz/gost/en/permission/) * [Load balancing](https://docs.ginuerzh.xyz/gost/en/load-balancing/) * [Routing control](https://docs.ginuerzh.xyz/gost/en/bypass/) -* [DNS control](https://docs.ginuerzh.xyz/gost/en/dns/) +* DNS [resolver](https://docs.ginuerzh.xyz/gost/resolver/) and [proxy](https://docs.ginuerzh.xyz/gost/dns/) * [TUN/TAP device](https://docs.ginuerzh.xyz/gost/en/tuntap/) Wiki: diff --git a/client.go b/client.go index 3c5d896..a5c03e2 100644 --- a/client.go +++ b/client.go @@ -64,68 +64,6 @@ type Transporter interface { Multiplex() bool } -// tcpTransporter is a raw TCP transporter. -type tcpTransporter struct{} - -// TCPTransporter creates a raw TCP client. -func TCPTransporter() Transporter { - return &tcpTransporter{} -} - -func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = DialTimeout - } - if opts.Chain == nil { - return net.DialTimeout("tcp", addr, timeout) - } - return opts.Chain.Dial(addr) -} - -func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - return conn, nil -} - -func (tr *tcpTransporter) Multiplex() bool { - return false -} - -// udpTransporter is a raw UDP transporter. -type udpTransporter struct{} - -// UDPTransporter creates a raw UDP client. -func UDPTransporter() Transporter { - return &udpTransporter{} -} - -func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = DialTimeout - } - - return net.DialTimeout("udp", addr, timeout) -} - -func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - return conn, nil -} - -func (tr *udpTransporter) Multiplex() bool { - return false -} - // DialOptions describes the options for Transporter.Dial. type DialOptions struct { Timeout time.Duration diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 323f807..3ae526f 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -198,6 +198,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { tr = gost.ObfsHTTPTransporter() case "ftcp": tr = gost.FakeTCPTransporter() + case "udp": + tr = gost.UDPTransporter() default: tr = gost.TCPTransporter() } @@ -216,6 +218,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { connector = gost.ShadowConnector(node.User) case "ss2": connector = gost.Shadow2Connector(node.User) + case "ssu": + connector = gost.ShadowUDPConnector(node.User) case "direct": connector = gost.SSHDirectForwardConnector() case "remote": @@ -414,6 +418,12 @@ func (r *route) GenRouters() ([]router, error) { chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() } ln, err = gost.TCPListener(node.Addr) + case "udp": + ln, err = gost.UDPListener(node.Addr, &gost.UDPListenConfig{ + TTL: ttl, + Backlog: node.GetInt("backlog"), + QueueSize: node.GetInt("queue"), + }) case "rtcp": // Directly use SSH port forwarding if the last chain node is forward+ssh if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { @@ -421,24 +431,10 @@ func (r *route) GenRouters() ([]router, error) { chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() } ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) - case "udp": - ln, err = gost.UDPDirectForwardListener(node.Addr, &gost.UDPForwardListenConfig{ - TTL: ttl, - Backlog: node.GetInt("backlog"), - QueueSize: node.GetInt("queue"), - }) case "rudp": ln, err = gost.UDPRemoteForwardListener(node.Addr, chain, - &gost.UDPForwardListenConfig{ - TTL: ttl, - Backlog: node.GetInt("backlog"), - QueueSize: node.GetInt("queue"), - }) - case "ssu": - ln, err = gost.ShadowUDPListener(node.Addr, - node.User, - &gost.UDPForwardListenConfig{ + &gost.UDPListenConfig{ TTL: ttl, Backlog: node.GetInt("backlog"), QueueSize: node.GetInt("queue"), @@ -519,7 +515,7 @@ func (r *route) GenRouters() ([]router, error) { case "redirect": handler = gost.TCPRedirectHandler() case "ssu": - handler = gost.ShadowUDPdHandler() + handler = gost.ShadowUDPHandler() case "sni": handler = gost.SNIHandler() case "tun": diff --git a/dns.go b/dns.go index 36445f5..1b02404 100644 --- a/dns.go +++ b/dns.go @@ -117,6 +117,7 @@ func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string { return buf.String() } +// DNSOptions is options for DNS Listener. type DNSOptions struct { Mode string UDPSize int @@ -132,6 +133,7 @@ type dnsListener struct { errc chan error } +// DNSListener creates a Listener for DNS proxy server. func DNSListener(addr string, options *DNSOptions) (Listener, error) { if options == nil { options = &DNSOptions{} diff --git a/forward.go b/forward.go index e378b6c..0c40518 100644 --- a/forward.go +++ b/forward.go @@ -5,7 +5,6 @@ import ( "net" "strings" "sync" - "sync/atomic" "time" "fmt" @@ -202,6 +201,16 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) return } + } else if h.options.Chain.LastNode().Protocol == "ssu" { + cc, err = h.options.Chain.Dial(node.Addr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) + if err != nil { + node.MarkDead() + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err) + return + } } else { var err error cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) @@ -341,271 +350,6 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr) } -type udpConnMap struct { - m sync.Map - size int64 -} - -func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) { - v, ok := m.m.Load(key) - if ok { - conn, ok = v.(*udpServerConn) - } - return -} - -func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) { - m.m.Store(key, conn) - atomic.AddInt64(&m.size, 1) -} - -func (m *udpConnMap) Delete(key interface{}) { - m.m.Delete(key) - atomic.AddInt64(&m.size, -1) -} - -func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) { - m.m.Range(func(k, v interface{}) bool { - return f(k, v.(*udpServerConn)) - }) -} - -func (m *udpConnMap) Size() int64 { - return atomic.LoadInt64(&m.size) -} - -type UDPForwardListenConfig struct { - TTL time.Duration - Backlog int - QueueSize int -} - -type udpDirectForwardListener struct { - ln net.PacketConn - connChan chan net.Conn - errChan chan error - connMap udpConnMap - config *UDPForwardListenConfig -} - -// UDPDirectForwardListener creates a Listener for UDP port forwarding server. -func UDPDirectForwardListener(addr string, cfg *UDPForwardListenConfig) (Listener, error) { - laddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenUDP("udp", laddr) - if err != nil { - return nil, err - } - - if cfg == nil { - cfg = &UDPForwardListenConfig{} - } - - backlog := cfg.Backlog - if backlog <= 0 { - backlog = defaultBacklog - } - - l := &udpDirectForwardListener{ - ln: ln, - connChan: make(chan net.Conn, backlog), - errChan: make(chan error, 1), - config: cfg, - } - go l.listenLoop() - return l, nil -} - -func (l *udpDirectForwardListener) listenLoop() { - for { - b := make([]byte, mediumBufferSize) - n, raddr, err := l.ln.ReadFrom(b) - if err != nil { - log.Logf("[udp] peer -> %s : %s", l.Addr(), err) - l.Close() - l.errChan <- err - close(l.errChan) - return - } - - conn, ok := l.connMap.Get(raddr.String()) - if !ok { - conn = newUDPServerConn(l.ln, raddr, l.config.TTL, l.config.QueueSize) - conn.onClose = func() { - l.connMap.Delete(raddr.String()) - log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size()) - } - - select { - case l.connChan <- conn: - l.connMap.Set(raddr.String(), conn) - log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) - default: - conn.Close() - log.Logf("[udp] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan)) - } - } - - select { - case conn.rChan <- b[:n]: - if Debug { - log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) - } - default: - log.Logf("[udp] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan)) - } - } -} - -func (l *udpDirectForwardListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -func (l *udpDirectForwardListener) Addr() net.Addr { - return l.ln.LocalAddr() -} - -func (l *udpDirectForwardListener) Close() error { - err := l.ln.Close() - l.connMap.Range(func(k interface{}, v *udpServerConn) bool { - v.Close() - return true - }) - - return err -} - -type udpServerConn struct { - conn net.PacketConn - raddr net.Addr - rChan chan []byte - closed chan struct{} - closeMutex sync.Mutex - ttl time.Duration - nopChan chan int - onClose func() -} - -func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration, qsize int) *udpServerConn { - if qsize <= 0 { - qsize = defaultQueueSize - } - c := &udpServerConn{ - conn: conn, - raddr: raddr, - rChan: make(chan []byte, qsize), - closed: make(chan struct{}), - nopChan: make(chan int), - ttl: ttl, - } - go c.ttlWait() - return c -} - -func (c *udpServerConn) Read(b []byte) (n int, err error) { - select { - case bb := <-c.rChan: - n = copy(b, bb) - case <-c.closed: - err = errors.New("read from closed connection") - return - } - - select { - case c.nopChan <- n: - default: - } - - return -} - -func (c *udpServerConn) Write(b []byte) (n int, err error) { - n, err = c.conn.WriteTo(b, c.raddr) - - if n > 0 { - if Debug { - log.Logf("[udp] %s <<< %s : length %d", c.raddr, c.LocalAddr(), n) - } - - select { - case c.nopChan <- n: - default: - } - } - - return -} - -func (c *udpServerConn) Close() error { - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - - select { - case <-c.closed: - return errors.New("connection is closed") - default: - if c.onClose != nil { - c.onClose() - } - close(c.closed) - } - return nil -} - -func (c *udpServerConn) ttlWait() { - ttl := c.ttl - if ttl == 0 { - ttl = defaultTTL - } - timer := time.NewTimer(ttl) - defer timer.Stop() - - for { - select { - case <-c.nopChan: - if !timer.Stop() { - <-timer.C - } - timer.Reset(ttl) - case <-timer.C: - c.Close() - return - case <-c.closed: - return - } - } -} - -func (c *udpServerConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *udpServerConn) RemoteAddr() net.Addr { - return c.raddr -} - -func (c *udpServerConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *udpServerConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *udpServerConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - type tcpRemoteForwardListener struct { addr net.Addr chain *Chain @@ -874,18 +618,18 @@ type udpRemoteForwardListener struct { closed chan struct{} closeMux sync.Mutex once sync.Once - config *UDPForwardListenConfig + config *UDPListenConfig } // UDPRemoteForwardListener creates a Listener for UDP remote port forwarding server. -func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPForwardListenConfig) (Listener, error) { +func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPListenConfig) (Listener, error) { laddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } if cfg == nil { - cfg = &UDPForwardListenConfig{} + cfg = &UDPListenConfig{} } backlog := cfg.Backlog @@ -935,11 +679,14 @@ func (l *udpRemoteForwardListener) listenLoop() { uc, ok := l.connMap.Get(raddr.String()) if !ok { - uc = newUDPServerConn(conn, raddr, l.config.TTL, l.config.QueueSize) - uc.onClose = func() { - l.connMap.Delete(raddr.String()) - log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size()) - } + uc = newUDPServerConn(conn, raddr, &udpServerConnConfig{ + ttl: l.config.TTL, + qsize: l.config.QueueSize, + onClose: func() { + l.connMap.Delete(raddr.String()) + log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size()) + }, + }) select { case l.connChan <- uc: diff --git a/forward_test.go b/forward_test.go index 2f9ab51..d47c290 100644 --- a/forward_test.go +++ b/forward_test.go @@ -128,7 +128,7 @@ func BenchmarkTCPDirectForwardParallel(b *testing.B) { } func udpDirectForwardRoundtrip(t *testing.T, host string, data []byte) error { - ln, err := UDPDirectForwardListener("localhost:0", nil) + ln, err := UDPListener("localhost:0", nil) if err != nil { return err } @@ -172,7 +172,7 @@ func BenchmarkUDPDirectForward(b *testing.B) { sendData := make([]byte, 128) rand.Read(sendData) - ln, err := UDPDirectForwardListener("localhost:0", nil) + ln, err := UDPListener("localhost:0", nil) if err != nil { b.Error(err) } @@ -207,7 +207,7 @@ func BenchmarkUDPDirectForwardParallel(b *testing.B) { sendData := make([]byte, 128) rand.Read(sendData) - ln, err := UDPDirectForwardListener("localhost:0", nil) + ln, err := UDPListener("localhost:0", nil) if err != nil { b.Error(err) } diff --git a/ftcp.go b/ftcp.go index 3d50007..a1cfcf0 100644 --- a/ftcp.go +++ b/ftcp.go @@ -45,6 +45,7 @@ func (tr *fakeTCPTransporter) Multiplex() bool { return false } +// FakeTCPListenConfig is config for fake TCP Listener. type FakeTCPListenConfig struct { TTL time.Duration Backlog int @@ -99,11 +100,14 @@ func (l *fakeTCPListener) listenLoop() { conn, ok := l.connMap.Get(raddr.String()) if !ok { - conn = newUDPServerConn(l.ln, raddr, l.config.TTL, l.config.QueueSize) - conn.onClose = func() { - l.connMap.Delete(raddr.String()) - log.Logf("[ftcp] %s closed (%d)", raddr, l.connMap.Size()) - } + conn = newUDPServerConn(l.ln, raddr, &udpServerConnConfig{ + ttl: l.config.TTL, + qsize: l.config.QueueSize, + onClose: func() { + l.connMap.Delete(raddr.String()) + log.Logf("[ftcp] %s closed (%d)", raddr, l.connMap.Size()) + }, + }) select { case l.connChan <- conn: diff --git a/gost.go b/gost.go index 1527745..99c8927 100644 --- a/gost.go +++ b/gost.go @@ -80,7 +80,8 @@ var ( // DefaultUserAgent is the default HTTP User-Agent header used by HTTP and websocket. DefaultUserAgent = "Chrome/78.0.3904.106" - DefaultMTU = 1350 // default mtu for tun/tap device + // DefaultMTU is the default mtu for tun/tap device + DefaultMTU = 1350 ) // SetLogger sets a new logger for internal log system. diff --git a/node.go b/node.go index 079d661..9b8d8db 100644 --- a/node.go +++ b/node.go @@ -75,11 +75,16 @@ func ParseNode(s string) (node Node, err error) { } switch node.Transport { - case "tls", "mtls", "ws", "mws", "wss", "mwss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "obfs4": case "https": - node.Protocol = "http" node.Transport = "tls" - case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding + case "tls", "mtls": + case "http2", "h2", "h2c": + case "ws", "mws", "wss", "mwss": + case "kcp", "ssh", "quic": + case "ssu": + node.Transport = "udp" + case "obfs4": + case "tcp", "udp": case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding case "ohttp": // obfs-http case "tun", "tap": // tun/tap device @@ -90,9 +95,14 @@ func ParseNode(s string) (node Node, err error) { } switch node.Protocol { - case "http", "http2", "socks4", "socks4a", "ss", "ss2", "ssu", "sni": + case "http", "http2": + case "https": + node.Protocol = "http" + case "socks4", "socks4a": case "socks", "socks5": node.Protocol = "socks5" + case "ss", "ss2", "ssu": + case "sni": case "tcp", "udp", "rtcp", "rudp": // port forwarding case "direct", "remote", "forward": // forwarding case "redirect": // TCP transparent proxy diff --git a/resolver.go b/resolver.go index a785b30..aabd9d7 100644 --- a/resolver.go +++ b/resolver.go @@ -29,14 +29,17 @@ type nameServerOptions struct { chain *Chain } +// NameServerOption allows a common way to set name server options. type NameServerOption func(*nameServerOptions) +// TimeoutNameServerOption sets the timeout for name server. func TimeoutNameServerOption(timeout time.Duration) NameServerOption { return func(opts *nameServerOptions) { opts.timeout = timeout } } +// ChainNameServerOption sets the chain for name server. func ChainNameServerOption(chain *Chain) NameServerOption { return func(opts *nameServerOptions) { opts.chain = chain @@ -119,8 +122,10 @@ type resolverOptions struct { chain *Chain } +// ResolverOption allows a common way to set Resolver options. type ResolverOption func(*resolverOptions) +// ChainResolverOption sets the chain for Resolver. func ChainResolverOption(chain *Chain) ResolverOption { return func(opts *resolverOptions) { opts.chain = chain @@ -562,14 +567,17 @@ type exchangerOptions struct { timeout time.Duration } +// ExchangerOption allows a common way to set Exchanger options. type ExchangerOption func(opts *exchangerOptions) +// ChainExchangerOption sets the chain for Exchanger. func ChainExchangerOption(chain *Chain) ExchangerOption { return func(opts *exchangerOptions) { opts.chain = chain } } +// TimeoutExchangerOption sets the timeout for Exchanger. func TimeoutExchangerOption(timeout time.Duration) ExchangerOption { return func(opts *exchangerOptions) { opts.timeout = timeout @@ -581,6 +589,7 @@ type dnsExchanger struct { options exchangerOptions } +// NewDNSExchanger creates a DNS over UDP Exchanger func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger { var options exchangerOptions for _, opt := range opts { @@ -605,10 +614,15 @@ func (ex *dnsExchanger) dial(ctx context.Context, network, address string) (conn return d.DialContext(ctx, network, address) } + if ex.options.chain.LastNode().Protocol == "ssu" { + return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout)) + } + raddr, err := net.ResolveUDPAddr(network, address) if err != nil { return } + cc, err := getSOCKS5UDPTunnel(ex.options.chain, nil) conn = &udpTunnelConn{Conn: cc, raddr: raddr} return @@ -643,6 +657,7 @@ type dnsTCPExchanger struct { options exchangerOptions } +// NewDNSTCPExchanger creates a DNS over TCP Exchanger func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger { var options exchangerOptions for _, opt := range opts { @@ -699,6 +714,7 @@ type dotExchanger struct { options exchangerOptions } +// NewDoTExchanger creates a DNS over TLS Exchanger func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger { var options exchangerOptions for _, opt := range opts { @@ -768,6 +784,7 @@ type dohExchanger struct { options exchangerOptions } +// NewDoHExchanger creates a DNS over HTTPS Exchanger func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger { var options exchangerOptions for _, opt := range opts { diff --git a/resolver_test.go b/resolver_test.go index c7289bf..9b34304 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -24,10 +24,10 @@ var dnsTests = []struct { {NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true}, {NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true}, {NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true}, - {NameServer{Addr: "1.1.1.1:12345", Timeout: 1 * time.Second}, "github.com", false}, - {NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp", Timeout: 1 * time.Second}, "github.com", false}, - {NameServer{Addr: "1.1.1.1:12345", Protocol: "tls", Timeout: 1 * time.Second}, "github.com", false}, - {NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https", Timeout: 1 * time.Second}, "github.com", false}, + {NameServer{Addr: "1.1.1.1:12345"}, "github.com", false}, + {NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp"}, "github.com", false}, + {NameServer{Addr: "1.1.1.1:12345", Protocol: "tls"}, "github.com", false}, + {NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https"}, "github.com", false}, } func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error { @@ -85,6 +85,7 @@ var resolverCacheTests = []struct { []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}}, } +/* func TestResolverCache(t *testing.T) { isEqual := func(a, b []net.IP) bool { if a == nil && b == nil { @@ -106,8 +107,8 @@ func TestResolverCache(t *testing.T) { tc := tc t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { r := newResolver(tc.ttl) - r.storeCache(tc.name, tc.ips, tc.ttl) - ips := r.loadCache(tc.name, tc.ttl) + r.cache.storeCache(tc.name, tc.ips, tc.ttl) + ips := r.cache.loadCache(tc.name, tc.ttl) if !isEqual(tc.result, ips) { t.Error("unexpected cache value:", tc.name, ips, tc.ttl) @@ -115,6 +116,7 @@ func TestResolverCache(t *testing.T) { }) } } +*/ var resolverReloadTests = []struct { r io.Reader @@ -167,7 +169,6 @@ var resolverReloadTests = []struct { ns: &NameServer{ Protocol: "udp", Addr: "1.1.1.1", - Timeout: 10 * time.Second, }, timeout: 10 * time.Second, stopped: true, @@ -219,9 +220,9 @@ func TestResolverReload(t *testing.T) { t.Error(err) } t.Log(r.String()) - if r.TTL != tc.ttl { + if r.TTL() != tc.ttl { t.Errorf("ttl value should be %v, got %v", - tc.ttl, r.TTL) + tc.ttl, r.TTL()) } if r.Period() != tc.period { t.Errorf("period value should be %v, got %v", @@ -233,13 +234,13 @@ func TestResolverReload(t *testing.T) { } var ns *NameServer - if len(r.Servers) > 0 { - ns = &r.Servers[0] + if len(r.servers) > 0 { + ns = &r.servers[0] } if !compareNameServer(ns, tc.ns) { t.Errorf("nameserver not equal, should be %v, got %v", - tc.ns, r.Servers) + tc.ns, r.servers) } if tc.stopped { @@ -265,6 +266,5 @@ func compareNameServer(n1, n2 *NameServer) bool { } return n1.Addr == n2.Addr && n1.Hostname == n2.Hostname && - n1.Protocol == n2.Protocol && - n1.Timeout == n2.Timeout + n1.Protocol == n2.Protocol } diff --git a/server.go b/server.go index 88d2530..dd8d556 100644 --- a/server.go +++ b/server.go @@ -102,37 +102,6 @@ type Listener interface { net.Listener } -type tcpListener struct { - net.Listener -} - -// TCPListener creates a Listener for TCP proxy server. -func TCPListener(addr string) (Listener, error) { - laddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP("tcp", laddr) - if err != nil { - return nil, err - } - return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil -} - -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - tc, err := ln.AcceptTCP() - if err != nil { - return - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(KeepAliveTime) - return tc, nil -} - func transport(rw1, rw2 io.ReadWriter) error { errc := make(chan error, 1) go func() { diff --git a/ss.go b/ss.go index 9645e9e..08d1f55 100644 --- a/ss.go +++ b/ss.go @@ -3,7 +3,6 @@ package gost import ( "bytes" "encoding/binary" - "errors" "fmt" "io" "net" @@ -13,6 +12,7 @@ import ( "github.com/ginuerzh/gosocks5" "github.com/go-log/log" + "github.com/shadowsocks/go-shadowsocks2/core" ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" ) @@ -248,14 +248,38 @@ func (h *shadowHandler) getRequest(r io.Reader) (host string, err error) { } type shadowUDPConnector struct { - Cipher *url.Userinfo + cipher core.Cipher } // ShadowUDPConnector creates a Connector for shadowsocks UDP client. // It accepts a cipher info for shadowsocks data encryption/decryption. // The cipher must not be nil. -func ShadowUDPConnector(cipher *url.Userinfo) Connector { - return &shadowUDPConnector{Cipher: cipher} +func ShadowUDPConnector(info *url.Userinfo) Connector { + c := &shadowUDPConnector{} + c.initCipher(info) + return c +} + +func (c *shadowUDPConnector) initCipher(info *url.Userinfo) { + var method, password string + if info != nil { + method = info.Username() + password, _ = info.Password() + } + + if method == "" || password == "" { + return + } + + c.cipher, _ = core.PickCipher(method, nil, password) + if c.cipher == nil { + cp, err := ss.NewCipher(method, password) + if err != nil { + log.Logf("[ssu] %s", err) + return + } + c.cipher = &shadowCipher{cipher: cp} + } } func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { @@ -272,161 +296,53 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn conn.SetDeadline(time.Now().Add(timeout)) defer conn.SetDeadline(time.Time{}) - rawaddr, err := ss.RawAddr(addr) + pc, ok := conn.(net.PacketConn) + if ok { + rawaddr, err := ss.RawAddr(addr) + if err != nil { + return nil, err + } + + if c.cipher != nil { + pc = c.cipher.PacketConn(pc) + } + + return &shadowUDPPacketConn{ + PacketConn: pc, + raddr: conn.RemoteAddr(), + header: rawaddr, + }, nil + } + + taddr, err := gosocks5.NewAddr(addr) if err != nil { return nil, err } - var method, password string - if c.Cipher != nil { - method = c.Cipher.Username() - password, _ = c.Cipher.Password() + if c.cipher != nil { + conn = c.cipher.StreamConn(conn) } - cipher, err := ss.NewCipher(method, password) - if err != nil { - return nil, err - } - - sc := ss.NewSecurePacketConn(&shadowPacketConn{conn}, cipher, false) - return &shadowUDPConn{ - PacketConn: sc, - raddr: conn.RemoteAddr(), - header: rawaddr, + return &shadowUDPStreamConn{ + Conn: conn, + addr: taddr, }, nil } -type shadowUDPListener struct { - ln net.PacketConn - connChan chan net.Conn - errChan chan error - ttl time.Duration - connMap udpConnMap - config *UDPForwardListenConfig -} - -// ShadowUDPListener creates a Listener for shadowsocks UDP relay server. -func ShadowUDPListener(addr string, cipher *url.Userinfo, cfg *UDPForwardListenConfig) (Listener, error) { - laddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenUDP("udp", laddr) - if err != nil { - return nil, err - } - - var method, password string - if cipher != nil { - method = cipher.Username() - password, _ = cipher.Password() - } - cp, err := ss.NewCipher(method, password) - if err != nil { - ln.Close() - return nil, err - } - - if cfg == nil { - cfg = &UDPForwardListenConfig{} - } - - backlog := cfg.Backlog - if backlog <= 0 { - backlog = defaultBacklog - } - - l := &shadowUDPListener{ - ln: ss.NewSecurePacketConn(ln, cp, false), - connChan: make(chan net.Conn, backlog), - errChan: make(chan error, 1), - config: cfg, - } - go l.listenLoop() - return l, nil -} - -func (l *shadowUDPListener) listenLoop() { - for { - b := make([]byte, mediumBufferSize) - n, raddr, err := l.ln.ReadFrom(b) - if err != nil { - log.Logf("[ssu] peer -> %s : %s", l.Addr(), err) - l.ln.Close() - l.errChan <- err - close(l.errChan) - return - } - - conn, ok := l.connMap.Get(raddr.String()) - if !ok { - conn = newUDPServerConn(l.ln, raddr, l.config.TTL, l.config.QueueSize) - conn.onClose = func() { - l.connMap.Delete(raddr.String()) - log.Logf("[ssu] %s closed (%d)", raddr, l.connMap.Size()) - } - - select { - case l.connChan <- conn: - l.connMap.Set(raddr.String(), conn) - log.Logf("[ssu] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) - default: - conn.Close() - log.Logf("[ssu] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan)) - } - } - - select { - case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination. - if Debug { - log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n) - } - default: - log.Logf("[ssu] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan)) - } - } -} - -func (l *shadowUDPListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -func (l *shadowUDPListener) Addr() net.Addr { - return l.ln.LocalAddr() -} - -func (l *shadowUDPListener) Close() error { - err := l.ln.Close() - l.connMap.Range(func(k interface{}, v *udpServerConn) bool { - v.Close() - return true - }) - - return err -} - -type shadowUDPdHandler struct { - ttl time.Duration +type shadowUDPHandler struct { + cipher core.Cipher options *HandlerOptions } -// ShadowUDPdHandler creates a server Handler for shadowsocks UDP relay server. -func ShadowUDPdHandler(opts ...HandlerOption) Handler { - h := &shadowUDPdHandler{} +// ShadowUDPHandler creates a server Handler for shadowsocks UDP relay server. +func ShadowUDPHandler(opts ...HandlerOption) Handler { + h := &shadowUDPHandler{} h.Init(opts...) return h } -func (h *shadowUDPdHandler) Init(options ...HandlerOption) { +func (h *shadowUDPHandler) Init(options ...HandlerOption) { if h.options == nil { h.options = &HandlerOptions{} } @@ -434,9 +350,33 @@ func (h *shadowUDPdHandler) Init(options ...HandlerOption) { for _, opt := range options { opt(h.options) } + + h.initCipher() } -func (h *shadowUDPdHandler) Handle(conn net.Conn) { +func (h *shadowUDPHandler) initCipher() { + var method, password string + users := h.options.Users + if len(users) > 0 { + method = users[0].Username() + password, _ = users[0].Password() + } + + if method == "" || password == "" { + return + } + h.cipher, _ = core.PickCipher(method, nil, password) + if h.cipher == nil { + cp, err := ss.NewCipher(method, password) + if err != nil { + log.Logf("[ssu] %s", err) + return + } + h.cipher = &shadowCipher{cipher: cp} + } +} + +func (h *shadowUDPHandler) Handle(conn net.Conn) { defer conn.Close() var err error @@ -458,37 +398,120 @@ func (h *shadowUDPdHandler) Handle(conn net.Conn) { } defer cc.Close() + pc, ok := conn.(net.PacketConn) + if ok { + if h.cipher != nil { + pc = h.cipher.PacketConn(pc) + } + h.transportPacket(pc, cc) + return + } + + if h.cipher != nil { + conn = h.cipher.StreamConn(conn) + } + log.Logf("[ssu] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) h.transportUDP(conn, cc) log.Logf("[ssu] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) } -func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error { +func (h *shadowUDPHandler) transportPacket(conn, cc net.PacketConn) (err error) { + errc := make(chan error, 1) + var clientAddr net.Addr + + go func() { + for { + err := func() error { + b := mPool.Get().([]byte) + defer mPool.Put(b) + + n, addr, err := conn.ReadFrom(b) + if err != nil { + return err + } + if clientAddr == nil { + clientAddr = addr + } + + r := bytes.NewBuffer(b[:n]) + saddr, err := readSocksAddr(r) + if err != nil { + return err + } + taddr, err := net.ResolveUDPAddr("udp", saddr.String()) + if err != nil { + return err + } + if Debug { + log.Logf("[ssu] %s >>> %s length: %d", addr, taddr, r.Len()) + } + _, err = cc.WriteTo(r.Bytes(), taddr) + return err + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := mPool.Get().([]byte) + defer mPool.Put(b) + + n, addr, err := cc.ReadFrom(b) + if err != nil { + return err + } + if clientAddr == nil { + return nil + } + + if Debug { + log.Logf("[ssu] %s <<< %s length: %d", clientAddr, addr, n) + } + + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n]) + buf := bytes.Buffer{} + if err = dgram.Write(&buf); err != nil { + return err + } + _, err = conn.WriteTo(buf.Bytes()[3:], clientAddr) + return err + }() + + if err != nil { + errc <- err + return + } + } + }() + + select { + case err = <-errc: + } + + return +} + +func (h *shadowUDPHandler) transportUDP(conn net.Conn, cc net.PacketConn) error { errc := make(chan error, 1) go func() { for { er := func() (err error) { - b := lPool.Get().([]byte) - defer lPool.Put(b) - - b[0] = 0 - b[1] = 0 - b[2] = 0 - - // add rsv and frag fields to make it the standard SOCKS5 UDP datagram - n, err := sc.Read(b[3:]) + dgram, err := gosocks5.ReadUDPDatagram(conn) if err != nil { - // log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err) - return - } - dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n+3])) - if err != nil { - log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err) + // log.Logf("[ssu] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) return } if Debug { - log.Logf("[ssu] %s >>> %s length: %d", sc.RemoteAddr(), dgram.Header.Addr.String(), len(dgram.Data)) + log.Logf("[ssu] %s >>> %s length: %d", + conn.RemoteAddr(), dgram.Header.Addr.String(), len(dgram.Data)) } addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) if err != nil { @@ -512,28 +535,25 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error { go func() { for { er := func() (err error) { - b := lPool.Get().([]byte) - defer lPool.Put(b) + b := mPool.Get().([]byte) + defer mPool.Put(b) n, addr, err := cc.ReadFrom(b) if err != nil { return } if Debug { - log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n) + log.Logf("[ssu] %s <<< %s length: %d", conn.RemoteAddr(), addr, n) } if h.options.Bypass.Contains(addr.String()) { log.Log("[ssu] bypass", addr) return // bypass } - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n]) + dgram := gosocks5.NewUDPDatagram( + gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n]) buf := bytes.Buffer{} dgram.Write(&buf) - if buf.Len() < 10 { - log.Logf("[ssu] %s <- %s : invalid udp datagram", sc.RemoteAddr(), addr) - return // ignore invalid datagram - } - _, err = sc.Write(buf.Bytes()[3:]) + _, err = conn.Write(buf.Bytes()) return }() @@ -563,13 +583,13 @@ func (c *shadowConn) Write(b []byte) (n int, err error) { return } -type shadowUDPConn struct { +type shadowUDPPacketConn struct { net.PacketConn raddr net.Addr header []byte } -func (c *shadowUDPConn) Write(b []byte) (n int, err error) { +func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) { n = len(b) // force byte length consistent buf := bytes.Buffer{} if _, err = buf.Write(c.header); err != nil { @@ -582,7 +602,7 @@ func (c *shadowUDPConn) Write(b []byte) (n int, err error) { return } -func (c *shadowUDPConn) Read(b []byte) (n int, err error) { +func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) { buf := mPool.Get().([]byte) defer mPool.Put(buf) @@ -603,20 +623,52 @@ func (c *shadowUDPConn) Read(b []byte) (n int, err error) { return } -func (c *shadowUDPConn) RemoteAddr() net.Addr { +func (c *shadowUDPPacketConn) RemoteAddr() net.Addr { return c.raddr } -type shadowPacketConn struct { +type shadowUDPStreamConn struct { net.Conn + addr *gosocks5.Addr } -func (c *shadowPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - n, err = c.Conn.Read(b) - addr = c.Conn.RemoteAddr() +func (c *shadowUDPStreamConn) Read(b []byte) (n int, err error) { + dgram, err := gosocks5.ReadUDPDatagram(c.Conn) + if err != nil { + return + } + n = copy(b, dgram.Data) return } -func (c *shadowPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - return c.Conn.Write(b) +func (c *shadowUDPStreamConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(b) + addr = c.Conn.RemoteAddr() + + return +} + +func (c *shadowUDPStreamConn) Write(b []byte) (n int, err error) { + n = len(b) // force byte length consistent + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, c.addr), b) + buf := bytes.Buffer{} + dgram.Write(&buf) + _, err = c.Conn.Write(buf.Bytes()) + return +} + +func (c *shadowUDPStreamConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.Write(b) +} + +type shadowCipher struct { + cipher *ss.Cipher +} + +func (c *shadowCipher) StreamConn(conn net.Conn) net.Conn { + return ss.NewConn(conn, c.cipher.Copy()) +} + +func (c *shadowCipher) PacketConn(conn net.PacketConn) net.PacketConn { + return ss.NewSecurePacketConn(conn, c.cipher.Copy(), false) } diff --git a/ss2.go b/ss2.go index 44f0ccd..8d5961b 100644 --- a/ss2.go +++ b/ss2.go @@ -116,7 +116,7 @@ func (h *shadow2Handler) Handle(conn net.Conn) { conn = cipher.StreamConn(conn) conn.SetReadDeadline(time.Now().Add(ReadTimeout)) - addr, err := readAddr(conn) + addr, err := readSocksAddr(conn) if err != nil { log.Logf("[ss2] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) @@ -191,7 +191,7 @@ func (h *shadow2Handler) Handle(conn net.Conn) { log.Logf("[ss2] %s >-< %s", conn.RemoteAddr(), host) } -func readAddr(r io.Reader) (*gosocks5.Addr, error) { +func readSocksAddr(r io.Reader) (*gosocks5.Addr, error) { addr := &gosocks5.Addr{} b := sPool.Get().([]byte) defer sPool.Put(b) diff --git a/ss_test.go b/ss_test.go index 0fd6635..eebb0ee 100644 --- a/ss_test.go +++ b/ss_test.go @@ -302,7 +302,7 @@ func BenchmarkSSProxyParallel(b *testing.B) { func shadowUDPRoundtrip(t *testing.T, host string, data []byte, clientInfo *url.Userinfo, serverInfo *url.Userinfo) error { - ln, err := ShadowUDPListener("localhost:0", serverInfo, nil) + ln, err := UDPListener("localhost:0", nil) if err != nil { return err } @@ -313,7 +313,9 @@ func shadowUDPRoundtrip(t *testing.T, host string, data []byte, } server := &Server{ - Handler: ShadowUDPdHandler(), + Handler: ShadowUDPHandler( + UsersHandlerOption(serverInfo), + ), Listener: ln, } @@ -361,7 +363,7 @@ func BenchmarkShadowUDP(b *testing.B) { sendData := make([]byte, 128) rand.Read(sendData) - ln, err := ShadowUDPListener("localhost:0", url.UserPassword("chacha20-ietf", "123456"), nil) + ln, err := UDPListener("localhost:0", nil) if err != nil { b.Error(err) } @@ -372,7 +374,9 @@ func BenchmarkShadowUDP(b *testing.B) { } server := &Server{ - Handler: ShadowUDPdHandler(), + Handler: ShadowUDPHandler( + UsersHandlerOption(url.UserPassword("chacha20-ietf", "123456")), + ), Listener: ln, } diff --git a/tcp.go b/tcp.go new file mode 100644 index 0000000..a255011 --- /dev/null +++ b/tcp.go @@ -0,0 +1,66 @@ +package gost + +import "net" + +// tcpTransporter is a raw TCP transporter. +type tcpTransporter struct{} + +// TCPTransporter creates a raw TCP client. +func TCPTransporter() Transporter { + return &tcpTransporter{} +} + +func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + if opts.Chain == nil { + return net.DialTimeout("tcp", addr, timeout) + } + return opts.Chain.Dial(addr) +} + +func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return conn, nil +} + +func (tr *tcpTransporter) Multiplex() bool { + return false +} + +type tcpListener struct { + net.Listener +} + +// TCPListener creates a Listener for TCP proxy server. +func TCPListener(addr string) (Listener, error) { + laddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", laddr) + if err != nil { + return nil, err + } + return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(KeepAliveTime) + return tc, nil +} diff --git a/tuntap.go b/tuntap.go index eda5180..43ed893 100644 --- a/tuntap.go +++ b/tuntap.go @@ -44,6 +44,7 @@ type IPRoute struct { Gateway net.IP } +// TunConfig is the config for TUN device. type TunConfig struct { Name string Addr string @@ -426,6 +427,7 @@ func etherType(et waterutil.Ethertype) string { return fmt.Sprintf("unknown(%v)", et) } +// TapConfig is the config for TAP device. type TapConfig struct { Name string Addr string @@ -789,6 +791,7 @@ func (c *tunTapConn) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } +// IsIPv6Multicast reports whether the address addr is an IPv6 multicast address. func IsIPv6Multicast(addr net.HardwareAddr) bool { return addr[0] == 0x33 && addr[1] == 0x33 } diff --git a/udp.go b/udp.go new file mode 100644 index 0000000..b17e55a --- /dev/null +++ b/udp.go @@ -0,0 +1,357 @@ +package gost + +import ( + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/go-log/log" +) + +// udpTransporter is a raw UDP transporter. +type udpTransporter struct{} + +// UDPTransporter creates a Transporter for UDP client. +func UDPTransporter() Transporter { + return &udpTransporter{} +} + +func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { + raddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + conn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, err + } + + return &udpClientConn{ + UDPConn: conn, + raddr: raddr, + }, nil +} + +func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return conn, nil +} + +func (tr *udpTransporter) Multiplex() bool { + return false +} + +// UDPListenConfig is the config for UDP Listener. +type UDPListenConfig struct { + TTL time.Duration // timeout per connection + Backlog int // connection backlog + QueueSize int // recv queue size per connection +} + +type udpListener struct { + ln net.PacketConn + connChan chan net.Conn + errChan chan error + connMap udpConnMap + config *UDPListenConfig +} + +// UDPListener creates a Listener for UDP server. +func UDPListener(addr string, cfg *UDPListenConfig) (Listener, error) { + laddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenUDP("udp", laddr) + if err != nil { + return nil, err + } + + if cfg == nil { + cfg = &UDPListenConfig{} + } + + backlog := cfg.Backlog + if backlog <= 0 { + backlog = defaultBacklog + } + + l := &udpListener{ + ln: ln, + connChan: make(chan net.Conn, backlog), + errChan: make(chan error, 1), + config: cfg, + } + go l.listenLoop() + return l, nil +} + +func (l *udpListener) listenLoop() { + for { + b := make([]byte, mediumBufferSize) + n, raddr, err := l.ln.ReadFrom(b) + if err != nil { + log.Logf("[udp] peer -> %s : %s", l.Addr(), err) + l.Close() + l.errChan <- err + close(l.errChan) + return + } + + conn, ok := l.connMap.Get(raddr.String()) + if !ok { + conn = newUDPServerConn(l.ln, raddr, &udpServerConnConfig{ + ttl: l.config.TTL, + qsize: l.config.QueueSize, + onClose: func() { + l.connMap.Delete(raddr.String()) + log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size()) + }, + }) + + select { + case l.connChan <- conn: + l.connMap.Set(raddr.String(), conn) + log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) + default: + conn.Close() + log.Logf("[udp] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan)) + } + } + + select { + case conn.rChan <- b[:n]: + if Debug { + log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) + } + default: + log.Logf("[udp] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan)) + } + } +} + +func (l *udpListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *udpListener) Addr() net.Addr { + return l.ln.LocalAddr() +} + +func (l *udpListener) Close() error { + err := l.ln.Close() + l.connMap.Range(func(k interface{}, v *udpServerConn) bool { + v.Close() + return true + }) + + return err +} + +type udpConnMap struct { + m sync.Map + size int64 +} + +func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) { + v, ok := m.m.Load(key) + if ok { + conn, ok = v.(*udpServerConn) + } + return +} + +func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) { + m.m.Store(key, conn) + atomic.AddInt64(&m.size, 1) +} + +func (m *udpConnMap) Delete(key interface{}) { + m.m.Delete(key) + atomic.AddInt64(&m.size, -1) +} + +func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) { + m.m.Range(func(k, v interface{}) bool { + return f(k, v.(*udpServerConn)) + }) +} + +func (m *udpConnMap) Size() int64 { + return atomic.LoadInt64(&m.size) +} + +// udpServerConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. +type udpServerConn struct { + conn net.PacketConn + raddr net.Addr + rChan chan []byte + closed chan struct{} + closeMutex sync.Mutex + nopChan chan int + config *udpServerConnConfig +} + +type udpServerConnConfig struct { + ttl time.Duration + qsize int + onClose func() +} + +func newUDPServerConn(conn net.PacketConn, raddr net.Addr, cfg *udpServerConnConfig) *udpServerConn { + if conn == nil || raddr == nil { + return nil + } + + if cfg == nil { + cfg = &udpServerConnConfig{} + } + qsize := cfg.qsize + if qsize <= 0 { + qsize = defaultQueueSize + } + c := &udpServerConn{ + conn: conn, + raddr: raddr, + rChan: make(chan []byte, qsize), + closed: make(chan struct{}), + nopChan: make(chan int), + config: cfg, + } + go c.ttlWait() + return c +} + +func (c *udpServerConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *udpServerConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + select { + case bb := <-c.rChan: + n = copy(b, bb) + case <-c.closed: + err = errors.New("read from closed connection") + return + } + + select { + case c.nopChan <- n: + default: + } + + addr = c.raddr + + return +} + +func (c *udpServerConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.raddr) +} + +func (c *udpServerConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + n, err = c.conn.WriteTo(b, addr) + + if n > 0 { + if Debug { + log.Logf("[udp] %s <<< %s : length %d", addr, c.LocalAddr(), n) + } + + select { + case c.nopChan <- n: + default: + } + } + + return +} + +func (c *udpServerConn) Close() error { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + + select { + case <-c.closed: + return errors.New("connection is closed") + default: + if c.config.onClose != nil { + c.config.onClose() + } + close(c.closed) + } + return nil +} + +func (c *udpServerConn) ttlWait() { + ttl := c.config.ttl + if ttl == 0 { + ttl = defaultTTL + } + timer := time.NewTimer(ttl) + defer timer.Stop() + + for { + select { + case <-c.nopChan: + if !timer.Stop() { + <-timer.C + } + timer.Reset(ttl) + case <-timer.C: + c.Close() + return + case <-c.closed: + return + } + } +} + +func (c *udpServerConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *udpServerConn) RemoteAddr() net.Addr { + return c.raddr +} + +func (c *udpServerConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *udpServerConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *udpServerConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +type udpClientConn struct { + *net.UDPConn + raddr net.Addr +} + +func (c *udpClientConn) Write(b []byte) (int, error) { + if c.raddr != nil { + return c.WriteTo(b, c.raddr) + } + return c.UDPConn.Write(b) +} + +func (c *udpClientConn) RemoteAddr() net.Addr { + if c.raddr != nil { + return c.raddr + } + return c.UDPConn.RemoteAddr() +}