From 62bccd7e65612bb880c592c9a94a602a432283a5 Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Mon, 31 Jul 2017 18:19:22 +0800 Subject: [PATCH] add UDP port forwarding support --- gost/examples/bench/srv.go | 5 +- gost/examples/forward/udp/cli.go | 53 +++++ gost/examples/forward/udp/forward.go | 56 +++++ gost/examples/forward/udp/srv.go | 44 ++++ gost/forward.go | 329 +++++++++++++++++++++++++-- gost/kcp.go | 10 +- gost/quic.go | 5 +- gost/socks.go | 8 +- gost/ss.go | 4 +- gost/ssh.go | 5 +- gost/ws.go | 1 + 11 files changed, 482 insertions(+), 38 deletions(-) create mode 100644 gost/examples/forward/udp/cli.go create mode 100644 gost/examples/forward/udp/forward.go create mode 100644 gost/examples/forward/udp/srv.go diff --git a/gost/examples/bench/srv.go b/gost/examples/bench/srv.go index 7d5d42b..68f66ff 100644 --- a/gost/examples/bench/srv.go +++ b/gost/examples/bench/srv.go @@ -156,7 +156,10 @@ func rtcpForwardServer() { if err != nil { log.Fatal() } - h := gost.RTCPForwardHandler(":1222", "localhost:22") + h := gost.RTCPForwardHandler( + ":1222", + gost.AddrHandlerOption("127.0.0.1:22"), + ) log.Fatal(s.Serve(ln, h)) } diff --git a/gost/examples/forward/udp/cli.go b/gost/examples/forward/udp/cli.go new file mode 100644 index 0000000..42dd16b --- /dev/null +++ b/gost/examples/forward/udp/cli.go @@ -0,0 +1,53 @@ +package main + +import ( + "flag" + "log" + "net" + "time" +) + +var ( + concurrency int + saddr string +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&saddr, "S", ":18080", "server address") + flag.IntVar(&concurrency, "c", 1, "Number of multiple echo to make at a time") + flag.Parse() +} + +func main() { + for i := 0; i < concurrency; i++ { + go udpEchoLoop() + } + select{} +} + +func udpEchoLoop() { + addr, err := net.ResolveUDPAddr("udp", saddr) + if err != nil { + log.Fatal(err) + } + conn, err := net.DialUDP("udp", nil, addr) + if err != nil { + log.Fatal(err) + } + + msg := []byte(`abcdefghijklmnopqrstuvwxyz`) + for { + if _, err := conn.Write(msg); err != nil { + log.Fatal(err) + } + b := make([]byte, 1024) + _, err := conn.Read(b) + if err != nil { + log.Fatal(err) + } + // log.Println(string(b[:n])) + time.Sleep(100 * time.Millisecond) + } +} diff --git a/gost/examples/forward/udp/forward.go b/gost/examples/forward/udp/forward.go new file mode 100644 index 0000000..0469766 --- /dev/null +++ b/gost/examples/forward/udp/forward.go @@ -0,0 +1,56 @@ +package main + +import ( + "flag" + "log" + "net/url" + "time" + + "github.com/ginuerzh/gost/gost" +) + +var ( + laddr, faddr string + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":18080", "listen address") + flag.StringVar(&faddr, "F", ":8080", "forward address") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} +func main() { + udpForwardServer() +} + +func udpForwardServer() { + s := &gost.Server{} + ln, err := gost.UDPForwardListener(laddr, time.Second*3) + if err != nil { + log.Fatal(err) + } + h := gost.UDPForwardHandler( + faddr, + gost.ChainHandlerOption(gost.NewChain(gost.Node{ + Protocol: "socks5", + Transport: "tcp", + Addr: ":11080", + User: url.UserPassword("admin", "123456"), + Client: &gost.Client{ + Connector: gost.SOCKS5Connector( + url.UserPassword("admin", "123456"), + ), + Transporter: gost.TCPTransporter(), + }, + })), + ) + log.Fatal(s.Serve(ln, h)) +} diff --git a/gost/examples/forward/udp/srv.go b/gost/examples/forward/udp/srv.go new file mode 100644 index 0000000..3aadf2d --- /dev/null +++ b/gost/examples/forward/udp/srv.go @@ -0,0 +1,44 @@ +package main + +import ( + "flag" + "log" + "net" +) + +var ( + laddr string +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":8080", "listen address") + flag.Parse() +} +func main() { + udpEchoServer() +} + +func udpEchoServer() { + addr, err := net.ResolveUDPAddr("udp", laddr) + if err != nil { + log.Fatal(err) + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + log.Fatal(err) + } + + for { + b := make([]byte, 1024) + n, raddr, err := conn.ReadFromUDP(b) + if err != nil { + log.Fatal(err) + } + if _, err = conn.WriteToUDP(b[:n], raddr); err != nil { + log.Fatal(err) + } + + } +} diff --git a/gost/forward.go b/gost/forward.go index 77d41da..2f3a19c 100644 --- a/gost/forward.go +++ b/gost/forward.go @@ -3,6 +3,7 @@ package gost import ( "errors" "net" + "sync" "time" "fmt" @@ -20,10 +21,8 @@ type tcpForwardHandler struct { // The raddr is the remote address that the server will forward to. func TCPForwardHandler(raddr string, opts ...HandlerOption) Handler { h := &tcpForwardHandler{ - raddr: raddr, - options: &HandlerOptions{ - Chain: new(Chain), - }, + raddr: raddr, + options: &HandlerOptions{}, } for _, opt := range opts { opt(h.options) @@ -55,13 +54,10 @@ type udpForwardHandler struct { // UDPForwardHandler creates a server Handler for UDP port forwarding server. // The raddr is the remote address that the server will forward to. -func UDPForwardHandler(raddr string, ttl time.Duration, opts ...HandlerOption) Handler { +func UDPForwardHandler(raddr string, opts ...HandlerOption) Handler { h := &udpForwardHandler{ - raddr: raddr, - ttl: ttl, - options: &HandlerOptions{ - Chain: new(Chain), - }, + raddr: raddr, + options: &HandlerOptions{}, } for _, opt := range opts { opt(h.options) @@ -71,6 +67,68 @@ func UDPForwardHandler(raddr string, ttl time.Duration, opts ...HandlerOption) H func (h *udpForwardHandler) Handle(conn net.Conn) { defer conn.Close() + + var cc net.Conn + if h.options.Chain.IsEmpty() { + raddr, err := net.ResolveUDPAddr("udp", h.raddr) + if err != nil { + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + return + } + cc, err = net.DialUDP("udp", nil, raddr) + if err != nil { + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + return + } + } else { + var err error + cc, err = h.getUDPTunnel() + if err != nil { + log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + return + } + cc = &udpTunnelConn{Conn: cc, raddr: h.raddr} + } + + defer cc.Close() + + log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), h.raddr) + transport(conn, cc) + log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), h.raddr) +} + +func (h *udpForwardHandler) getUDPTunnel() (net.Conn, error) { + conn, err := h.options.Chain.Conn() + if err != nil { + return nil, err + } + cc, err := socks5Handshake(conn, h.options.Chain.LastNode().User) + if err != nil { + conn.Close() + return nil, err + } + conn = cc + + conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if err = gosocks5.NewRequest(CmdUDPTun, nil).Write(conn); err != nil { + conn.Close() + return nil, err + } + conn.SetWriteDeadline(time.Time{}) + + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) + reply, err := gosocks5.ReadReply(conn) + if err != nil { + conn.Close() + return nil, err + } + conn.SetReadDeadline(time.Time{}) + + if reply.Rep != gosocks5.Succeeded { + conn.Close() + return nil, errors.New("UDP tunnel failure") + } + return conn, nil } type rtcpForwardHandler struct { @@ -184,34 +242,265 @@ func (h *rudpForwardHandler) Handle(conn net.Conn) { } type udpForwardListener struct { - addr *net.UDPAddr - conn *net.UDPConn + ln *net.UDPConn + conns map[string]*udpServerConn + connMutex sync.Mutex + connChan chan net.Conn + errChan chan error + ttl time.Duration } // UDPForwardListener creates a Listener for UDP port forwarding server. -func UDPForwardListener(addr string) (Listener, error) { +func UDPForwardListener(addr string, ttl time.Duration) (Listener, error) { laddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } - conn, err := net.ListenUDP("udp", laddr) + ln, err := net.ListenUDP("udp", laddr) if err != nil { return nil, err } - return &udpForwardListener{conn: conn}, nil + l := &udpForwardListener{ + ln: ln, + conns: make(map[string]*udpServerConn), + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + ttl: ttl, + } + go l.listenLoop() + return l, nil } -func (l *udpForwardListener) Accept() (net.Conn, error) { - // TODO: create udp forward connection - return nil, nil +func (l *udpForwardListener) listenLoop() { + for { + b := make([]byte, mediumBufferSize) + n, raddr, err := l.ln.ReadFromUDP(b) + if err != nil { + log.Logf("[udp] peer -> %s : %s", l.Addr(), err) + l.ln.Close() + l.errChan <- err + close(l.errChan) + } + if Debug { + log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) + } + conn, ok := l.conns[raddr.String()] + if !ok || conn.Closed() { + conn = newUDPServerConn(l.ln, raddr, l.ttl) + l.conns[raddr.String()] = conn + + select { + case l.connChan <- conn: + default: + conn.Close() + log.Logf("[udp] %s - %s: connection queue is full", raddr, l.Addr()) + } + } + + select { + case conn.rChan <- b[:n]: + default: + log.Logf("[udp] %s -> %s : write queue is full", raddr, l.Addr()) + } + } +} + +func (l *udpForwardListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return } func (l *udpForwardListener) Addr() net.Addr { - return l.addr + return l.ln.LocalAddr() } func (l *udpForwardListener) Close() error { - return l.conn.Close() + return l.ln.Close() +} + +type udpServerConn struct { + conn *net.UDPConn + raddr *net.UDPAddr + rChan, wChan chan []byte + closed chan struct{} + brokenChan chan struct{} + closeMutex sync.Mutex + ttl time.Duration + nopChan chan int +} + +func newUDPServerConn(conn *net.UDPConn, raddr *net.UDPAddr, ttl time.Duration) *udpServerConn { + c := &udpServerConn{ + conn: conn, + raddr: raddr, + rChan: make(chan []byte, 128), + wChan: make(chan []byte, 128), + closed: make(chan struct{}), + brokenChan: make(chan struct{}), + nopChan: make(chan int), + ttl: ttl, + } + go c.writeLoop() + go c.ttlWait() + return c +} + +func (c *udpServerConn) Read(b []byte) (n int, err error) { + select { + case bb := <-c.rChan: + n = copy(b, bb) + if n != len(bb) { + err = errors.New("read partial data") + return + } + case <-c.brokenChan: + err = errors.New("Broken pipe") + case <-c.closed: + err = errors.New("read from closed connection") + return + } + + select { + case c.nopChan <- n: + default: + } + return +} + +func (c *udpServerConn) Write(b []byte) (n int, err error) { + if len(b) == 0 { + return 0, nil + } + select { + case c.wChan <- b: + n = len(b) + case <-c.brokenChan: + err = errors.New("Broken pipe") + case <-c.closed: + err = errors.New("write to closed connection") + return + } + + select { + case c.nopChan <- n: + default: + } + + return +} + +func (c *udpServerConn) Close() error { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + + select { + case <-c.closed: + return errors.New("connection is closed") + default: + close(c.closed) + } + return nil +} + +func (c *udpServerConn) Closed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + +func (c *udpServerConn) writeLoop() { + for { + select { + case b, ok := <-c.wChan: + if !ok { + return + } + n, err := c.conn.WriteToUDP(b, c.raddr) + if err != nil { + log.Logf("[udp] %s <<< %s : %s", c.RemoteAddr(), c.LocalAddr(), err) + return + } + if Debug { + log.Logf("[udp] %s <<< %s : length %d", c.RemoteAddr(), c.LocalAddr(), n) + } + case <-c.brokenChan: + return + case <-c.closed: + return + } + } +} + +func (c *udpServerConn) ttlWait() { + timer := time.NewTimer(c.ttl) + + for { + select { + case <-c.nopChan: + timer.Reset(c.ttl) + case <-timer.C: + close(c.brokenChan) + return + case <-c.closed: + return + } + } +} + +func (c *udpServerConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *udpServerConn) RemoteAddr() net.Addr { + return c.raddr +} + +func (c *udpServerConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *udpServerConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *udpServerConn) SetWriteDeadline(t time.Time) error { + return nil +} + +type udpTunnelConn struct { + raddr string + net.Conn +} + +func (c *udpTunnelConn) Read(b []byte) (n int, err error) { + dgram, err := gosocks5.ReadUDPDatagram(c.Conn) + if err != nil { + return + } + n = copy(b, dgram.Data) + return +} + +func (c *udpTunnelConn) Write(b []byte) (n int, err error) { + addr, err := net.ResolveUDPAddr("udp", c.raddr) + if err != nil { + return + } + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(addr)), b) + if err = dgram.Write(c.Conn); err != nil { + return + } + return len(b), nil } type rtcpForwardListener struct { diff --git a/gost/kcp.go b/gost/kcp.go index ed6cf70..857c697 100644 --- a/gost/kcp.go +++ b/gost/kcp.go @@ -95,10 +95,6 @@ type kcpConn struct { stream *smux.Stream } -func newKCPConn(conn net.Conn, stream *smux.Stream) *kcpConn { - return &kcpConn{conn: conn, stream: stream} -} - func (c *kcpConn) Read(b []byte) (n int, err error) { return c.stream.Read(b) } @@ -141,7 +137,7 @@ func (session *kcpSession) GetConn() (*kcpConn, error) { if err != nil { return nil, err } - return newKCPConn(session.conn, stream), nil + return &kcpConn{conn: session.conn, stream: stream}, nil } func (session *kcpSession) Close() error { @@ -378,9 +374,11 @@ func (l *kcpListener) mux(conn net.Conn) { return } + cc := &kcpConn{conn: conn, stream: stream} select { - case l.connChan <- newKCPConn(conn, stream): + case l.connChan <- cc: default: + cc.Close() log.Logf("[kcp] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) } } diff --git a/gost/quic.go b/gost/quic.go index a82d08a..e4ce855 100644 --- a/gost/quic.go +++ b/gost/quic.go @@ -191,9 +191,12 @@ func (l *quicListener) sessionLoop(session quic.Session) { log.Log("[quic] accept stream:", err) return } + + cc := &quicConn{Stream: stream, laddr: session.LocalAddr(), raddr: session.RemoteAddr()} select { - case l.connChan <- &quicConn{Stream: stream, laddr: session.LocalAddr(), raddr: session.RemoteAddr()}: + case l.connChan <- cc: default: + cc.Close() log.Logf("[quic] %s - %s: connection queue is full", session.RemoteAddr(), session.LocalAddr()) } } diff --git a/gost/socks.go b/gost/socks.go index a44192f..9d8799d 100644 --- a/gost/socks.go +++ b/gost/socks.go @@ -340,9 +340,7 @@ type socks5Handler struct { // SOCKS5Handler creates a server Handler for SOCKS5 proxy server. func SOCKS5Handler(opts ...HandlerOption) Handler { - options := &HandlerOptions{ - Chain: new(Chain), - } + options := &HandlerOptions{} for _, opt := range opts { opt(options) } @@ -972,9 +970,7 @@ type socks4Handler struct { // SOCKS4Handler creates a server Handler for SOCKS4(A) proxy server. func SOCKS4Handler(opts ...HandlerOption) Handler { - options := &HandlerOptions{ - Chain: new(Chain), - } + options := &HandlerOptions{} for _, opt := range opts { opt(options) } diff --git a/gost/ss.go b/gost/ss.go index e6babd9..f3e2889 100644 --- a/gost/ss.go +++ b/gost/ss.go @@ -95,9 +95,7 @@ type shadowHandler struct { // ShadowHandler creates a server Handler for shadowsocks proxy server. func ShadowHandler(opts ...HandlerOption) Handler { h := &shadowHandler{ - options: &HandlerOptions{ - Chain: new(Chain), - }, + options: &HandlerOptions{}, } for _, opt := range opts { opt(h.options) diff --git a/gost/ssh.go b/gost/ssh.go index 98305bb..d570006 100644 --- a/gost/ssh.go +++ b/gost/ssh.go @@ -85,6 +85,7 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string) (net.Con select { case cc.session.connChan <- rc: default: + rc.Close() log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), addr) } } @@ -670,9 +671,11 @@ func (l *sshTunnelListener) serveConn(conn net.Conn) { continue } go ssh.DiscardRequests(requests) + cc := &sshConn{conn: conn, channel: channel} select { - case l.connChan <- &sshConn{conn: conn, channel: channel}: + case l.connChan <- cc: default: + cc.Close() log.Logf("[ssh] %s - %s: connection queue is full", conn.RemoteAddr(), l.Addr()) } diff --git a/gost/ws.go b/gost/ws.go index fb56254..ff4605c 100644 --- a/gost/ws.go +++ b/gost/ws.go @@ -217,6 +217,7 @@ func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) { select { case l.connChan <- websocketServerConn(conn): default: + conn.Close() log.Logf("[ws] %s - %s: connection queue is full", r.RemoteAddr, l.addr) } }