diff --git a/cmd/gost/main.go b/cmd/gost/main.go index a7ce5c5..1ee4221 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -361,6 +361,8 @@ func serve(chain *gost.Chain) error { handler = gost.TCPRedirectHandler(handlerOptions...) case "ssu": handler = gost.ShadowUDPdHandler(handlerOptions...) + case "sni": + handler = gost.SNIHandler(handlerOptions...) default: handler = gost.AutoHandler(handlerOptions...) } diff --git a/handler.go b/handler.go index 451c4a6..2b51e4e 100644 --- a/handler.go +++ b/handler.go @@ -84,12 +84,11 @@ func AutoHandler(opts ...HandlerOption) Handler { } func (h *autoHandler) Handle(conn net.Conn) { - defer conn.Close() - br := bufio.NewReader(conn) b, err := br.Peek(1) if err != nil { log.Log(err) + conn.Close() return } diff --git a/node.go b/node.go index e9da9ec..8b1e943 100644 --- a/node.go +++ b/node.go @@ -64,7 +64,7 @@ func ParseNode(s string) (node Node, err error) { } switch node.Protocol { - case "http", "http2", "socks4", "socks4a", "ss", "ssu": + case "http", "http2", "socks4", "socks4a", "ss", "ssu", "sni": case "socks", "socks5": node.Protocol = "socks5" case "tcp", "udp", "rtcp", "rudp": // port forwarding diff --git a/sni.go b/sni.go new file mode 100644 index 0000000..3d984c0 --- /dev/null +++ b/sni.go @@ -0,0 +1,106 @@ +// SNI proxy based on https://github.com/bradfitz/tcpproxy + +package gost + +import ( + "bufio" + "bytes" + "crypto/tls" + "io" + "net" + + "github.com/go-log/log" +) + +type sniHandler struct { + options []HandlerOption +} + +// SNIHandler creates a server Handler for SNI proxy server. +func SNIHandler(opts ...HandlerOption) Handler { + h := &sniHandler{ + options: opts, + } + return h +} + +func (h *sniHandler) Handle(conn net.Conn) { + br := bufio.NewReader(conn) + isTLS, sni, err := clientHelloServerName(br) + if err != nil { + log.Log("[sni]", err) + return + } + + conn = &bufferdConn{br: br, Conn: conn} + // We assume that it is HTTP request + if !isTLS { + HTTPHandler(h.options...).Handle(conn) + return + } + + defer conn.Close() + + if sni == "" { + log.Log("[sni] The client does not support SNI") + return + } + + options := &HandlerOptions{} + for _, opt := range h.options { + opt(options) + } + + if !Can("tcp", sni, options.Whitelist, options.Blacklist) { + log.Logf("[sni] Unauthorized to tcp connect to %s", sni) + return + } + + cc, err := options.Chain.Dial(sni) + if err != nil { + log.Logf("[sni] %s -> %s : %s", conn.RemoteAddr(), sni, err) + return + } + defer cc.Close() + log.Logf("[sni] %s <-> %s", cc.LocalAddr(), sni) + transport(conn, cc) + log.Logf("[sni] %s >-< %s", cc.LocalAddr(), sni) +} + +// clientHelloServerName returns the SNI server name inside the TLS ClientHello, +// without consuming any bytes from br. +// On any error, the empty string is returned. +func clientHelloServerName(br *bufio.Reader) (isTLS bool, sni string, err error) { + const recordHeaderLen = 5 + hdr, err := br.Peek(recordHeaderLen) + if err != nil { + return + } + const recordTypeHandshake = 0x16 + if hdr[0] != recordTypeHandshake { + return // Not TLS. + } + isTLS = true + recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] + helloBytes, err := br.Peek(recordHeaderLen + recLen) + if err != nil { + return + } + tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ + GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + sni = hello.ServerName + return nil, nil + }, + }).Handshake() + return +} + +// sniSniffConn is a net.Conn that reads from r, fails on Writes, +// and crashes otherwise. +type sniSniffConn struct { + r io.Reader + net.Conn // nil; crash on any unexpected use +} + +func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }