From 9dc73c0bcc1ddbb9adbf6fd4c173dedf33216b2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 14 Jul 2022 18:53:42 +0800 Subject: [PATCH] Fix monitor --- go.mod | 2 +- go.sum | 4 +- monitor.go | 21 ++++++-- monitor_linux.go | 131 +++++++++++++++++++++++++++++++++------------ monitor_other.go | 12 ++++- monitor_windows.go | 111 ++++++++++++++++++++++++++++++-------- tun_linux.go | 67 ++++++++++------------- 7 files changed, 244 insertions(+), 104 deletions(-) diff --git a/go.mod b/go.mod index eab189a..e4d7596 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/sagernet/sing-tun go 1.18 require ( - github.com/sagernet/sing v0.0.0-20220711062652-4394f7cbbae1 + github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0 github.com/vishvananda/netlink v1.1.0 golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e gvisor.dev/gvisor v0.0.0-20220711011657-cecae2f4234d diff --git a/go.sum b/go.sum index 04c97ed..8211807 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -github.com/sagernet/sing v0.0.0-20220711062652-4394f7cbbae1 h1:gssTBQKiiXd1zALSOzQFZl3qwzCy4O76eSH0YY9A+Po= -github.com/sagernet/sing v0.0.0-20220711062652-4394f7cbbae1/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= +github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0 h1:8tnMLN6jdqKkjPXwgEekwloPaAmvbxQAMMHdWYOiMj8= +github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= diff --git a/monitor.go b/monitor.go index 9711ef7..6ef0de7 100644 --- a/monitor.go +++ b/monitor.go @@ -1,14 +1,27 @@ package tun -import E "github.com/sagernet/sing/common/exceptions" +import ( + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/x/list" +) var ErrNoRoute = E.New("no route to internet") -type InterfaceMonitor interface { +type ( + NetworkUpdateCallback = func() error + DefaultInterfaceUpdateCallback = func() +) + +type NetworkUpdateMonitor interface { + Start() error + Close() error + RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] + UnregisterCallback(element *list.Element[NetworkUpdateCallback]) +} + +type DefaultInterfaceMonitor interface { Start() error Close() error DefaultInterfaceName() string DefaultInterfaceIndex() int } - -type InterfaceMonitorCallback func() diff --git a/monitor_linux.go b/monitor_linux.go index f36d9cc..e2b6b55 100644 --- a/monitor_linux.go +++ b/monitor_linux.go @@ -1,35 +1,69 @@ package tun import ( + "context" + "net" "os" + "sync" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/x/list" "github.com/vishvananda/netlink" ) -type NativeMonitor struct { - defaultInterfaceName string - defaultInterfaceIndex int - update chan netlink.RouteUpdate - close chan struct{} - callback InterfaceMonitorCallback +type networkUpdateMonitor struct { + routeUpdate chan netlink.RouteUpdate + linkUpdate chan netlink.LinkUpdate + close chan struct{} + errorHandler E.Handler + + access sync.Mutex + callbacks list.List[NetworkUpdateCallback] } -func NewMonitor(callback InterfaceMonitorCallback) (InterfaceMonitor, error) { - return &NativeMonitor{ - callback: callback, - update: make(chan netlink.RouteUpdate, 2), - close: make(chan struct{}), +func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) { + return &networkUpdateMonitor{ + routeUpdate: make(chan netlink.RouteUpdate, 2), + linkUpdate: make(chan netlink.LinkUpdate, 2), + close: make(chan struct{}), + errorHandler: errorHandler, }, nil } -func (m *NativeMonitor) Start() error { - err := netlink.RouteSubscribe(m.update, m.close) +func (m *networkUpdateMonitor) RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] { + m.access.Lock() + defer m.access.Unlock() + return m.callbacks.PushBack(callback) +} + +func (m *networkUpdateMonitor) UnregisterCallback(element *list.Element[NetworkUpdateCallback]) { + m.access.Lock() + defer m.access.Unlock() + m.callbacks.Remove(element) +} + +func (m *networkUpdateMonitor) emit() { + m.access.Lock() + callbacks := m.callbacks.Array() + m.access.Unlock() + for _, callback := range callbacks { + err := callback() + if err != nil { + m.errorHandler.NewError(context.Background(), err) + } + } +} + +func (m *networkUpdateMonitor) Start() error { + err := netlink.RouteSubscribe(m.routeUpdate, m.close) + if err != nil { + return err + } + err = netlink.LinkSubscribe(m.linkUpdate, m.close) if err != nil { return err } - err = m.checkUpdate() if err != nil { return err } @@ -37,32 +71,73 @@ func (m *NativeMonitor) Start() error { return nil } -func (m *NativeMonitor) loopUpdate() { +func (m *networkUpdateMonitor) loopUpdate() { for { select { case <-m.close: return - case <-m.update: - m.checkUpdate() + case <-m.routeUpdate: + case <-m.linkUpdate: } + m.emit() } } -func (m *NativeMonitor) checkUpdate() error { +func (m *networkUpdateMonitor) Close() error { + select { + case <-m.close: + return os.ErrClosed + default: + } + close(m.close) + return nil +} + +type defaultInterfaceMonitor struct { + defaultInterfaceName string + defaultInterfaceIndex int + networkMonitor NetworkUpdateMonitor + element *list.Element[NetworkUpdateCallback] + callback DefaultInterfaceUpdateCallback +} + +func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) { + return &defaultInterfaceMonitor{ + networkMonitor: networkMonitor, + callback: callback, + }, nil +} + +func (m *defaultInterfaceMonitor) Start() error { + err := m.checkUpdate() + if err != nil { + return err + } + m.element = m.networkMonitor.RegisterCallback(m.checkUpdate) + return nil +} + +func (m *defaultInterfaceMonitor) Close() error { + m.networkMonitor.UnregisterCallback(m.element) + return nil +} + +func (m *defaultInterfaceMonitor) checkUpdate() error { routes, err := netlink.RouteList(nil, netlink.FAMILY_V4) if err != nil { return err } for _, route := range routes { - if route.Dst != nil { - continue - } var link netlink.Link link, err = netlink.LinkByIndex(route.LinkIndex) if err != nil { return err } + if link.Attrs().Flags&net.FlagUp == 0 { + continue + } + if link.Type() == "tuntap" { continue } @@ -82,20 +157,10 @@ func (m *NativeMonitor) checkUpdate() error { return E.New("no route to internet") } -func (m *NativeMonitor) Close() error { - select { - case <-m.close: - return os.ErrClosed - default: - } - close(m.close) - return nil -} - -func (m *NativeMonitor) DefaultInterfaceName() string { +func (m *defaultInterfaceMonitor) DefaultInterfaceName() string { return m.defaultInterfaceName } -func (m *NativeMonitor) DefaultInterfaceIndex() int { +func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int { return m.defaultInterfaceIndex } diff --git a/monitor_other.go b/monitor_other.go index b10d336..49b7151 100644 --- a/monitor_other.go +++ b/monitor_other.go @@ -2,8 +2,16 @@ package tun -import "os" +import ( + "os" -func NewMonitor() (InterfaceMonitor, error) { + E "github.com/sagernet/sing/common/exceptions" +) + +func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) { + return nil, os.ErrInvalid +} + +func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) { return nil, os.ErrInvalid } diff --git a/monitor_windows.go b/monitor_windows.go index afca49b..bbf7edb 100644 --- a/monitor_windows.go +++ b/monitor_windows.go @@ -1,40 +1,111 @@ package tun import ( + "context" + "sync" + "github.com/sagernet/sing-tun/internal/winipcfg" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/x/list" "golang.org/x/sys/windows" ) -var _ InterfaceMonitor = (*NativeMonitor)(nil) +type networkUpdateMonitor struct { + routeListener *winipcfg.RouteChangeCallback + interfaceListener *winipcfg.InterfaceChangeCallback + errorHandler E.Handler -type NativeMonitor struct { - listener *winipcfg.RouteChangeCallback - callback InterfaceMonitorCallback - defaultInterfaceName string - defaultInterfaceIndex int + access sync.Mutex + callbacks list.List[NetworkUpdateCallback] } -func NewMonitor(callback InterfaceMonitorCallback) (InterfaceMonitor, error) { - return &NativeMonitor{callback: callback}, nil +func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) { + return &networkUpdateMonitor{ + errorHandler: errorHandler, + }, nil } -func (m *NativeMonitor) Start() error { - err := m.checkUpdate() - if err != nil { - return err +func (m *networkUpdateMonitor) RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] { + m.access.Lock() + defer m.access.Unlock() + return m.callbacks.PushBack(callback) +} + +func (m *networkUpdateMonitor) UnregisterCallback(element *list.Element[NetworkUpdateCallback]) { + m.access.Lock() + defer m.access.Unlock() + m.callbacks.Remove(element) +} + +func (m *networkUpdateMonitor) emit() { + m.access.Lock() + callbacks := m.callbacks.Array() + m.access.Unlock() + for _, callback := range callbacks { + err := callback() + if err != nil { + m.errorHandler.NewError(context.Background(), err) + } } - listener, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { - m.checkUpdate() +} + +func (m *networkUpdateMonitor) Start() error { + routeListener, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { + m.emit() }) if err != nil { return err } - m.listener = listener + m.routeListener = routeListener + interfaceListener, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + m.emit() + }) + if err != nil { + routeListener.Unregister() + return err + } + m.interfaceListener = interfaceListener return nil } -func (m *NativeMonitor) checkUpdate() error { +func (m *networkUpdateMonitor) Close() error { + return E.Errors( + m.routeListener.Unregister(), + m.interfaceListener.Unregister(), + ) +} + +type defaultInterfaceMonitor struct { + defaultInterfaceName string + defaultInterfaceIndex int + networkMonitor NetworkUpdateMonitor + element *list.Element[NetworkUpdateCallback] + callback DefaultInterfaceUpdateCallback +} + +func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) { + return &defaultInterfaceMonitor{ + networkMonitor: networkMonitor, + callback: callback, + }, nil +} + +func (m *defaultInterfaceMonitor) Start() error { + err := m.checkUpdate() + if err != nil { + return err + } + m.element = m.networkMonitor.RegisterCallback(m.checkUpdate) + return nil +} + +func (m *defaultInterfaceMonitor) Close() error { + m.networkMonitor.UnregisterCallback(m.element) + return nil +} + +func (m *defaultInterfaceMonitor) checkUpdate() error { rows, err := winipcfg.GetIPForwardTable2(windows.AF_INET) if err != nil { return err @@ -85,14 +156,10 @@ func (m *NativeMonitor) checkUpdate() error { return nil } -func (m *NativeMonitor) Close() error { - return m.listener.Unregister() -} - -func (m *NativeMonitor) DefaultInterfaceName() string { +func (m *defaultInterfaceMonitor) DefaultInterfaceName() string { return m.defaultInterfaceName } -func (m *NativeMonitor) DefaultInterfaceIndex() int { +func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int { return m.defaultInterfaceIndex } diff --git a/tun_linux.go b/tun_linux.go index 11890e0..168fe1d 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -45,6 +45,29 @@ func Open(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu return nativeTun, nil } +func (t *NativeTun) routes(tunLink netlink.Link) []netlink.Route { + var routes []netlink.Route + if t.inet4Address.IsValid() { + routes = append(routes, netlink.Route{ + Dst: &net.IPNet{ + IP: net.IPv4zero, + Mask: net.CIDRMask(0, 32), + }, + LinkIndex: tunLink.Attrs().Index, + }) + } + if t.inet6Address.IsValid() { + routes = append(routes, netlink.Route{ + Dst: &net.IPNet{ + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 128), + }, + LinkIndex: tunLink.Attrs().Index, + }) + } + return routes +} + func (t *NativeTun) configure() error { tunLink, err := netlink.LinkByName(t.name) if err != nil { @@ -77,26 +100,8 @@ func (t *NativeTun) configure() error { } if t.autoRoute { - if t.inet4Address.IsValid() { - err = netlink.RouteAdd(&netlink.Route{ - Dst: &net.IPNet{ - IP: net.IPv4zero, - Mask: net.CIDRMask(0, 32), - }, - LinkIndex: tunLink.Attrs().Index, - }) - if err != nil { - return err - } - } - if t.inet6Address.IsValid() { - err = netlink.RouteAdd(&netlink.Route{ - Dst: &net.IPNet{ - IP: net.IPv6zero, - Mask: net.CIDRMask(0, 128), - }, - LinkIndex: tunLink.Attrs().Index, - }) + for _, route := range t.routes(tunLink) { + err = netlink.RouteAdd(&route) if err != nil { return err } @@ -133,26 +138,8 @@ func (t *NativeTun) Close() error { return err } if t.autoRoute { - if t.inet4Address.IsValid() { - err = netlink.RouteDel(&netlink.Route{ - Dst: &net.IPNet{ - IP: net.IPv4zero, - Mask: net.CIDRMask(0, 32), - }, - LinkIndex: tunLink.Attrs().Index, - }) - if err != nil { - return err - } - } - if t.inet6Address.IsValid() { - err = netlink.RouteDel(&netlink.Route{ - Dst: &net.IPNet{ - IP: net.IPv6zero, - Mask: net.CIDRMask(0, 128), - }, - LinkIndex: tunLink.Attrs().Index, - }) + for _, route := range t.routes(tunLink) { + err = netlink.RouteDel(&route) if err != nil { return err }