diff --git a/cmd/gost/route.go b/cmd/gost/route.go index aecaea7..f3a259e 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -228,6 +228,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { connector = gost.SNIConnector(node.Get("host")) case "http": connector = gost.HTTPConnector(node.User) + case "relay": + connector = gost.RelayConnector(node.User) default: connector = gost.AutoConnector(node.User) } @@ -529,6 +531,8 @@ func (r *route) GenRouters() ([]router, error) { handler = gost.TapHandler() case "dns": handler = gost.DNSHandler(node.Remote) + case "relay": + handler = gost.RelayHandler(node.Remote) default: // start from 2.5, if remote is not empty, then we assume that it is a forward tunnel. if node.Remote != "" { diff --git a/forward.go b/forward.go index 9e25fa3..11ea0ae 100644 --- a/forward.go +++ b/forward.go @@ -24,7 +24,7 @@ func ForwardConnector() Connector { } func (c *forwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { - return conn, nil + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) } func (c *forwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { @@ -74,13 +74,6 @@ func (h *baseForwardHandler) Init(options ...HandlerOption) { n++ } - if len(h.group.Nodes()) == 0 { - h.group.AddNode(Node{ // dummy address - ID: n, - Addr: ":0", - Host: ":0", - }) - } } type tcpDirectForwardHandler struct { @@ -113,6 +106,8 @@ func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) { func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { defer conn.Close() + log.Logf("[tcp] %s - %s", conn.RemoteAddr(), conn.LocalAddr()) + retries := 1 if h.options.Chain != nil && h.options.Chain.Retries > 0 { retries = h.options.Chain.Retries @@ -125,19 +120,20 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { var node Node var err error for i := 0; i < retries; i++ { - node, err = h.group.Next() - if err != nil { - log.Logf("[tcp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) - return + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + log.Logf("[tcp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } } - log.Logf("[tcp] %s - %s", conn.RemoteAddr(), node.Addr) cc, err = h.options.Chain.Dial(node.Addr, RetryChainOption(h.options.Retries), TimeoutChainOption(h.options.Timeout), ) if err != nil { - log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err) + log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) node.MarkDead() } else { break @@ -150,9 +146,13 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { node.ResetDead() defer cc.Close() - log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr) + addr := node.Addr + if addr == "" { + addr = conn.LocalAddr().String() + } + log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), addr) transport(conn, cc) - log.Logf("[tcp] %s >-< %s", conn.RemoteAddr(), node.Addr) + log.Logf("[tcp] %s >-< %s", conn.RemoteAddr(), addr) } type udpDirectForwardHandler struct { @@ -185,24 +185,34 @@ func (h *udpDirectForwardHandler) Init(options ...HandlerOption) { func (h *udpDirectForwardHandler) Handle(conn net.Conn) { defer conn.Close() - node, err := h.group.Next() - if err != nil { - log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) - return + log.Logf("[udp] %s - %s", conn.RemoteAddr(), conn.LocalAddr()) + + var node Node + var err error + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } } cc, err := h.options.Chain.DialContext(context.Background(), "udp", node.Addr) if err != nil { node.MarkDead() - log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) + log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) return } defer cc.Close() node.ResetDead() - log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr) + addr := node.Addr + if addr == "" { + addr = conn.LocalAddr().String() + } + log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), addr) transport(conn, cc) - log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), node.Addr) + log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), addr) } type tcpRemoteForwardHandler struct { @@ -244,10 +254,12 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { var node Node var err error for i := 0; i < retries; i++ { - node, err = h.group.Next() - if err != nil { - log.Logf("[rtcp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) - return + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + log.Logf("[rtcp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + return + } } cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout) if err != nil { @@ -299,10 +311,14 @@ func (h *udpRemoteForwardHandler) Init(options ...HandlerOption) { func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { defer conn.Close() - node, err := h.group.Next() - if err != nil { - log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) - return + var node Node + var err error + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) + return + } } raddr, err := net.ResolveUDPAddr("udp", node.Addr) diff --git a/go.mod b/go.mod index 7394e6d..62f24c4 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/docker/libcontainer v2.2.1+incompatible github.com/ginuerzh/gosocks4 v0.0.1 github.com/ginuerzh/gosocks5 v0.2.0 + github.com/ginuerzh/relay v0.0.0-20200226123819-7f0ae19c2e02 github.com/ginuerzh/tls-dissector v0.0.2-0.20200224064855-24ab2b3a3796 github.com/go-log/log v0.1.0 github.com/gobwas/glob v0.2.3 diff --git a/go.sum b/go.sum index 961e6c0..9311a7b 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/ginuerzh/gosocks4 v0.0.1 h1:ojDKUyz+uaEeRm2usY1cyQiXTqJqrKxfeE6SVBXq4 github.com/ginuerzh/gosocks4 v0.0.1/go.mod h1:8SdwBMKjfJ9+BfP2vDJM1jcrgWUbWV6qxBPHHVrwptY= github.com/ginuerzh/gosocks5 v0.2.0 h1:K0Ua23U9LU3BZrf3XpGDcs0mP8DiEpa6PJE4TA/MU3s= github.com/ginuerzh/gosocks5 v0.2.0/go.mod h1:qp22mr6tH/prEoaN0pFukq76LlScIE+F2rP2ZP5ZHno= +github.com/ginuerzh/relay v0.0.0-20200226123819-7f0ae19c2e02 h1:+TjDSHAHDdhoWaZ/8oArW+ANz4St2Jb8xc5/YtfMBP0= +github.com/ginuerzh/relay v0.0.0-20200226123819-7f0ae19c2e02/go.mod h1:VLcyv9iIjbxHOSJ3bCUPyc/AXSsMf8ZlX2M1xIvAlFQ= github.com/ginuerzh/tls-dissector v0.0.1 h1:yF6fIt78TO4CdjiLLn6R8r0XajQJE1Lbnuq6rP8mGW8= github.com/ginuerzh/tls-dissector v0.0.1/go.mod h1:u/kbBOqIOgJv39gywuUb3VwyzdZG5DKquOqfToKE6lk= github.com/ginuerzh/tls-dissector v0.0.2-0.20200223041816-c0cb3da7ea91 h1:bFBTbZglO4xNVWSLwDEcVKBIurTXGL2sNKi9UuQima4= diff --git a/gost.go b/gost.go index e135113..9e7e421 100644 --- a/gost.go +++ b/gost.go @@ -20,7 +20,7 @@ import ( ) // Version is the gost version. -const Version = "2.10.2-dev" +const Version = "2.11.0-dev" // Debug is a flag that enables the debug log. var Debug bool diff --git a/node.go b/node.go index 0e31451..f64afc4 100644 --- a/node.go +++ b/node.go @@ -111,6 +111,7 @@ func ParseNode(s string) (node Node, err error) { case "tun", "tap": // tun/tap device case "ftcp": // fake TCP case "dns", "dot", "doh": + case "relay": default: node.Protocol = "" } diff --git a/obfs.go b/obfs.go index 684fbac..060f4e5 100644 --- a/obfs.go +++ b/obfs.go @@ -321,6 +321,7 @@ type obfsTLSConn struct { handshakeMutex sync.Mutex } +// ClientObfsTLSConn creates a connection for obfs-tls client. func ClientObfsTLSConn(conn net.Conn, host string) net.Conn { return &obfsTLSConn{ Conn: conn, @@ -329,6 +330,7 @@ func ClientObfsTLSConn(conn net.Conn, host string) net.Conn { } } +// ServerObfsTLSConn creates a connection for obfs-tls server. func ServerObfsTLSConn(conn net.Conn, host string) net.Conn { return &obfsTLSConn{ Conn: conn, diff --git a/relay.go b/relay.go new file mode 100644 index 0000000..22dbdcf --- /dev/null +++ b/relay.go @@ -0,0 +1,331 @@ +package gost + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/url" + "strconv" + "sync" + "time" + + "github.com/ginuerzh/relay" + "github.com/go-log/log" +) + +type relayConnector struct { + user *url.Userinfo + remoteAddr string +} + +// RelayConnector creates a Connector for TCP/UDP data relay. +func RelayConnector(user *url.Userinfo) Connector { + return &relayConnector{ + user: user, + } +} + +func (c *relayConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return conn, nil +} + +func (c *relayConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + var udp bool + if network == "udp" || network == "udp4" || network == "udp6" { + udp = true + } + + req := &relay.Request{ + Version: relay.Version1, + } + if udp { + req.Flags |= relay.FUDP + } + + if c.user != nil { + pwd, _ := c.user.Password() + req.Features = append(req.Features, &relay.UserAuthFeature{ + Username: c.user.Username(), + Password: pwd, + }) + } + if address != "" { + host, port, _ := net.SplitHostPort(address) + nport, _ := strconv.ParseUint(port, 10, 16) + if host == "" { + host = net.IPv4zero.String() + } + + if nport > 0 { + var atype uint8 + ip := net.ParseIP(host) + if ip == nil { + atype = relay.AddrDomain + } else if ip.To4() == nil { + atype = relay.AddrIPv6 + } else { + atype = relay.AddrIPv4 + } + + req.Features = append(req.Features, &relay.TargetAddrFeature{ + AType: atype, + Host: host, + Port: uint16(nport), + }) + } + } + + rc := &relayConn{ + udp: udp, + Conn: conn, + } + + // write the header at once. + if opts.NoDelay { + if _, err := req.WriteTo(rc); err != nil { + return nil, err + } + } else { + if _, err := req.WriteTo(&rc.wbuf); err != nil { + return nil, err + } + } + + return rc, nil +} + +type relayHandler struct { + *baseForwardHandler +} + +// RelayHandler creates a server Handler for TCP/UDP relay server. +func RelayHandler(raddr string, opts ...HandlerOption) Handler { + h := &relayHandler{ + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, + } + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *relayHandler) Init(options ...HandlerOption) { + h.baseForwardHandler.Init(options...) +} + +func (h *relayHandler) Handle(conn net.Conn) { + defer conn.Close() + + req := &relay.Request{} + if _, err := req.ReadFrom(conn); err != nil { + log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + + if req.Version != relay.Version1 { + log.Logf("[relay] %s - %s : bad version", conn.RemoteAddr(), conn.LocalAddr()) + return + } + + var user, pass string + var raddr string + for _, f := range req.Features { + if f.Type() == relay.FeatureUserAuth { + feature := f.(*relay.UserAuthFeature) + user, pass = feature.Username, feature.Password + } + if f.Type() == relay.FeatureTargetAddr { + feature := f.(*relay.TargetAddrFeature) + raddr = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) + } + } + + resp := &relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + if h.options.Authenticator != nil && !h.options.Authenticator.Authenticate(user, pass) { + resp.Status = relay.StatusUnauthorized + resp.WriteTo(conn) + log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user) + return + } + + if raddr != "" { + if len(h.group.Nodes()) > 0 { + resp.Status = relay.StatusForbidden + resp.WriteTo(conn) + log.Logf("[relay] %s -> %s : relay to %s is forbidden", + conn.RemoteAddr(), conn.LocalAddr(), raddr) + return + } + } else { + if len(h.group.Nodes()) == 0 { + resp.Status = relay.StatusBadRequest + resp.WriteTo(conn) + log.Logf("[relay] %s -> %s : bad request, target addr is needed", + conn.RemoteAddr(), conn.LocalAddr()) + return + } + } + + udp := (req.Flags & relay.FUDP) == relay.FUDP + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + network := "tcp" + if udp { + network = "udp" + } + + ctx := context.TODO() + var cc net.Conn + var node Node + var err error + for i := 0; i < retries; i++ { + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + resp.Status = relay.StatusServiceUnavailable + resp.WriteTo(conn) + log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + raddr = node.Addr + } + + log.Logf("[relay] %s -> %s -> %s", conn.RemoteAddr(), conn.LocalAddr(), raddr) + cc, err = h.options.Chain.DialContext(ctx, + network, raddr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) + if err != nil { + log.Logf("[relay] %s -> %s : %s", conn.RemoteAddr(), raddr, err) + node.MarkDead() + } else { + break + } + } + if err != nil { + resp.Status = relay.StatusServiceUnavailable + resp.WriteTo(conn) + return + } + + node.ResetDead() + defer cc.Close() + + sc := &relayConn{ + Conn: conn, + isServer: true, + udp: udp, + } + resp.WriteTo(&sc.wbuf) + conn = sc + + log.Logf("[relay] %s <-> %s", conn.RemoteAddr(), raddr) + transport(conn, cc) + log.Logf("[relay] %s >-< %s", conn.RemoteAddr(), raddr) +} + +type relayConn struct { + net.Conn + isServer bool + udp bool + wbuf bytes.Buffer + once sync.Once +} + +func (c *relayConn) Read(b []byte) (n int, err error) { + c.once.Do(func() { + if c.isServer { + return + } + resp := new(relay.Response) + _, err = resp.ReadFrom(c.Conn) + if err != nil { + return + } + if resp.Version != relay.Version1 { + err = relay.ErrBadVersion + return + } + if resp.Status != relay.StatusOK { + err = fmt.Errorf("status %d", resp.Status) + return + } + }) + + if !c.udp { + return c.Conn.Read(b) + } + var bb [2]byte + _, err = io.ReadFull(c.Conn, bb[:]) + if err != nil { + return + } + dlen := int(binary.BigEndian.Uint16(bb[:])) + if len(b) >= dlen { + return io.ReadFull(c.Conn, b[:dlen]) + } + buf := make([]byte, dlen) + _, err = io.ReadFull(c.Conn, buf) + n = copy(b, buf) + return +} + +func (c *relayConn) Write(b []byte) (n int, err error) { + if len(b) > 0xFFFF { + err = errors.New("write: data maximum exceeded") + return + } + n = len(b) // force byte length consistent + if c.wbuf.Len() > 0 { + if c.udp { + var bb [2]byte + binary.BigEndian.PutUint16(bb[:2], uint16(len(b))) + c.wbuf.Write(bb[:]) + } + c.wbuf.Write(b) // append the data to the cached header + // _, err = c.Conn.Write(c.wbuf.Bytes()) + // c.wbuf.Reset() + _, err = c.wbuf.WriteTo(c.Conn) + return + } + + if !c.udp { + return c.Conn.Write(b) + } + buf := make([]byte, 2+len(b)) + binary.BigEndian.PutUint16(buf[:2], uint16(len(b))) + n = copy(buf[2:], b) + _, err = c.Conn.Write(buf) + return +} diff --git a/snapcraft.yaml b/snapcraft.yaml index 0b407b1..3384f2e 100644 --- a/snapcraft.yaml +++ b/snapcraft.yaml @@ -1,6 +1,6 @@ name: gost type: app -version: '2.10.2' +version: '2.11.0' title: GO Simple Tunnel summary: A simple security tunnel written in golang description: | diff --git a/tuntap.go b/tuntap.go index 8f47b41..733c60a 100644 --- a/tuntap.go +++ b/tuntap.go @@ -40,6 +40,7 @@ func ipProtocol(p waterutil.IPProtocol) string { return fmt.Sprintf("unknown(%d)", p) } +// IPRoute is an IP routing entry. type IPRoute struct { Dest *net.IPNet Gateway net.IP