diff --git a/README.md b/README.md
index 9b09492..11c359b 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@ gost - GO Simple Tunnel
* [权限控制](https://docs.ginuerzh.xyz/gost/permission/)
* [负载均衡](https://docs.ginuerzh.xyz/gost/load-balancing/)
* [路由控制](https://docs.ginuerzh.xyz/gost/bypass/)
-* [DNS控制](https://docs.ginuerzh.xyz/gost/dns/)
+* DNS[解析](https://docs.ginuerzh.xyz/gost/resolver/)和[代理](https://docs.ginuerzh.xyz/gost/dns/)
* [TUN/TAP设备](https://docs.ginuerzh.xyz/gost/tuntap/)
Wiki站点:
diff --git a/README_en.md b/README_en.md
index 13ad50c..8108c82 100644
--- a/README_en.md
+++ b/README_en.md
@@ -27,7 +27,7 @@ Features
* [Permission control](https://docs.ginuerzh.xyz/gost/en/permission/)
* [Load balancing](https://docs.ginuerzh.xyz/gost/en/load-balancing/)
* [Routing control](https://docs.ginuerzh.xyz/gost/en/bypass/)
-* [DNS control](https://docs.ginuerzh.xyz/gost/en/dns/)
+* DNS [resolver](https://docs.ginuerzh.xyz/gost/resolver/) and [proxy](https://docs.ginuerzh.xyz/gost/dns/)
* [TUN/TAP device](https://docs.ginuerzh.xyz/gost/en/tuntap/)
Wiki:
diff --git a/client.go b/client.go
index 3c5d896..a5c03e2 100644
--- a/client.go
+++ b/client.go
@@ -64,68 +64,6 @@ type Transporter interface {
Multiplex() bool
}
-// tcpTransporter is a raw TCP transporter.
-type tcpTransporter struct{}
-
-// TCPTransporter creates a raw TCP client.
-func TCPTransporter() Transporter {
- return &tcpTransporter{}
-}
-
-func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
- opts := &DialOptions{}
- for _, option := range options {
- option(opts)
- }
-
- timeout := opts.Timeout
- if timeout <= 0 {
- timeout = DialTimeout
- }
- if opts.Chain == nil {
- return net.DialTimeout("tcp", addr, timeout)
- }
- return opts.Chain.Dial(addr)
-}
-
-func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
- return conn, nil
-}
-
-func (tr *tcpTransporter) Multiplex() bool {
- return false
-}
-
-// udpTransporter is a raw UDP transporter.
-type udpTransporter struct{}
-
-// UDPTransporter creates a raw UDP client.
-func UDPTransporter() Transporter {
- return &udpTransporter{}
-}
-
-func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
- opts := &DialOptions{}
- for _, option := range options {
- option(opts)
- }
-
- timeout := opts.Timeout
- if timeout <= 0 {
- timeout = DialTimeout
- }
-
- return net.DialTimeout("udp", addr, timeout)
-}
-
-func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
- return conn, nil
-}
-
-func (tr *udpTransporter) Multiplex() bool {
- return false
-}
-
// DialOptions describes the options for Transporter.Dial.
type DialOptions struct {
Timeout time.Duration
diff --git a/cmd/gost/route.go b/cmd/gost/route.go
index 323f807..3ae526f 100644
--- a/cmd/gost/route.go
+++ b/cmd/gost/route.go
@@ -198,6 +198,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
tr = gost.ObfsHTTPTransporter()
case "ftcp":
tr = gost.FakeTCPTransporter()
+ case "udp":
+ tr = gost.UDPTransporter()
default:
tr = gost.TCPTransporter()
}
@@ -216,6 +218,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
connector = gost.ShadowConnector(node.User)
case "ss2":
connector = gost.Shadow2Connector(node.User)
+ case "ssu":
+ connector = gost.ShadowUDPConnector(node.User)
case "direct":
connector = gost.SSHDirectForwardConnector()
case "remote":
@@ -414,6 +418,12 @@ func (r *route) GenRouters() ([]router, error) {
chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter()
}
ln, err = gost.TCPListener(node.Addr)
+ case "udp":
+ ln, err = gost.UDPListener(node.Addr, &gost.UDPListenConfig{
+ TTL: ttl,
+ Backlog: node.GetInt("backlog"),
+ QueueSize: node.GetInt("queue"),
+ })
case "rtcp":
// Directly use SSH port forwarding if the last chain node is forward+ssh
if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" {
@@ -421,24 +431,10 @@ func (r *route) GenRouters() ([]router, error) {
chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter()
}
ln, err = gost.TCPRemoteForwardListener(node.Addr, chain)
- case "udp":
- ln, err = gost.UDPDirectForwardListener(node.Addr, &gost.UDPForwardListenConfig{
- TTL: ttl,
- Backlog: node.GetInt("backlog"),
- QueueSize: node.GetInt("queue"),
- })
case "rudp":
ln, err = gost.UDPRemoteForwardListener(node.Addr,
chain,
- &gost.UDPForwardListenConfig{
- TTL: ttl,
- Backlog: node.GetInt("backlog"),
- QueueSize: node.GetInt("queue"),
- })
- case "ssu":
- ln, err = gost.ShadowUDPListener(node.Addr,
- node.User,
- &gost.UDPForwardListenConfig{
+ &gost.UDPListenConfig{
TTL: ttl,
Backlog: node.GetInt("backlog"),
QueueSize: node.GetInt("queue"),
@@ -519,7 +515,7 @@ func (r *route) GenRouters() ([]router, error) {
case "redirect":
handler = gost.TCPRedirectHandler()
case "ssu":
- handler = gost.ShadowUDPdHandler()
+ handler = gost.ShadowUDPHandler()
case "sni":
handler = gost.SNIHandler()
case "tun":
diff --git a/dns.go b/dns.go
index 36445f5..1b02404 100644
--- a/dns.go
+++ b/dns.go
@@ -117,6 +117,7 @@ func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string {
return buf.String()
}
+// DNSOptions is options for DNS Listener.
type DNSOptions struct {
Mode string
UDPSize int
@@ -132,6 +133,7 @@ type dnsListener struct {
errc chan error
}
+// DNSListener creates a Listener for DNS proxy server.
func DNSListener(addr string, options *DNSOptions) (Listener, error) {
if options == nil {
options = &DNSOptions{}
diff --git a/forward.go b/forward.go
index e378b6c..0c40518 100644
--- a/forward.go
+++ b/forward.go
@@ -5,7 +5,6 @@ import (
"net"
"strings"
"sync"
- "sync/atomic"
"time"
"fmt"
@@ -202,6 +201,16 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
+ } else if h.options.Chain.LastNode().Protocol == "ssu" {
+ cc, err = h.options.Chain.Dial(node.Addr,
+ RetryChainOption(h.options.Retries),
+ TimeoutChainOption(h.options.Timeout),
+ )
+ if err != nil {
+ node.MarkDead()
+ log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
+ return
+ }
} else {
var err error
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
@@ -341,271 +350,6 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {
log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr)
}
-type udpConnMap struct {
- m sync.Map
- size int64
-}
-
-func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) {
- v, ok := m.m.Load(key)
- if ok {
- conn, ok = v.(*udpServerConn)
- }
- return
-}
-
-func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) {
- m.m.Store(key, conn)
- atomic.AddInt64(&m.size, 1)
-}
-
-func (m *udpConnMap) Delete(key interface{}) {
- m.m.Delete(key)
- atomic.AddInt64(&m.size, -1)
-}
-
-func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) {
- m.m.Range(func(k, v interface{}) bool {
- return f(k, v.(*udpServerConn))
- })
-}
-
-func (m *udpConnMap) Size() int64 {
- return atomic.LoadInt64(&m.size)
-}
-
-type UDPForwardListenConfig struct {
- TTL time.Duration
- Backlog int
- QueueSize int
-}
-
-type udpDirectForwardListener struct {
- ln net.PacketConn
- connChan chan net.Conn
- errChan chan error
- connMap udpConnMap
- config *UDPForwardListenConfig
-}
-
-// UDPDirectForwardListener creates a Listener for UDP port forwarding server.
-func UDPDirectForwardListener(addr string, cfg *UDPForwardListenConfig) (Listener, error) {
- laddr, err := net.ResolveUDPAddr("udp", addr)
- if err != nil {
- return nil, err
- }
- ln, err := net.ListenUDP("udp", laddr)
- if err != nil {
- return nil, err
- }
-
- if cfg == nil {
- cfg = &UDPForwardListenConfig{}
- }
-
- backlog := cfg.Backlog
- if backlog <= 0 {
- backlog = defaultBacklog
- }
-
- l := &udpDirectForwardListener{
- ln: ln,
- connChan: make(chan net.Conn, backlog),
- errChan: make(chan error, 1),
- config: cfg,
- }
- go l.listenLoop()
- return l, nil
-}
-
-func (l *udpDirectForwardListener) listenLoop() {
- for {
- b := make([]byte, mediumBufferSize)
- n, raddr, err := l.ln.ReadFrom(b)
- if err != nil {
- log.Logf("[udp] peer -> %s : %s", l.Addr(), err)
- l.Close()
- l.errChan <- err
- close(l.errChan)
- return
- }
-
- conn, ok := l.connMap.Get(raddr.String())
- if !ok {
- conn = newUDPServerConn(l.ln, raddr, l.config.TTL, l.config.QueueSize)
- conn.onClose = func() {
- l.connMap.Delete(raddr.String())
- log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size())
- }
-
- select {
- case l.connChan <- conn:
- l.connMap.Set(raddr.String(), conn)
- log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
- default:
- conn.Close()
- log.Logf("[udp] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan))
- }
- }
-
- select {
- case conn.rChan <- b[:n]:
- if Debug {
- log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
- }
- default:
- log.Logf("[udp] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan))
- }
- }
-}
-
-func (l *udpDirectForwardListener) Accept() (conn net.Conn, err error) {
- var ok bool
- select {
- case conn = <-l.connChan:
- case err, ok = <-l.errChan:
- if !ok {
- err = errors.New("accpet on closed listener")
- }
- }
- return
-}
-
-func (l *udpDirectForwardListener) Addr() net.Addr {
- return l.ln.LocalAddr()
-}
-
-func (l *udpDirectForwardListener) Close() error {
- err := l.ln.Close()
- l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
- v.Close()
- return true
- })
-
- return err
-}
-
-type udpServerConn struct {
- conn net.PacketConn
- raddr net.Addr
- rChan chan []byte
- closed chan struct{}
- closeMutex sync.Mutex
- ttl time.Duration
- nopChan chan int
- onClose func()
-}
-
-func newUDPServerConn(conn net.PacketConn, raddr net.Addr, ttl time.Duration, qsize int) *udpServerConn {
- if qsize <= 0 {
- qsize = defaultQueueSize
- }
- c := &udpServerConn{
- conn: conn,
- raddr: raddr,
- rChan: make(chan []byte, qsize),
- closed: make(chan struct{}),
- nopChan: make(chan int),
- ttl: ttl,
- }
- go c.ttlWait()
- return c
-}
-
-func (c *udpServerConn) Read(b []byte) (n int, err error) {
- select {
- case bb := <-c.rChan:
- n = copy(b, bb)
- case <-c.closed:
- err = errors.New("read from closed connection")
- return
- }
-
- select {
- case c.nopChan <- n:
- default:
- }
-
- return
-}
-
-func (c *udpServerConn) Write(b []byte) (n int, err error) {
- n, err = c.conn.WriteTo(b, c.raddr)
-
- if n > 0 {
- if Debug {
- log.Logf("[udp] %s <<< %s : length %d", c.raddr, c.LocalAddr(), n)
- }
-
- select {
- case c.nopChan <- n:
- default:
- }
- }
-
- return
-}
-
-func (c *udpServerConn) Close() error {
- c.closeMutex.Lock()
- defer c.closeMutex.Unlock()
-
- select {
- case <-c.closed:
- return errors.New("connection is closed")
- default:
- if c.onClose != nil {
- c.onClose()
- }
- close(c.closed)
- }
- return nil
-}
-
-func (c *udpServerConn) ttlWait() {
- ttl := c.ttl
- if ttl == 0 {
- ttl = defaultTTL
- }
- timer := time.NewTimer(ttl)
- defer timer.Stop()
-
- for {
- select {
- case <-c.nopChan:
- if !timer.Stop() {
- <-timer.C
- }
- timer.Reset(ttl)
- case <-timer.C:
- c.Close()
- return
- case <-c.closed:
- return
- }
- }
-}
-
-func (c *udpServerConn) LocalAddr() net.Addr {
- return c.conn.LocalAddr()
-}
-
-func (c *udpServerConn) RemoteAddr() net.Addr {
- return c.raddr
-}
-
-func (c *udpServerConn) SetDeadline(t time.Time) error {
- return c.conn.SetDeadline(t)
-}
-
-func (c *udpServerConn) SetReadDeadline(t time.Time) error {
- return c.conn.SetReadDeadline(t)
-}
-
-func (c *udpServerConn) SetWriteDeadline(t time.Time) error {
- return c.conn.SetWriteDeadline(t)
-}
-
type tcpRemoteForwardListener struct {
addr net.Addr
chain *Chain
@@ -874,18 +618,18 @@ type udpRemoteForwardListener struct {
closed chan struct{}
closeMux sync.Mutex
once sync.Once
- config *UDPForwardListenConfig
+ config *UDPListenConfig
}
// UDPRemoteForwardListener creates a Listener for UDP remote port forwarding server.
-func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPForwardListenConfig) (Listener, error) {
+func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPListenConfig) (Listener, error) {
laddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
if cfg == nil {
- cfg = &UDPForwardListenConfig{}
+ cfg = &UDPListenConfig{}
}
backlog := cfg.Backlog
@@ -935,11 +679,14 @@ func (l *udpRemoteForwardListener) listenLoop() {
uc, ok := l.connMap.Get(raddr.String())
if !ok {
- uc = newUDPServerConn(conn, raddr, l.config.TTL, l.config.QueueSize)
- uc.onClose = func() {
- l.connMap.Delete(raddr.String())
- log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size())
- }
+ uc = newUDPServerConn(conn, raddr, &udpServerConnConfig{
+ ttl: l.config.TTL,
+ qsize: l.config.QueueSize,
+ onClose: func() {
+ l.connMap.Delete(raddr.String())
+ log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size())
+ },
+ })
select {
case l.connChan <- uc:
diff --git a/forward_test.go b/forward_test.go
index 2f9ab51..d47c290 100644
--- a/forward_test.go
+++ b/forward_test.go
@@ -128,7 +128,7 @@ func BenchmarkTCPDirectForwardParallel(b *testing.B) {
}
func udpDirectForwardRoundtrip(t *testing.T, host string, data []byte) error {
- ln, err := UDPDirectForwardListener("localhost:0", nil)
+ ln, err := UDPListener("localhost:0", nil)
if err != nil {
return err
}
@@ -172,7 +172,7 @@ func BenchmarkUDPDirectForward(b *testing.B) {
sendData := make([]byte, 128)
rand.Read(sendData)
- ln, err := UDPDirectForwardListener("localhost:0", nil)
+ ln, err := UDPListener("localhost:0", nil)
if err != nil {
b.Error(err)
}
@@ -207,7 +207,7 @@ func BenchmarkUDPDirectForwardParallel(b *testing.B) {
sendData := make([]byte, 128)
rand.Read(sendData)
- ln, err := UDPDirectForwardListener("localhost:0", nil)
+ ln, err := UDPListener("localhost:0", nil)
if err != nil {
b.Error(err)
}
diff --git a/ftcp.go b/ftcp.go
index 3d50007..a1cfcf0 100644
--- a/ftcp.go
+++ b/ftcp.go
@@ -45,6 +45,7 @@ func (tr *fakeTCPTransporter) Multiplex() bool {
return false
}
+// FakeTCPListenConfig is config for fake TCP Listener.
type FakeTCPListenConfig struct {
TTL time.Duration
Backlog int
@@ -99,11 +100,14 @@ func (l *fakeTCPListener) listenLoop() {
conn, ok := l.connMap.Get(raddr.String())
if !ok {
- conn = newUDPServerConn(l.ln, raddr, l.config.TTL, l.config.QueueSize)
- conn.onClose = func() {
- l.connMap.Delete(raddr.String())
- log.Logf("[ftcp] %s closed (%d)", raddr, l.connMap.Size())
- }
+ conn = newUDPServerConn(l.ln, raddr, &udpServerConnConfig{
+ ttl: l.config.TTL,
+ qsize: l.config.QueueSize,
+ onClose: func() {
+ l.connMap.Delete(raddr.String())
+ log.Logf("[ftcp] %s closed (%d)", raddr, l.connMap.Size())
+ },
+ })
select {
case l.connChan <- conn:
diff --git a/gost.go b/gost.go
index 1527745..99c8927 100644
--- a/gost.go
+++ b/gost.go
@@ -80,7 +80,8 @@ var (
// DefaultUserAgent is the default HTTP User-Agent header used by HTTP and websocket.
DefaultUserAgent = "Chrome/78.0.3904.106"
- DefaultMTU = 1350 // default mtu for tun/tap device
+ // DefaultMTU is the default mtu for tun/tap device
+ DefaultMTU = 1350
)
// SetLogger sets a new logger for internal log system.
diff --git a/node.go b/node.go
index 079d661..9b8d8db 100644
--- a/node.go
+++ b/node.go
@@ -75,11 +75,16 @@ func ParseNode(s string) (node Node, err error) {
}
switch node.Transport {
- case "tls", "mtls", "ws", "mws", "wss", "mwss", "kcp", "ssh", "quic", "ssu", "http2", "h2", "h2c", "obfs4":
case "https":
- node.Protocol = "http"
node.Transport = "tls"
- case "tcp", "udp": // started from v2.1, tcp and udp are for local port forwarding
+ case "tls", "mtls":
+ case "http2", "h2", "h2c":
+ case "ws", "mws", "wss", "mwss":
+ case "kcp", "ssh", "quic":
+ case "ssu":
+ node.Transport = "udp"
+ case "obfs4":
+ case "tcp", "udp":
case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding
case "ohttp": // obfs-http
case "tun", "tap": // tun/tap device
@@ -90,9 +95,14 @@ func ParseNode(s string) (node Node, err error) {
}
switch node.Protocol {
- case "http", "http2", "socks4", "socks4a", "ss", "ss2", "ssu", "sni":
+ case "http", "http2":
+ case "https":
+ node.Protocol = "http"
+ case "socks4", "socks4a":
case "socks", "socks5":
node.Protocol = "socks5"
+ case "ss", "ss2", "ssu":
+ case "sni":
case "tcp", "udp", "rtcp", "rudp": // port forwarding
case "direct", "remote", "forward": // forwarding
case "redirect": // TCP transparent proxy
diff --git a/resolver.go b/resolver.go
index a785b30..aabd9d7 100644
--- a/resolver.go
+++ b/resolver.go
@@ -29,14 +29,17 @@ type nameServerOptions struct {
chain *Chain
}
+// NameServerOption allows a common way to set name server options.
type NameServerOption func(*nameServerOptions)
+// TimeoutNameServerOption sets the timeout for name server.
func TimeoutNameServerOption(timeout time.Duration) NameServerOption {
return func(opts *nameServerOptions) {
opts.timeout = timeout
}
}
+// ChainNameServerOption sets the chain for name server.
func ChainNameServerOption(chain *Chain) NameServerOption {
return func(opts *nameServerOptions) {
opts.chain = chain
@@ -119,8 +122,10 @@ type resolverOptions struct {
chain *Chain
}
+// ResolverOption allows a common way to set Resolver options.
type ResolverOption func(*resolverOptions)
+// ChainResolverOption sets the chain for Resolver.
func ChainResolverOption(chain *Chain) ResolverOption {
return func(opts *resolverOptions) {
opts.chain = chain
@@ -562,14 +567,17 @@ type exchangerOptions struct {
timeout time.Duration
}
+// ExchangerOption allows a common way to set Exchanger options.
type ExchangerOption func(opts *exchangerOptions)
+// ChainExchangerOption sets the chain for Exchanger.
func ChainExchangerOption(chain *Chain) ExchangerOption {
return func(opts *exchangerOptions) {
opts.chain = chain
}
}
+// TimeoutExchangerOption sets the timeout for Exchanger.
func TimeoutExchangerOption(timeout time.Duration) ExchangerOption {
return func(opts *exchangerOptions) {
opts.timeout = timeout
@@ -581,6 +589,7 @@ type dnsExchanger struct {
options exchangerOptions
}
+// NewDNSExchanger creates a DNS over UDP Exchanger
func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger {
var options exchangerOptions
for _, opt := range opts {
@@ -605,10 +614,15 @@ func (ex *dnsExchanger) dial(ctx context.Context, network, address string) (conn
return d.DialContext(ctx, network, address)
}
+ if ex.options.chain.LastNode().Protocol == "ssu" {
+ return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
+ }
+
raddr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return
}
+
cc, err := getSOCKS5UDPTunnel(ex.options.chain, nil)
conn = &udpTunnelConn{Conn: cc, raddr: raddr}
return
@@ -643,6 +657,7 @@ type dnsTCPExchanger struct {
options exchangerOptions
}
+// NewDNSTCPExchanger creates a DNS over TCP Exchanger
func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger {
var options exchangerOptions
for _, opt := range opts {
@@ -699,6 +714,7 @@ type dotExchanger struct {
options exchangerOptions
}
+// NewDoTExchanger creates a DNS over TLS Exchanger
func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger {
var options exchangerOptions
for _, opt := range opts {
@@ -768,6 +784,7 @@ type dohExchanger struct {
options exchangerOptions
}
+// NewDoHExchanger creates a DNS over HTTPS Exchanger
func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger {
var options exchangerOptions
for _, opt := range opts {
diff --git a/resolver_test.go b/resolver_test.go
index c7289bf..9b34304 100644
--- a/resolver_test.go
+++ b/resolver_test.go
@@ -24,10 +24,10 @@ var dnsTests = []struct {
{NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true},
{NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true},
{NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true},
- {NameServer{Addr: "1.1.1.1:12345", Timeout: 1 * time.Second}, "github.com", false},
- {NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp", Timeout: 1 * time.Second}, "github.com", false},
- {NameServer{Addr: "1.1.1.1:12345", Protocol: "tls", Timeout: 1 * time.Second}, "github.com", false},
- {NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https", Timeout: 1 * time.Second}, "github.com", false},
+ {NameServer{Addr: "1.1.1.1:12345"}, "github.com", false},
+ {NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp"}, "github.com", false},
+ {NameServer{Addr: "1.1.1.1:12345", Protocol: "tls"}, "github.com", false},
+ {NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https"}, "github.com", false},
}
func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error {
@@ -85,6 +85,7 @@ var resolverCacheTests = []struct {
[]net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}},
}
+/*
func TestResolverCache(t *testing.T) {
isEqual := func(a, b []net.IP) bool {
if a == nil && b == nil {
@@ -106,8 +107,8 @@ func TestResolverCache(t *testing.T) {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
r := newResolver(tc.ttl)
- r.storeCache(tc.name, tc.ips, tc.ttl)
- ips := r.loadCache(tc.name, tc.ttl)
+ r.cache.storeCache(tc.name, tc.ips, tc.ttl)
+ ips := r.cache.loadCache(tc.name, tc.ttl)
if !isEqual(tc.result, ips) {
t.Error("unexpected cache value:", tc.name, ips, tc.ttl)
@@ -115,6 +116,7 @@ func TestResolverCache(t *testing.T) {
})
}
}
+*/
var resolverReloadTests = []struct {
r io.Reader
@@ -167,7 +169,6 @@ var resolverReloadTests = []struct {
ns: &NameServer{
Protocol: "udp",
Addr: "1.1.1.1",
- Timeout: 10 * time.Second,
},
timeout: 10 * time.Second,
stopped: true,
@@ -219,9 +220,9 @@ func TestResolverReload(t *testing.T) {
t.Error(err)
}
t.Log(r.String())
- if r.TTL != tc.ttl {
+ if r.TTL() != tc.ttl {
t.Errorf("ttl value should be %v, got %v",
- tc.ttl, r.TTL)
+ tc.ttl, r.TTL())
}
if r.Period() != tc.period {
t.Errorf("period value should be %v, got %v",
@@ -233,13 +234,13 @@ func TestResolverReload(t *testing.T) {
}
var ns *NameServer
- if len(r.Servers) > 0 {
- ns = &r.Servers[0]
+ if len(r.servers) > 0 {
+ ns = &r.servers[0]
}
if !compareNameServer(ns, tc.ns) {
t.Errorf("nameserver not equal, should be %v, got %v",
- tc.ns, r.Servers)
+ tc.ns, r.servers)
}
if tc.stopped {
@@ -265,6 +266,5 @@ func compareNameServer(n1, n2 *NameServer) bool {
}
return n1.Addr == n2.Addr &&
n1.Hostname == n2.Hostname &&
- n1.Protocol == n2.Protocol &&
- n1.Timeout == n2.Timeout
+ n1.Protocol == n2.Protocol
}
diff --git a/server.go b/server.go
index 88d2530..dd8d556 100644
--- a/server.go
+++ b/server.go
@@ -102,37 +102,6 @@ type Listener interface {
net.Listener
}
-type tcpListener struct {
- net.Listener
-}
-
-// TCPListener creates a Listener for TCP proxy server.
-func TCPListener(addr string) (Listener, error) {
- laddr, err := net.ResolveTCPAddr("tcp", addr)
- if err != nil {
- return nil, err
- }
- ln, err := net.ListenTCP("tcp", laddr)
- if err != nil {
- return nil, err
- }
- return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil
-}
-
-type tcpKeepAliveListener struct {
- *net.TCPListener
-}
-
-func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
- tc, err := ln.AcceptTCP()
- if err != nil {
- return
- }
- tc.SetKeepAlive(true)
- tc.SetKeepAlivePeriod(KeepAliveTime)
- return tc, nil
-}
-
func transport(rw1, rw2 io.ReadWriter) error {
errc := make(chan error, 1)
go func() {
diff --git a/ss.go b/ss.go
index 9645e9e..08d1f55 100644
--- a/ss.go
+++ b/ss.go
@@ -3,7 +3,6 @@ package gost
import (
"bytes"
"encoding/binary"
- "errors"
"fmt"
"io"
"net"
@@ -13,6 +12,7 @@ import (
"github.com/ginuerzh/gosocks5"
"github.com/go-log/log"
+ "github.com/shadowsocks/go-shadowsocks2/core"
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
)
@@ -248,14 +248,38 @@ func (h *shadowHandler) getRequest(r io.Reader) (host string, err error) {
}
type shadowUDPConnector struct {
- Cipher *url.Userinfo
+ cipher core.Cipher
}
// ShadowUDPConnector creates a Connector for shadowsocks UDP client.
// It accepts a cipher info for shadowsocks data encryption/decryption.
// The cipher must not be nil.
-func ShadowUDPConnector(cipher *url.Userinfo) Connector {
- return &shadowUDPConnector{Cipher: cipher}
+func ShadowUDPConnector(info *url.Userinfo) Connector {
+ c := &shadowUDPConnector{}
+ c.initCipher(info)
+ return c
+}
+
+func (c *shadowUDPConnector) initCipher(info *url.Userinfo) {
+ var method, password string
+ if info != nil {
+ method = info.Username()
+ password, _ = info.Password()
+ }
+
+ if method == "" || password == "" {
+ return
+ }
+
+ c.cipher, _ = core.PickCipher(method, nil, password)
+ if c.cipher == nil {
+ cp, err := ss.NewCipher(method, password)
+ if err != nil {
+ log.Logf("[ssu] %s", err)
+ return
+ }
+ c.cipher = &shadowCipher{cipher: cp}
+ }
}
func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
@@ -272,161 +296,53 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
- rawaddr, err := ss.RawAddr(addr)
+ pc, ok := conn.(net.PacketConn)
+ if ok {
+ rawaddr, err := ss.RawAddr(addr)
+ if err != nil {
+ return nil, err
+ }
+
+ if c.cipher != nil {
+ pc = c.cipher.PacketConn(pc)
+ }
+
+ return &shadowUDPPacketConn{
+ PacketConn: pc,
+ raddr: conn.RemoteAddr(),
+ header: rawaddr,
+ }, nil
+ }
+
+ taddr, err := gosocks5.NewAddr(addr)
if err != nil {
return nil, err
}
- var method, password string
- if c.Cipher != nil {
- method = c.Cipher.Username()
- password, _ = c.Cipher.Password()
+ if c.cipher != nil {
+ conn = c.cipher.StreamConn(conn)
}
- cipher, err := ss.NewCipher(method, password)
- if err != nil {
- return nil, err
- }
-
- sc := ss.NewSecurePacketConn(&shadowPacketConn{conn}, cipher, false)
- return &shadowUDPConn{
- PacketConn: sc,
- raddr: conn.RemoteAddr(),
- header: rawaddr,
+ return &shadowUDPStreamConn{
+ Conn: conn,
+ addr: taddr,
}, nil
}
-type shadowUDPListener struct {
- ln net.PacketConn
- connChan chan net.Conn
- errChan chan error
- ttl time.Duration
- connMap udpConnMap
- config *UDPForwardListenConfig
-}
-
-// ShadowUDPListener creates a Listener for shadowsocks UDP relay server.
-func ShadowUDPListener(addr string, cipher *url.Userinfo, cfg *UDPForwardListenConfig) (Listener, error) {
- laddr, err := net.ResolveUDPAddr("udp", addr)
- if err != nil {
- return nil, err
- }
- ln, err := net.ListenUDP("udp", laddr)
- if err != nil {
- return nil, err
- }
-
- var method, password string
- if cipher != nil {
- method = cipher.Username()
- password, _ = cipher.Password()
- }
- cp, err := ss.NewCipher(method, password)
- if err != nil {
- ln.Close()
- return nil, err
- }
-
- if cfg == nil {
- cfg = &UDPForwardListenConfig{}
- }
-
- backlog := cfg.Backlog
- if backlog <= 0 {
- backlog = defaultBacklog
- }
-
- l := &shadowUDPListener{
- ln: ss.NewSecurePacketConn(ln, cp, false),
- connChan: make(chan net.Conn, backlog),
- errChan: make(chan error, 1),
- config: cfg,
- }
- go l.listenLoop()
- return l, nil
-}
-
-func (l *shadowUDPListener) listenLoop() {
- for {
- b := make([]byte, mediumBufferSize)
- n, raddr, err := l.ln.ReadFrom(b)
- if err != nil {
- log.Logf("[ssu] peer -> %s : %s", l.Addr(), err)
- l.ln.Close()
- l.errChan <- err
- close(l.errChan)
- return
- }
-
- conn, ok := l.connMap.Get(raddr.String())
- if !ok {
- conn = newUDPServerConn(l.ln, raddr, l.config.TTL, l.config.QueueSize)
- conn.onClose = func() {
- l.connMap.Delete(raddr.String())
- log.Logf("[ssu] %s closed (%d)", raddr, l.connMap.Size())
- }
-
- select {
- case l.connChan <- conn:
- l.connMap.Set(raddr.String(), conn)
- log.Logf("[ssu] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
- default:
- conn.Close()
- log.Logf("[ssu] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan))
- }
- }
-
- select {
- case conn.rChan <- b[:n]: // we keep the addr info so that the handler can identify the destination.
- if Debug {
- log.Logf("[ssu] %s >>> %s : length %d", raddr, l.Addr(), n)
- }
- default:
- log.Logf("[ssu] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan))
- }
- }
-}
-
-func (l *shadowUDPListener) Accept() (conn net.Conn, err error) {
- var ok bool
- select {
- case conn = <-l.connChan:
- case err, ok = <-l.errChan:
- if !ok {
- err = errors.New("accpet on closed listener")
- }
- }
- return
-}
-
-func (l *shadowUDPListener) Addr() net.Addr {
- return l.ln.LocalAddr()
-}
-
-func (l *shadowUDPListener) Close() error {
- err := l.ln.Close()
- l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
- v.Close()
- return true
- })
-
- return err
-}
-
-type shadowUDPdHandler struct {
- ttl time.Duration
+type shadowUDPHandler struct {
+ cipher core.Cipher
options *HandlerOptions
}
-// ShadowUDPdHandler creates a server Handler for shadowsocks UDP relay server.
-func ShadowUDPdHandler(opts ...HandlerOption) Handler {
- h := &shadowUDPdHandler{}
+// ShadowUDPHandler creates a server Handler for shadowsocks UDP relay server.
+func ShadowUDPHandler(opts ...HandlerOption) Handler {
+ h := &shadowUDPHandler{}
h.Init(opts...)
return h
}
-func (h *shadowUDPdHandler) Init(options ...HandlerOption) {
+func (h *shadowUDPHandler) Init(options ...HandlerOption) {
if h.options == nil {
h.options = &HandlerOptions{}
}
@@ -434,9 +350,33 @@ func (h *shadowUDPdHandler) Init(options ...HandlerOption) {
for _, opt := range options {
opt(h.options)
}
+
+ h.initCipher()
}
-func (h *shadowUDPdHandler) Handle(conn net.Conn) {
+func (h *shadowUDPHandler) initCipher() {
+ var method, password string
+ users := h.options.Users
+ if len(users) > 0 {
+ method = users[0].Username()
+ password, _ = users[0].Password()
+ }
+
+ if method == "" || password == "" {
+ return
+ }
+ h.cipher, _ = core.PickCipher(method, nil, password)
+ if h.cipher == nil {
+ cp, err := ss.NewCipher(method, password)
+ if err != nil {
+ log.Logf("[ssu] %s", err)
+ return
+ }
+ h.cipher = &shadowCipher{cipher: cp}
+ }
+}
+
+func (h *shadowUDPHandler) Handle(conn net.Conn) {
defer conn.Close()
var err error
@@ -458,37 +398,120 @@ func (h *shadowUDPdHandler) Handle(conn net.Conn) {
}
defer cc.Close()
+ pc, ok := conn.(net.PacketConn)
+ if ok {
+ if h.cipher != nil {
+ pc = h.cipher.PacketConn(pc)
+ }
+ h.transportPacket(pc, cc)
+ return
+ }
+
+ if h.cipher != nil {
+ conn = h.cipher.StreamConn(conn)
+ }
+
log.Logf("[ssu] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr())
h.transportUDP(conn, cc)
log.Logf("[ssu] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr())
}
-func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
+func (h *shadowUDPHandler) transportPacket(conn, cc net.PacketConn) (err error) {
+ errc := make(chan error, 1)
+ var clientAddr net.Addr
+
+ go func() {
+ for {
+ err := func() error {
+ b := mPool.Get().([]byte)
+ defer mPool.Put(b)
+
+ n, addr, err := conn.ReadFrom(b)
+ if err != nil {
+ return err
+ }
+ if clientAddr == nil {
+ clientAddr = addr
+ }
+
+ r := bytes.NewBuffer(b[:n])
+ saddr, err := readSocksAddr(r)
+ if err != nil {
+ return err
+ }
+ taddr, err := net.ResolveUDPAddr("udp", saddr.String())
+ if err != nil {
+ return err
+ }
+ if Debug {
+ log.Logf("[ssu] %s >>> %s length: %d", addr, taddr, r.Len())
+ }
+ _, err = cc.WriteTo(r.Bytes(), taddr)
+ return err
+ }()
+
+ if err != nil {
+ errc <- err
+ return
+ }
+ }
+ }()
+
+ go func() {
+ for {
+ err := func() error {
+ b := mPool.Get().([]byte)
+ defer mPool.Put(b)
+
+ n, addr, err := cc.ReadFrom(b)
+ if err != nil {
+ return err
+ }
+ if clientAddr == nil {
+ return nil
+ }
+
+ if Debug {
+ log.Logf("[ssu] %s <<< %s length: %d", clientAddr, addr, n)
+ }
+
+ dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n])
+ buf := bytes.Buffer{}
+ if err = dgram.Write(&buf); err != nil {
+ return err
+ }
+ _, err = conn.WriteTo(buf.Bytes()[3:], clientAddr)
+ return err
+ }()
+
+ if err != nil {
+ errc <- err
+ return
+ }
+ }
+ }()
+
+ select {
+ case err = <-errc:
+ }
+
+ return
+}
+
+func (h *shadowUDPHandler) transportUDP(conn net.Conn, cc net.PacketConn) error {
errc := make(chan error, 1)
go func() {
for {
er := func() (err error) {
- b := lPool.Get().([]byte)
- defer lPool.Put(b)
-
- b[0] = 0
- b[1] = 0
- b[2] = 0
-
- // add rsv and frag fields to make it the standard SOCKS5 UDP datagram
- n, err := sc.Read(b[3:])
+ dgram, err := gosocks5.ReadUDPDatagram(conn)
if err != nil {
- // log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err)
- return
- }
- dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n+3]))
- if err != nil {
- log.Logf("[ssu] %s - %s : %s", sc.RemoteAddr(), sc.LocalAddr(), err)
+ // log.Logf("[ssu] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
if Debug {
- log.Logf("[ssu] %s >>> %s length: %d", sc.RemoteAddr(), dgram.Header.Addr.String(), len(dgram.Data))
+ log.Logf("[ssu] %s >>> %s length: %d",
+ conn.RemoteAddr(), dgram.Header.Addr.String(), len(dgram.Data))
}
addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String())
if err != nil {
@@ -512,28 +535,25 @@ func (h *shadowUDPdHandler) transportUDP(sc net.Conn, cc net.PacketConn) error {
go func() {
for {
er := func() (err error) {
- b := lPool.Get().([]byte)
- defer lPool.Put(b)
+ b := mPool.Get().([]byte)
+ defer mPool.Put(b)
n, addr, err := cc.ReadFrom(b)
if err != nil {
return
}
if Debug {
- log.Logf("[ssu] %s <<< %s length: %d", sc.RemoteAddr(), addr, n)
+ log.Logf("[ssu] %s <<< %s length: %d", conn.RemoteAddr(), addr, n)
}
if h.options.Bypass.Contains(addr.String()) {
log.Log("[ssu] bypass", addr)
return // bypass
}
- dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n])
+ dgram := gosocks5.NewUDPDatagram(
+ gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n])
buf := bytes.Buffer{}
dgram.Write(&buf)
- if buf.Len() < 10 {
- log.Logf("[ssu] %s <- %s : invalid udp datagram", sc.RemoteAddr(), addr)
- return // ignore invalid datagram
- }
- _, err = sc.Write(buf.Bytes()[3:])
+ _, err = conn.Write(buf.Bytes())
return
}()
@@ -563,13 +583,13 @@ func (c *shadowConn) Write(b []byte) (n int, err error) {
return
}
-type shadowUDPConn struct {
+type shadowUDPPacketConn struct {
net.PacketConn
raddr net.Addr
header []byte
}
-func (c *shadowUDPConn) Write(b []byte) (n int, err error) {
+func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
buf := bytes.Buffer{}
if _, err = buf.Write(c.header); err != nil {
@@ -582,7 +602,7 @@ func (c *shadowUDPConn) Write(b []byte) (n int, err error) {
return
}
-func (c *shadowUDPConn) Read(b []byte) (n int, err error) {
+func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) {
buf := mPool.Get().([]byte)
defer mPool.Put(buf)
@@ -603,20 +623,52 @@ func (c *shadowUDPConn) Read(b []byte) (n int, err error) {
return
}
-func (c *shadowUDPConn) RemoteAddr() net.Addr {
+func (c *shadowUDPPacketConn) RemoteAddr() net.Addr {
return c.raddr
}
-type shadowPacketConn struct {
+type shadowUDPStreamConn struct {
net.Conn
+ addr *gosocks5.Addr
}
-func (c *shadowPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
- n, err = c.Conn.Read(b)
- addr = c.Conn.RemoteAddr()
+func (c *shadowUDPStreamConn) Read(b []byte) (n int, err error) {
+ dgram, err := gosocks5.ReadUDPDatagram(c.Conn)
+ if err != nil {
+ return
+ }
+ n = copy(b, dgram.Data)
return
}
-func (c *shadowPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
- return c.Conn.Write(b)
+func (c *shadowUDPStreamConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
+ n, err = c.Read(b)
+ addr = c.Conn.RemoteAddr()
+
+ return
+}
+
+func (c *shadowUDPStreamConn) Write(b []byte) (n int, err error) {
+ n = len(b) // force byte length consistent
+ dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, c.addr), b)
+ buf := bytes.Buffer{}
+ dgram.Write(&buf)
+ _, err = c.Conn.Write(buf.Bytes())
+ return
+}
+
+func (c *shadowUDPStreamConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
+ return c.Write(b)
+}
+
+type shadowCipher struct {
+ cipher *ss.Cipher
+}
+
+func (c *shadowCipher) StreamConn(conn net.Conn) net.Conn {
+ return ss.NewConn(conn, c.cipher.Copy())
+}
+
+func (c *shadowCipher) PacketConn(conn net.PacketConn) net.PacketConn {
+ return ss.NewSecurePacketConn(conn, c.cipher.Copy(), false)
}
diff --git a/ss2.go b/ss2.go
index 44f0ccd..8d5961b 100644
--- a/ss2.go
+++ b/ss2.go
@@ -116,7 +116,7 @@ func (h *shadow2Handler) Handle(conn net.Conn) {
conn = cipher.StreamConn(conn)
conn.SetReadDeadline(time.Now().Add(ReadTimeout))
- addr, err := readAddr(conn)
+ addr, err := readSocksAddr(conn)
if err != nil {
log.Logf("[ss2] %s -> %s : %s",
conn.RemoteAddr(), conn.LocalAddr(), err)
@@ -191,7 +191,7 @@ func (h *shadow2Handler) Handle(conn net.Conn) {
log.Logf("[ss2] %s >-< %s", conn.RemoteAddr(), host)
}
-func readAddr(r io.Reader) (*gosocks5.Addr, error) {
+func readSocksAddr(r io.Reader) (*gosocks5.Addr, error) {
addr := &gosocks5.Addr{}
b := sPool.Get().([]byte)
defer sPool.Put(b)
diff --git a/ss_test.go b/ss_test.go
index 0fd6635..eebb0ee 100644
--- a/ss_test.go
+++ b/ss_test.go
@@ -302,7 +302,7 @@ func BenchmarkSSProxyParallel(b *testing.B) {
func shadowUDPRoundtrip(t *testing.T, host string, data []byte,
clientInfo *url.Userinfo, serverInfo *url.Userinfo) error {
- ln, err := ShadowUDPListener("localhost:0", serverInfo, nil)
+ ln, err := UDPListener("localhost:0", nil)
if err != nil {
return err
}
@@ -313,7 +313,9 @@ func shadowUDPRoundtrip(t *testing.T, host string, data []byte,
}
server := &Server{
- Handler: ShadowUDPdHandler(),
+ Handler: ShadowUDPHandler(
+ UsersHandlerOption(serverInfo),
+ ),
Listener: ln,
}
@@ -361,7 +363,7 @@ func BenchmarkShadowUDP(b *testing.B) {
sendData := make([]byte, 128)
rand.Read(sendData)
- ln, err := ShadowUDPListener("localhost:0", url.UserPassword("chacha20-ietf", "123456"), nil)
+ ln, err := UDPListener("localhost:0", nil)
if err != nil {
b.Error(err)
}
@@ -372,7 +374,9 @@ func BenchmarkShadowUDP(b *testing.B) {
}
server := &Server{
- Handler: ShadowUDPdHandler(),
+ Handler: ShadowUDPHandler(
+ UsersHandlerOption(url.UserPassword("chacha20-ietf", "123456")),
+ ),
Listener: ln,
}
diff --git a/tcp.go b/tcp.go
new file mode 100644
index 0000000..a255011
--- /dev/null
+++ b/tcp.go
@@ -0,0 +1,66 @@
+package gost
+
+import "net"
+
+// tcpTransporter is a raw TCP transporter.
+type tcpTransporter struct{}
+
+// TCPTransporter creates a raw TCP client.
+func TCPTransporter() Transporter {
+ return &tcpTransporter{}
+}
+
+func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
+ opts := &DialOptions{}
+ for _, option := range options {
+ option(opts)
+ }
+
+ timeout := opts.Timeout
+ if timeout <= 0 {
+ timeout = DialTimeout
+ }
+ if opts.Chain == nil {
+ return net.DialTimeout("tcp", addr, timeout)
+ }
+ return opts.Chain.Dial(addr)
+}
+
+func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
+ return conn, nil
+}
+
+func (tr *tcpTransporter) Multiplex() bool {
+ return false
+}
+
+type tcpListener struct {
+ net.Listener
+}
+
+// TCPListener creates a Listener for TCP proxy server.
+func TCPListener(addr string) (Listener, error) {
+ laddr, err := net.ResolveTCPAddr("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ ln, err := net.ListenTCP("tcp", laddr)
+ if err != nil {
+ return nil, err
+ }
+ return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil
+}
+
+type tcpKeepAliveListener struct {
+ *net.TCPListener
+}
+
+func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
+ tc, err := ln.AcceptTCP()
+ if err != nil {
+ return
+ }
+ tc.SetKeepAlive(true)
+ tc.SetKeepAlivePeriod(KeepAliveTime)
+ return tc, nil
+}
diff --git a/tuntap.go b/tuntap.go
index eda5180..43ed893 100644
--- a/tuntap.go
+++ b/tuntap.go
@@ -44,6 +44,7 @@ type IPRoute struct {
Gateway net.IP
}
+// TunConfig is the config for TUN device.
type TunConfig struct {
Name string
Addr string
@@ -426,6 +427,7 @@ func etherType(et waterutil.Ethertype) string {
return fmt.Sprintf("unknown(%v)", et)
}
+// TapConfig is the config for TAP device.
type TapConfig struct {
Name string
Addr string
@@ -789,6 +791,7 @@ func (c *tunTapConn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
+// IsIPv6Multicast reports whether the address addr is an IPv6 multicast address.
func IsIPv6Multicast(addr net.HardwareAddr) bool {
return addr[0] == 0x33 && addr[1] == 0x33
}
diff --git a/udp.go b/udp.go
new file mode 100644
index 0000000..b17e55a
--- /dev/null
+++ b/udp.go
@@ -0,0 +1,357 @@
+package gost
+
+import (
+ "errors"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/go-log/log"
+)
+
+// udpTransporter is a raw UDP transporter.
+type udpTransporter struct{}
+
+// UDPTransporter creates a Transporter for UDP client.
+func UDPTransporter() Transporter {
+ return &udpTransporter{}
+}
+
+func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
+ raddr, err := net.ResolveUDPAddr("udp", addr)
+ if err != nil {
+ return nil, err
+ }
+
+ conn, err := net.ListenUDP("udp", nil)
+ if err != nil {
+ return nil, err
+ }
+
+ return &udpClientConn{
+ UDPConn: conn,
+ raddr: raddr,
+ }, nil
+}
+
+func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
+ return conn, nil
+}
+
+func (tr *udpTransporter) Multiplex() bool {
+ return false
+}
+
+// UDPListenConfig is the config for UDP Listener.
+type UDPListenConfig struct {
+ TTL time.Duration // timeout per connection
+ Backlog int // connection backlog
+ QueueSize int // recv queue size per connection
+}
+
+type udpListener struct {
+ ln net.PacketConn
+ connChan chan net.Conn
+ errChan chan error
+ connMap udpConnMap
+ config *UDPListenConfig
+}
+
+// UDPListener creates a Listener for UDP server.
+func UDPListener(addr string, cfg *UDPListenConfig) (Listener, error) {
+ laddr, err := net.ResolveUDPAddr("udp", addr)
+ if err != nil {
+ return nil, err
+ }
+ ln, err := net.ListenUDP("udp", laddr)
+ if err != nil {
+ return nil, err
+ }
+
+ if cfg == nil {
+ cfg = &UDPListenConfig{}
+ }
+
+ backlog := cfg.Backlog
+ if backlog <= 0 {
+ backlog = defaultBacklog
+ }
+
+ l := &udpListener{
+ ln: ln,
+ connChan: make(chan net.Conn, backlog),
+ errChan: make(chan error, 1),
+ config: cfg,
+ }
+ go l.listenLoop()
+ return l, nil
+}
+
+func (l *udpListener) listenLoop() {
+ for {
+ b := make([]byte, mediumBufferSize)
+ n, raddr, err := l.ln.ReadFrom(b)
+ if err != nil {
+ log.Logf("[udp] peer -> %s : %s", l.Addr(), err)
+ l.Close()
+ l.errChan <- err
+ close(l.errChan)
+ return
+ }
+
+ conn, ok := l.connMap.Get(raddr.String())
+ if !ok {
+ conn = newUDPServerConn(l.ln, raddr, &udpServerConnConfig{
+ ttl: l.config.TTL,
+ qsize: l.config.QueueSize,
+ onClose: func() {
+ l.connMap.Delete(raddr.String())
+ log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size())
+ },
+ })
+
+ select {
+ case l.connChan <- conn:
+ l.connMap.Set(raddr.String(), conn)
+ log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size())
+ default:
+ conn.Close()
+ log.Logf("[udp] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan))
+ }
+ }
+
+ select {
+ case conn.rChan <- b[:n]:
+ if Debug {
+ log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n)
+ }
+ default:
+ log.Logf("[udp] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan))
+ }
+ }
+}
+
+func (l *udpListener) Accept() (conn net.Conn, err error) {
+ var ok bool
+ select {
+ case conn = <-l.connChan:
+ case err, ok = <-l.errChan:
+ if !ok {
+ err = errors.New("accpet on closed listener")
+ }
+ }
+ return
+}
+
+func (l *udpListener) Addr() net.Addr {
+ return l.ln.LocalAddr()
+}
+
+func (l *udpListener) Close() error {
+ err := l.ln.Close()
+ l.connMap.Range(func(k interface{}, v *udpServerConn) bool {
+ v.Close()
+ return true
+ })
+
+ return err
+}
+
+type udpConnMap struct {
+ m sync.Map
+ size int64
+}
+
+func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) {
+ v, ok := m.m.Load(key)
+ if ok {
+ conn, ok = v.(*udpServerConn)
+ }
+ return
+}
+
+func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) {
+ m.m.Store(key, conn)
+ atomic.AddInt64(&m.size, 1)
+}
+
+func (m *udpConnMap) Delete(key interface{}) {
+ m.m.Delete(key)
+ atomic.AddInt64(&m.size, -1)
+}
+
+func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) {
+ m.m.Range(func(k, v interface{}) bool {
+ return f(k, v.(*udpServerConn))
+ })
+}
+
+func (m *udpConnMap) Size() int64 {
+ return atomic.LoadInt64(&m.size)
+}
+
+// udpServerConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
+type udpServerConn struct {
+ conn net.PacketConn
+ raddr net.Addr
+ rChan chan []byte
+ closed chan struct{}
+ closeMutex sync.Mutex
+ nopChan chan int
+ config *udpServerConnConfig
+}
+
+type udpServerConnConfig struct {
+ ttl time.Duration
+ qsize int
+ onClose func()
+}
+
+func newUDPServerConn(conn net.PacketConn, raddr net.Addr, cfg *udpServerConnConfig) *udpServerConn {
+ if conn == nil || raddr == nil {
+ return nil
+ }
+
+ if cfg == nil {
+ cfg = &udpServerConnConfig{}
+ }
+ qsize := cfg.qsize
+ if qsize <= 0 {
+ qsize = defaultQueueSize
+ }
+ c := &udpServerConn{
+ conn: conn,
+ raddr: raddr,
+ rChan: make(chan []byte, qsize),
+ closed: make(chan struct{}),
+ nopChan: make(chan int),
+ config: cfg,
+ }
+ go c.ttlWait()
+ return c
+}
+
+func (c *udpServerConn) Read(b []byte) (n int, err error) {
+ n, _, err = c.ReadFrom(b)
+ return
+}
+
+func (c *udpServerConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
+ select {
+ case bb := <-c.rChan:
+ n = copy(b, bb)
+ case <-c.closed:
+ err = errors.New("read from closed connection")
+ return
+ }
+
+ select {
+ case c.nopChan <- n:
+ default:
+ }
+
+ addr = c.raddr
+
+ return
+}
+
+func (c *udpServerConn) Write(b []byte) (n int, err error) {
+ return c.WriteTo(b, c.raddr)
+}
+
+func (c *udpServerConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
+ n, err = c.conn.WriteTo(b, addr)
+
+ if n > 0 {
+ if Debug {
+ log.Logf("[udp] %s <<< %s : length %d", addr, c.LocalAddr(), n)
+ }
+
+ select {
+ case c.nopChan <- n:
+ default:
+ }
+ }
+
+ return
+}
+
+func (c *udpServerConn) Close() error {
+ c.closeMutex.Lock()
+ defer c.closeMutex.Unlock()
+
+ select {
+ case <-c.closed:
+ return errors.New("connection is closed")
+ default:
+ if c.config.onClose != nil {
+ c.config.onClose()
+ }
+ close(c.closed)
+ }
+ return nil
+}
+
+func (c *udpServerConn) ttlWait() {
+ ttl := c.config.ttl
+ if ttl == 0 {
+ ttl = defaultTTL
+ }
+ timer := time.NewTimer(ttl)
+ defer timer.Stop()
+
+ for {
+ select {
+ case <-c.nopChan:
+ if !timer.Stop() {
+ <-timer.C
+ }
+ timer.Reset(ttl)
+ case <-timer.C:
+ c.Close()
+ return
+ case <-c.closed:
+ return
+ }
+ }
+}
+
+func (c *udpServerConn) LocalAddr() net.Addr {
+ return c.conn.LocalAddr()
+}
+
+func (c *udpServerConn) RemoteAddr() net.Addr {
+ return c.raddr
+}
+
+func (c *udpServerConn) SetDeadline(t time.Time) error {
+ return c.conn.SetDeadline(t)
+}
+
+func (c *udpServerConn) SetReadDeadline(t time.Time) error {
+ return c.conn.SetReadDeadline(t)
+}
+
+func (c *udpServerConn) SetWriteDeadline(t time.Time) error {
+ return c.conn.SetWriteDeadline(t)
+}
+
+type udpClientConn struct {
+ *net.UDPConn
+ raddr net.Addr
+}
+
+func (c *udpClientConn) Write(b []byte) (int, error) {
+ if c.raddr != nil {
+ return c.WriteTo(b, c.raddr)
+ }
+ return c.UDPConn.Write(b)
+}
+
+func (c *udpClientConn) RemoteAddr() net.Addr {
+ if c.raddr != nil {
+ return c.raddr
+ }
+ return c.UDPConn.RemoteAddr()
+}