diff --git a/.travis.yml b/.travis.yml index 6760012..e687e22 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,3 +7,6 @@ install: true script: - env GO111MODULE=on go test -race -v - cd cmd/gost && env GO111MODULE=on go build + +after_success: + - bash <(curl -s https://codecov.io/bash) diff --git a/handler.go b/handler.go index c1fab92..14b534e 100644 --- a/handler.go +++ b/handler.go @@ -34,6 +34,7 @@ type HandlerOptions struct { Hosts *Hosts ProbeResist string Node Node + Host string } // HandlerOption allows a common way to set handler options. @@ -137,6 +138,13 @@ func NodeHandlerOption(node Node) HandlerOption { } } +// HostHandlerOption sets the target host for SNI proxy. +func HostHandlerOption(host string) HandlerOption { + return func(opts *HandlerOptions) { + opts.Host = host + } +} + type autoHandler struct { options *HandlerOptions } diff --git a/http2_test.go b/http2_test.go index 336b60f..fb8f80a 100644 --- a/http2_test.go +++ b/http2_test.go @@ -3,6 +3,7 @@ package gost import ( "crypto/rand" "crypto/tls" + "fmt" "net/http/httptest" "net/url" "testing" @@ -418,6 +419,71 @@ func TestSSOverH2(t *testing.T) { } } +func sniOverH2Roundtrip(targetURL string, data []byte, host string) error { + ln, err := H2Listener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: H2Transporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverH2(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverH2Roundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + func httpOverH2CRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { @@ -709,3 +775,68 @@ func TestSSOverH2C(t *testing.T) { } } } + +func sniOverH2CRoundtrip(targetURL string, data []byte, host string) error { + ln, err := H2CListener("") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: H2CTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverH2C(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverH2CRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} diff --git a/http_test.go b/http_test.go index 14f6106..e5a5cd6 100644 --- a/http_test.go +++ b/http_test.go @@ -20,6 +20,7 @@ import ( func init() { // SetLogger(&LogLogger{}) // Debug = true + cert, err := GenCertificate() if err != nil { panic(err) diff --git a/kcp_test.go b/kcp_test.go index 454530c..7159575 100644 --- a/kcp_test.go +++ b/kcp_test.go @@ -2,6 +2,7 @@ package gost import ( "crypto/rand" + "fmt" "net/http/httptest" "net/url" "testing" @@ -299,3 +300,68 @@ func TestSSOverKCP(t *testing.T) { } } } + +func sniOverKCPRoundtrip(targetURL string, data []byte, host string) error { + ln, err := KCPListener("localhost:0", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverKCP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverKCPRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} diff --git a/quic.go b/quic.go index 541bda3..0a362f2 100644 --- a/quic.go +++ b/quic.go @@ -137,7 +137,7 @@ func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICC } session, err := quic.Dial(udpConn, udpAddr, addr, config.TLSConfig, quicConfig) if err != nil { - log.Log("quic dial:", err) + log.Logf("quic dial %s: %v", addr, err) return nil, err } return &quicSession{conn: conn, session: session}, nil diff --git a/quic_test.go b/quic_test.go index 79dd7d2..c55ac49 100644 --- a/quic_test.go +++ b/quic_test.go @@ -2,6 +2,7 @@ package gost import ( "crypto/rand" + "fmt" "net/http/httptest" "net/url" "testing" @@ -299,3 +300,68 @@ func TestSSOverQUIC(t *testing.T) { } } } + +func sniOverQUICRoundtrip(targetURL string, data []byte, host string) error { + ln, err := QUICListener("localhost:0", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverQUIC(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverQUICRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} diff --git a/sni.go b/sni.go index 342a343..88031c1 100644 --- a/sni.go +++ b/sni.go @@ -91,7 +91,11 @@ func (h *sniHandler) Handle(conn net.Conn) { return } - host = net.JoinHostPort(host, "443") + _, sport, _ := net.SplitHostPort(h.options.Host) + if sport == "" { + sport = "443" + } + host = net.JoinHostPort(host, sport) log.Logf("[sni] %s -> %s -> %s", conn.RemoteAddr(), h.options.Node.String(), host) diff --git a/sni_test.go b/sni_test.go index a705ce2..ffc5b52 100644 --- a/sni_test.go +++ b/sni_test.go @@ -6,6 +6,8 @@ import ( "crypto/rand" "crypto/tls" "errors" + "fmt" + "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -18,18 +20,21 @@ func sniRoundtrip(client *Client, server *Server, targetURL string, data []byte) if err != nil { return } - defer conn.Close() - conn.SetDeadline(time.Now().Add(3 * time.Second)) - - conn, err = client.Handshake(conn) + conn, err = client.Handshake(conn, AddrHandshakeOption(server.Addr().String())) if err != nil { return } + defer conn.Close() + u, err := url.Parse(targetURL) if err != nil { return } + + conn.SetDeadline(time.Now().Add(3 * time.Second)) + defer conn.SetDeadline(time.Time{}) + conn, err = client.Connect(conn, u.Host) if err != nil { return @@ -38,8 +43,8 @@ func sniRoundtrip(client *Client, server *Server, targetURL string, data []byte) if u.Scheme == "https" { conn = tls.Client(conn, &tls.Config{ - InsecureSkipVerify: false, - ServerName: u.Hostname(), + InsecureSkipVerify: true, + // ServerName: u.Hostname(), }) u.Scheme = "http" } @@ -64,6 +69,15 @@ func sniRoundtrip(client *Client, server *Server, targetURL string, data []byte) return errors.New(resp.Status) } + recv, err := ioutil.ReadAll(resp.Body) + if err != nil { + return + } + + if !bytes.Equal(data, recv) { + return fmt.Errorf("data not equal") + } + return } @@ -73,6 +87,11 @@ func sniProxyRoundtrip(targetURL string, data []byte, host string) error { return err } + u, err := url.Parse(targetURL) + if err != nil { + return err + } + client := &Client{ Connector: SNIConnector(host), Transporter: TCPTransporter(), @@ -80,7 +99,7 @@ func sniProxyRoundtrip(targetURL string, data []byte, host string) error { server := &Server{ Listener: ln, - Handler: SNIHandler(), + Handler: SNIHandler(HostHandlerOption(u.Host)), } go server.Run() @@ -90,19 +109,40 @@ func sniProxyRoundtrip(targetURL string, data []byte, host string) error { } func TestSNIProxy(t *testing.T) { - httpSrv := httptest.NewTLSServer(httpTestHandler) + httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + sendData := make([]byte, 128) rand.Read(sendData) - err := sniProxyRoundtrip("https://github.com", sendData, "") - if err != nil { - t.Errorf("got error: %v", err) + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, } - err = sniProxyRoundtrip("https://github.com", sendData, "google.com") - if err != nil { - t.Errorf("got error: %v", err) + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniProxyRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) } } diff --git a/ssh_test.go b/ssh_test.go index 5d043d5..a30fd64 100644 --- a/ssh_test.go +++ b/ssh_test.go @@ -3,6 +3,7 @@ package gost import ( "crypto/rand" "crypto/tls" + "fmt" "net/http/httptest" "net/url" "testing" @@ -301,3 +302,68 @@ func TestSSOverSSHTunnel(t *testing.T) { } } } + +func sniOverSSHTunnelRoundtrip(targetURL string, data []byte, host string) error { + ln, err := SSHTunnelListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverSSHTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverSSHTunnelRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} diff --git a/tls_test.go b/tls_test.go index f571f36..498302a 100644 --- a/tls_test.go +++ b/tls_test.go @@ -3,6 +3,7 @@ package gost import ( "crypto/rand" "crypto/tls" + "fmt" "net/http/httptest" "net/url" "testing" @@ -302,6 +303,71 @@ func TestSSOverTLS(t *testing.T) { } } +func sniOverTLSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := TLSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverTLSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + func httpOverMTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { @@ -595,3 +661,68 @@ func TestSSOverMTLS(t *testing.T) { } } } + +func sniOverMTLSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := MTLSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverMTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverMTLSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} diff --git a/ws_test.go b/ws_test.go index da2e13a..a1850a6 100644 --- a/ws_test.go +++ b/ws_test.go @@ -2,6 +2,7 @@ package gost import ( "crypto/rand" + "fmt" "net/http/httptest" "net/url" "testing" @@ -300,6 +301,71 @@ func TestSSOverWS(t *testing.T) { } } +func sniOverWSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := WSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverWSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + func httpOverMWSRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { @@ -593,3 +659,68 @@ func TestSSOverMWS(t *testing.T) { } } } + +func sniOverMWSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := MWSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverMWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverMWSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} diff --git a/wss_test.go b/wss_test.go index 7f5ae71..531fdb2 100644 --- a/wss_test.go +++ b/wss_test.go @@ -3,6 +3,7 @@ package gost import ( "crypto/rand" "crypto/tls" + "fmt" "net/http/httptest" "net/url" "testing" @@ -301,6 +302,71 @@ func TestSSOverWSS(t *testing.T) { } } +func sniOverWSSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := WSSListener("", nil, nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverWSSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + func httpOverMWSSRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { @@ -594,3 +660,68 @@ func TestSSOverMWSS(t *testing.T) { } } } + +func sniOverMWSSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := MWSSListener("", nil, nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverMWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverMWSSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +}