diff --git a/client.go b/client.go index fca8f04..f63cff4 100644 --- a/client.go +++ b/client.go @@ -62,10 +62,10 @@ type Transporter interface { Multiplex() bool } -type tcpTransporter struct { -} +// tcpTransporter is a raw TCP transporter. +type tcpTransporter struct{} -// TCPTransporter creates a transporter for TCP proxy client. +// TCPTransporter creates a raw TCP client. func TCPTransporter() Transporter { return &tcpTransporter{} } @@ -90,6 +90,30 @@ func (tr *tcpTransporter) Multiplex() bool { return false } +// udpTransporter is a raw UDP transporter. +type udpTransporter struct{} + +// UDPTransporter creates a raw UDP client. +func UDPTransporter() Transporter { + return &udpTransporter{} +} + +func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + return net.DialTimeout("udp", addr, opts.Timeout) +} + +func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return conn, nil +} + +func (tr *udpTransporter) Multiplex() bool { + return false +} + // DialOptions describes the options for Transporter.Dial. type DialOptions struct { Timeout time.Duration diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..745c883 --- /dev/null +++ b/common_test.go @@ -0,0 +1,121 @@ +package gost + +import ( + "bytes" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "sync" +) + +func init() { + // SetLogger(&LogLogger{}) + // Debug = true + + cert, err := GenCertificate() + if err != nil { + panic(err) + } + DefaultTLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } +} + +var ( + httpTestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(w, r.Body) + }) + + udpTestHandler = udpHandlerFunc(func(w io.Writer, r *udpRequest) { + io.Copy(w, r.Body) + }) +) + +type udpRequest struct { + Body io.Reader + RemoteAddr string +} + +type udpResponseWriter struct { + conn net.PacketConn + addr net.Addr +} + +func (w *udpResponseWriter) Write(p []byte) (int, error) { + return w.conn.WriteTo(p, w.addr) +} + +type udpHandlerFunc func(w io.Writer, r *udpRequest) + +// udpTestServer is a UDP server for test. +type udpTestServer struct { + ln net.PacketConn + handler udpHandlerFunc + wg sync.WaitGroup + mu sync.Mutex // guards closed and conns + closed bool +} + +func newUDPTestServer(handler udpHandlerFunc) *udpTestServer { + laddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + ln, err := net.ListenUDP("udp", laddr) + if err != nil { + panic(fmt.Sprintf("udptest: failed to listen on a port: %v", err)) + } + return &udpTestServer{ + ln: ln, + handler: handler, + } +} + +func (s *udpTestServer) Start() { + go s.serve() +} + +func (s *udpTestServer) serve() { + for { + data := make([]byte, 1024) + n, raddr, err := s.ln.ReadFrom(data) + if err != nil { + return + } + if s.handler != nil { + s.wg.Add(1) + go func() { + defer s.wg.Done() + w := &udpResponseWriter{ + conn: s.ln, + addr: raddr, + } + r := &udpRequest{ + Body: bytes.NewReader(data[:n]), + RemoteAddr: raddr.String(), + } + s.handler(w, r) + }() + } + } +} + +func (s *udpTestServer) Addr() string { + return s.ln.LocalAddr().String() +} + +func (s *udpTestServer) Close() error { + s.mu.Lock() + + if s.closed { + s.mu.Unlock() + return nil + } + + err := s.ln.Close() + s.closed = true + s.mu.Unlock() + + s.wg.Wait() + + return err +} diff --git a/forward.go b/forward.go index c901c60..a185bf6 100644 --- a/forward.go +++ b/forward.go @@ -272,9 +272,6 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { defer conn.Close() retries := 1 - if h.options.Chain != nil && h.options.Chain.Retries > 0 { - retries = h.options.Chain.Retries - } if h.options.Retries > 0 { retries = h.options.Retries } @@ -422,7 +419,7 @@ func (l *udpDirectForwardListener) listenLoop() { n, raddr, err := l.ln.ReadFrom(b) if err != nil { log.Logf("[udp] peer -> %s : %s", l.Addr(), err) - l.ln.Close() + l.Close() l.errChan <- err close(l.errChan) return @@ -593,10 +590,14 @@ func (c *udpServerConn) ttlWait() { ttl = defaultTTL } timer := time.NewTimer(ttl) + defer timer.Stop() for { select { case <-c.nopChan: + if !timer.Stop() { + <-timer.C + } timer.Reset(ttl) case <-timer.C: close(c.brokenChan) @@ -628,12 +629,15 @@ func (c *udpServerConn) SetWriteDeadline(t time.Time) error { } type tcpRemoteForwardListener struct { - addr net.Addr - chain *Chain - ln net.Listener - session *muxSession - mutex sync.Mutex - closed chan struct{} + addr net.Addr + chain *Chain + connChan chan net.Conn + ln net.Listener + session *muxSession + sessionMux sync.Mutex + closed chan struct{} + closeMux sync.Mutex + errChan chan error } // TCPRemoteForwardListener creates a Listener for TCP remote port forwarding server. @@ -643,23 +647,56 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) { return nil, err } - return &tcpRemoteForwardListener{ - addr: laddr, - chain: chain, - closed: make(chan struct{}), - }, nil -} - -func (l *tcpRemoteForwardListener) Accept() (net.Conn, error) { - select { - case <-l.closed: - return nil, errors.New("closed") - default: + ln := &tcpRemoteForwardListener{ + addr: laddr, + chain: chain, + connChan: make(chan net.Conn, 1024), + closed: make(chan struct{}), + errChan: make(chan error), } + if !ln.isChainValid() { + ln.ln, err = net.Listen("tcp", ln.addr.String()) + return ln, err + } + + go ln.listenLoop() + + if err = <-ln.errChan; err != nil { + ln.Close() + } + + return ln, err +} + +func (l *tcpRemoteForwardListener) isChainValid() bool { + lastNode := l.chain.LastNode() + if (lastNode.Protocol == "forward" && lastNode.Transport == "ssh") || + lastNode.Protocol == "socks5" { + return true + } + return false +} + +func (l *tcpRemoteForwardListener) listenLoop() { var tempDelay time.Duration + var once sync.Once + for { conn, err := l.accept() + + once.Do(func() { + l.errChan <- err + close(l.errChan) + }) + + select { + case <-l.closed: + conn.Close() + return + default: + } + if err != nil { if tempDelay == 0 { tempDelay = 1000 * time.Millisecond @@ -673,15 +710,37 @@ func (l *tcpRemoteForwardListener) Accept() (net.Conn, error) { time.Sleep(tempDelay) continue } - return conn, nil + + select { + case l.connChan <- conn: + default: + conn.Close() + log.Logf("[rtcp] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + } } } +func (l *tcpRemoteForwardListener) Accept() (conn net.Conn, err error) { + if l.ln != nil { + return l.ln.Accept() + } + + select { + case conn = <-l.connChan: + case <-l.closed: + err = errors.New("closed") + } + + return +} + func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { lastNode := l.chain.LastNode() if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" { - conn, err = l.chain.Dial(l.addr.String()) - } else if lastNode.Protocol == "socks5" { + return l.chain.Dial(l.addr.String()) + } + + if lastNode.Protocol == "socks5" { if lastNode.GetBool("mbind") { return l.muxAccept() // multiplexing support for binding. } @@ -694,14 +753,6 @@ func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { if err != nil { cc.Close() } - } else { - if l.ln == nil { - l.ln, err = net.Listen("tcp", l.addr.String()) - if err != nil { - return - } - } - conn, err = l.ln.Accept() } return } @@ -721,8 +772,8 @@ func (l *tcpRemoteForwardListener) muxAccept() (conn net.Conn, err error) { } func (l *tcpRemoteForwardListener) getSession() (*muxSession, error) { - l.mutex.Lock() - defer l.mutex.Unlock() + l.sessionMux.Lock() + defer l.sessionMux.Unlock() if l.session != nil && !l.session.IsClosed() { return l.session, nil @@ -810,22 +861,40 @@ func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, e } func (l *tcpRemoteForwardListener) Addr() net.Addr { + if l.ln != nil { + return l.ln.Addr() + } return l.addr } func (l *tcpRemoteForwardListener) Close() error { - close(l.closed) + if l.ln != nil { + return l.ln.Close() + } + + l.closeMux.Lock() + defer l.closeMux.Unlock() + + select { + case <-l.closed: + return nil + default: + close(l.closed) + } return nil } type udpRemoteForwardListener struct { - addr *net.UDPAddr + addr net.Addr chain *Chain conns map[string]*udpServerConn connChan chan net.Conn + ln *net.UDPConn errChan chan error ttl time.Duration closed chan struct{} + closeMux sync.Mutex + once sync.Once } // UDPRemoteForwardListener creates a Listener for UDP remote port forwarding server. @@ -844,8 +913,17 @@ func UDPRemoteForwardListener(addr string, chain *Chain, ttl time.Duration) (Lis ttl: ttl, closed: make(chan struct{}), } + go ln.listenLoop() - return ln, nil + + err = <-ln.errChan + + return ln, err +} + +func (l *udpRemoteForwardListener) isChainValid() bool { + lastNode := l.chain.LastNode() + return lastNode.Protocol == "socks5" } func (l *udpRemoteForwardListener) listenLoop() { @@ -855,7 +933,6 @@ func (l *udpRemoteForwardListener) listenLoop() { log.Logf("[rudp] %s : %s", l.Addr(), err) return } - defer conn.Close() for { @@ -911,9 +988,19 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { conn = &udpTunnelConn{Conn: cc} } } else { - conn, err = net.ListenUDP("udp", l.addr) + var uc *net.UDPConn + uc, err = net.ListenUDP("udp", l.addr.(*net.UDPAddr)) + if err == nil { + l.addr = uc.LocalAddr() + conn = uc + } } + l.once.Do(func() { + l.errChan <- err + close(l.errChan) + }) + if err != nil { if tempDelay == 0 { tempDelay = 1000 * time.Millisecond @@ -932,13 +1019,10 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { } func (l *udpRemoteForwardListener) Accept() (conn net.Conn, err error) { - var ok bool select { case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } + case <-l.closed: + err = errors.New("accpet on closed listener") } return } @@ -948,6 +1032,15 @@ func (l *udpRemoteForwardListener) Addr() net.Addr { } func (l *udpRemoteForwardListener) Close() error { - close(l.closed) + l.closeMux.Lock() + defer l.closeMux.Unlock() + + select { + case <-l.closed: + return nil + default: + close(l.closed) + } + return nil } diff --git a/forward_test.go b/forward_test.go new file mode 100644 index 0000000..e75dd71 --- /dev/null +++ b/forward_test.go @@ -0,0 +1,333 @@ +package gost + +import ( + "bytes" + "crypto/rand" + "fmt" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +func tcpDirectForwardRoundtrip(targetURL string, data []byte) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestTCPDirectForward(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := tcpDirectForwardRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func BenchmarkTCPDirectForward(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TCPTransporter(), + } + + u, err := url.Parse(httpSrv.URL) + if err != nil { + b.Error(err) + } + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkTCPDirectForwardParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TCPTransporter(), + } + + u, err := url.Parse(httpSrv.URL) + if err != nil { + b.Error(err) + } + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func udpRoundtrip(client *Client, server *Server, host string, data []byte) (err error) { + conn, err := proxyConn(client, server) + if err != nil { + return + } + defer conn.Close() + + conn.SetDeadline(time.Now().Add(1 * time.Second)) + defer conn.SetDeadline(time.Time{}) + + conn, err = client.Connect(conn, host) + if err != nil { + return + } + + if _, err = conn.Write(data); err != nil { + return + } + + recv := make([]byte, len(data)) + if _, err = conn.Read(recv); err != nil { + return + } + + if !bytes.Equal(data, recv) { + return fmt.Errorf("data not equal") + } + + return +} + +func udpDirectForwardRoundtrip(host string, data []byte) error { + ln, err := UDPDirectForwardListener("localhost:0", 0) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: UDPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: UDPDirectForwardHandler(host), + } + + go server.Run() + defer server.Close() + + return udpRoundtrip(client, server, host, data) +} + +func TestUDPDirectForward(t *testing.T) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + err := udpDirectForwardRoundtrip(udpSrv.Addr(), sendData) + if err != nil { + t.Error(err) + } +} + +func BenchmarkUDPDirectForward(b *testing.B) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := UDPDirectForwardListener("localhost:0", 0) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: UDPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: UDPDirectForwardHandler(udpSrv.Addr()), + } + + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkUDPDirectForwardParallel(b *testing.B) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := UDPDirectForwardListener("localhost:0", 0) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: UDPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: UDPDirectForwardHandler(udpSrv.Addr()), + } + + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := udpRoundtrip(client, server, udpSrv.Addr(), sendData); err != nil { + b.Error(err) + } + } + }) +} + +func tcpRemoteForwardRoundtrip(t *testing.T, targetURL string, data []byte) error { + ln, err := TCPRemoteForwardListener("localhost:0", nil) // listening on localhost + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: TCPRemoteForwardHandler(u.Host), // forward to u.Host + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestTCPRemoteForward(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := tcpRemoteForwardRoundtrip(t, httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func udpRemoteForwardRoundtrip(t *testing.T, host string, data []byte) error { + ln, err := UDPRemoteForwardListener("localhost:0", nil, 0) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: UDPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: UDPRemoteForwardHandler(host), + } + + go server.Run() + defer server.Close() + + return udpRoundtrip(client, server, host, data) +} + +func TestUDPRemoteForward(t *testing.T) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := udpRemoteForwardRoundtrip(t, udpSrv.Addr(), sendData) + if err != nil { + t.Error(err) + } +} diff --git a/http2_test.go b/http2_test.go index fb8f80a..58c4784 100644 --- a/http2_test.go +++ b/http2_test.go @@ -484,6 +484,46 @@ func TestSNIOverH2(t *testing.T) { } } +func h2ForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := H2Listener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: H2Transporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestH2ForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := h2ForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + func httpOverH2CRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { @@ -840,3 +880,43 @@ func TestSNIOverH2C(t *testing.T) { }) } } + +func h2cForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := H2CListener("") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: H2CTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestH2CForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := h2cForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/http_test.go b/http_test.go index e5a5cd6..e8ac188 100644 --- a/http_test.go +++ b/http_test.go @@ -4,10 +4,8 @@ import ( "bufio" "bytes" "crypto/rand" - "crypto/tls" "errors" "fmt" - "io" "io/ioutil" "net" "net/http" @@ -17,23 +15,6 @@ import ( "time" ) -func init() { - // SetLogger(&LogLogger{}) - // Debug = true - - cert, err := GenCertificate() - if err != nil { - panic(err) - } - DefaultTLSConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } -} - -var httpTestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.Copy(w, r.Body) -}) - // proxyConn obtains a connection to the proxy server. func proxyConn(client *Client, server *Server) (net.Conn, error) { conn, err := client.Dial(server.Addr().String()) diff --git a/kcp_test.go b/kcp_test.go index 7159575..0e915a0 100644 --- a/kcp_test.go +++ b/kcp_test.go @@ -365,3 +365,43 @@ func TestSNIOverKCP(t *testing.T) { }) } } + +func kcpForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := KCPListener("localhost:0", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestKCPForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := kcpForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/quic_test.go b/quic_test.go index c55ac49..45a78b3 100644 --- a/quic_test.go +++ b/quic_test.go @@ -365,3 +365,43 @@ func TestSNIOverQUIC(t *testing.T) { }) } } + +func quicForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := QUICListener("localhost:0", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestQUICForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := quicForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/selector.go b/selector.go index 8913ac0..e8d11e2 100644 --- a/selector.go +++ b/selector.go @@ -22,9 +22,7 @@ type defaultSelector struct { } func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, error) { - sopts := SelectOptions{ - Strategy: &RoundStrategy{}, - } + sopts := SelectOptions{} for _, opt := range opts { opt(&sopts) } @@ -35,7 +33,11 @@ func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, erro if len(nodes) == 0 { return Node{}, ErrNoneAvailable } - return sopts.Strategy.Apply(nodes), nil + strategy := sopts.Strategy + if strategy == nil { + strategy = &RoundStrategy{} + } + return strategy.Apply(nodes), nil } // SelectOption is the option used when making a select call. diff --git a/socks.go b/socks.go index c72f46c..480623f 100644 --- a/socks.go +++ b/socks.go @@ -172,7 +172,7 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con return nil, err } if Debug { - log.Log("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), resp) + log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), resp) } log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr()) return nil, gosocks5.ErrAuthFailure diff --git a/ssh_test.go b/ssh_test.go index a30fd64..0395260 100644 --- a/ssh_test.go +++ b/ssh_test.go @@ -367,3 +367,43 @@ func TestSNIOverSSHTunnel(t *testing.T) { }) } } + +func sshForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := SSHTunnelListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSHForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := sshForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/tls_test.go b/tls_test.go index 498302a..79eed68 100644 --- a/tls_test.go +++ b/tls_test.go @@ -368,6 +368,46 @@ func TestSNIOverTLS(t *testing.T) { } } +func tlsForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := TLSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestTLSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := tlsForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + func httpOverMTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { @@ -726,3 +766,43 @@ func TestSNIOverMTLS(t *testing.T) { }) } } + +func mtlsForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := MTLSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestMTLSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := mtlsForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/ws_test.go b/ws_test.go index a1850a6..1ea9d63 100644 --- a/ws_test.go +++ b/ws_test.go @@ -366,6 +366,46 @@ func TestSNIOverWS(t *testing.T) { } } +func wsForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := WSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestWSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := wsForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + func httpOverMWSRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { @@ -724,3 +764,43 @@ func TestSNIOverMWS(t *testing.T) { }) } } + +func mwsForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := MWSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestMWSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := mwsForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/wss_test.go b/wss_test.go index 531fdb2..2c8719a 100644 --- a/wss_test.go +++ b/wss_test.go @@ -367,6 +367,46 @@ func TestSNIOverWSS(t *testing.T) { } } +func wssForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := WSSListener("", nil, nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestWSSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := wssForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + func httpOverMWSSRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { @@ -725,3 +765,43 @@ func TestSNIOverMWSS(t *testing.T) { }) } } + +func mwssForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := MWSSListener("", nil, nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestMWSSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := mwssForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +}