From 7402fb2407d282b2c99b7c1ccb96aa1d73a56e6d Mon Sep 17 00:00:00 2001 From: "rui.zheng" Date: Sat, 22 Jul 2017 14:02:41 +0800 Subject: [PATCH] add server and listeners --- gost/chain.go | 13 +-- gost/cli/cli.go | 57 ++++++++-- gost/client.go | 35 ++++-- gost/gost.go | 8 ++ gost/http.go | 135 +++++++++++++++++++++- gost/server.go | 124 +++++++++++++++++++-- gost/socks.go | 220 ++++++++++++++++++++++++++++++++++-- gost/srv/cert.pem | 18 +++ gost/srv/key.pem | 27 +++++ gost/srv/srv.go | 111 +++++++++++++++++++ gost/ss.go | 131 +++++++++++++++++++++- gost/tls.go | 15 ++- gost/ws.go | 277 ++++++++++++++++++++++++++++++++++++++++++++++ 13 files changed, 1113 insertions(+), 58 deletions(-) create mode 100644 gost/srv/cert.pem create mode 100644 gost/srv/key.pem create mode 100644 gost/srv/srv.go create mode 100644 gost/ws.go diff --git a/gost/chain.go b/gost/chain.go index c2b9fae..b8bc832 100644 --- a/gost/chain.go +++ b/gost/chain.go @@ -1,7 +1,6 @@ package gost import ( - "context" "net" ) @@ -15,18 +14,18 @@ func NewChain(nodes ...Node) *Chain { } } -func (c *Chain) Dial(ctx context.Context, addr string) (net.Conn, error) { +func (c *Chain) Dial(addr string) (net.Conn, error) { if len(c.Nodes) == 0 { return net.Dial("tcp", addr) } nodes := c.Nodes - conn, err := nodes[0].Client.Dial(ctx, nodes[0].Addr) + conn, err := nodes[0].Client.Dial(nodes[0].Addr) if err != nil { return nil, err } - conn, err = nodes[0].Client.Handshake(ctx, conn) + conn, err = nodes[0].Client.Handshake(conn) if err != nil { return nil, err } @@ -37,12 +36,12 @@ func (c *Chain) Dial(ctx context.Context, addr string) (net.Conn, error) { } next := nodes[i+1] - cc, err := node.Client.Connect(ctx, conn, next.Addr) + cc, err := node.Client.Connect(conn, next.Addr) if err != nil { conn.Close() return nil, err } - cc, err = next.Client.Handshake(ctx, cc) + cc, err = next.Client.Handshake(cc) if err != nil { conn.Close() return nil, err @@ -51,7 +50,7 @@ func (c *Chain) Dial(ctx context.Context, addr string) (net.Conn, error) { conn = cc } - cc, err := nodes[len(nodes)-1].Client.Connect(ctx, conn, addr) + cc, err := nodes[len(nodes)-1].Client.Connect( conn, addr) if err != nil { conn.Close() return nil, err diff --git a/gost/cli/cli.go b/gost/cli/cli.go index 142d960..b46539b 100644 --- a/gost/cli/cli.go +++ b/gost/cli/cli.go @@ -2,11 +2,11 @@ package main import ( "bufio" - "context" "crypto/tls" "log" "net/http" "net/http/httputil" + "net/url" "github.com/ginuerzh/gost/gost" @@ -19,6 +19,30 @@ func init() { func main() { chain := gost.NewChain( + /* + // http+ws + gost.Node{ + Addr: "127.0.0.1:8000", + Client: gost.NewClient( + gost.HTTPConnector(url.UserPassword("admin", "123456")), + gost.WSTransporter("127.0.0.1:8000", nil), + ), + }, + */ + + // http+wss + gost.Node{ + Addr: "127.0.0.1:8443", + Client: gost.NewClient( + gost.HTTPConnector(url.UserPassword("admin", "123456")), + gost.WSSTransporter( + "127.0.0.1:8443", + &gost.WSOptions{TLSConfig: &tls.Config{InsecureSkipVerify: true}}, + ), + ), + }, + /* + // http+tcp gost.Node{ Addr: "127.0.0.1:1080", Client: gost.NewClient( @@ -26,27 +50,36 @@ func main() { gost.TCPTransporter(), ), }, + */ + + /* + // http+tls gost.Node{ - Addr: "172.24.222.54:8338", + Addr: "127.0.0.1:1443", + Client: gost.NewClient( + gost.HTTPConnector(url.UserPassword("admin", "123456")), + gost.TLSTransporter(&tls.Config{InsecureSkipVerify: true}), + ), + }, + */ + + /* + // ss+tcp + gost.Node{ + Addr: "127.0.0.1:8338", Client: gost.NewClient( gost.ShadowConnector(url.UserPassword("chacha20", "123456")), gost.TCPTransporter(), ), }, - gost.Node{ - Addr: "172.24.222.54:8080", - Client: gost.NewClient( - gost.SOCKS5Connector(url.UserPassword("cmdsh", "cmdsh123456")), - gost.TCPTransporter(), - ), - }, + */ ) - conn, err := chain.Dial(context.Background(), "baidu.com:443") + conn, err := chain.Dial("localhost:10000") if err != nil { log.Fatal(err) } - conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) - req, err := http.NewRequest(http.MethodGet, "https://www.baidu.com", nil) + //conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) + req, err := http.NewRequest(http.MethodGet, "http://localhost:10000/pkg", nil) if err != nil { log.Fatal(err) } diff --git a/gost/client.go b/gost/client.go index 0b2c407..96ab013 100644 --- a/gost/client.go +++ b/gost/client.go @@ -1,7 +1,6 @@ package gost import ( - "context" "net" ) @@ -10,9 +9,6 @@ type Client struct { Transporter Transporter } -// DefaultClient is a standard HTTP proxy -var DefaultClient = NewClient(HTTPConnector(nil), TCPTransporter()) - func NewClient(c Connector, tr Transporter) *Client { return &Client{ Connector: c, @@ -21,25 +17,40 @@ func NewClient(c Connector, tr Transporter) *Client { } // Dial connects to the target address -func (c *Client) Dial(ctx context.Context, addr string) (net.Conn, error) { +func (c *Client) Dial(addr string) (net.Conn, error) { return net.Dial(c.Transporter.Network(), addr) } -func (c *Client) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) { - return c.Transporter.Handshake(ctx, conn) +func (c *Client) Handshake(conn net.Conn) (net.Conn, error) { + return c.Transporter.Handshake(conn) } -func (c *Client) Connect(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { - return c.Connector.Connect(ctx, conn, addr) +func (c *Client) Connect(conn net.Conn, addr string) (net.Conn, error) { + return c.Connector.Connect(conn, addr) +} + +// DefaultClient is a standard HTTP proxy client +var DefaultClient = NewClient(HTTPConnector(nil), TCPTransporter()) + +func Dial(addr string) (net.Conn, error) { + return DefaultClient.Dial(addr) +} + +func Handshake(conn net.Conn) (net.Conn, error) { + return DefaultClient.Handshake(conn) +} + +func Connect(conn net.Conn, addr string) (net.Conn, error) { + return DefaultClient.Connect(conn, addr) } type Connector interface { - Connect(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) + Connect(conn net.Conn, addr string) (net.Conn, error) } type Transporter interface { Network() string - Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) + Handshake(conn net.Conn) (net.Conn, error) } type tcpTransporter struct { @@ -53,6 +64,6 @@ func (tr *tcpTransporter) Network() string { return "tcp" } -func (tr *tcpTransporter) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) { +func (tr *tcpTransporter) Handshake(conn net.Conn) (net.Conn, error) { return conn, nil } diff --git a/gost/gost.go b/gost/gost.go index f57b59b..b3f7974 100644 --- a/gost/gost.go +++ b/gost/gost.go @@ -4,8 +4,16 @@ import ( "github.com/go-log/log" ) +const Version = "2.4-dev20170722" + var Debug bool +var ( + SmallBufferSize = 1 * 1024 // 1KB small buffer + MediumBufferSize = 8 * 1024 // 8KB medium buffer + LargeBufferSize = 32 * 1024 // 32KB large buffer +) + func init() { log.DefaultLogger = &logger{} } diff --git a/gost/http.go b/gost/http.go index 03c8f74..3d328f8 100644 --- a/gost/http.go +++ b/gost/http.go @@ -2,13 +2,13 @@ package gost import ( "bufio" - "context" "encoding/base64" "fmt" "net" "net/http" "net/http/httputil" "net/url" + "strings" "github.com/go-log/log" ) @@ -21,7 +21,7 @@ func HTTPConnector(user *url.Userinfo) Connector { return &httpConnector{User: user} } -func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { +func (c *httpConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { req := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Host: addr}, @@ -66,3 +66,134 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, addr string) return conn, nil } + +type httpHandler struct { + options *HandlerOptions +} + +func HTTPHandler(opts ...HandlerOption) Handler { + h := &httpHandler{ + options: &HandlerOptions{ + Chain: new(Chain), + }, + } + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *httpHandler) Handle(conn net.Conn) { + req, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + log.Log("[http]", err) + return + } + + log.Logf("[http] %s %s -> %s %s", req.Method, conn.RemoteAddr(), req.Host, req.Proto) + + if Debug { + dump, _ := httputil.DumpRequest(req, false) + log.Logf(string(dump)) + } + + if req.Method == "PRI" && req.ProtoMajor == 2 { + log.Logf("[http] %s <- %s : Not an HTTP2 server", conn.RemoteAddr(), req.Host) + resp := "HTTP/1.1 400 Bad Request\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n" + conn.Write([]byte(resp)) + return + } + + valid := false + u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")) + users := h.options.Users + for _, user := range users { + username := user.Username() + password, _ := user.Password() + if (u == username && p == password) || + (u == username && password == "") || + (username == "" && p == password) { + valid = true + break + } + } + + if len(users) > 0 && !valid { + log.Logf("[http] %s <- %s : proxy authentication required", conn.RemoteAddr(), req.Host) + resp := "HTTP/1.1 407 Proxy Authentication Required\r\n" + + "Proxy-Authenticate: Basic realm=\"gost\"\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n" + conn.Write([]byte(resp)) + return + } + + req.Header.Del("Proxy-Authorization") + + // forward http request + //lastNode := s.Base.Chain.lastNode + //if lastNode != nil && lastNode.Transport == "" && (lastNode.Protocol == "http" || lastNode.Protocol == "") { + // s.forwardRequest(req) + // return + //} + + // if !s.Base.Node.Can("tcp", req.Host) { + // glog.Errorf("Unauthorized to tcp connect to %s", req.Host) + // return + // } + + cc, err := h.options.Chain.Dial(req.Host) + if err != nil { + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) + + b := []byte("HTTP/1.1 503 Service unavailable\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n") + if Debug { + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b)) + } + conn.Write(b) + return + } + defer cc.Close() + + if req.Method == http.MethodConnect { + b := []byte("HTTP/1.1 200 Connection established\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n") + if Debug { + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), req.Host, string(b)) + } + conn.Write(b) + } else { + req.Header.Del("Proxy-Connection") + + if err = req.Write(cc); err != nil { + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), req.Host, err) + return + } + } + + log.Logf("[http] %s <-> %s", conn.RemoteAddr(), req.Host) + transport(conn, cc) + log.Logf("[http] %s >-< %s", conn.RemoteAddr(), req.Host) +} + +func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) { + if proxyAuth == "" { + return + } + + if !strings.HasPrefix(proxyAuth, "Basic ") { + return + } + c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic ")) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + return + } + + return cs[:s], cs[s+1:], true +} diff --git a/gost/server.go b/gost/server.go index 3d9af9c..4ff03b7 100644 --- a/gost/server.go +++ b/gost/server.go @@ -2,17 +2,127 @@ package gost import ( "crypto/tls" + "io" + "net" + "time" + "net/url" + + "github.com/go-log/log" ) type Server struct { - Addr string `opt:"addr"` // [host]:port - Protocol string `opt:"protocol"` // protocol: http/socks5/ss - TLSConfig *tls.Config - Chain *Chain - Users []url.Userinfo `opt:"user"` // authentication for proxy + l net.Listener + handler Handler } -func (s *Server) Run() error { - return nil +func (s *Server) Handle(h Handler) { + s.handler = h +} + +func (s *Server) Serve(l net.Listener) error { + defer l.Close() + + var tempDelay time.Duration + for { + conn, e := l.Accept() + if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + log.Logf("server: Accept error: %v; retrying in %v", e, tempDelay) + time.Sleep(tempDelay) + continue + } + return e + } + tempDelay = 0 + go s.handler.Handle(conn) + } + +} + +type Listener interface { + net.Listener +} + +type tcpListener struct { + net.Listener +} + +func TCPListener(addr string) (Listener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + return &tcpListener{Listener: &tcpKeepAliveListener{ln.(*net.TCPListener)}}, nil +} + +type Handler interface { + Handle(net.Conn) +} + +type HandlerOptions struct { + Chain *Chain + Users []*url.Userinfo + TLSConfig *tls.Config +} + +type HandlerOption func(opts *HandlerOptions) + +func ChainHandlerOption(chain *Chain) HandlerOption { + return func(opts *HandlerOptions) { + opts.Chain = chain + } +} + +func UsersHandlerOption(users ...*url.Userinfo) HandlerOption { + return func(opts *HandlerOptions) { + opts.Users = users + } +} + +func TLSConfigHandlerOption(config *tls.Config) HandlerOption { + return func(opts *HandlerOptions) { + opts.TLSConfig = config + } +} + +func transport(rw1, rw2 io.ReadWriter) error { + errc := make(chan error, 1) + go func() { + _, err := io.Copy(rw1, rw2) + errc <- err + }() + + go func() { + _, err := io.Copy(rw2, rw1) + errc <- err + }() + + return <-errc +} + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil } diff --git a/gost/socks.go b/gost/socks.go index 1c079e4..5b0cb06 100644 --- a/gost/socks.go +++ b/gost/socks.go @@ -1,7 +1,6 @@ package gost import ( - "context" "crypto/tls" "errors" "fmt" @@ -23,25 +22,25 @@ const ( CmdUdpTun uint8 = 0xF3 // extended method for udp over tcp ) -type ClientSelector struct { +type clientSelector struct { methods []uint8 User *url.Userinfo TLSConfig *tls.Config } -func (selector *ClientSelector) Methods() []uint8 { +func (selector *clientSelector) Methods() []uint8 { return selector.methods } -func (selector *ClientSelector) AddMethod(methods ...uint8) { +func (selector *clientSelector) AddMethod(methods ...uint8) { selector.methods = append(selector.methods, methods...) } -func (selector *ClientSelector) Select(methods ...uint8) (method uint8) { +func (selector *clientSelector) Select(methods ...uint8) (method uint8) { return } -func (selector *ClientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { +func (selector *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { switch method { case MethodTLS: conn = tls.Client(conn, selector.TLSConfig) @@ -83,6 +82,105 @@ func (selector *ClientSelector) OnSelected(method uint8, conn net.Conn) (net.Con return conn, nil } +type serverSelector struct { + methods []uint8 + Users []*url.Userinfo + TLSConfig *tls.Config +} + +func (selector *serverSelector) Methods() []uint8 { + return selector.methods +} + +func (selector *serverSelector) AddMethod(methods ...uint8) { + selector.methods = append(selector.methods, methods...) +} + +func (selector *serverSelector) Select(methods ...uint8) (method uint8) { + if Debug { + log.Logf("[socks5] %d %d %v", gosocks5.Ver5, len(methods), methods) + } + method = gosocks5.MethodNoAuth + for _, m := range methods { + if m == MethodTLS { + method = m + break + } + } + + // when user/pass is set, auth is mandatory + if len(selector.Users) > 0 { + if method == gosocks5.MethodNoAuth { + method = gosocks5.MethodUserPass + } + if method == MethodTLS { + method = MethodTLSAuth + } + } + + return +} + +func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { + if Debug { + log.Logf("[socks5] %d %d", gosocks5.Ver5, method) + } + switch method { + case MethodTLS: + conn = tls.Server(conn, selector.TLSConfig) + + case gosocks5.MethodUserPass, MethodTLSAuth: + if method == MethodTLSAuth { + conn = tls.Server(conn, selector.TLSConfig) + } + + req, err := gosocks5.ReadUserPassRequest(conn) + if err != nil { + log.Log("[socks5]", err) + return nil, err + } + if Debug { + log.Log("[socks5]", req.String()) + } + valid := false + for _, user := range selector.Users { + username := user.Username() + password, _ := user.Password() + if (req.Username == username && req.Password == password) || + (req.Username == username && password == "") || + (username == "" && req.Password == password) { + valid = true + break + } + } + if len(selector.Users) > 0 && !valid { + resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) + if err := resp.Write(conn); err != nil { + log.Log("[socks5]", err) + return nil, err + } + if Debug { + log.Log("[socks5]", resp) + } + log.Log("[socks5] proxy authentication required") + return nil, gosocks5.ErrAuthFailure + } + + resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) + if err := resp.Write(conn); err != nil { + log.Log("[socks5]", err) + return nil, err + } + if Debug { + log.Log("[socks5]", resp) + } + case gosocks5.MethodNoAcceptable: + return nil, gosocks5.ErrBadMethod + } + + return conn, nil +} + type socks5Connector struct { User *url.Userinfo } @@ -91,8 +189,8 @@ func SOCKS5Connector(user *url.Userinfo) Connector { return &socks5Connector{User: user} } -func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { - selector := &ClientSelector{ +func (c *socks5Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { + selector := &clientSelector{ TLSConfig: &tls.Config{InsecureSkipVerify: true}, User: c.User, } @@ -148,7 +246,7 @@ func SOCKS4Connector() Connector { return &socks4Connector{} } -func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { +func (c *socks4Connector) Connect(conn net.Conn, addr string) (net.Conn, error) { taddr, err := net.ResolveTCPAddr("tcp4", addr) if err != nil { return nil, err @@ -191,7 +289,7 @@ func SOCKS4AConnector() Connector { return &socks4aConnector{} } -func (c *socks4aConnector) Connect(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { +func (c *socks4aConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -223,3 +321,105 @@ func (c *socks4aConnector) Connect(ctx context.Context, conn net.Conn, addr stri return conn, nil } + +type socks5Handler struct { + selector *serverSelector + options *HandlerOptions +} + +func SOCKS5Handler(opts ...HandlerOption) Handler { + options := &HandlerOptions{ + Chain: new(Chain), + } + for _, opt := range opts { + opt(options) + } + + selector := &serverSelector{ // socks5 server selector + Users: options.Users, + TLSConfig: options.TLSConfig, + } + // methods that socks5 server supported + selector.AddMethod( + gosocks5.MethodNoAuth, + gosocks5.MethodUserPass, + MethodTLS, + MethodTLSAuth, + ) + return &socks5Handler{ + options: options, + selector: selector, + } +} + +func (h *socks5Handler) Handle(conn net.Conn) { + conn = gosocks5.ServerConn(conn, h.selector) + req, err := gosocks5.ReadRequest(conn) + if err != nil { + log.Log("[socks5]", err) + return + } + + if Debug { + log.Logf("[socks5] %s -> %s\n%s", conn.RemoteAddr(), req.Addr, req) + } + switch req.Cmd { + case gosocks5.CmdConnect: + log.Logf("[socks5-connect] %s -> %s", conn.RemoteAddr(), req.Addr) + h.handleConnect(conn, req) + + case gosocks5.CmdBind: + log.Logf("[socks5-bind] %s - %s", conn.RemoteAddr(), req.Addr) + h.handleBind(conn, req) + + case gosocks5.CmdUdp: + log.Logf("[socks5-udp] %s - %s", conn.RemoteAddr(), req.Addr) + //s.handleUDPRelay(req) + + case CmdUdpTun: + log.Logf("[socks5-rudp] %s - %s", conn.RemoteAddr(), req.Addr) + //s.handleUDPTunnel(req) + + default: + log.Log("[socks5] Unrecognized request:", req.Cmd) + } +} + +func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { + addr := req.Addr.String() + + //! if !s.Base.Node.Can("tcp", addr) { + //! glog.Errorf("Unauthorized to tcp connect to %s", addr) + //! rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) + //! rep.Write(s.conn) + //! return + //! } + + cc, err := h.options.Chain.Dial(addr) + if err != nil { + log.Logf("[socks5-connect] %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + rep := gosocks5.NewReply(gosocks5.HostUnreachable, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } + return + } + defer cc.Close() + + rep := gosocks5.NewReply(gosocks5.Succeeded, nil) + if err := rep.Write(conn); err != nil { + log.Logf("[socks5-connect] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) + return + } + if Debug { + log.Logf("[socks5-connect] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } + log.Logf("[socks5-connect] %s <-> %s", conn.RemoteAddr(), req.Addr) + transport(conn, cc) + log.Logf("[socks5-connect] %s >-< %s", conn.RemoteAddr(), req.Addr) +} + +func (h *socks5Handler) handleBind(conn net.Conn, req *gosocks5.Request) { + // TODO: socks5 bind +} diff --git a/gost/srv/cert.pem b/gost/srv/cert.pem new file mode 100644 index 0000000..6b402ee --- /dev/null +++ b/gost/srv/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE----- diff --git a/gost/srv/key.pem b/gost/srv/key.pem new file mode 100644 index 0000000..42f26e2 --- /dev/null +++ b/gost/srv/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY----- diff --git a/gost/srv/srv.go b/gost/srv/srv.go new file mode 100644 index 0000000..15ce218 --- /dev/null +++ b/gost/srv/srv.go @@ -0,0 +1,111 @@ +package main + +import ( + "crypto/tls" + "log" + + "net/url" + + "sync" + + "github.com/ginuerzh/gost/gost" +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + gost.Debug = true +} + +func main() { + wg := sync.WaitGroup{} + wg.Add(1) + go httpServer(&wg) + wg.Add(1) + go tlsServer(&wg) + wg.Add(1) + go shadowServer(&wg) + wg.Add(1) + go wsServer(&wg) + wg.Add(1) + go wssServer(&wg) + wg.Wait() +} + +func httpServer(wg *sync.WaitGroup) { + defer wg.Done() + + s := &gost.Server{} + s.Handle(gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + )) + ln, err := gost.TCPListener(":1080") + if err != nil { + log.Fatal(err) + } + log.Fatal(s.Serve(ln)) +} + +func tlsServer(wg *sync.WaitGroup) { + defer wg.Done() + + s := &gost.Server{} + s.Handle(gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + )) + cert, err := tls.LoadX509KeyPair("cert.pem", "key.pem") + if err != nil { + log.Fatal(err) + } + ln, err := gost.TLSListener(":1443", &tls.Config{Certificates: []tls.Certificate{cert}}) + if err != nil { + log.Fatal(err) + } + log.Fatal(s.Serve(ln)) +} + +func wsServer(wg *sync.WaitGroup) { + defer wg.Done() + + s := &gost.Server{} + s.Handle(gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + )) + ln, err := gost.WSListener(":8000", nil) + if err != nil { + log.Fatal(err) + } + log.Fatal(s.Serve(ln)) +} + +func wssServer(wg *sync.WaitGroup) { + defer wg.Done() + + s := &gost.Server{} + s.Handle(gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + )) + + cert, err := tls.LoadX509KeyPair("cert.pem", "key.pem") + if err != nil { + log.Fatal(err) + } + ln, err := gost.WSSListener(":8443", &gost.WSOptions{TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}}}) + if err != nil { + log.Fatal(err) + } + log.Fatal(s.Serve(ln)) +} + +func shadowServer(wg *sync.WaitGroup) { + defer wg.Done() + + s := &gost.Server{} + s.Handle(gost.ShadowHandler( + gost.UsersHandlerOption(url.UserPassword("chacha20", "123456")), + )) + ln, err := gost.TCPListener(":8338") + if err != nil { + log.Fatal(err) + } + log.Fatal(s.Serve(ln)) +} diff --git a/gost/ss.go b/gost/ss.go index 34cdd68..ea4404f 100644 --- a/gost/ss.go +++ b/gost/ss.go @@ -1,11 +1,15 @@ package gost import ( - "context" + "encoding/binary" + "fmt" + "io" "net" "net/url" + "strconv" "time" + "github.com/go-log/log" ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" ) @@ -15,10 +19,6 @@ type shadowConn struct { conn net.Conn } -func ShadowConn(conn net.Conn) net.Conn { - return &shadowConn{conn: conn} -} - func (c *shadowConn) Read(b []byte) (n int, err error) { return c.conn.Read(b) } @@ -61,7 +61,7 @@ func ShadowConnector(cipher *url.Userinfo) Connector { return &shadowConnector{Cipher: cipher} } -func (c *shadowConnector) Connect(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { +func (c *shadowConnector) Connect(conn net.Conn, addr string) (net.Conn, error) { rawaddr, err := ss.RawAddr(addr) if err != nil { return nil, err @@ -84,3 +84,122 @@ func (c *shadowConnector) Connect(ctx context.Context, conn net.Conn, addr strin } return &shadowConn{conn: sc}, nil } + +type shadowHandler struct { + options *HandlerOptions +} + +func ShadowHandler(opts ...HandlerOption) Handler { + h := &shadowHandler{ + options: &HandlerOptions{ + Chain: new(Chain), + }, + } + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *shadowHandler) Handle(conn net.Conn) { + var method, password string + + users := h.options.Users + if len(users) > 0 { + method = users[0].Username() + password, _ = users[0].Password() + } + cipher, err := ss.NewCipher(method, password) + if err != nil { + log.Log("[ss]", err) + return + } + conn = &shadowConn{conn: ss.NewConn(conn, cipher)} + + log.Logf("[ss] %s - %s", conn.RemoteAddr(), conn.LocalAddr()) + + addr, err := h.getRequest(conn) + if err != nil { + log.Logf("[ss] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + log.Logf("[ss] %s -> %s", conn.RemoteAddr(), addr) + + cc, err := h.options.Chain.Dial(addr) + if err != nil { + log.Logf("[ss] %s -> %s : %s", conn.RemoteAddr(), addr, err) + return + } + defer cc.Close() + + log.Logf("[ss] %s <-> %s", conn.RemoteAddr(), addr) + transport(conn, cc) + log.Logf("[ss] %s >-< %s", conn.RemoteAddr(), addr) +} + +const ( + idType = 0 // address type index + idIP0 = 1 // ip addres start index + idDmLen = 1 // domain address length index + idDm0 = 2 // domain address start index + + typeIPv4 = 1 // type is ipv4 address + typeDm = 3 // type is domain address + typeIPv6 = 4 // type is ipv6 address + + lenIPv4 = net.IPv4len + 2 // ipv4 + 2port + lenIPv6 = net.IPv6len + 2 // ipv6 + 2port + lenDmBase = 2 // 1addrLen + 2port, plus addrLen + lenHmacSha1 = 10 +) + +// This function is copied from shadowsocks library with some modification. +func (h *shadowHandler) getRequest(conn net.Conn) (host string, err error) { + // buf size should at least have the same size with the largest possible + // request size (when addrType is 3, domain name has at most 256 bytes) + // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + buf := make([]byte, SmallBufferSize) + + // read till we get possible domain length field + conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil { + return + } + + var reqStart, reqEnd int + addrType := buf[idType] + switch addrType & ss.AddrMask { + case typeIPv4: + reqStart, reqEnd = idIP0, idIP0+lenIPv4 + case typeIPv6: + reqStart, reqEnd = idIP0, idIP0+lenIPv6 + case typeDm: + if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil { + return + } + reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase) + default: + err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask) + return + } + + if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil { + return + } + + // Return string for typeIP is not most efficient, but browsers (Chrome, + // Safari, Firefox) all seems using typeDm exclusively. So this is not a + // big problem. + switch addrType & ss.AddrMask { + case typeIPv4: + host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String() + case typeIPv6: + host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String() + case typeDm: + host = string(buf[idDm0 : idDm0+buf[idDmLen]]) + } + // parse port + port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd]) + host = net.JoinHostPort(host, strconv.Itoa(int(port))) + return +} diff --git a/gost/tls.go b/gost/tls.go index 798fc3e..f4cf824 100644 --- a/gost/tls.go +++ b/gost/tls.go @@ -1,7 +1,6 @@ package gost import ( - "context" "crypto/tls" "net" ) @@ -18,6 +17,18 @@ func (tr *tlsTransporter) Network() string { return "tcp" } -func (tr *tlsTransporter) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) { +func (tr *tlsTransporter) Handshake(conn net.Conn) (net.Conn, error) { return tls.Client(conn, tr.TLSClientConfig), nil } + +type tlsListener struct { + net.Listener +} + +func TLSListener(addr string, config *tls.Config) (Listener, error) { + ln, err := tls.Listen("tcp", addr, config) + if err != nil { + return nil, err + } + return &tlsListener{Listener: ln}, nil +} diff --git a/gost/ws.go b/gost/ws.go new file mode 100644 index 0000000..14db8b4 --- /dev/null +++ b/gost/ws.go @@ -0,0 +1,277 @@ +package gost + +import ( + "crypto/tls" + "net" + "net/http" + "net/http/httputil" + "time" + + "net/url" + + "github.com/go-log/log" + "gopkg.in/gorilla/websocket.v1" +) + +type WSOptions struct { + ReadBufferSize int + WriteBufferSize int + HandshakeTimeout time.Duration + EnableCompression bool + TLSConfig *tls.Config +} + +type websocketConn struct { + conn *websocket.Conn + rb []byte +} + +func websocketClientConn(url string, conn net.Conn, options *WSOptions) (net.Conn, error) { + if options == nil { + options = &WSOptions{} + } + dialer := websocket.Dialer{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + TLSClientConfig: options.TLSConfig, + HandshakeTimeout: options.HandshakeTimeout, + EnableCompression: options.EnableCompression, + NetDial: func(net, addr string) (net.Conn, error) { + return conn, nil + }, + } + c, resp, err := dialer.Dial(url, nil) + if err != nil { + return nil, err + } + resp.Body.Close() + return &websocketConn{conn: c}, nil +} + +func websocketServerConn(conn *websocket.Conn) net.Conn { + // conn.EnableWriteCompression(true) + return &websocketConn{ + conn: conn, + } +} + +func (c *websocketConn) Read(b []byte) (n int, err error) { + if len(c.rb) == 0 { + _, c.rb, err = c.conn.ReadMessage() + } + n = copy(b, c.rb) + c.rb = c.rb[n:] + return +} + +func (c *websocketConn) Write(b []byte) (n int, err error) { + err = c.conn.WriteMessage(websocket.BinaryMessage, b) + n = len(b) + return +} + +func (c *websocketConn) Close() error { + return c.conn.Close() +} + +func (c *websocketConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *websocketConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (conn *websocketConn) SetDeadline(t time.Time) error { + if err := conn.SetReadDeadline(t); err != nil { + return err + } + return conn.SetWriteDeadline(t) +} +func (c *websocketConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *websocketConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +type wsTransporter struct { + addr string + options *WSOptions +} + +func WSTransporter(addr string, opts *WSOptions) Transporter { + return &wsTransporter{ + addr: addr, + options: opts, + } +} + +func (tr *wsTransporter) Network() string { + return "tcp" +} + +func (tr *wsTransporter) Handshake(conn net.Conn) (net.Conn, error) { + url := url.URL{Scheme: "ws", Host: tr.addr, Path: "/ws"} + return websocketClientConn(url.String(), conn, tr.options) +} + +type wssTransporter struct { + addr string + options *WSOptions +} + +func WSSTransporter(addr string, opts *WSOptions) Transporter { + return &wssTransporter{ + addr: addr, + options: opts, + } +} + +func (tr *wssTransporter) Network() string { + return "tcp" +} + +func (tr *wssTransporter) Handshake(conn net.Conn) (net.Conn, error) { + url := url.URL{Scheme: "wss", Host: tr.addr, Path: "/ws"} + return websocketClientConn(url.String(), conn, tr.options) +} + +type wsListener struct { + addr net.Addr + upgrader *websocket.Upgrader + srv *http.Server + connChan chan net.Conn + errChan chan error +} + +func WSListener(addr string, options *WSOptions) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + if options == nil { + options = &WSOptions{} + } + l := &wsListener{ + addr: tcpAddr, + upgrader: &websocket.Upgrader{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: options.EnableCompression, + }, + connChan: make(chan net.Conn, 32), + errChan: make(chan error, 1), + } + + mux := http.NewServeMux() + mux.Handle("/ws", http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{Addr: addr, Handler: mux} + + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + + go func() { + err := l.srv.Serve(tcpKeepAliveListener{ln}) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + select { + case err := <-l.errChan: + return nil, err + default: + } + + return l, nil +} + +func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) { + log.Logf("[ws] %s -> %s", r.RemoteAddr, l.addr) + if Debug { + dump, _ := httputil.DumpRequest(r, false) + log.Log(string(dump)) + } + conn, err := l.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Logf("[ws] %s - %s : %s", r.RemoteAddr, l.addr, err) + return + } + l.connChan <- websocketServerConn(conn) +} + +func (l *wsListener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errChan: + } + return +} + +func (l *wsListener) Close() error { + return l.srv.Close() +} + +func (l *wsListener) Addr() net.Addr { + return l.addr +} + +type wssListener struct { + *wsListener +} + +func WSSListener(addr string, options *WSOptions) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + if options == nil { + options = &WSOptions{} + } + l := &wssListener{ + wsListener: &wsListener{ + addr: tcpAddr, + upgrader: &websocket.Upgrader{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: options.EnableCompression, + }, + connChan: make(chan net.Conn, 32), + errChan: make(chan error, 1), + }, + } + + mux := http.NewServeMux() + mux.Handle("/ws", http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{ + Addr: addr, + TLSConfig: options.TLSConfig, + Handler: mux, + } + + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + + go func() { + err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, options.TLSConfig)) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + select { + case err := <-l.errChan: + return nil, err + default: + } + + return l, nil +}