add ssu connector

This commit is contained in:
ginuerzh 2020-01-30 13:37:27 +08:00
parent 4133cf30b4
commit 6ce3639c02
19 changed files with 763 additions and 597 deletions

View File

@ -30,7 +30,7 @@ gost - GO Simple Tunnel
* [权限控制](https://docs.ginuerzh.xyz/gost/permission/)
* [负载均衡](https://docs.ginuerzh.xyz/gost/load-balancing/)
* [路由控制](https://docs.ginuerzh.xyz/gost/bypass/)
* [DNS控制](https://docs.ginuerzh.xyz/gost/dns/)
* DNS[解析](https://docs.ginuerzh.xyz/gost/resolver/)和[代理](https://docs.ginuerzh.xyz/gost/dns/)
* [TUN/TAP设备](https://docs.ginuerzh.xyz/gost/tuntap/)
Wiki站点: <https://docs.ginuerzh.xyz/gost/>

View File

@ -27,7 +27,7 @@ Features
* [Permission control](https://docs.ginuerzh.xyz/gost/en/permission/)
* [Load balancing](https://docs.ginuerzh.xyz/gost/en/load-balancing/)
* [Routing control](https://docs.ginuerzh.xyz/gost/en/bypass/)
* [DNS control](https://docs.ginuerzh.xyz/gost/en/dns/)
* DNS [resolver](https://docs.ginuerzh.xyz/gost/resolver/) and [proxy](https://docs.ginuerzh.xyz/gost/dns/)
* [TUN/TAP device](https://docs.ginuerzh.xyz/gost/en/tuntap/)
Wiki: <https://docs.ginuerzh.xyz/gost/en/>

View File

@ -64,68 +64,6 @@ type Transporter interface {
Multiplex() bool
}
// tcpTransporter is a raw TCP transporter.
type tcpTransporter struct{}
// TCPTransporter creates a raw TCP client.
func TCPTransporter() Transporter {
return &tcpTransporter{}
}
func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
opts := &DialOptions{}
for _, option := range options {
option(opts)
}
timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}
if opts.Chain == nil {
return net.DialTimeout("tcp", addr, timeout)
}
return opts.Chain.Dial(addr)
}
func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
return conn, nil
}
func (tr *tcpTransporter) Multiplex() bool {
return false
}
// udpTransporter is a raw UDP transporter.
type udpTransporter struct{}
// UDPTransporter creates a raw UDP client.
func UDPTransporter() Transporter {
return &udpTransporter{}
}
func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
opts := &DialOptions{}
for _, option := range options {
option(opts)
}
timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}
return net.DialTimeout("udp", addr, timeout)
}
func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
return conn, nil
}
func (tr *udpTransporter) Multiplex() bool {
return false
}
// DialOptions describes the options for Transporter.Dial.
type DialOptions struct {
Timeout time.Duration

View File

@ -198,6 +198,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
tr = gost.ObfsHTTPTransporter()
case "ftcp":
tr = gost.FakeTCPTransporter()
case "udp":
tr = gost.UDPTransporter()
default:
tr = gost.TCPTransporter()
}
@ -216,6 +218,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
connector = gost.ShadowConnector(node.User)
case "ss2":
connector = gost.Shadow2Connector(node.User)
case "ssu":
connector = gost.ShadowUDPConnector(node.User)
case "direct":
connector = gost.SSHDirectForwardConnector()
case "remote":
@ -414,6 +418,12 @@ func (r *route) GenRouters() ([]router, error) {
chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter()
}
ln, err = gost.TCPListener(node.Addr)
case "udp":
ln, err = gost.UDPListener(node.Addr, &gost.UDPListenConfig{
TTL: ttl,
Backlog: node.GetInt("backlog"),
QueueSize: node.GetInt("queue"),
})
case "rtcp":
// Directly use SSH port forwarding if the last chain node is forward+ssh
if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" {
@ -421,24 +431,10 @@ func (r *route) GenRouters() ([]router, error) {
chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter()
}
ln, err = gost.TCPRemoteForwardListener(node.Addr, chain)
case "udp":
ln, err = gost.UDPDirectForwardListener(node.Addr, &gost.UDPForwardListenConfig{
TTL: ttl,
Backlog: node.GetInt("backlog"),
QueueSize: node.GetInt("queue"),
})
case "rudp":
ln, err = gost.UDPRemoteForwardListener(node.Addr,
chain,
&gost.UDPForwardListenConfig{
TTL: ttl,
Backlog: node.GetInt("backlog"),
QueueSize: node.GetInt("queue"),
})
case "ssu":
ln, err = gost.ShadowUDPListener(node.Addr,
node.User,
&gost.UDPForwardListenConfig{
&gost.UDPListenConfig{
TTL: ttl,
Backlog: node.GetInt("backlog"),
QueueSize: node.GetInt("queue"),
@ -519,7 +515,7 @@ func (r *route) GenRouters() ([]router, error) {
case "redirect":
handler = gost.TCPRedirectHandler()
case "ssu":
handler = gost.ShadowUDPdHandler()
handler = gost.ShadowUDPHandler()
case "sni":
handler = gost.SNIHandler()
case "tun":

2
dns.go
View File

@ -117,6 +117,7 @@ func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string {
return buf.String()
}
// DNSOptions is options for DNS Listener.
type DNSOptions struct {
Mode string
UDPSize int
@ -132,6 +133,7 @@ type dnsListener struct {
errc chan error
}
// DNSListener creates a Listener for DNS proxy server.
func DNSListener(addr string, options *DNSOptions) (Listener, error) {
if options == nil {
options = &DNSOptions{}

View File

@ -5,7 +5,6 @@ import (
"net"
"strings"
"sync"
"sync/atomic"
"time"
"fmt"
@ -202,6 +201,16 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
} else if h.options.Chain.LastNode().Protocol == "ssu" {
cc, err = h.options.Chain.Dial(node.Addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
} else {
var err error
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
@ -341,271 +350,6 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr)
}
type udpConnMap struct {
m sync.Map
size int64
}
func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) {
v, ok := m.m.Load(key)
if ok {
conn, ok = v.(*udpServerConn)
}
return
}
func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) {
m.m.Store(key, conn)
atomic.AddInt64(&m.size, 1)
}
func (m *udpConnMap) Delete(key interface{}) {
m.m.Delete(key)
atomic.AddInt64(&m.size, -1)
}
func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) {
m.m.Range(func(k, v interface{}) bool {
return f(k, v.(*udpServerConn))
})
}
func (m *udpConnMap) Size() int64 {
return atomic.LoadInt64(&m.size)
}
type UDPForwardListenConfig struct {
TTL time.Duration
Backlog int
QueueSize int
}
type udpDirectForwardListener struct {
ln net.PacketConn
connChan chan net.Conn
errChan chan error
connMap udpConnMap
config *UDPForwardListenConfig
}
// UDPDirectForwardListener creates a Listener for UDP port forwarding server.
func UDPDirectForwardListener(addr string, cfg *UDPForwardListenConfig) (Listener, error) {
laddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
ln, err := net.ListenUDP("udp", laddr)
if err != nil {
return nil, err
}
if cfg == nil {
cfg = &UDPForwardListenConfig{}
}
backlog := cfg.Backlog
if backlog <= 0 {
backlog = defaultBacklog
}
l := &udpDirectForwardListener{
ln: ln,
connChan: make(chan net.Conn, backlog),
errChan: make(chan error, 1),
config: cfg,
}
go l.listenLoop()
return l, nil
}
func (l *udpDirectForwardListener) listenLoop() {
for {
b := make([]byte, mediumBufferSize)
n, raddr, err := l.ln.ReadFrom(b)
if err != nil {
log.Logf("[udp] peer -> %s : %s", l.Addr(), err)
l.Close()
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connMap.Get(raddr.String())
if !ok {
conn = newUDPServerConn(l.ln, raddr, l.config.TTL, l.config.QueueSize)
conn.onClose = func() {
l.connMap.Delete(raddr.String())
log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size())
}
select {
case l.connChan <- conn:
l.connMap.Set(raddr.String(), conn)
log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
default:
conn.Close()
log.Logf("[udp] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan))
}
}
select {
case conn.rChan <- b[:n]:
if Debug {
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
}
default:
log.Logf("[udp] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan))
}
}
}
func (l *udpDirectForwardListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = errors.New("accpet on closed listener")
}
}
return
}
func (l *udpDirectForwardListener) Addr() net.Addr {
return l.ln.LocalAddr()
}
func (l *udpDirectForwardListener) Close() error {
err := l.ln.Close()
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
v.Close()
return true
})
return err
}
type udpServerConn struct {
conn net.PacketConn
raddr net.Addr
rChan chan []byte
closed chan struct{}
closeMutex sync.Mutex
ttl time.Duration
nopChan chan int
onClose func()
}
func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration, qsize int) *udpServerConn {
if qsize <= 0 {
qsize = defaultQueueSize
}
c := &udpServerConn{
conn: conn,
raddr: raddr,
rChan: make(chan []byte, qsize),
closed: make(chan struct{}),
nopChan: make(chan int),
ttl: ttl,
}
go c.ttlWait()
return c
}
func (c *udpServerConn) Read(b []byte) (n int, err error) {
select {
case bb := <-c.rChan:
n = copy(b, bb)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
select {
case c.nopChan <- n:
default:
}
return
}
func (c *udpServerConn) Write(b []byte) (n int, err error) {
n, err = c.conn.WriteTo(b, c.raddr)
if n > 0 {
if Debug {
log.Logf("[udp] %s <<< %s : length %d", c.raddr, c.LocalAddr(), n)
}
select {
case c.nopChan <- n:
default:
}
}
return
}
func (c *udpServerConn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.onClose != nil {
c.onClose()
}
close(c.closed)
}
return nil
}
func (c *udpServerConn) ttlWait() {
ttl := c.ttl
if ttl == 0 {
ttl = defaultTTL
}
timer := time.NewTimer(ttl)
defer timer.Stop()
for {
select {
case <-c.nopChan:
if !timer.Stop() {
<-timer.C
}
timer.Reset(ttl)
case <-timer.C:
c.Close()
return
case <-c.closed:
return
}
}
}
func (c *udpServerConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *udpServerConn) RemoteAddr() net.Addr {
return c.raddr
}
func (c *udpServerConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *udpServerConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *udpServerConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
type tcpRemoteForwardListener struct {
addr net.Addr
chain *Chain
@ -874,18 +618,18 @@ type udpRemoteForwardListener struct {
closed chan struct{}
closeMux sync.Mutex
once sync.Once
config *UDPForwardListenConfig
config *UDPListenConfig
}
// UDPRemoteForwardListener creates a Listener for UDP remote port forwarding server.
func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPForwardListenConfig) (Listener, error) {
func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPListenConfig) (Listener, error) {
laddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
if cfg == nil {
cfg = &UDPForwardListenConfig{}
cfg = &UDPListenConfig{}
}
backlog := cfg.Backlog
@ -935,11 +679,14 @@ func (l *udpRemoteForwardListener) listenLoop() {
uc, ok := l.connMap.Get(raddr.String())
if !ok {
uc = newUDPServerConn(conn, raddr, l.config.TTL, l.config.QueueSize)
uc.onClose = func() {
uc = newUDPServerConn(conn, raddr, &udpServerConnConfig{
ttl: l.config.TTL,
qsize: l.config.QueueSize,
onClose: func() {
l.connMap.Delete(raddr.String())
log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size())
}
},
})
select {
case l.connChan <- uc:

View File

@ -128,7 +128,7 @@ func BenchmarkTCPDirectForwardParallel(b *testing.B) {
}
func udpDirectForwardRoundtrip(t *testing.T, host string, data []byte) error {
ln, err := UDPDirectForwardListener("localhost:0", nil)
ln, err := UDPListener("localhost:0", nil)
if err != nil {
return err
}
@ -172,7 +172,7 @@ func BenchmarkUDPDirectForward(b *testing.B) {
sendData := make([]byte, 128)
rand.Read(sendData)
ln, err := UDPDirectForwardListener("localhost:0", nil)
ln, err := UDPListener("localhost:0", nil)
if err != nil {
b.Error(err)
}
@ -207,7 +207,7 @@ func BenchmarkUDPDirectForwardParallel(b *testing.B) {
sendData := make([]byte, 128)
rand.Read(sendData)
ln, err := UDPDirectForwardListener("localhost:0", nil)
ln, err := UDPListener("localhost:0", nil)
if err != nil {
b.Error(err)
}

10
ftcp.go
View File

@ -45,6 +45,7 @@ func (tr *fakeTCPTransporter) Multiplex() bool {
return false
}
// FakeTCPListenConfig is config for fake TCP Listener.
type FakeTCPListenConfig struct {
TTL time.Duration
Backlog int
@ -99,11 +100,14 @@ func (l *fakeTCPListener) listenLoop() {
conn, ok := l.connMap.Get(raddr.String())
if !ok {
conn = newUDPServerConn(l.ln, raddr, l.config.TTL, l.config.QueueSize)
conn.onClose = func() {
conn = newUDPServerConn(l.ln, raddr, &udpServerConnConfig{
ttl: l.config.TTL,
qsize: l.config.QueueSize,
onClose: func() {
l.connMap.Delete(raddr.String())
log.Logf("[ftcp] %s closed (%d)", raddr, l.connMap.Size())
}
},
})
select {
case l.connChan <- conn:

View File

@ -80,7 +80,8 @@ var (
// DefaultUserAgent is the default HTTP User-Agent header used by HTTP and websocket.
DefaultUserAgent = "Chrome/78.0.3904.106"
DefaultMTU = 1350 // default mtu for tun/tap device
// DefaultMTU is the default mtu for tun/tap device
DefaultMTU = 1350
)
// SetLogger sets a new logger for internal log system.

18
node.go
View File

@ -75,11 +75,16 @@ func ParseNode(s string) (node Node, err error) {
}
switch node.Transport {
case "tls", "mtls", "ws", "mws", "wss", "mwss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "obfs4":
case "https":
node.Protocol = "http"
node.Transport = "tls"
case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding
case "tls", "mtls":
case "http2", "h2", "h2c":
case "ws", "mws", "wss", "mwss":
case "kcp", "ssh", "quic":
case "ssu":
node.Transport = "udp"
case "obfs4":
case "tcp", "udp":
case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding
case "ohttp": // obfs-http
case "tun", "tap": // tun/tap device
@ -90,9 +95,14 @@ func ParseNode(s string) (node Node, err error) {
}
switch node.Protocol {
case "http", "http2", "socks4", "socks4a", "ss", "ss2", "ssu", "sni":
case "http", "http2":
case "https":
node.Protocol = "http"
case "socks4", "socks4a":
case "socks", "socks5":
node.Protocol = "socks5"
case "ss", "ss2", "ssu":
case "sni":
case "tcp", "udp", "rtcp", "rudp": // port forwarding
case "direct", "remote", "forward": // forwarding
case "redirect": // TCP transparent proxy

View File

@ -29,14 +29,17 @@ type nameServerOptions struct {
chain *Chain
}
// NameServerOption allows a common way to set name server options.
type NameServerOption func(*nameServerOptions)
// TimeoutNameServerOption sets the timeout for name server.
func TimeoutNameServerOption(timeout time.Duration) NameServerOption {
return func(opts *nameServerOptions) {
opts.timeout = timeout
}
}
// ChainNameServerOption sets the chain for name server.
func ChainNameServerOption(chain *Chain) NameServerOption {
return func(opts *nameServerOptions) {
opts.chain = chain
@ -119,8 +122,10 @@ type resolverOptions struct {
chain *Chain
}
// ResolverOption allows a common way to set Resolver options.
type ResolverOption func(*resolverOptions)
// ChainResolverOption sets the chain for Resolver.
func ChainResolverOption(chain *Chain) ResolverOption {
return func(opts *resolverOptions) {
opts.chain = chain
@ -562,14 +567,17 @@ type exchangerOptions struct {
timeout time.Duration
}
// ExchangerOption allows a common way to set Exchanger options.
type ExchangerOption func(opts *exchangerOptions)
// ChainExchangerOption sets the chain for Exchanger.
func ChainExchangerOption(chain *Chain) ExchangerOption {
return func(opts *exchangerOptions) {
opts.chain = chain
}
}
// TimeoutExchangerOption sets the timeout for Exchanger.
func TimeoutExchangerOption(timeout time.Duration) ExchangerOption {
return func(opts *exchangerOptions) {
opts.timeout = timeout
@ -581,6 +589,7 @@ type dnsExchanger struct {
options exchangerOptions
}
// NewDNSExchanger creates a DNS over UDP Exchanger
func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger {
var options exchangerOptions
for _, opt := range opts {
@ -605,10 +614,15 @@ func (ex *dnsExchanger) dial(ctx context.Context, network, address string) (conn
return d.DialContext(ctx, network, address)
}
if ex.options.chain.LastNode().Protocol == "ssu" {
return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
}
raddr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return
}
cc, err := getSOCKS5UDPTunnel(ex.options.chain, nil)
conn = &udpTunnelConn{Conn: cc, raddr: raddr}
return
@ -643,6 +657,7 @@ type dnsTCPExchanger struct {
options exchangerOptions
}
// NewDNSTCPExchanger creates a DNS over TCP Exchanger
func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger {
var options exchangerOptions
for _, opt := range opts {
@ -699,6 +714,7 @@ type dotExchanger struct {
options exchangerOptions
}
// NewDoTExchanger creates a DNS over TLS Exchanger
func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger {
var options exchangerOptions
for _, opt := range opts {
@ -768,6 +784,7 @@ type dohExchanger struct {
options exchangerOptions
}
// NewDoHExchanger creates a DNS over HTTPS Exchanger
func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger {
var options exchangerOptions
for _, opt := range opts {

View File

@ -24,10 +24,10 @@ var dnsTests = []struct {
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true},
{NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true},
{NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true},
{NameServer{Addr: "1.1.1.1:12345", Timeout: 1 * time.Second}, "github.com", false},
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp", Timeout: 1 * time.Second}, "github.com", false},
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tls", Timeout: 1 * time.Second}, "github.com", false},
{NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https", Timeout: 1 * time.Second}, "github.com", false},
{NameServer{Addr: "1.1.1.1:12345"}, "github.com", false},
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp"}, "github.com", false},
{NameServer{Addr: "1.1.1.1:12345", Protocol: "tls"}, "github.com", false},
{NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https"}, "github.com", false},
}
func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error {
@ -85,6 +85,7 @@ var resolverCacheTests = []struct {
[]net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}},
}
/*
func TestResolverCache(t *testing.T) {
isEqual := func(a, b []net.IP) bool {
if a == nil && b == nil {
@ -106,8 +107,8 @@ func TestResolverCache(t *testing.T) {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
r := newResolver(tc.ttl)
r.storeCache(tc.name, tc.ips, tc.ttl)
ips := r.loadCache(tc.name, tc.ttl)
r.cache.storeCache(tc.name, tc.ips, tc.ttl)
ips := r.cache.loadCache(tc.name, tc.ttl)
if !isEqual(tc.result, ips) {
t.Error("unexpected cache value:", tc.name, ips, tc.ttl)
@ -115,6 +116,7 @@ func TestResolverCache(t *testing.T) {
})
}
}
*/
var resolverReloadTests = []struct {
r io.Reader
@ -167,7 +169,6 @@ var resolverReloadTests = []struct {
ns: &NameServer{
Protocol: "udp",
Addr: "1.1.1.1",
Timeout: 10 * time.Second,
},
timeout: 10 * time.Second,
stopped: true,
@ -219,9 +220,9 @@ func TestResolverReload(t *testing.T) {
t.Error(err)
}
t.Log(r.String())
if r.TTL != tc.ttl {
if r.TTL() != tc.ttl {
t.Errorf("ttl value should be %v, got %v",
tc.ttl, r.TTL)
tc.ttl, r.TTL())
}
if r.Period() != tc.period {
t.Errorf("period value should be %v, got %v",
@ -233,13 +234,13 @@ func TestResolverReload(t *testing.T) {
}
var ns *NameServer
if len(r.Servers) > 0 {
ns = &r.Servers[0]
if len(r.servers) > 0 {
ns = &r.servers[0]
}
if !compareNameServer(ns, tc.ns) {
t.Errorf("nameserver not equal, should be %v, got %v",
tc.ns, r.Servers)
tc.ns, r.servers)
}
if tc.stopped {
@ -265,6 +266,5 @@ func compareNameServer(n1, n2 *NameServer) bool {
}
return n1.Addr == n2.Addr &&
n1.Hostname == n2.Hostname &&
n1.Protocol == n2.Protocol &&
n1.Timeout == n2.Timeout
n1.Protocol == n2.Protocol
}

View File

@ -102,37 +102,6 @@ type Listener interface {
net.Listener
}
type tcpListener struct {
net.Listener
}
// TCPListener creates a Listener for TCP proxy server.
func TCPListener(addr string) (Listener, error) {
laddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return nil, err
}
return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil
}
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(KeepAliveTime)
return tc, nil
}
func transport(rw1, rw2 io.ReadWriter) error {
errc := make(chan error, 1)
go func() {

390
ss.go
View File

@ -3,7 +3,6 @@ package gost
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
@ -13,6 +12,7 @@ import (
"github.com/ginuerzh/gosocks5"
"github.com/go-log/log"
"github.com/shadowsocks/go-shadowsocks2/core"
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
)
@ -248,14 +248,38 @@ func (h *shadowHandler) getRequest(r io.Reader) (host string, err error) {
}
type shadowUDPConnector struct {
Cipher *url.Userinfo
cipher core.Cipher
}
// ShadowUDPConnector creates a Connector for shadowsocks UDP client.
// It accepts a cipher info for shadowsocks data encryption/decryption.
// The cipher must not be nil.
func ShadowUDPConnector(cipher *url.Userinfo) Connector {
return &shadowUDPConnector{Cipher: cipher}
func ShadowUDPConnector(info *url.Userinfo) Connector {
c := &shadowUDPConnector{}
c.initCipher(info)
return c
}
func (c *shadowUDPConnector) initCipher(info *url.Userinfo) {
var method, password string
if info != nil {
method = info.Username()
password, _ = info.Password()
}
if method == "" || password == "" {
return
}
c.cipher, _ = core.PickCipher(method, nil, password)
if c.cipher == nil {
cp, err := ss.NewCipher(method, password)
if err != nil {
log.Logf("[ssu] %s", err)
return
}
c.cipher = &shadowCipher{cipher: cp}
}
}
func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
@ -272,161 +296,53 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
pc, ok := conn.(net.PacketConn)
if ok {
rawaddr, err := ss.RawAddr(addr)
if err != nil {
return nil, err
}
var method, password string
if c.Cipher != nil {
method = c.Cipher.Username()
password, _ = c.Cipher.Password()
if c.cipher != nil {
pc = c.cipher.PacketConn(pc)
}
cipher, err := ss.NewCipher(method, password)
if err != nil {
return nil, err
}
sc := ss.NewSecurePacketConn(&shadowPacketConn{conn}, cipher, false)
return &shadowUDPConn{
PacketConn: sc,
return &shadowUDPPacketConn{
PacketConn: pc,
raddr: conn.RemoteAddr(),
header: rawaddr,
}, nil
}
type shadowUDPListener struct {
ln net.PacketConn
connChan chan net.Conn
errChan chan error
ttl time.Duration
connMap udpConnMap
config *UDPForwardListenConfig
}
// ShadowUDPListener creates a Listener for shadowsocks UDP relay server.
func ShadowUDPListener(addr string, cipher *url.Userinfo, cfg *UDPForwardListenConfig) (Listener, error) {
laddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
ln, err := net.ListenUDP("udp", laddr)
taddr, err := gosocks5.NewAddr(addr)
if err != nil {
return nil, err
}
var method, password string
if cipher != nil {
method = cipher.Username()
password, _ = cipher.Password()
}
cp, err := ss.NewCipher(method, password)
if err != nil {
ln.Close()
return nil, err
if c.cipher != nil {
conn = c.cipher.StreamConn(conn)
}
if cfg == nil {
cfg = &UDPForwardListenConfig{}
}
backlog := cfg.Backlog
if backlog <= 0 {
backlog = defaultBacklog
}
l := &shadowUDPListener{
ln: ss.NewSecurePacketConn(ln, cp, false),
connChan: make(chan net.Conn, backlog),
errChan: make(chan error, 1),
config: cfg,
}
go l.listenLoop()
return l, nil
return &shadowUDPStreamConn{
Conn: conn,
addr: taddr,
}, nil
}
func (l *shadowUDPListener) listenLoop() {
for {
b := make([]byte, mediumBufferSize)
n, raddr, err := l.ln.ReadFrom(b)
if err != nil {
log.Logf("[ssu] peer -> %s : %s", l.Addr(), err)
l.ln.Close()
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connMap.Get(raddr.String())
if !ok {
conn = newUDPServerConn(l.ln, raddr, l.config.TTL, l.config.QueueSize)
conn.onClose = func() {
l.connMap.Delete(raddr.String())
log.Logf("[ssu] %s closed (%d)", raddr, l.connMap.Size())
}
select {
case l.connChan <- conn:
l.connMap.Set(raddr.String(), conn)
log.Logf("[ssu] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
default:
conn.Close()
log.Logf("[ssu] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan))
}
}
select {
case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination.
if Debug {
log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n)
}
default:
log.Logf("[ssu] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan))
}
}
}
func (l *shadowUDPListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = errors.New("accpet on closed listener")
}
}
return
}
func (l *shadowUDPListener) Addr() net.Addr {
return l.ln.LocalAddr()
}
func (l *shadowUDPListener) Close() error {
err := l.ln.Close()
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
v.Close()
return true
})
return err
}
type shadowUDPdHandler struct {
ttl time.Duration
type shadowUDPHandler struct {
cipher core.Cipher
options *HandlerOptions
}
// ShadowUDPdHandler creates a server Handler for shadowsocks UDP relay server.
func ShadowUDPdHandler(opts ...HandlerOption) Handler {
h := &shadowUDPdHandler{}
// ShadowUDPHandler creates a server Handler for shadowsocks UDP relay server.
func ShadowUDPHandler(opts ...HandlerOption) Handler {
h := &shadowUDPHandler{}
h.Init(opts...)
return h
}
func (h *shadowUDPdHandler) Init(options ...HandlerOption) {
func (h *shadowUDPHandler) Init(options ...HandlerOption) {
if h.options == nil {
h.options = &HandlerOptions{}
}
@ -434,9 +350,33 @@ func (h *shadowUDPdHandler) Init(options ...HandlerOption) {
for _, opt := range options {
opt(h.options)
}
h.initCipher()
}
func (h *shadowUDPdHandler) Handle(conn net.Conn) {
func (h *shadowUDPHandler) initCipher() {
var method, password string
users := h.options.Users
if len(users) > 0 {
method = users[0].Username()
password, _ = users[0].Password()
}
if method == "" || password == "" {
return
}
h.cipher, _ = core.PickCipher(method, nil, password)
if h.cipher == nil {
cp, err := ss.NewCipher(method, password)
if err != nil {
log.Logf("[ssu] %s", err)
return
}
h.cipher = &shadowCipher{cipher: cp}
}
}
func (h *shadowUDPHandler) Handle(conn net.Conn) {
defer conn.Close()
var err error
@ -458,37 +398,120 @@ func (h *shadowUDPdHandler) Handle(conn net.Conn) {
}
defer cc.Close()
pc, ok := conn.(net.PacketConn)
if ok {
if h.cipher != nil {
pc = h.cipher.PacketConn(pc)
}
h.transportPacket(pc, cc)
return
}
if h.cipher != nil {
conn = h.cipher.StreamConn(conn)
}
log.Logf("[ssu] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr())
h.transportUDP(conn, cc)
log.Logf("[ssu] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr())
}
func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
func (h *shadowUDPHandler) transportPacket(conn, cc net.PacketConn) (err error) {
errc := make(chan error, 1)
var clientAddr net.Addr
go func() {
for {
err := func() error {
b := mPool.Get().([]byte)
defer mPool.Put(b)
n, addr, err := conn.ReadFrom(b)
if err != nil {
return err
}
if clientAddr == nil {
clientAddr = addr
}
r := bytes.NewBuffer(b[:n])
saddr, err := readSocksAddr(r)
if err != nil {
return err
}
taddr, err := net.ResolveUDPAddr("udp", saddr.String())
if err != nil {
return err
}
if Debug {
log.Logf("[ssu] %s >>> %s length: %d", addr, taddr, r.Len())
}
_, err = cc.WriteTo(r.Bytes(), taddr)
return err
}()
if err != nil {
errc <- err
return
}
}
}()
go func() {
for {
err := func() error {
b := mPool.Get().([]byte)
defer mPool.Put(b)
n, addr, err := cc.ReadFrom(b)
if err != nil {
return err
}
if clientAddr == nil {
return nil
}
if Debug {
log.Logf("[ssu] %s <<< %s length: %d", clientAddr, addr, n)
}
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n])
buf := bytes.Buffer{}
if err = dgram.Write(&buf); err != nil {
return err
}
_, err = conn.WriteTo(buf.Bytes()[3:], clientAddr)
return err
}()
if err != nil {
errc <- err
return
}
}
}()
select {
case err = <-errc:
}
return
}
func (h *shadowUDPHandler) transportUDP(conn net.Conn, cc net.PacketConn) error {
errc := make(chan error, 1)
go func() {
for {
er := func() (err error) {
b := lPool.Get().([]byte)
defer lPool.Put(b)
b[0] = 0
b[1] = 0
b[2] = 0
// add rsv and frag fields to make it the standard SOCKS5 UDP datagram
n, err := sc.Read(b[3:])
dgram, err := gosocks5.ReadUDPDatagram(conn)
if err != nil {
// log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err)
return
}
dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n+3]))
if err != nil {
log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err)
// log.Logf("[ssu] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
if Debug {
log.Logf("[ssu] %s >>> %s length: %d", sc.RemoteAddr(), dgram.Header.Addr.String(), len(dgram.Data))
log.Logf("[ssu] %s >>> %s length: %d",
conn.RemoteAddr(), dgram.Header.Addr.String(), len(dgram.Data))
}
addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String())
if err != nil {
@ -512,28 +535,25 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
go func() {
for {
er := func() (err error) {
b := lPool.Get().([]byte)
defer lPool.Put(b)
b := mPool.Get().([]byte)
defer mPool.Put(b)
n, addr, err := cc.ReadFrom(b)
if err != nil {
return
}
if Debug {
log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n)
log.Logf("[ssu] %s <<< %s length: %d", conn.RemoteAddr(), addr, n)
}
if h.options.Bypass.Contains(addr.String()) {
log.Log("[ssu] bypass", addr)
return // bypass
}
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n])
dgram := gosocks5.NewUDPDatagram(
gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n])
buf := bytes.Buffer{}
dgram.Write(&buf)
if buf.Len() < 10 {
log.Logf("[ssu] %s <- %s : invalid udp datagram", sc.RemoteAddr(), addr)
return // ignore invalid datagram
}
_, err = sc.Write(buf.Bytes()[3:])
_, err = conn.Write(buf.Bytes())
return
}()
@ -563,13 +583,13 @@ func (c *shadowConn) Write(b []byte) (n int, err error) {
return
}
type shadowUDPConn struct {
type shadowUDPPacketConn struct {
net.PacketConn
raddr net.Addr
header []byte
}
func (c *shadowUDPConn) Write(b []byte) (n int, err error) {
func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
buf := bytes.Buffer{}
if _, err = buf.Write(c.header); err != nil {
@ -582,7 +602,7 @@ func (c *shadowUDPConn) Write(b []byte) (n int, err error) {
return
}
func (c *shadowUDPConn) Read(b []byte) (n int, err error) {
func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) {
buf := mPool.Get().([]byte)
defer mPool.Put(buf)
@ -603,20 +623,52 @@ func (c *shadowUDPConn) Read(b []byte) (n int, err error) {
return
}
func (c *shadowUDPConn) RemoteAddr() net.Addr {
func (c *shadowUDPPacketConn) RemoteAddr() net.Addr {
return c.raddr
}
type shadowPacketConn struct {
type shadowUDPStreamConn struct {
net.Conn
addr *gosocks5.Addr
}
func (c *shadowPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Conn.Read(b)
addr = c.Conn.RemoteAddr()
func (c *shadowUDPStreamConn) Read(b []byte) (n int, err error) {
dgram, err := gosocks5.ReadUDPDatagram(c.Conn)
if err != nil {
return
}
n = copy(b, dgram.Data)
return
}
func (c *shadowPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Conn.Write(b)
func (c *shadowUDPStreamConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.Conn.RemoteAddr()
return
}
func (c *shadowUDPStreamConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, c.addr), b)
buf := bytes.Buffer{}
dgram.Write(&buf)
_, err = c.Conn.Write(buf.Bytes())
return
}
func (c *shadowUDPStreamConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Write(b)
}
type shadowCipher struct {
cipher *ss.Cipher
}
func (c *shadowCipher) StreamConn(conn net.Conn) net.Conn {
return ss.NewConn(conn, c.cipher.Copy())
}
func (c *shadowCipher) PacketConn(conn net.PacketConn) net.PacketConn {
return ss.NewSecurePacketConn(conn, c.cipher.Copy(), false)
}

4
ss2.go
View File

@ -116,7 +116,7 @@ func (h *shadow2Handler) Handle(conn net.Conn) {
conn = cipher.StreamConn(conn)
conn.SetReadDeadline(time.Now().Add(ReadTimeout))
addr, err := readAddr(conn)
addr, err := readSocksAddr(conn)
if err != nil {
log.Logf("[ss2] %s -> %s : %s",
conn.RemoteAddr(), conn.LocalAddr(), err)
@ -191,7 +191,7 @@ func (h *shadow2Handler) Handle(conn net.Conn) {
log.Logf("[ss2] %s >-< %s", conn.RemoteAddr(), host)
}
func readAddr(r io.Reader) (*gosocks5.Addr, error) {
func readSocksAddr(r io.Reader) (*gosocks5.Addr, error) {
addr := &gosocks5.Addr{}
b := sPool.Get().([]byte)
defer sPool.Put(b)

View File

@ -302,7 +302,7 @@ func BenchmarkSSProxyParallel(b *testing.B) {
func shadowUDPRoundtrip(t *testing.T, host string, data []byte,
clientInfo *url.Userinfo, serverInfo *url.Userinfo) error {
ln, err := ShadowUDPListener("localhost:0", serverInfo, nil)
ln, err := UDPListener("localhost:0", nil)
if err != nil {
return err
}
@ -313,7 +313,9 @@ func shadowUDPRoundtrip(t *testing.T, host string, data []byte,
}
server := &Server{
Handler: ShadowUDPdHandler(),
Handler: ShadowUDPHandler(
UsersHandlerOption(serverInfo),
),
Listener: ln,
}
@ -361,7 +363,7 @@ func BenchmarkShadowUDP(b *testing.B) {
sendData := make([]byte, 128)
rand.Read(sendData)
ln, err := ShadowUDPListener("localhost:0", url.UserPassword("chacha20-ietf", "123456"), nil)
ln, err := UDPListener("localhost:0", nil)
if err != nil {
b.Error(err)
}
@ -372,7 +374,9 @@ func BenchmarkShadowUDP(b *testing.B) {
}
server := &Server{
Handler: ShadowUDPdHandler(),
Handler: ShadowUDPHandler(
UsersHandlerOption(url.UserPassword("chacha20-ietf", "123456")),
),
Listener: ln,
}

66
tcp.go Normal file
View File

@ -0,0 +1,66 @@
package gost
import "net"
// tcpTransporter is a raw TCP transporter.
type tcpTransporter struct{}
// TCPTransporter creates a raw TCP client.
func TCPTransporter() Transporter {
return &tcpTransporter{}
}
func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
opts := &DialOptions{}
for _, option := range options {
option(opts)
}
timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}
if opts.Chain == nil {
return net.DialTimeout("tcp", addr, timeout)
}
return opts.Chain.Dial(addr)
}
func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
return conn, nil
}
func (tr *tcpTransporter) Multiplex() bool {
return false
}
type tcpListener struct {
net.Listener
}
// TCPListener creates a Listener for TCP proxy server.
func TCPListener(addr string) (Listener, error) {
laddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
ln, err := net.ListenTCP("tcp", laddr)
if err != nil {
return nil, err
}
return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil
}
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(KeepAliveTime)
return tc, nil
}

View File

@ -44,6 +44,7 @@ type IPRoute struct {
Gateway net.IP
}
// TunConfig is the config for TUN device.
type TunConfig struct {
Name string
Addr string
@ -426,6 +427,7 @@ func etherType(et waterutil.Ethertype) string {
return fmt.Sprintf("unknown(%v)", et)
}
// TapConfig is the config for TAP device.
type TapConfig struct {
Name string
Addr string
@ -789,6 +791,7 @@ func (c *tunTapConn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
// IsIPv6Multicast reports whether the address addr is an IPv6 multicast address.
func IsIPv6Multicast(addr net.HardwareAddr) bool {
return addr[0] == 0x33 && addr[1] == 0x33
}

357
udp.go Normal file
View File

@ -0,0 +1,357 @@
package gost
import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/go-log/log"
)
// udpTransporter is a raw UDP transporter.
type udpTransporter struct{}
// UDPTransporter creates a Transporter for UDP client.
func UDPTransporter() Transporter {
return &udpTransporter{}
}
func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
raddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
return &udpClientConn{
UDPConn: conn,
raddr: raddr,
}, nil
}
func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
return conn, nil
}
func (tr *udpTransporter) Multiplex() bool {
return false
}
// UDPListenConfig is the config for UDP Listener.
type UDPListenConfig struct {
TTL time.Duration // timeout per connection
Backlog int // connection backlog
QueueSize int // recv queue size per connection
}
type udpListener struct {
ln net.PacketConn
connChan chan net.Conn
errChan chan error
connMap udpConnMap
config *UDPListenConfig
}
// UDPListener creates a Listener for UDP server.
func UDPListener(addr string, cfg *UDPListenConfig) (Listener, error) {
laddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
ln, err := net.ListenUDP("udp", laddr)
if err != nil {
return nil, err
}
if cfg == nil {
cfg = &UDPListenConfig{}
}
backlog := cfg.Backlog
if backlog <= 0 {
backlog = defaultBacklog
}
l := &udpListener{
ln: ln,
connChan: make(chan net.Conn, backlog),
errChan: make(chan error, 1),
config: cfg,
}
go l.listenLoop()
return l, nil
}
func (l *udpListener) listenLoop() {
for {
b := make([]byte, mediumBufferSize)
n, raddr, err := l.ln.ReadFrom(b)
if err != nil {
log.Logf("[udp] peer -> %s : %s", l.Addr(), err)
l.Close()
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connMap.Get(raddr.String())
if !ok {
conn = newUDPServerConn(l.ln, raddr, &udpServerConnConfig{
ttl: l.config.TTL,
qsize: l.config.QueueSize,
onClose: func() {
l.connMap.Delete(raddr.String())
log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size())
},
})
select {
case l.connChan <- conn:
l.connMap.Set(raddr.String(), conn)
log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
default:
conn.Close()
log.Logf("[udp] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan))
}
}
select {
case conn.rChan <- b[:n]:
if Debug {
log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
}
default:
log.Logf("[udp] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan))
}
}
}
func (l *udpListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = errors.New("accpet on closed listener")
}
}
return
}
func (l *udpListener) Addr() net.Addr {
return l.ln.LocalAddr()
}
func (l *udpListener) Close() error {
err := l.ln.Close()
l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
v.Close()
return true
})
return err
}
type udpConnMap struct {
m sync.Map
size int64
}
func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) {
v, ok := m.m.Load(key)
if ok {
conn, ok = v.(*udpServerConn)
}
return
}
func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) {
m.m.Store(key, conn)
atomic.AddInt64(&m.size, 1)
}
func (m *udpConnMap) Delete(key interface{}) {
m.m.Delete(key)
atomic.AddInt64(&m.size, -1)
}
func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) {
m.m.Range(func(k, v interface{}) bool {
return f(k, v.(*udpServerConn))
})
}
func (m *udpConnMap) Size() int64 {
return atomic.LoadInt64(&m.size)
}
// udpServerConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type udpServerConn struct {
conn net.PacketConn
raddr net.Addr
rChan chan []byte
closed chan struct{}
closeMutex sync.Mutex
nopChan chan int
config *udpServerConnConfig
}
type udpServerConnConfig struct {
ttl time.Duration
qsize int
onClose func()
}
func newUDPServerConn(conn net.PacketConn, raddr net.Addr, cfg *udpServerConnConfig) *udpServerConn {
if conn == nil || raddr == nil {
return nil
}
if cfg == nil {
cfg = &udpServerConnConfig{}
}
qsize := cfg.qsize
if qsize <= 0 {
qsize = defaultQueueSize
}
c := &udpServerConn{
conn: conn,
raddr: raddr,
rChan: make(chan []byte, qsize),
closed: make(chan struct{}),
nopChan: make(chan int),
config: cfg,
}
go c.ttlWait()
return c
}
func (c *udpServerConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *udpServerConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rChan:
n = copy(b, bb)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
select {
case c.nopChan <- n:
default:
}
addr = c.raddr
return
}
func (c *udpServerConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.raddr)
}
func (c *udpServerConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
n, err = c.conn.WriteTo(b, addr)
if n > 0 {
if Debug {
log.Logf("[udp] %s <<< %s : length %d", addr, c.LocalAddr(), n)
}
select {
case c.nopChan <- n:
default:
}
}
return
}
func (c *udpServerConn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.config.onClose != nil {
c.config.onClose()
}
close(c.closed)
}
return nil
}
func (c *udpServerConn) ttlWait() {
ttl := c.config.ttl
if ttl == 0 {
ttl = defaultTTL
}
timer := time.NewTimer(ttl)
defer timer.Stop()
for {
select {
case <-c.nopChan:
if !timer.Stop() {
<-timer.C
}
timer.Reset(ttl)
case <-timer.C:
c.Close()
return
case <-c.closed:
return
}
}
}
func (c *udpServerConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *udpServerConn) RemoteAddr() net.Addr {
return c.raddr
}
func (c *udpServerConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *udpServerConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *udpServerConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
type udpClientConn struct {
*net.UDPConn
raddr net.Addr
}
func (c *udpClientConn) Write(b []byte) (int, error) {
if c.raddr != nil {
return c.WriteTo(b, c.raddr)
}
return c.UDPConn.Write(b)
}
func (c *udpClientConn) RemoteAddr() net.Addr {
if c.raddr != nil {
return c.raddr
}
return c.UDPConn.RemoteAddr()
}