From aef1c86b438488ea3af555341d7e5f93688d742a Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Fri, 6 Jan 2017 17:57:59 +0800 Subject: [PATCH] #62 fix UDP port forwarding --- forward.go | 351 ++++++++++++++++++++++++++++++----------------------- gost.go | 2 + server.go | 7 +- socks.go | 98 +-------------- 4 files changed, 206 insertions(+), 252 deletions(-) diff --git a/forward.go b/forward.go index 6bdbd9c..fcfd67f 100644 --- a/forward.go +++ b/forward.go @@ -62,12 +62,164 @@ func (s *TcpForwardServer) handleTcpForward(conn net.Conn, raddr net.Addr) { glog.V(LINFO).Infof("[tcp] %s >-< %s", conn.RemoteAddr(), raddr) } -type UdpForwardServer struct { - Base *ProxyServer +type packet struct { + srcAddr *net.UDPAddr // src address + dstAddr *net.UDPAddr // dest address + data []byte } -func NewUdpForwardServer(base *ProxyServer) *UdpForwardServer { - return &UdpForwardServer{Base: base} +type cnode struct { + chain *ProxyChain + conn net.Conn + srcAddr, dstAddr *net.UDPAddr + rChan, wChan chan *packet + err error + ttl time.Duration +} + +func (node *cnode) getUDPTunnel() (net.Conn, error) { + conn, err := node.chain.GetConn() + if err != nil { + return nil, err + } + + conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if err = gosocks5.NewRequest(CmdUdpTun, nil).Write(conn); err != nil { + conn.Close() + return nil, err + } + conn.SetWriteDeadline(time.Time{}) + + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) + reply, err := gosocks5.ReadReply(conn) + if err != nil { + conn.Close() + return nil, err + } + conn.SetReadDeadline(time.Time{}) + + if reply.Rep != gosocks5.Succeeded { + conn.Close() + return nil, errors.New("UDP tunnel failure") + } + + return conn, nil +} + +func (node *cnode) run() { + if len(node.chain.Nodes()) == 0 { + lconn, err := net.ListenUDP("udp", nil) + if err != nil { + glog.V(LWARNING).Infof("[udp] %s -> %s : %s", node.srcAddr, node.dstAddr, err) + return + } + node.conn = lconn + } else { + tc, err := node.getUDPTunnel() + if err != nil { + glog.V(LWARNING).Infof("[udp-tun] %s -> %s : %s", node.srcAddr, node.dstAddr, err) + return + } + node.conn = tc + } + + defer node.conn.Close() + + timer := time.NewTimer(node.ttl) + errChan := make(chan error, 2) + + go func() { + for { + switch c := node.conn.(type) { + case *net.UDPConn: + b := make([]byte, MediumBufferSize) + n, addr, err := c.ReadFromUDP(b) + if err != nil { + glog.V(LWARNING).Infof("[udp] %s <- %s : %s", node.srcAddr, node.dstAddr, err) + node.err = err + errChan <- err + return + } + + timer.Reset(node.ttl) + glog.V(LDEBUG).Infof("[udp] %s <<< %s : length %d", node.srcAddr, addr, n) + + if node.dstAddr.String() != addr.String() { + glog.V(LWARNING).Infof("[udp] %s <- %s : dst-addr mismatch (%s)", node.srcAddr, node.dstAddr, addr) + break + } + select { + // swap srcAddr with dstAddr + case node.rChan <- &packet{srcAddr: node.dstAddr, dstAddr: node.srcAddr, data: b[:n]}: + case <-time.After(time.Second * 3): + glog.V(LWARNING).Infof("[udp] %s <- %s : %s", node.srcAddr, node.dstAddr, "recv queue is full, discard") + } + + default: + dgram, err := gosocks5.ReadUDPDatagram(c) + if err != nil { + glog.V(LWARNING).Infof("[udp-tun] %s <- %s : %s", node.srcAddr, node.dstAddr, err) + node.err = err + errChan <- err + return + } + + timer.Reset(node.ttl) + glog.V(LDEBUG).Infof("[udp-tun] %s <<< %s : length %d", node.srcAddr, dgram.Header.Addr.String(), len(dgram.Data)) + + if dgram.Header.Addr.String() != node.dstAddr.String() { + glog.V(LWARNING).Infof("[udp-tun] %s <- %s : dst-addr mismatch (%s)", node.srcAddr, node.dstAddr, dgram.Header.Addr) + break + } + select { + // swap srcAddr with dstAddr + case node.rChan <- &packet{srcAddr: node.dstAddr, dstAddr: node.srcAddr, data: dgram.Data}: + case <-time.After(time.Second * 3): + glog.V(LWARNING).Infof("[udp-tun] %s <- %s : %s", node.srcAddr, node.dstAddr, "recv queue is full, discard") + } + } + } + }() + + go func() { + for pkt := range node.wChan { + glog.V(LDEBUG).Infof("[udp] %s >>> %s : length %d", pkt.srcAddr, pkt.dstAddr, len(pkt.data)) + timer.Reset(node.ttl) + + switch c := node.conn.(type) { + case *net.UDPConn: + if _, err := c.WriteToUDP(pkt.data, pkt.dstAddr); err != nil { + glog.V(LWARNING).Infof("[udp] %s -> %s : %s", pkt.srcAddr, pkt.dstAddr, err) + node.err = err + errChan <- err + return + } + + default: + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(pkt.data)), 0, ToSocksAddr(pkt.dstAddr)), pkt.data) + if err := dgram.Write(c); err != nil { + glog.V(LWARNING).Infof("[udp-tun] %s -> %s : %s", pkt.srcAddr, pkt.dstAddr, err) + node.err = err + errChan <- err + return + } + } + } + }() + + select { + case <-errChan: + case <-timer.C: + } +} + +type UdpForwardServer struct { + Base *ProxyServer + TTL int +} + +func NewUdpForwardServer(base *ProxyServer, ttl int) *UdpForwardServer { + return &UdpForwardServer{Base: base, TTL: ttl} } func (s *UdpForwardServer) ListenAndServe() error { @@ -88,7 +240,9 @@ func (s *UdpForwardServer) ListenAndServe() error { } defer conn.Close() - if len(s.Base.Chain.nodes) == 0 { + rChan, wChan := make(chan *packet, 128), make(chan *packet, 128) + // start send queue + go func(ch chan<- *packet) { for { b := make([]byte, MediumBufferSize) n, addr, err := conn.ReadFromUDP(b) @@ -96,172 +250,61 @@ func (s *UdpForwardServer) ListenAndServe() error { glog.V(LWARNING).Infof("[udp] %s -> %s : %s", laddr, raddr, err) continue } - go func() { - s.handleUdpForwardLocal(conn, addr, raddr, b[:n]) - }() - } - } - - rChan, wChan := make(chan *gosocks5.UDPDatagram, 32), make(chan *gosocks5.UDPDatagram, 32) - - go func() { - for { - b := make([]byte, MediumBufferSize) - n, addr, err := conn.ReadFromUDP(b) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", laddr, raddr, err) - return - } select { - case rChan <- gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(n), 0, ToSocksAddr(addr)), b[:n]): - default: - // glog.V(LWARNING).Infof("[udp-connect] %s -> %s : rbuf is full", laddr, raddr) + case ch <- &packet{srcAddr: addr, dstAddr: raddr, data: b[:n]}: + case <-time.After(time.Second * 3): + glog.V(LWARNING).Infof("[udp] %s -> %s : %s", addr, raddr, "send queue is full, discard") } } - }() - - go func() { - for { - dgram := <-wChan - addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", laddr, raddr, err) - continue // drop silently - } - if _, err = conn.WriteToUDP(dgram.Data, addr); err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", laddr, raddr, err) + }(wChan) + // start recv queue + go func(ch <-chan *packet) { + for pkt := range ch { + if _, err := conn.WriteToUDP(pkt.data, pkt.dstAddr); err != nil { + glog.V(LWARNING).Infof("[udp] %s <- %s : %s", pkt.dstAddr, pkt.srcAddr, err) return } } - }() + }(rChan) - for { - s.handleUdpForwardTunnel(laddr, raddr, rChan, wChan) - } -} + // mapping client to node + m := make(map[string]*cnode) -func (s *UdpForwardServer) handleUdpForwardLocal(conn *net.UDPConn, laddr, raddr *net.UDPAddr, data []byte) { - lconn, err := net.ListenUDP("udp", nil) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", laddr, raddr, err) - return - } - defer lconn.Close() - - if _, err := lconn.WriteToUDP(data, raddr); err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", laddr, raddr, err) - return - } - glog.V(LDEBUG).Infof("[udp] %s >>> %s length %d", laddr, raddr, len(data)) - - b := make([]byte, MediumBufferSize) - lconn.SetReadDeadline(time.Now().Add(ReadTimeout)) - n, addr, err := lconn.ReadFromUDP(b) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", laddr, raddr, err) - return - } - glog.V(LDEBUG).Infof("[udp] %s <<< %s length %d", laddr, addr, n) - - if _, err := conn.WriteToUDP(b[:n], laddr); err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", laddr, raddr, err) - } - return -} - -func (s *UdpForwardServer) handleUdpForwardTunnel(laddr, raddr *net.UDPAddr, rChan, wChan chan *gosocks5.UDPDatagram) { - var cc net.Conn - var err error - retry := 0 - - for { - cc, err = s.prepareUdpConnectTunnel(raddr) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", laddr, raddr, err) - time.Sleep((1 << uint(retry)) * time.Second) - if retry < 5 { - retry++ - } - continue - } - break - } - defer cc.Close() - - glog.V(LINFO).Infof("[udp] %s <-> %s", laddr, raddr) - - rExit := make(chan interface{}) - errc := make(chan error, 2) - - go func() { - for { - select { - case dgram := <-rChan: - if err := dgram.Write(cc); err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", laddr, raddr, err) - errc <- err - return - } - glog.V(LDEBUG).Infof("[udp-tun] %s >>> %s length: %d", laddr, raddr, len(dgram.Data)) - case <-rExit: - // glog.V(LDEBUG).Infof("[udp-connect] %s -> %s : exited", laddr, raddr) - return + // start dispatcher + for pkt := range wChan { + // clear obsolete nodes + for k, node := range m { + if node != nil && node.err != nil { + close(node.wChan) + delete(m, k) + glog.V(LINFO).Infof("[udp] clear node %s", k) } } - }() - go func() { - for { - dgram, err := gosocks5.ReadUDPDatagram(cc) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", laddr, raddr, err) - close(rExit) - errc <- err - return - } - select { - case wChan <- dgram: - glog.V(LDEBUG).Infof("[udp-tun] %s <<< %s length: %d", laddr, raddr, len(dgram.Data)) - default: + node, ok := m[pkt.srcAddr.String()] + if !ok { + node = &cnode{ + chain: s.Base.Chain, + srcAddr: pkt.srcAddr, + dstAddr: pkt.dstAddr, + rChan: rChan, + wChan: make(chan *packet, 32), + ttl: time.Duration(s.TTL) * time.Second, } + m[pkt.srcAddr.String()] = node + go node.run() + glog.V(LDEBUG).Infof("[udp] %s -> %s : new client (%d)", pkt.srcAddr, pkt.dstAddr, len(m)) } - }() - select { - case <-errc: - //log.Println("w exit", err) - } - glog.V(LINFO).Infof("[udp] %s >-< %s", laddr, raddr) -} - -func (s *UdpForwardServer) prepareUdpConnectTunnel(addr net.Addr) (net.Conn, error) { - conn, err := s.Base.Chain.GetConn() - if err != nil { - return nil, err + select { + case node.wChan <- pkt: + case <-time.After(time.Second * 3): + glog.V(LWARNING).Infof("[udp] %s -> %s : %s", pkt.srcAddr, pkt.dstAddr, "node send queue is full, discard") + } } - conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err = gosocks5.NewRequest(CmdUdpConnect, ToSocksAddr(addr)).Write(conn); err != nil { - conn.Close() - return nil, err - } - conn.SetWriteDeadline(time.Time{}) - - conn.SetReadDeadline(time.Now().Add(ReadTimeout)) - reply, err := gosocks5.ReadReply(conn) - if err != nil { - conn.Close() - return nil, err - } - conn.SetReadDeadline(time.Time{}) - - if reply.Rep != gosocks5.Succeeded { - conn.Close() - return nil, errors.New("failure") - } - - return conn, nil + return nil } type RTcpForwardServer struct { diff --git a/gost.go b/gost.go index 9ce97cc..90b1798 100644 --- a/gost.go +++ b/gost.go @@ -28,6 +28,8 @@ var ( DialTimeout = 30 * time.Second ReadTimeout = 90 * time.Second WriteTimeout = 90 * time.Second + + DefaultTTL = 60 // default udp node TTL in second for udp port forwarding ) var ( diff --git a/server.go b/server.go index ceb804d..3b40988 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/http" + "strconv" ) type ProxyServer struct { @@ -75,7 +76,11 @@ func (s *ProxyServer) Serve() error { case "tcp": // Local TCP port forwarding return NewTcpForwardServer(s).ListenAndServe() case "udp": // Local UDP port forwarding - return NewUdpForwardServer(s).ListenAndServe() + ttl, _ := strconv.Atoi(s.Node.Get("ttl")) + if ttl <= 0 { + ttl = DefaultTTL + } + return NewUdpForwardServer(s, ttl).ListenAndServe() case "rtcp": // Remote TCP port forwarding return NewRTcpForwardServer(s).Serve() case "rudp": // Remote UDP port forwarding diff --git a/socks.go b/socks.go index a570e0b..41fe5bd 100644 --- a/socks.go +++ b/socks.go @@ -21,8 +21,7 @@ const ( ) const ( - CmdUdpConnect uint8 = 0xF1 // extended method for udp local port forwarding - CmdUdpTun uint8 = 0xF3 // extended method for udp over tcp + CmdUdpTun uint8 = 0xF3 // extended method for udp over tcp ) type clientSelector struct { @@ -190,10 +189,6 @@ func (s *Socks5Server) HandleRequest(req *gosocks5.Request) { glog.V(LINFO).Infof("[socks5-bind] %s - %s", s.conn.RemoteAddr(), req.Addr) s.handleBind(req) - case CmdUdpConnect: - glog.V(LINFO).Infof("[udp] %s - %s", s.conn.RemoteAddr(), req.Addr) - s.handleUDPConnect(req) - case gosocks5.CmdUdp: glog.V(LINFO).Infof("[socks5-udp] %s - %s", s.conn.RemoteAddr(), req.Addr) s.handleUDPRelay(req) @@ -257,38 +252,6 @@ func (s *Socks5Server) handleBind(req *gosocks5.Request) { glog.V(LINFO).Infof("[socks5-bind] %s >-< %s", s.conn.RemoteAddr(), cc.RemoteAddr()) } -func (s *Socks5Server) handleUDPConnect(req *gosocks5.Request) { - cc, err := s.Base.Chain.GetConn() - - // connection error - if err != nil && err != ErrEmptyChain { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", s.conn.RemoteAddr(), req.Addr, err) - reply := gosocks5.NewReply(gosocks5.Failure, nil) - reply.Write(s.conn) - glog.V(LDEBUG).Infof("[udp] %s <- %s\n%s", s.conn.RemoteAddr(), req.Addr, reply) - return - } - - // serve udp connect - if err == ErrEmptyChain { - s.udpConnect(req.Addr.String()) - return - } - - defer cc.Close() - - // forward request - if err := req.Write(cc); err != nil { - glog.V(LINFO).Infof("[udp] %s -> %s : %s", s.conn.RemoteAddr(), req.Addr, err) - gosocks5.NewReply(gosocks5.Failure, nil).Write(s.conn) - return - } - - glog.V(LINFO).Infof("[udp] %s <-> %s", s.conn.RemoteAddr(), req.Addr) - s.Base.transport(s.conn, cc) - glog.V(LINFO).Infof("[udp] %s >-< %s", s.conn.RemoteAddr(), req.Addr) -} - func (s *Socks5Server) handleUDPRelay(req *gosocks5.Request) { bindAddr, _ := net.ResolveUDPAddr("udp", req.Addr.String()) relay, err := net.ListenUDP("udp", bindAddr) // udp associate, strict mode: if the port already in use, it will return error @@ -506,65 +469,6 @@ func (s *Socks5Server) bindOn(addr string) { } } -func (s *Socks5Server) udpConnect(addr string) { - raddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - glog.V(LINFO).Infof("[udp] %s -> %s : %s", s.conn.RemoteAddr(), addr, err) - gosocks5.NewReply(gosocks5.Failure, nil).Write(s.conn) - return - } - - if err := gosocks5.NewReply(gosocks5.Succeeded, nil).Write(s.conn); err != nil { - glog.V(LINFO).Infof("[udp] %s <- %s : %s", s.conn.RemoteAddr(), addr, err) - return - } - - glog.V(LINFO).Infof("[udp] %s <-> %s", s.conn.RemoteAddr(), raddr) - defer glog.V(LINFO).Infof("[udp] %s >-< %s", s.conn.RemoteAddr(), raddr) - - for { - dgram, err := gosocks5.ReadUDPDatagram(s.conn) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", s.conn.RemoteAddr(), addr, err) - return - } - - go func() { - b := make([]byte, LargeBufferSize) - - relay, err := net.DialUDP("udp", nil, raddr) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", s.conn.RemoteAddr(), raddr, err) - return - } - defer relay.Close() - - if _, err := relay.Write(dgram.Data); err != nil { - glog.V(LWARNING).Infof("[udp] %s -> %s : %s", s.conn.RemoteAddr(), raddr, err) - return - } - glog.V(LDEBUG).Infof("[udp-tun] %s >>> %s length: %d", s.conn.RemoteAddr(), raddr, len(dgram.Data)) - - relay.SetReadDeadline(time.Now().Add(time.Second * 60)) - n, err := relay.Read(b) - if err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", s.conn.RemoteAddr(), raddr, err) - return - } - relay.SetReadDeadline(time.Time{}) - - glog.V(LDEBUG).Infof("[udp-tun] %s <<< %s length: %d", s.conn.RemoteAddr(), raddr, n) - - s.conn.SetWriteDeadline(time.Now().Add(time.Second * 90)) - if err := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(n), 0, dgram.Header.Addr), b[:n]).Write(s.conn); err != nil { - glog.V(LWARNING).Infof("[udp] %s <- %s : %s", s.conn.RemoteAddr(), raddr, err) - return - } - s.conn.SetWriteDeadline(time.Time{}) - }() - } -} - func (s *Socks5Server) transportUDP(relay, peer *net.UDPConn) (err error) { errc := make(chan error, 2)