Split set route to Start()

This commit is contained in:
世界 2024-11-06 17:09:04 +08:00
parent 9bcc1ec384
commit 24206c3edd
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 114 additions and 47 deletions

1
tun.go
View file

@ -24,6 +24,7 @@ type Handler interface {
type Tun interface { type Tun interface {
io.ReadWriter io.ReadWriter
N.VectorisedWriter N.VectorisedWriter
Start() error
Close() error Close() error
} }

View file

@ -1,6 +1,7 @@
package tun package tun
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@ -24,9 +25,10 @@ const PacketOffset = 4
type NativeTun struct { type NativeTun struct {
tunFile *os.File tunFile *os.File
tunWriter N.VectorisedWriter tunWriter N.VectorisedWriter
mtu uint32 options Options
inet4Address [4]byte inet4Address [4]byte
inet6Address [16]byte inet6Address [16]byte
routerSet bool
} }
func New(options Options) (Tun, error) { func New(options Options) (Tun, error) {
@ -54,7 +56,7 @@ func New(options Options) (Tun, error) {
nativeTun := &NativeTun{ nativeTun := &NativeTun{
tunFile: os.NewFile(uintptr(tunFd), "utun"), tunFile: os.NewFile(uintptr(tunFd), "utun"),
mtu: options.MTU, options: options,
} }
if len(options.Inet4Address) > 0 { if len(options.Inet4Address) > 0 {
nativeTun.inet4Address = options.Inet4Address[0].Addr().As4() nativeTun.inet4Address = options.Inet4Address[0].Addr().As4()
@ -70,6 +72,15 @@ func New(options Options) (Tun, error) {
return nativeTun, nil return nativeTun, nil
} }
func (t *NativeTun) Start() error {
return t.setRoutes()
}
func (t *NativeTun) Close() error {
defer flushDNSCache()
return E.Errors(t.unsetRoutes(), t.tunFile.Close())
}
func (t *NativeTun) Read(p []byte) (n int, err error) { func (t *NativeTun) Read(p []byte) (n int, err error) {
return t.tunFile.Read(p) return t.tunFile.Read(p)
} }
@ -93,11 +104,6 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...)) return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
} }
func (t *NativeTun) Close() error {
flushDNSCache()
return t.tunFile.Close()
}
const utunControlName = "com.apple.net.utun_control" const utunControlName = "com.apple.net.utun_control"
const ( const (
@ -239,28 +245,68 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
} }
} }
} }
if options.AutoRoute { return nil
var routeRanges []netip.Prefix }
routeRanges, err = options.BuildAutoRouteRanges(false)
func (t *NativeTun) setRoutes() error {
if t.options.AutoRoute && t.options.FileDescriptor == 0 {
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil { if err != nil {
return err return err
} }
gateway4, gateway6 := options.Inet4GatewayAddr(), options.Inet6GatewayAddr() gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
for _, routeRange := range routeRanges { for _, destination := range routeRanges {
if routeRange.Addr().Is4() { var gateway netip.Addr
err = addRoute(routeRange, gateway4) if destination.Addr().Is4() {
gateway = gateway4
} else { } else {
err = addRoute(routeRange, gateway6) gateway = gateway6
} }
err = execRoute(unix.RTM_ADD, destination, gateway)
if err != nil { if err != nil {
return E.Cause(err, "add route: ", routeRange) if errors.Is(err, unix.EEXIST) {
err = execRoute(unix.RTM_DELETE, destination, gateway)
if err != nil {
return E.Cause(err, "remove existing route: ", destination)
}
err = execRoute(unix.RTM_ADD, destination, gateway)
if err != nil {
return E.Cause(err, "re-add route: ", destination)
}
}
return E.Cause(err, "add route: ", destination)
} }
} }
flushDNSCache() flushDNSCache()
t.routerSet = true
} }
return nil return nil
} }
func (t *NativeTun) unsetRoutes() error {
if !t.routerSet {
return nil
}
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
for _, destination := range routeRanges {
var gateway netip.Addr
if destination.Addr().Is4() {
gateway = gateway4
} else {
gateway = gateway6
}
err = execRoute(unix.RTM_DELETE, destination, gateway)
if err != nil {
err = E.Errors(err, E.Cause(err, "delete route: ", destination))
}
}
return err
}
func useSocket(domain, typ, proto int, block func(socketFd int) error) error { func useSocket(domain, typ, proto int, block func(socketFd int) error) error {
socketFd, err := unix.Socket(domain, typ, proto) socketFd, err := unix.Socket(domain, typ, proto)
if err != nil { if err != nil {
@ -270,13 +316,16 @@ func useSocket(domain, typ, proto int, block func(socketFd int) error) error {
return block(socketFd) return block(socketFd)
} }
func addRoute(destination netip.Prefix, gateway netip.Addr) error { func execRoute(rtmType int, destination netip.Prefix, gateway netip.Addr) error {
routeMessage := route.RouteMessage{ routeMessage := route.RouteMessage{
Type: unix.RTM_ADD, Type: rtmType,
Flags: unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY,
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
Flags: unix.RTF_STATIC | unix.RTF_GATEWAY,
Seq: 1, Seq: 1,
} }
if rtmType == unix.RTM_ADD {
routeMessage.Flags |= unix.RTF_UP
}
if gateway.Is4() { if gateway.Is4() {
routeMessage.Addrs = []route.Addr{ routeMessage.Addrs = []route.Addr{
syscall.RTAX_DST: &route.Inet4Addr{IP: destination.Addr().As4()}, syscall.RTAX_DST: &route.Inet4Addr{IP: destination.Addr().As4()},
@ -300,5 +349,5 @@ func addRoute(destination netip.Prefix, gateway netip.Addr) error {
} }
func flushDNSCache() { func flushDNSCache() {
shell.Exec("dscacheutil", "-flushcache").Start() go shell.Exec("dscacheutil", "-flushcache").Run()
} }

View file

@ -24,7 +24,7 @@ type DarwinEndpoint struct {
} }
func (e *DarwinEndpoint) MTU() uint32 { func (e *DarwinEndpoint) MTU() uint32 {
return e.tun.mtu return e.tun.options.MTU
} }
func (e *DarwinEndpoint) SetMTU(mtu uint32) { func (e *DarwinEndpoint) SetMTU(mtu uint32) {
@ -57,7 +57,7 @@ func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
} }
func (e *DarwinEndpoint) dispatchLoop() { func (e *DarwinEndpoint) dispatchLoop() {
packetBuffer := make([]byte, e.tun.mtu+PacketOffset) packetBuffer := make([]byte, e.tun.options.MTU+PacketOffset)
for { for {
n, err := e.tun.tunFile.Read(packetBuffer) n, err := e.tun.tunFile.Read(packetBuffer)
if err != nil { if err != nil {

View file

@ -1,6 +1,7 @@
package tun package tun
import ( import (
"errors"
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
@ -222,7 +223,7 @@ func open(name string, vnetHdr bool) (int, error) {
func (t *NativeTun) configure(tunLink netlink.Link) error { func (t *NativeTun) configure(tunLink netlink.Link) error {
err := netlink.LinkSetMTU(tunLink, int(t.options.MTU)) err := netlink.LinkSetMTU(tunLink, int(t.options.MTU))
if err == unix.EPERM { if errors.Is(err, unix.EPERM) {
// unprivileged // unprivileged
return nil return nil
} else if err != nil { } else if err != nil {
@ -288,11 +289,23 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
t.txChecksumOffload = true t.txChecksumOffload = true
} }
err = netlink.LinkSetUp(tunLink) return nil
}
func (t *NativeTun) Start() error {
tunLink, err := netlink.LinkByName(t.options.Name)
if err != nil { if err != nil {
return err return err
} }
err = netlink.LinkSetUp(tunLink)
if errors.Is(err, unix.EPERM) {
// unprivileged
return nil
} else if err != nil {
return err
}
if t.options.IPRoute2TableIndex == 0 { if t.options.IPRoute2TableIndex == 0 {
for { for {
t.options.IPRoute2TableIndex = int(rand.Uint32()) t.options.IPRoute2TableIndex = int(rand.Uint32())

View file

@ -106,27 +106,6 @@ func (t *NativeTun) configure() error {
if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 { if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
_ = luid.DisableDNSRegistration() _ = luid.DisableDNSRegistration()
} }
if t.options.AutoRoute {
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = luid.AddRoute(routeRange, gateway4, 0)
} else {
err = luid.AddRoute(routeRange, gateway6, 0)
}
}
if err != nil {
return err
}
err = windnsapi.FlushResolverCache()
if err != nil {
return err
}
}
if len(t.options.Inet4Address) > 0 { if len(t.options.Inet4Address) > 0 {
inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET)) inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET))
if err != nil { if err != nil {
@ -166,8 +145,34 @@ func (t *NativeTun) configure() error {
return E.Cause(err, "set ipv6 options") return E.Cause(err, "set ipv6 options")
} }
} }
return nil
}
if t.options.AutoRoute && t.options.StrictRoute { func (t *NativeTun) Start() error {
if !t.options.AutoRoute {
return nil
}
luid := winipcfg.LUID(t.adapter.LUID())
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = luid.AddRoute(routeRange, gateway4, 0)
} else {
err = luid.AddRoute(routeRange, gateway6, 0)
}
}
if err != nil {
return err
}
err = windnsapi.FlushResolverCache()
if err != nil {
return err
}
if t.options.StrictRoute {
var engine uintptr var engine uintptr
session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC} session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC}
err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine)) err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine))
@ -340,7 +345,6 @@ func (t *NativeTun) configure() error {
} }
} }
} }
return nil return nil
} }