diff --git a/monitor.go b/monitor.go index 91b3c84..68bcff9 100644 --- a/monitor.go +++ b/monitor.go @@ -1,6 +1,8 @@ package tun import ( + "net/netip" + E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/x/list" ) @@ -23,8 +25,8 @@ type NetworkUpdateMonitor interface { type DefaultInterfaceMonitor interface { Start() error Close() error - DefaultInterfaceName() string - DefaultInterfaceIndex() int + DefaultInterfaceName(destination netip.Addr) string + DefaultInterfaceIndex(destination netip.Addr) int RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback] UnregisterCallback(element *list.Element[DefaultInterfaceUpdateCallback]) } diff --git a/monitor_shared.go b/monitor_shared.go index eeef602..42c85d8 100644 --- a/monitor_shared.go +++ b/monitor_shared.go @@ -4,9 +4,14 @@ package tun import ( "context" + "net" + "net/netip" "sync" "time" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/x/list" ) @@ -39,6 +44,7 @@ func (m *networkUpdateMonitor) NewError(ctx context.Context, err error) { } type defaultInterfaceMonitor struct { + networkAddresses []networkAddress defaultInterfaceName string defaultInterfaceIndex int networkMonitor NetworkUpdateMonitor @@ -47,6 +53,12 @@ type defaultInterfaceMonitor struct { callbacks list.List[DefaultInterfaceUpdateCallback] } +type networkAddress struct { + interfaceName string + interfaceIndex int + addresses []netip.Prefix +} + func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor) (DefaultInterfaceMonitor, error) { return &defaultInterfaceMonitor{ networkMonitor: networkMonitor, @@ -64,9 +76,41 @@ func (m *defaultInterfaceMonitor) Start() error { func (m *defaultInterfaceMonitor) delayCheckUpdate() error { time.Sleep(time.Second) + err := m.updateInterfaces() + if err != nil { + m.networkMonitor.NewError(context.Background(), E.Cause(err, "update interfaces")) + } return m.checkUpdate() } +func (m *defaultInterfaceMonitor) updateInterfaces() error { + interfaces, err := net.Interfaces() + if err != nil { + return err + } + var addresses []networkAddress + for _, iif := range interfaces { + var netAddresses []net.Addr + netAddresses, err = iif.Addrs() + if err != nil { + return err + } + var address networkAddress + address.interfaceName = iif.Name + address.interfaceIndex = iif.Index + address.addresses = common.Map(common.FilterIsInstance(netAddresses, func(it net.Addr) (*net.IPNet, bool) { + value, loaded := it.(*net.IPNet) + return value, loaded + }), func(it *net.IPNet) netip.Prefix { + bits, _ := it.Mask.Size() + return netip.PrefixFrom(M.AddrFromIP(it.IP), bits) + }) + addresses = append(addresses, address) + } + m.networkAddresses = addresses + return nil +} + func (m *defaultInterfaceMonitor) Close() error { if m.element != nil { m.networkMonitor.UnregisterCallback(m.element) @@ -74,11 +118,25 @@ func (m *defaultInterfaceMonitor) Close() error { return nil } -func (m *defaultInterfaceMonitor) DefaultInterfaceName() string { +func (m *defaultInterfaceMonitor) DefaultInterfaceName(destination netip.Addr) string { + for _, address := range m.networkAddresses { + for _, prefix := range address.addresses { + if prefix.Contains(destination) { + return address.interfaceName + } + } + } return m.defaultInterfaceName } -func (m *defaultInterfaceMonitor) DefaultInterfaceIndex() int { +func (m *defaultInterfaceMonitor) DefaultInterfaceIndex(destination netip.Addr) int { + for _, address := range m.networkAddresses { + for _, prefix := range address.addresses { + if prefix.Contains(destination) { + return address.interfaceIndex + } + } + } return m.defaultInterfaceIndex }