diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index 8ecec21..d64feaa 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "io/ioutil" + "net" "net/url" "os" "strings" @@ -273,3 +274,44 @@ func parseHosts(s string) *gost.Hosts { return hosts } + +func parseIPRoutes(s string) (routes []gost.IPRoute) { + if s == "" { + return + } + + file, err := os.Open(s) + if err != nil { + ss := strings.Split(s, ",") + for _, s := range ss { + if _, inet, _ := net.ParseCIDR(strings.TrimSpace(s)); inet != nil { + routes = append(routes, gost.IPRoute{Dest: inet}) + } + } + return + } + + defer file.Close() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.Replace(scanner.Text(), "\t", " ", -1) + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + var route gost.IPRoute + ss := strings.Split(line, " ") + if len(ss) > 0 && ss[0] != "" { + _, route.Dest, _ = net.ParseCIDR(strings.TrimSpace(ss[0])) + if route.Dest == nil { + continue + } + } + if len(ss) > 1 && ss[1] != "" { + route.Gateway = net.ParseIP(ss[1]) + } + routes = append(routes, route) + } + return routes +} diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 184052a..8c554b0 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -311,6 +311,8 @@ func (r *route) GenRouters() ([]router, error) { ttl = time.Duration(node.GetInt("ttl")) * time.Second } + tunRoutes := parseIPRoutes(node.Get("route")) + var ln gost.Listener switch node.Transport { case "tls": @@ -415,7 +417,7 @@ func (r *route) GenRouters() ([]router, error) { Name: node.Get("name"), Addr: node.Get("net"), MTU: node.GetInt("mtu"), - Routes: strings.Split(node.Get("route"), ","), + Routes: tunRoutes, Gateway: node.Get("gw"), } ln, err = gost.TunListener(cfg) @@ -525,6 +527,7 @@ func (r *route) GenRouters() ([]router, error) { gost.NodeHandlerOption(node), gost.IPsHandlerOption(ips), gost.TCPModeHandlerOption(node.GetBool("tcp")), + gost.IPRoutesHandlerOption(tunRoutes...), ) rt := router{ diff --git a/handler.go b/handler.go index 5ee42c6..db53147 100644 --- a/handler.go +++ b/handler.go @@ -41,6 +41,7 @@ type HandlerOptions struct { Host string IPs []string TCPMode bool + IPRoutes []IPRoute } // HandlerOption allows a common way to set handler options. @@ -203,6 +204,13 @@ func TCPModeHandlerOption(b bool) HandlerOption { } } +// IPRoutesHandlerOption sets the IP routes for tun tunnel. +func IPRoutesHandlerOption(routes ...IPRoute) HandlerOption { + return func(opts *HandlerOptions) { + opts.IPRoutes = routes + } +} + type autoHandler struct { options *HandlerOptions } diff --git a/tuntap.go b/tuntap.go index 795fc4e..84b1bbb 100644 --- a/tuntap.go +++ b/tuntap.go @@ -39,11 +39,16 @@ func ipProtocol(p waterutil.IPProtocol) string { return fmt.Sprintf("unknown(%d)", p) } +type IPRoute struct { + Dest *net.IPNet + Gateway net.IP +} + type TunConfig struct { Name string Addr string MTU int - Routes []string + Routes []IPRoute Gateway string } @@ -224,6 +229,20 @@ func (h *tunHandler) initTunnelConn(pc net.PacketConn) (net.PacketConn, error) { return pc, nil } +func (h *tunHandler) findRouteFor(dst net.IP) net.Addr { + for _, route := range h.options.IPRoutes { + if route.Dest.Contains(dst) && route.Gateway != nil { + if v, ok := h.routes.Load(ipToTunRouteKey(route.Gateway)); ok { + return v.(net.Addr) + } + } + } + if v, ok := h.routes.Load(ipToTunRouteKey(dst)); ok { + return v.(net.Addr) + } + return nil +} + func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.Addr) error { errc := make(chan error, 1) @@ -279,10 +298,7 @@ func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.A return err } - var addr net.Addr - if v, ok := h.routes.Load(ipToTunRouteKey(dst)); ok { - addr = v.(net.Addr) - } + addr := h.findRouteFor(dst) if addr == nil { log.Logf("[tun] no route for %s -> %s", src, dst) return nil @@ -361,11 +377,11 @@ func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.A log.Logf("[tun] new route: %s -> %s", src, addr) } - if v, ok := h.routes.Load(ipToTunRouteKey(dst)); ok { + if addr := h.findRouteFor(dst); addr != nil { if Debug { - log.Logf("[tun] find route: %s -> %s", dst, v) + log.Logf("[tun] find route: %s -> %s", dst, addr) } - _, err := conn.WriteTo(b[:n], v.(net.Addr)) + _, err := conn.WriteTo(b[:n], addr) return err } diff --git a/tuntap_darwin.go b/tuntap_darwin.go index 6aa102a..51cdd9e 100644 --- a/tuntap_darwin.go +++ b/tuntap_darwin.go @@ -37,7 +37,7 @@ func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) { return } - if err = addRoutes(ifce.Name(), cfg.Routes...); err != nil { + if err = addTunRoutes(ifce.Name(), cfg.Routes...); err != nil { return } @@ -58,12 +58,12 @@ func createTap(cfg TapConfig) (conn net.Conn, itf *net.Interface, err error) { return } -func addRoutes(ifName string, routes ...string) error { +func addTunRoutes(ifName string, routes ...IPRoute) error { for _, route := range routes { - if route == "" { + if route.Dest == nil { continue } - cmd := fmt.Sprintf("route add -net %s -interface %s", route, ifName) + cmd := fmt.Sprintf("route add -net %s -interface %s", route.Dest.String(), ifName) log.Log("[tun]", cmd) args := strings.Split(cmd, " ") if er := exec.Command(args[0], args[1:]...).Run(); er != nil { diff --git a/tuntap_linux.go b/tuntap_linux.go index f2c53fc..e2bd01b 100644 --- a/tuntap_linux.go +++ b/tuntap_linux.go @@ -136,14 +136,14 @@ func createTap(cfg TapConfig) (conn net.Conn, itf *net.Interface, err error) { return } -func addTunRoutes(ifName string, routes ...string) error { +func addTunRoutes(ifName string, routes ...IPRoute) error { for _, route := range routes { - if route == "" { + if route.Dest == nil { continue } - cmd := fmt.Sprintf("ip route add %s dev %s", route, ifName) + cmd := fmt.Sprintf("ip route add %s dev %s", route.Dest.String(), ifName) log.Logf("[tun] %s", cmd) - if err := netlink.AddRoute(route, "", "", ifName); err != nil { + if err := netlink.AddRoute(route.Dest.String(), "", "", ifName); err != nil { return fmt.Errorf("%s: %v", cmd, err) } } diff --git a/tuntap_unix.go b/tuntap_unix.go index 479dc31..a7db978 100644 --- a/tuntap_unix.go +++ b/tuntap_unix.go @@ -96,12 +96,12 @@ func createTap(cfg TapConfig) (conn net.Conn, itf *net.Interface, err error) { return } -func addTunRoutes(ifName string, routes ...string) error { +func addTunRoutes(ifName string, routes ...IPRoute) error { for _, route := range routes { - if route == "" { + if route.Dest == nil { continue } - cmd := fmt.Sprintf("route add -net %s -interface %s", route, ifName) + cmd := fmt.Sprintf("route add -net %s -interface %s", route.Dest.String(), ifName) log.Logf("[tun] %s", cmd) args := strings.Split(cmd, " ") if er := exec.Command(args[0], args[1:]...).Run(); er != nil { diff --git a/tuntap_windows.go b/tuntap_windows.go index e19692a..b5e12d9 100644 --- a/tuntap_windows.go +++ b/tuntap_windows.go @@ -98,16 +98,16 @@ func createTap(cfg TapConfig) (conn net.Conn, itf *net.Interface, err error) { return } -func addTunRoutes(ifName string, gw string, routes ...string) error { +func addTunRoutes(ifName string, gw string, routes ...IPRoute) error { for _, route := range routes { - if route == "" { + if route.Dest == nil { continue } - deleteRoute(ifName, route) + deleteRoute(ifName, route.Dest.String()) cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active", - route, ifName) + route.Dest.String(), ifName) if gw != "" { cmd += " nexthop=" + gw }