diff --git a/tun.go b/tun.go index 1706ef9..f6252da 100644 --- a/tun.go +++ b/tun.go @@ -26,6 +26,7 @@ type tunHandler struct { raddr string options *HandlerOptions ipNet *net.IPNet + routes sync.Map } // TunHandler creates a handler for tun tunnel. @@ -97,7 +98,6 @@ func (h *tunHandler) createTun() (conn net.Conn, err error) { } func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.Addr) error { - var routes sync.Map errc := make(chan error, 1) go func() { @@ -130,7 +130,7 @@ func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.A } addr := raddr - if v, ok := routes.Load(header.Dst.String()); ok { + if v, ok := h.routes.Load(header.Dst.String()); ok { addr = v.(net.Addr) } if addr == nil { @@ -195,11 +195,11 @@ func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.A } if h.ipNet != nil && h.ipNet.IP.Equal(header.Src.Mask(h.ipNet.Mask)) { - if actual, loaded := routes.LoadOrStore(header.Src.String(), addr); loaded { + if actual, loaded := h.routes.LoadOrStore(header.Src.String(), addr); loaded { if actual.(net.Addr).String() != addr.String() { log.Logf("[tun] %s <- %s: update route: %s -> %s (old %s)", tun.LocalAddr(), addr, header.Src, addr, actual.(net.Addr)) - routes.Store(header.Src.String(), addr) + h.routes.Store(header.Src.String(), addr) } } else { log.Logf("[tun] %s: new route: %s -> %s", tun.LocalAddr(), header.Src, addr) diff --git a/tun_windows.go b/tun_windows.go index aa68aa2..3aa71ea 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -38,7 +38,7 @@ func createTun(cfg TunConfig) (conn net.Conn, ipNet *net.IPNet, err error) { return } - if err = addRoutes(ip.String(), cfg.Routes...); err != nil { + if err = addRoutes(ifce.Name(), cfg.Routes...); err != nil { return } @@ -49,18 +49,16 @@ func createTun(cfg TunConfig) (conn net.Conn, ipNet *net.IPNet, err error) { return } -func addRoutes(ifIP string, routes ...string) error { +func addRoutes(ifName string, routes ...string) error { for _, route := range routes { if route == "" { continue } - _, inet, err := net.ParseCIDR(route) - if err != nil { - return err - } - cmd := fmt.Sprintf("route ADD %s MASK %s %s", - inet.IP, ipMask(inet.Mask), ifIP) + deleteRoute(ifName, route) + + cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active", + route, ifName) log.Log("[tun]", cmd) args := strings.Split(cmd, " ") if er := exec.Command(args[0], args[1:]...).Run(); er != nil { @@ -70,6 +68,13 @@ func addRoutes(ifIP string, routes ...string) error { return nil } +func deleteRoute(ifName string, route string) error { + cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=%s store=active", + route, ifName) + args := strings.Split(cmd, " ") + return exec.Command(args[0], args[1:]...).Run() +} + func ipMask(mask net.IPMask) string { return fmt.Sprintf("%d.%d.%d.%d", mask[0], mask[1], mask[2], mask[3]) }