diff --git a/monitor.go b/monitor.go index 6ef0de7..91b3c84 100644 --- a/monitor.go +++ b/monitor.go @@ -9,7 +9,7 @@ var ErrNoRoute = E.New("no route to internet") type ( NetworkUpdateCallback = func() error - DefaultInterfaceUpdateCallback = func() + DefaultInterfaceUpdateCallback = func() error ) type NetworkUpdateMonitor interface { @@ -17,6 +17,7 @@ type NetworkUpdateMonitor interface { Close() error RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] UnregisterCallback(element *list.Element[NetworkUpdateCallback]) + E.Handler } type DefaultInterfaceMonitor interface { @@ -24,4 +25,6 @@ type DefaultInterfaceMonitor interface { Close() error DefaultInterfaceName() string DefaultInterfaceIndex() int + RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback] + UnregisterCallback(element *list.Element[DefaultInterfaceUpdateCallback]) } diff --git a/monitor_android.go b/monitor_android.go index 57f88c8..2d6f183 100644 --- a/monitor_android.go +++ b/monitor_android.go @@ -42,7 +42,7 @@ func (m *defaultInterfaceMonitor) checkUpdate() error { if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { return nil } - m.callback() + m.emit() return nil } diff --git a/monitor_darwin.go b/monitor_darwin.go index 1f1fc6d..46cb98f 100644 --- a/monitor_darwin.go +++ b/monitor_darwin.go @@ -13,6 +13,7 @@ import ( "golang.org/x/net/route" "golang.org/x/sys/unix" + "syscall" ) type networkUpdateMonitor struct { @@ -29,30 +30,6 @@ func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, erro }, nil } -func (m *networkUpdateMonitor) RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] { - m.access.Lock() - defer m.access.Unlock() - return m.callbacks.PushBack(callback) -} - -func (m *networkUpdateMonitor) UnregisterCallback(element *list.Element[NetworkUpdateCallback]) { - m.access.Lock() - defer m.access.Unlock() - m.callbacks.Remove(element) -} - -func (m *networkUpdateMonitor) emit() { - m.access.Lock() - callbacks := m.callbacks.Array() - m.access.Unlock() - for _, callback := range callbacks { - err := callback() - if err != nil { - m.errorHandler.NewError(context.Background(), err) - } - } -} - func (m *networkUpdateMonitor) Start() error { routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0) if err != nil { @@ -88,7 +65,7 @@ func (m *networkUpdateMonitor) loopUpdate() { } m.emit() } - if !E.IsClosed(err) { + if err != syscall.EAGAIN { m.errorHandler.NewError(context.Background(), err) } } @@ -97,43 +74,6 @@ func (m *networkUpdateMonitor) Close() error { return common.Close(common.PtrOrNil(m.routeSocket)) } -type defaultInterfaceMonitor struct { - defaultInterfaceName string - defaultInterfaceIndex int - networkMonitor NetworkUpdateMonitor - element *list.Element[NetworkUpdateCallback] - callback DefaultInterfaceUpdateCallback -} - -func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) { - return &defaultInterfaceMonitor{ - networkMonitor: networkMonitor, - callback: callback, - }, nil -} - -func (m *defaultInterfaceMonitor) Start() error { - err := m.checkUpdate() - if err != nil { - return err - } - m.element = m.networkMonitor.RegisterCallback(m.checkUpdate) - return nil -} - -func (m *defaultInterfaceMonitor) Close() error { - m.networkMonitor.UnregisterCallback(m.element) - return nil -} - -func (m *defaultInterfaceMonitor) DefaultInterfaceName() string { - return m.defaultInterfaceName -} - -func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int { - return m.defaultInterfaceIndex -} - func (m *defaultInterfaceMonitor) checkUpdate() error { ribMessage, err := route.FetchRIB(unix.AF_UNSPEC, route.RIBTypeRoute, 0) if err != nil { @@ -163,7 +103,7 @@ func (m *defaultInterfaceMonitor) checkUpdate() error { if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { return nil } - m.callback() + m.emit() return nil } } diff --git a/monitor_linux.go b/monitor_linux.go index 56b2a12..32f4537 100644 --- a/monitor_linux.go +++ b/monitor_linux.go @@ -1,7 +1,6 @@ package tun import ( - "context" "os" "sync" @@ -29,30 +28,6 @@ func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, erro }, nil } -func (m *networkUpdateMonitor) RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] { - m.access.Lock() - defer m.access.Unlock() - return m.callbacks.PushBack(callback) -} - -func (m *networkUpdateMonitor) UnregisterCallback(element *list.Element[NetworkUpdateCallback]) { - m.access.Lock() - defer m.access.Unlock() - m.callbacks.Remove(element) -} - -func (m *networkUpdateMonitor) emit() { - m.access.Lock() - callbacks := m.callbacks.Array() - m.access.Unlock() - for _, callback := range callbacks { - err := callback() - if err != nil { - m.errorHandler.NewError(context.Background(), err) - } - } -} - func (m *networkUpdateMonitor) Start() error { err := netlink.RouteSubscribe(m.routeUpdate, m.close) if err != nil { @@ -87,40 +62,3 @@ func (m *networkUpdateMonitor) Close() error { close(m.close) return nil } - -type defaultInterfaceMonitor struct { - defaultInterfaceName string - defaultInterfaceIndex int - networkMonitor NetworkUpdateMonitor - element *list.Element[NetworkUpdateCallback] - callback DefaultInterfaceUpdateCallback -} - -func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) { - return &defaultInterfaceMonitor{ - networkMonitor: networkMonitor, - callback: callback, - }, nil -} - -func (m *defaultInterfaceMonitor) Start() error { - err := m.checkUpdate() - if err != nil { - return err - } - m.element = m.networkMonitor.RegisterCallback(m.checkUpdate) - return nil -} - -func (m *defaultInterfaceMonitor) Close() error { - m.networkMonitor.UnregisterCallback(m.element) - return nil -} - -func (m *defaultInterfaceMonitor) DefaultInterfaceName() string { - return m.defaultInterfaceName -} - -func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int { - return m.defaultInterfaceIndex -} diff --git a/monitor_linux_default.go b/monitor_linux_default.go index 050c7f3..b190fef 100644 --- a/monitor_linux_default.go +++ b/monitor_linux_default.go @@ -34,7 +34,7 @@ func (m *defaultInterfaceMonitor) checkUpdate() error { if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex { return nil } - m.callback() + m.emit() return nil } return E.New("no route to internet") diff --git a/monitor_other.go b/monitor_other.go index e1fbd71..d44ff1b 100644 --- a/monitor_other.go +++ b/monitor_other.go @@ -12,6 +12,6 @@ func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, erro return nil, os.ErrInvalid } -func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) { +func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor) (DefaultInterfaceMonitor, error) { return nil, os.ErrInvalid } diff --git a/monitor_shared.go b/monitor_shared.go new file mode 100644 index 0000000..cede3a3 --- /dev/null +++ b/monitor_shared.go @@ -0,0 +1,99 @@ +//go:build linux || windows || darwin + +package tun + +import ( + "context" + "sync" + + "github.com/sagernet/sing/common/x/list" +) + +func (m *networkUpdateMonitor) RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] { + m.access.Lock() + defer m.access.Unlock() + return m.callbacks.PushBack(callback) +} + +func (m *networkUpdateMonitor) UnregisterCallback(element *list.Element[NetworkUpdateCallback]) { + m.access.Lock() + defer m.access.Unlock() + m.callbacks.Remove(element) +} + +func (m *networkUpdateMonitor) emit() { + m.access.Lock() + callbacks := m.callbacks.Array() + m.access.Unlock() + for _, callback := range callbacks { + err := callback() + if err != nil { + m.NewError(context.Background(), err) + } + } +} + +func (m *networkUpdateMonitor) NewError(ctx context.Context, err error) { + m.errorHandler.NewError(ctx, err) +} + +type defaultInterfaceMonitor struct { + defaultInterfaceName string + defaultInterfaceIndex int + networkMonitor NetworkUpdateMonitor + element *list.Element[NetworkUpdateCallback] + access sync.Mutex + callbacks list.List[DefaultInterfaceUpdateCallback] +} + +func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor) (DefaultInterfaceMonitor, error) { + return &defaultInterfaceMonitor{ + networkMonitor: networkMonitor, + }, nil +} + +func (m *defaultInterfaceMonitor) Start() error { + err := m.checkUpdate() + if err != nil { + return err + } + m.element = m.networkMonitor.RegisterCallback(m.checkUpdate) + return nil +} + +func (m *defaultInterfaceMonitor) Close() error { + m.networkMonitor.UnregisterCallback(m.element) + return nil +} + +func (m *defaultInterfaceMonitor) DefaultInterfaceName() string { + return m.defaultInterfaceName +} + +func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int { + return m.defaultInterfaceIndex +} + +func (m *defaultInterfaceMonitor) RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback] { + m.access.Lock() + defer m.access.Unlock() + return m.callbacks.PushBack(callback) +} + +func (m *defaultInterfaceMonitor) UnregisterCallback(element *list.Element[DefaultInterfaceUpdateCallback]) { + m.access.Lock() + defer m.access.Unlock() + m.callbacks.Remove(element) +} + +func (m *defaultInterfaceMonitor) emit() { + m.access.Lock() + callbacks := m.callbacks.Array() + m.access.Unlock() + for _, callback := range callbacks { + err := callback() + if err != nil { + m.networkMonitor.NewError(context.Background(), err) + } + } +} diff --git a/monitor_windows.go b/monitor_windows.go index bbf7edb..7d81742 100644 --- a/monitor_windows.go +++ b/monitor_windows.go @@ -1,7 +1,6 @@ package tun import ( - "context" "sync" "github.com/sagernet/sing-tun/internal/winipcfg" @@ -26,30 +25,6 @@ func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, erro }, nil } -func (m *networkUpdateMonitor) RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] { - m.access.Lock() - defer m.access.Unlock() - return m.callbacks.PushBack(callback) -} - -func (m *networkUpdateMonitor) UnregisterCallback(element *list.Element[NetworkUpdateCallback]) { - m.access.Lock() - defer m.access.Unlock() - m.callbacks.Remove(element) -} - -func (m *networkUpdateMonitor) emit() { - m.access.Lock() - callbacks := m.callbacks.Array() - m.access.Unlock() - for _, callback := range callbacks { - err := callback() - if err != nil { - m.errorHandler.NewError(context.Background(), err) - } - } -} - func (m *networkUpdateMonitor) Start() error { routeListener, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { m.emit() @@ -76,35 +51,6 @@ func (m *networkUpdateMonitor) Close() error { ) } -type defaultInterfaceMonitor struct { - defaultInterfaceName string - defaultInterfaceIndex int - networkMonitor NetworkUpdateMonitor - element *list.Element[NetworkUpdateCallback] - callback DefaultInterfaceUpdateCallback -} - -func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) { - return &defaultInterfaceMonitor{ - networkMonitor: networkMonitor, - callback: callback, - }, nil -} - -func (m *defaultInterfaceMonitor) Start() error { - err := m.checkUpdate() - if err != nil { - return err - } - m.element = m.networkMonitor.RegisterCallback(m.checkUpdate) - return nil -} - -func (m *defaultInterfaceMonitor) Close() error { - m.networkMonitor.UnregisterCallback(m.element) - return nil -} - func (m *defaultInterfaceMonitor) checkUpdate() error { rows, err := winipcfg.GetIPForwardTable2(windows.AF_INET) if err != nil { @@ -152,14 +98,6 @@ func (m *defaultInterfaceMonitor) checkUpdate() error { return nil } - m.callback() + m.emit() return nil } - -func (m *defaultInterfaceMonitor) DefaultInterfaceName() string { - return m.defaultInterfaceName -} - -func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int { - return m.defaultInterfaceIndex -}