Fix monitor

This commit is contained in:
世界 2022-07-14 18:53:42 +08:00
parent 6c2c28da9d
commit 9dc73c0bcc
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 244 additions and 104 deletions

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/sagernet/sing-tun
go 1.18 go 1.18
require ( require (
github.com/sagernet/sing v0.0.0-20220711062652-4394f7cbbae1 github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e
gvisor.dev/gvisor v0.0.0-20220711011657-cecae2f4234d gvisor.dev/gvisor v0.0.0-20220711011657-cecae2f4234d

4
go.sum
View file

@ -1,7 +1,7 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/sagernet/sing v0.0.0-20220711062652-4394f7cbbae1 h1:gssTBQKiiXd1zALSOzQFZl3qwzCy4O76eSH0YY9A+Po= github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0 h1:8tnMLN6jdqKkjPXwgEekwloPaAmvbxQAMMHdWYOiMj8=
github.com/sagernet/sing v0.0.0-20220711062652-4394f7cbbae1/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c= github.com/sagernet/sing v0.0.0-20220714145306-09b55ce4b6d0/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c=
github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=

View file

@ -1,14 +1,27 @@
package tun package tun
import E "github.com/sagernet/sing/common/exceptions" import (
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/x/list"
)
var ErrNoRoute = E.New("no route to internet") var ErrNoRoute = E.New("no route to internet")
type InterfaceMonitor interface { type (
NetworkUpdateCallback = func() error
DefaultInterfaceUpdateCallback = func()
)
type NetworkUpdateMonitor interface {
Start() error
Close() error
RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback]
UnregisterCallback(element *list.Element[NetworkUpdateCallback])
}
type DefaultInterfaceMonitor interface {
Start() error Start() error
Close() error Close() error
DefaultInterfaceName() string DefaultInterfaceName() string
DefaultInterfaceIndex() int DefaultInterfaceIndex() int
} }
type InterfaceMonitorCallback func()

View file

@ -1,35 +1,69 @@
package tun package tun
import ( import (
"context"
"net"
"os" "os"
"sync"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/x/list"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
) )
type NativeMonitor struct { type networkUpdateMonitor struct {
defaultInterfaceName string routeUpdate chan netlink.RouteUpdate
defaultInterfaceIndex int linkUpdate chan netlink.LinkUpdate
update chan netlink.RouteUpdate close chan struct{}
close chan struct{} errorHandler E.Handler
callback InterfaceMonitorCallback
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
} }
func NewMonitor(callback InterfaceMonitorCallback) (InterfaceMonitor, error) { func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return &NativeMonitor{ return &networkUpdateMonitor{
callback: callback, routeUpdate: make(chan netlink.RouteUpdate, 2),
update: make(chan netlink.RouteUpdate, 2), linkUpdate: make(chan netlink.LinkUpdate, 2),
close: make(chan struct{}), close: make(chan struct{}),
errorHandler: errorHandler,
}, nil }, nil
} }
func (m *NativeMonitor) Start() error { func (m *networkUpdateMonitor) RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] {
err := netlink.RouteSubscribe(m.update, m.close) m.access.Lock()
defer m.access.Unlock()
return m.callbacks.PushBack(callback)
}
func (m *networkUpdateMonitor) UnregisterCallback(element *list.Element[NetworkUpdateCallback]) {
m.access.Lock()
defer m.access.Unlock()
m.callbacks.Remove(element)
}
func (m *networkUpdateMonitor) emit() {
m.access.Lock()
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
err := callback()
if err != nil {
m.errorHandler.NewError(context.Background(), err)
}
}
}
func (m *networkUpdateMonitor) Start() error {
err := netlink.RouteSubscribe(m.routeUpdate, m.close)
if err != nil {
return err
}
err = netlink.LinkSubscribe(m.linkUpdate, m.close)
if err != nil { if err != nil {
return err return err
} }
err = m.checkUpdate()
if err != nil { if err != nil {
return err return err
} }
@ -37,32 +71,73 @@ func (m *NativeMonitor) Start() error {
return nil return nil
} }
func (m *NativeMonitor) loopUpdate() { func (m *networkUpdateMonitor) loopUpdate() {
for { for {
select { select {
case <-m.close: case <-m.close:
return return
case <-m.update: case <-m.routeUpdate:
m.checkUpdate() case <-m.linkUpdate:
} }
m.emit()
} }
} }
func (m *NativeMonitor) checkUpdate() error { func (m *networkUpdateMonitor) Close() error {
select {
case <-m.close:
return os.ErrClosed
default:
}
close(m.close)
return nil
}
type defaultInterfaceMonitor struct {
defaultInterfaceName string
defaultInterfaceIndex int
networkMonitor NetworkUpdateMonitor
element *list.Element[NetworkUpdateCallback]
callback DefaultInterfaceUpdateCallback
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) {
return &defaultInterfaceMonitor{
networkMonitor: networkMonitor,
callback: callback,
}, nil
}
func (m *defaultInterfaceMonitor) Start() error {
err := m.checkUpdate()
if err != nil {
return err
}
m.element = m.networkMonitor.RegisterCallback(m.checkUpdate)
return nil
}
func (m *defaultInterfaceMonitor) Close() error {
m.networkMonitor.UnregisterCallback(m.element)
return nil
}
func (m *defaultInterfaceMonitor) checkUpdate() error {
routes, err := netlink.RouteList(nil, netlink.FAMILY_V4) routes, err := netlink.RouteList(nil, netlink.FAMILY_V4)
if err != nil { if err != nil {
return err return err
} }
for _, route := range routes { for _, route := range routes {
if route.Dst != nil {
continue
}
var link netlink.Link var link netlink.Link
link, err = netlink.LinkByIndex(route.LinkIndex) link, err = netlink.LinkByIndex(route.LinkIndex)
if err != nil { if err != nil {
return err return err
} }
if link.Attrs().Flags&net.FlagUp == 0 {
continue
}
if link.Type() == "tuntap" { if link.Type() == "tuntap" {
continue continue
} }
@ -82,20 +157,10 @@ func (m *NativeMonitor) checkUpdate() error {
return E.New("no route to internet") return E.New("no route to internet")
} }
func (m *NativeMonitor) Close() error { func (m *defaultInterfaceMonitor) DefaultInterfaceName() string {
select {
case <-m.close:
return os.ErrClosed
default:
}
close(m.close)
return nil
}
func (m *NativeMonitor) DefaultInterfaceName() string {
return m.defaultInterfaceName return m.defaultInterfaceName
} }
func (m *NativeMonitor) DefaultInterfaceIndex() int { func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int {
return m.defaultInterfaceIndex return m.defaultInterfaceIndex
} }

View file

@ -2,8 +2,16 @@
package tun package tun
import "os" import (
"os"
func NewMonitor() (InterfaceMonitor, error) { E "github.com/sagernet/sing/common/exceptions"
)
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return nil, os.ErrInvalid
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }

View file

@ -1,40 +1,111 @@
package tun package tun
import ( import (
"context"
"sync"
"github.com/sagernet/sing-tun/internal/winipcfg" "github.com/sagernet/sing-tun/internal/winipcfg"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
var _ InterfaceMonitor = (*NativeMonitor)(nil) type networkUpdateMonitor struct {
routeListener *winipcfg.RouteChangeCallback
interfaceListener *winipcfg.InterfaceChangeCallback
errorHandler E.Handler
type NativeMonitor struct { access sync.Mutex
listener *winipcfg.RouteChangeCallback callbacks list.List[NetworkUpdateCallback]
callback InterfaceMonitorCallback
defaultInterfaceName string
defaultInterfaceIndex int
} }
func NewMonitor(callback InterfaceMonitorCallback) (InterfaceMonitor, error) { func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return &NativeMonitor{callback: callback}, nil return &networkUpdateMonitor{
errorHandler: errorHandler,
}, nil
} }
func (m *NativeMonitor) Start() error { func (m *networkUpdateMonitor) RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback] {
err := m.checkUpdate() m.access.Lock()
if err != nil { defer m.access.Unlock()
return err return m.callbacks.PushBack(callback)
}
func (m *networkUpdateMonitor) UnregisterCallback(element *list.Element[NetworkUpdateCallback]) {
m.access.Lock()
defer m.access.Unlock()
m.callbacks.Remove(element)
}
func (m *networkUpdateMonitor) emit() {
m.access.Lock()
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
err := callback()
if err != nil {
m.errorHandler.NewError(context.Background(), err)
}
} }
listener, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { }
m.checkUpdate()
func (m *networkUpdateMonitor) Start() error {
routeListener, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
m.emit()
}) })
if err != nil { if err != nil {
return err return err
} }
m.listener = listener m.routeListener = routeListener
interfaceListener, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
m.emit()
})
if err != nil {
routeListener.Unregister()
return err
}
m.interfaceListener = interfaceListener
return nil return nil
} }
func (m *NativeMonitor) checkUpdate() error { func (m *networkUpdateMonitor) Close() error {
return E.Errors(
m.routeListener.Unregister(),
m.interfaceListener.Unregister(),
)
}
type defaultInterfaceMonitor struct {
defaultInterfaceName string
defaultInterfaceIndex int
networkMonitor NetworkUpdateMonitor
element *list.Element[NetworkUpdateCallback]
callback DefaultInterfaceUpdateCallback
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, callback DefaultInterfaceUpdateCallback) (DefaultInterfaceMonitor, error) {
return &defaultInterfaceMonitor{
networkMonitor: networkMonitor,
callback: callback,
}, nil
}
func (m *defaultInterfaceMonitor) Start() error {
err := m.checkUpdate()
if err != nil {
return err
}
m.element = m.networkMonitor.RegisterCallback(m.checkUpdate)
return nil
}
func (m *defaultInterfaceMonitor) Close() error {
m.networkMonitor.UnregisterCallback(m.element)
return nil
}
func (m *defaultInterfaceMonitor) checkUpdate() error {
rows, err := winipcfg.GetIPForwardTable2(windows.AF_INET) rows, err := winipcfg.GetIPForwardTable2(windows.AF_INET)
if err != nil { if err != nil {
return err return err
@ -85,14 +156,10 @@ func (m *NativeMonitor) checkUpdate() error {
return nil return nil
} }
func (m *NativeMonitor) Close() error { func (m *defaultInterfaceMonitor) DefaultInterfaceName() string {
return m.listener.Unregister()
}
func (m *NativeMonitor) DefaultInterfaceName() string {
return m.defaultInterfaceName return m.defaultInterfaceName
} }
func (m *NativeMonitor) DefaultInterfaceIndex() int { func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int {
return m.defaultInterfaceIndex return m.defaultInterfaceIndex
} }

View file

@ -45,6 +45,29 @@ func Open(name string, inet4Address netip.Prefix, inet6Address netip.Prefix, mtu
return nativeTun, nil return nativeTun, nil
} }
func (t *NativeTun) routes(tunLink netlink.Link) []netlink.Route {
var routes []netlink.Route
if t.inet4Address.IsValid() {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(0, 32),
},
LinkIndex: tunLink.Attrs().Index,
})
}
if t.inet6Address.IsValid() {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv6zero,
Mask: net.CIDRMask(0, 128),
},
LinkIndex: tunLink.Attrs().Index,
})
}
return routes
}
func (t *NativeTun) configure() error { func (t *NativeTun) configure() error {
tunLink, err := netlink.LinkByName(t.name) tunLink, err := netlink.LinkByName(t.name)
if err != nil { if err != nil {
@ -77,26 +100,8 @@ func (t *NativeTun) configure() error {
} }
if t.autoRoute { if t.autoRoute {
if t.inet4Address.IsValid() { for _, route := range t.routes(tunLink) {
err = netlink.RouteAdd(&netlink.Route{ err = netlink.RouteAdd(&route)
Dst: &net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(0, 32),
},
LinkIndex: tunLink.Attrs().Index,
})
if err != nil {
return err
}
}
if t.inet6Address.IsValid() {
err = netlink.RouteAdd(&netlink.Route{
Dst: &net.IPNet{
IP: net.IPv6zero,
Mask: net.CIDRMask(0, 128),
},
LinkIndex: tunLink.Attrs().Index,
})
if err != nil { if err != nil {
return err return err
} }
@ -133,26 +138,8 @@ func (t *NativeTun) Close() error {
return err return err
} }
if t.autoRoute { if t.autoRoute {
if t.inet4Address.IsValid() { for _, route := range t.routes(tunLink) {
err = netlink.RouteDel(&netlink.Route{ err = netlink.RouteDel(&route)
Dst: &net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(0, 32),
},
LinkIndex: tunLink.Attrs().Index,
})
if err != nil {
return err
}
}
if t.inet6Address.IsValid() {
err = netlink.RouteDel(&netlink.Route{
Dst: &net.IPNet{
IP: net.IPv6zero,
Mask: net.CIDRMask(0, 128),
},
LinkIndex: tunLink.Attrs().Index,
})
if err != nil { if err != nil {
return err return err
} }