diff --git a/monitor_darwin.go b/monitor_darwin.go index 032f329..42f9c36 100644 --- a/monitor_darwin.go +++ b/monitor_darwin.go @@ -42,7 +42,7 @@ func (m *networkUpdateMonitor) loopUpdate() { select { case <-m.done: return - case <-time.After(time.Second): + default: } err := m.loopUpdate0() if err != nil { @@ -67,7 +67,16 @@ func (m *networkUpdateMonitor) loopUpdate1(routeSocketFile *os.File) { defer routeSocketFile.Close() buffer := buf.NewPacket() defer buffer.Release() + done := make(chan struct{}) + go func() { + select { + case <-m.done: + routeSocketFile.Close() + case <-done: + } + }() n, err := routeSocketFile.Read(buffer.FreeBytes()) + close(done) if err != nil { return } @@ -92,57 +101,59 @@ func (m *networkUpdateMonitor) Close() error { } func (m *defaultInterfaceMonitor) checkUpdate() error { - ribMessage, err := route.FetchRIB(unix.AF_UNSPEC, route.RIBTypeRoute, 0) - if err != nil { - return err - } - routeMessages, err := route.ParseRIB(route.RIBTypeRoute, ribMessage) - if err != nil { - return err - } - var defaultInterface *net.Interface - for _, rawRouteMessage := range routeMessages { - routeMessage := rawRouteMessage.(*route.RouteMessage) - if len(routeMessage.Addrs) <= unix.RTAX_NETMASK { - continue - } - destination, isIPv4Destination := routeMessage.Addrs[unix.RTAX_DST].(*route.Inet4Addr) - if !isIPv4Destination { - continue - } - if destination.IP != netip.IPv4Unspecified().As4() { - continue - } - mask, isIPv4Mask := routeMessage.Addrs[unix.RTAX_NETMASK].(*route.Inet4Addr) - if !isIPv4Mask { - continue - } - ones, _ := net.IPMask(mask.IP[:]).Size() - if ones != 0 { - continue - } - routeInterface, err := net.InterfaceByIndex(routeMessage.Index) + var ( + defaultInterface *net.Interface + err error + ) + if m.options.UnderNetworkExtension { + defaultInterface, err = getDefaultInterfaceBySocket() if err != nil { return err } - if routeMessage.Flags&unix.RTF_UP == 0 { - continue + } else { + ribMessage, err := route.FetchRIB(unix.AF_UNSPEC, route.RIBTypeRoute, 0) + if err != nil { + return err } - if routeMessage.Flags&unix.RTF_GATEWAY == 0 { - continue + routeMessages, err := route.ParseRIB(route.RIBTypeRoute, ribMessage) + if err != nil { + return err } - if routeMessage.Flags&unix.RTF_IFSCOPE != 0 { - // continue - } - defaultInterface = routeInterface - break - } - if defaultInterface == nil { - if m.options.UnderNetworkExtension { - defaultInterface, err = getDefaultInterfaceBySocket() + for _, rawRouteMessage := range routeMessages { + routeMessage := rawRouteMessage.(*route.RouteMessage) + if len(routeMessage.Addrs) <= unix.RTAX_NETMASK { + continue + } + destination, isIPv4Destination := routeMessage.Addrs[unix.RTAX_DST].(*route.Inet4Addr) + if !isIPv4Destination { + continue + } + if destination.IP != netip.IPv4Unspecified().As4() { + continue + } + mask, isIPv4Mask := routeMessage.Addrs[unix.RTAX_NETMASK].(*route.Inet4Addr) + if !isIPv4Mask { + continue + } + ones, _ := net.IPMask(mask.IP[:]).Size() + if ones != 0 { + continue + } + routeInterface, err := net.InterfaceByIndex(routeMessage.Index) if err != nil { return err } + if routeMessage.Flags&unix.RTF_UP == 0 { + continue + } + if routeMessage.Flags&unix.RTF_GATEWAY == 0 { + continue + } + if routeMessage.Flags&unix.RTF_IFSCOPE != 0 { + // continue + } + defaultInterface = routeInterface + break } } if defaultInterface == nil { diff --git a/monitor_shared.go b/monitor_shared.go index 66c8f1a..f286657 100644 --- a/monitor_shared.go +++ b/monitor_shared.go @@ -6,7 +6,6 @@ import ( "errors" "net" "net/netip" - "runtime" "sync" "time" @@ -44,6 +43,7 @@ type defaultInterfaceMonitor struct { defaultInterfaceIndex int androidVPNEnabled bool networkMonitor NetworkUpdateMonitor + checkUpdateTimer *time.Timer element *list.Element[NetworkUpdateCallback] access sync.Mutex callbacks list.List[DefaultInterfaceUpdateCallback] @@ -72,9 +72,13 @@ func (m *defaultInterfaceMonitor) Start() error { } func (m *defaultInterfaceMonitor) delayCheckUpdate() { - if runtime.GOOS == "android" { - time.Sleep(time.Second) + if m.checkUpdateTimer != nil { + m.checkUpdateTimer.Stop() } + m.checkUpdateTimer = time.AfterFunc(time.Second, m.postCheckUpdate) +} + +func (m *defaultInterfaceMonitor) postCheckUpdate() { err := m.updateInterfaces() if err != nil { m.logger.Error("update interfaces: ", err)