Fix monitor

This commit is contained in:
世界 2022-07-14 18:53:42 +08:00
parent 6c2c28da9d
commit 9dc73c0bcc
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 244 additions and 104 deletions

View file

@ -1,40 +1,111 @@
package tun
import (
"context"
"sync"
"github.com/sagernet/sing-tun/internal/winipcfg"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/sys/windows"
)
var _ InterfaceMonitor = (*NativeMonitor)(nil)
type networkUpdateMonitor struct {
routeListener *winipcfg.RouteChangeCallback
interfaceListener *winipcfg.InterfaceChangeCallback
errorHandler E.Handler
type NativeMonitor struct {
listener *winipcfg.RouteChangeCallback
callback InterfaceMonitorCallback
defaultInterfaceName string
defaultInterfaceIndex int
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
}
func NewMonitor(callback InterfaceMonitorCallback) (InterfaceMonitor, error) {
return &NativeMonitor{callback: callback}, nil
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
errorHandler: errorHandler,
}, nil
}
func (m *NativeMonitor) Start() error {
err := m.checkUpdate()
if err != nil {
return err
func (m *networkUpdateMonitor) RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] {
m.access.Lock()
defer m.access.Unlock()
return m.callbacks.PushBack(callback)
}
func (m *networkUpdateMonitor) UnregisterCallback(element *list.Element[NetworkUpdateCallback]) {
m.access.Lock()
defer m.access.Unlock()
m.callbacks.Remove(element)
}
func (m *networkUpdateMonitor) emit() {
m.access.Lock()
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
err := callback()
if err != nil {
m.errorHandler.NewError(context.Background(), err)
}
}
listener, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
m.checkUpdate()
}
func (m *networkUpdateMonitor) Start() error {
routeListener, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
m.emit()
})
if err != nil {
return err
}
m.listener = listener
m.routeListener = routeListener
interfaceListener, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
m.emit()
})
if err != nil {
routeListener.Unregister()
return err
}
m.interfaceListener = interfaceListener
return nil
}
func (m *NativeMonitor) checkUpdate() error {
func (m *networkUpdateMonitor) Close() error {
return E.Errors(
m.routeListener.Unregister(),
m.interfaceListener.Unregister(),
)
}
type defaultInterfaceMonitor struct {
defaultInterfaceName string
defaultInterfaceIndex int
networkMonitor NetworkUpdateMonitor
element *list.Element[NetworkUpdateCallback]
callback DefaultInterfaceUpdateCallback
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) {
return &defaultInterfaceMonitor{
networkMonitor: networkMonitor,
callback: callback,
}, nil
}
func (m *defaultInterfaceMonitor) Start() error {
err := m.checkUpdate()
if err != nil {
return err
}
m.element = m.networkMonitor.RegisterCallback(m.checkUpdate)
return nil
}
func (m *defaultInterfaceMonitor) Close() error {
m.networkMonitor.UnregisterCallback(m.element)
return nil
}
func (m *defaultInterfaceMonitor) checkUpdate() error {
rows, err := winipcfg.GetIPForwardTable2(windows.AF_INET)
if err != nil {
return err
@ -85,14 +156,10 @@ func (m *NativeMonitor) checkUpdate() error {
return nil
}
func (m *NativeMonitor) Close() error {
return m.listener.Unregister()
}
func (m *NativeMonitor) DefaultInterfaceName() string {
func (m *defaultInterfaceMonitor) DefaultInterfaceName() string {
return m.defaultInterfaceName
}
func (m *NativeMonitor) DefaultInterfaceIndex() int {
func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int {
return m.defaultInterfaceIndex
}