Split set route to Start()

This commit is contained in:
世界 2024-11-06 17:09:04 +08:00
parent 9bcc1ec384
commit 24206c3edd
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 114 additions and 47 deletions

1
tun.go
View file

@ -24,6 +24,7 @@ type Handler interface {
type Tun interface {
io.ReadWriter
N.VectorisedWriter
Start() error
Close() error
}

View file

@ -1,6 +1,7 @@
package tun
import (
"errors"
"fmt"
"net"
"net/netip"
@ -24,9 +25,10 @@ const PacketOffset = 4
type NativeTun struct {
tunFile *os.File
tunWriter N.VectorisedWriter
mtu uint32
options Options
inet4Address [4]byte
inet6Address [16]byte
routerSet bool
}
func New(options Options) (Tun, error) {
@ -54,7 +56,7 @@ func New(options Options) (Tun, error) {
nativeTun := &NativeTun{
tunFile: os.NewFile(uintptr(tunFd), "utun"),
mtu: options.MTU,
options: options,
}
if len(options.Inet4Address) > 0 {
nativeTun.inet4Address = options.Inet4Address[0].Addr().As4()
@ -70,6 +72,15 @@ func New(options Options) (Tun, error) {
return nativeTun, nil
}
func (t *NativeTun) Start() error {
return t.setRoutes()
}
func (t *NativeTun) Close() error {
defer flushDNSCache()
return E.Errors(t.unsetRoutes(), t.tunFile.Close())
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
return t.tunFile.Read(p)
}
@ -93,11 +104,6 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
}
func (t *NativeTun) Close() error {
flushDNSCache()
return t.tunFile.Close()
}
const utunControlName = "com.apple.net.utun_control"
const (
@ -239,28 +245,68 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
}
}
}
if options.AutoRoute {
var routeRanges []netip.Prefix
routeRanges, err = options.BuildAutoRouteRanges(false)
return nil
}
func (t *NativeTun) setRoutes() error {
if t.options.AutoRoute && t.options.FileDescriptor == 0 {
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
gateway4, gateway6 := options.Inet4GatewayAddr(), options.Inet6GatewayAddr()
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = addRoute(routeRange, gateway4)
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
for _, destination := range routeRanges {
var gateway netip.Addr
if destination.Addr().Is4() {
gateway = gateway4
} else {
err = addRoute(routeRange, gateway6)
gateway = gateway6
}
err = execRoute(unix.RTM_ADD, destination, gateway)
if err != nil {
return E.Cause(err, "add route: ", routeRange)
if errors.Is(err, unix.EEXIST) {
err = execRoute(unix.RTM_DELETE, destination, gateway)
if err != nil {
return E.Cause(err, "remove existing route: ", destination)
}
err = execRoute(unix.RTM_ADD, destination, gateway)
if err != nil {
return E.Cause(err, "re-add route: ", destination)
}
}
return E.Cause(err, "add route: ", destination)
}
}
flushDNSCache()
t.routerSet = true
}
return nil
}
func (t *NativeTun) unsetRoutes() error {
if !t.routerSet {
return nil
}
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
for _, destination := range routeRanges {
var gateway netip.Addr
if destination.Addr().Is4() {
gateway = gateway4
} else {
gateway = gateway6
}
err = execRoute(unix.RTM_DELETE, destination, gateway)
if err != nil {
err = E.Errors(err, E.Cause(err, "delete route: ", destination))
}
}
return err
}
func useSocket(domain, typ, proto int, block func(socketFd int) error) error {
socketFd, err := unix.Socket(domain, typ, proto)
if err != nil {
@ -270,13 +316,16 @@ func useSocket(domain, typ, proto int, block func(socketFd int) error) error {
return block(socketFd)
}
func addRoute(destination netip.Prefix, gateway netip.Addr) error {
func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error {
routeMessage := route.RouteMessage{
Type: unix.RTM_ADD,
Flags: unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY,
Type: rtmType,
Version: unix.RTM_VERSION,
Flags: unix.RTF_STATIC | unix.RTF_GATEWAY,
Seq: 1,
}
if rtmType == unix.RTM_ADD {
routeMessage.Flags |= unix.RTF_UP
}
if gateway.Is4() {
routeMessage.Addrs = []route.Addr{
syscall.RTAX_DST: &route.Inet4Addr{IP: destination.Addr().As4()},
@ -300,5 +349,5 @@ func addRoute(destination netip.Prefix, gateway netip.Addr) error {
}
func flushDNSCache() {
shell.Exec("dscacheutil", "-flushcache").Start()
go shell.Exec("dscacheutil", "-flushcache").Run()
}

View file

@ -24,7 +24,7 @@ type DarwinEndpoint struct {
}
func (e *DarwinEndpoint) MTU() uint32 {
return e.tun.mtu
return e.tun.options.MTU
}
func (e *DarwinEndpoint) SetMTU(mtu uint32) {
@ -57,7 +57,7 @@ func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}
func (e *DarwinEndpoint) dispatchLoop() {
packetBuffer := make([]byte, e.tun.mtu+PacketOffset)
packetBuffer := make([]byte, e.tun.options.MTU+PacketOffset)
for {
n, err := e.tun.tunFile.Read(packetBuffer)
if err != nil {

View file

@ -1,6 +1,7 @@
package tun
import (
"errors"
"math/rand"
"net"
"net/netip"
@ -222,7 +223,7 @@ func open(name string, vnetHdr bool) (int, error) {
func (t *NativeTun) configure(tunLink netlink.Link) error {
err := netlink.LinkSetMTU(tunLink, int(t.options.MTU))
if err == unix.EPERM {
if errors.Is(err, unix.EPERM) {
// unprivileged
return nil
} else if err != nil {
@ -288,11 +289,23 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
t.txChecksumOffload = true
}
err = netlink.LinkSetUp(tunLink)
return nil
}
func (t *NativeTun) Start() error {
tunLink, err := netlink.LinkByName(t.options.Name)
if err != nil {
return err
}
err = netlink.LinkSetUp(tunLink)
if errors.Is(err, unix.EPERM) {
// unprivileged
return nil
} else if err != nil {
return err
}
if t.options.IPRoute2TableIndex == 0 {
for {
t.options.IPRoute2TableIndex = int(rand.Uint32())

View file

@ -106,27 +106,6 @@ func (t *NativeTun) configure() error {
if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
_ = luid.DisableDNSRegistration()
}
if t.options.AutoRoute {
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = luid.AddRoute(routeRange, gateway4, 0)
} else {
err = luid.AddRoute(routeRange, gateway6, 0)
}
}
if err != nil {
return err
}
err = windnsapi.FlushResolverCache()
if err != nil {
return err
}
}
if len(t.options.Inet4Address) > 0 {
inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET))
if err != nil {
@ -166,8 +145,34 @@ func (t *NativeTun) configure() error {
return E.Cause(err, "set ipv6 options")
}
}
return nil
}
if t.options.AutoRoute && t.options.StrictRoute {
func (t *NativeTun) Start() error {
if !t.options.AutoRoute {
return nil
}
luid := winipcfg.LUID(t.adapter.LUID())
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = luid.AddRoute(routeRange, gateway4, 0)
} else {
err = luid.AddRoute(routeRange, gateway6, 0)
}
}
if err != nil {
return err
}
err = windnsapi.FlushResolverCache()
if err != nil {
return err
}
if t.options.StrictRoute {
var engine uintptr
session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC}
err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine))
@ -340,7 +345,6 @@ func (t *NativeTun) configure() error {
}
}
}
return nil
}