refactor: Modular inbound/outbound manager

This commit is contained in:
世界 2024-11-09 21:16:11 +08:00
parent 28ec898a8c
commit 19fb214226
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
69 changed files with 1186 additions and 754 deletions

View file

@ -77,7 +77,6 @@ func exchangeDNSPacket(ctx context.Context, router *Router, conn N.PacketConn, b
return err
}
err = conn.WritePacket(responseBuffer, destination)
responseBuffer.Release()
return err
}

View file

@ -87,7 +87,7 @@ type Router struct {
v2rayServer adapter.V2RayServer
platformInterface platform.Interface
needWIFIState bool
needPackageManager bool
enforcePackageManager bool
wifiState adapter.WIFIState
started bool
}
@ -123,7 +123,7 @@ func NewRouter(
pauseManager: service.FromContext[pause.Manager](ctx),
platformInterface: service.FromContext[platform.Interface](ctx),
needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule),
needPackageManager: common.Any(inbounds, func(inbound option.Inbound) bool {
enforcePackageManager: common.Any(inbounds, func(inbound option.Inbound) bool {
if tunOptions, isTUN := inbound.Options.(*option.TunInboundOptions); isTUN && tunOptions.AutoRoute {
return true
}
@ -191,7 +191,8 @@ func NewRouter(
transportTags[i] = tag
transportTagMap[tag] = true
}
ctx = adapter.ContextWithRouter(ctx, router)
ctx = service.ContextWith[adapter.Router](ctx, router)
outboundManager := service.FromContext[adapter.OutboundManager](ctx)
for {
lastLen := len(dummyTransportMap)
for i, server := range dnsOptions.Servers {
@ -201,9 +202,9 @@ func NewRouter(
}
var detour N.Dialer
if server.Detour == "" {
detour = dialer.NewRouter(router)
detour = dialer.NewDefaultOutbound(outboundManager)
} else {
detour = dialer.NewDetour(router, server.Detour)
detour = dialer.NewDetour(outboundManager, server.Detour)
}
var serverProtocol string
switch server.Address {
@ -327,7 +328,7 @@ func NewRouter(
}
usePlatformDefaultInterfaceMonitor := router.platformInterface != nil && router.platformInterface.UsePlatformDefaultInterfaceMonitor()
needInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool {
enforceInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool {
if httpMixedOptions, isHTTPMixed := inbound.Options.(*option.HTTPMixedInboundOptions); isHTTPMixed && httpMixedOptions.SetSystemProxy {
return true
}
@ -339,7 +340,7 @@ func NewRouter(
if !usePlatformDefaultInterfaceMonitor {
networkMonitor, err := tun.NewNetworkUpdateMonitor(router.logger)
if !((err != nil && !needInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) {
if !((err != nil && !enforceInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) {
if err != nil {
return nil, err
}
@ -362,7 +363,7 @@ func NewRouter(
}
if ntpOptions.Enabled {
ntpDialer, err := dialer.New(router, ntpOptions.DialerOptions)
ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions)
if err != nil {
return nil, E.Cause(err, "create NTP service")
}
@ -380,73 +381,6 @@ func NewRouter(
return router, nil
}
func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error {
inboundByTag := make(map[string]adapter.Inbound)
for _, inbound := range inbounds {
inboundByTag[inbound.Tag()] = inbound
}
outboundByTag := make(map[string]adapter.Outbound)
for _, detour := range outbounds {
outboundByTag[detour.Tag()] = detour
}
var defaultOutboundForConnection adapter.Outbound
var defaultOutboundForPacketConnection adapter.Outbound
if r.defaultDetour != "" {
detour, loaded := outboundByTag[r.defaultDetour]
if !loaded {
return E.New("default detour not found: ", r.defaultDetour)
}
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
}
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
}
}
if defaultOutboundForConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
break
}
}
}
if defaultOutboundForPacketConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
break
}
}
}
if defaultOutboundForConnection == nil || defaultOutboundForPacketConnection == nil {
detour := defaultOutbound()
if defaultOutboundForConnection == nil {
defaultOutboundForConnection = detour
}
if defaultOutboundForPacketConnection == nil {
defaultOutboundForPacketConnection = detour
}
outbounds = append(outbounds, detour)
outboundByTag[detour.Tag()] = detour
}
r.inboundByTag = inboundByTag
r.outbounds = outbounds
r.defaultOutboundForConnection = defaultOutboundForConnection
r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection
r.outboundByTag = outboundByTag
for i, rule := range r.rules {
routeAction, isRoute := rule.Action().(*R.RuleActionRoute)
if !isRoute {
continue
}
if _, loaded := outboundByTag[routeAction.Outbound]; !loaded {
return E.New("outbound not found for rule[", i, "]: ", routeAction.Outbound)
}
}
return nil
}
func (r *Router) Outbounds() []adapter.Outbound {
if !r.started {
return nil
@ -454,140 +388,240 @@ func (r *Router) Outbounds() []adapter.Outbound {
return r.outbounds
}
func (r *Router) PreStart() error {
func (r *Router) Start(stage adapter.StartStage) error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
if r.interfaceMonitor != nil {
monitor.Start("initialize interface monitor")
err := r.interfaceMonitor.Start()
monitor.Finish()
if err != nil {
return err
}
}
if r.networkMonitor != nil {
monitor.Start("initialize network monitor")
err := r.networkMonitor.Start()
monitor.Finish()
if err != nil {
return err
}
}
if r.fakeIPStore != nil {
monitor.Start("initialize fakeip store")
err := r.fakeIPStore.Start()
monitor.Finish()
if err != nil {
return err
}
}
return nil
}
func (r *Router) Start() error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
if r.needGeoIPDatabase {
monitor.Start("initialize geoip database")
err := r.prepareGeoIPDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
monitor.Start("initialize geosite database")
err := r.prepareGeositeDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
for _, rule := range r.rules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
for _, rule := range r.dnsRules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
err := common.Close(r.geositeReader)
if err != nil {
return err
}
r.geositeCache = nil
r.geositeReader = nil
}
if runtime.GOOS == "windows" {
powerListener, err := winpowrprof.NewEventListener(r.notifyWindowsPowerEvent)
if err == nil {
r.powerListener = powerListener
} else {
r.logger.Warn("initialize power listener: ", err)
}
}
if r.powerListener != nil {
monitor.Start("start power listener")
err := r.powerListener.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start power listener")
}
}
monitor.Start("initialize DNS client")
r.dnsClient.Start()
monitor.Finish()
if C.IsAndroid && r.platformInterface == nil {
monitor.Start("initialize package manager")
packageManager, err := tun.NewPackageManager(tun.PackageManagerOptions{
Callback: r,
Logger: r.logger,
})
monitor.Finish()
if err != nil {
return E.Cause(err, "create package manager")
}
if r.needPackageManager {
monitor.Start("start package manager")
err = packageManager.Start()
switch stage {
case adapter.StartStateInitialize:
if r.interfaceMonitor != nil {
monitor.Start("initialize interface monitor")
err := r.interfaceMonitor.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start package manager")
return err
}
}
r.packageManager = packageManager
}
if r.networkMonitor != nil {
monitor.Start("initialize network monitor")
err := r.networkMonitor.Start()
monitor.Finish()
if err != nil {
return err
}
}
if r.fakeIPStore != nil {
monitor.Start("initialize fakeip store")
err := r.fakeIPStore.Start()
monitor.Finish()
if err != nil {
return err
}
}
case adapter.StartStateStart:
if r.needGeoIPDatabase {
monitor.Start("initialize geoip database")
err := r.prepareGeoIPDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
monitor.Start("initialize geosite database")
err := r.prepareGeositeDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
for _, rule := range r.rules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
for _, rule := range r.dnsRules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
err := common.Close(r.geositeReader)
if err != nil {
return err
}
r.geositeCache = nil
r.geositeReader = nil
}
for i, rule := range r.dnsRules {
monitor.Start("initialize DNS rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS rule[", i, "]")
if runtime.GOOS == "windows" {
powerListener, err := winpowrprof.NewEventListener(r.notifyWindowsPowerEvent)
if err == nil {
r.powerListener = powerListener
} else {
r.logger.Warn("initialize power listener: ", err)
}
}
}
for i, transport := range r.transports {
monitor.Start("initialize DNS transport[", i, "]")
err := transport.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS server[", i, "]")
if r.powerListener != nil {
monitor.Start("start power listener")
err := r.powerListener.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start power listener")
}
}
}
if r.timeService != nil {
monitor.Start("initialize time service")
err := r.timeService.Start()
monitor.Start("initialize DNS client")
r.dnsClient.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize time service")
if C.IsAndroid && r.platformInterface == nil {
monitor.Start("initialize package manager")
packageManager, err := tun.NewPackageManager(tun.PackageManagerOptions{
Callback: r,
Logger: r.logger,
})
monitor.Finish()
if err != nil {
return E.Cause(err, "create package manager")
}
if r.enforcePackageManager {
monitor.Start("start package manager")
err = packageManager.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start package manager")
}
}
r.packageManager = packageManager
}
for i, rule := range r.dnsRules {
monitor.Start("initialize DNS rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS rule[", i, "]")
}
}
for i, transport := range r.transports {
monitor.Start("initialize DNS transport[", i, "]")
err := transport.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS server[", i, "]")
}
}
if r.timeService != nil {
monitor.Start("initialize time service")
err := r.timeService.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize time service")
}
}
case adapter.StartStatePostStart:
var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set")
cacheContext = adapter.NewHTTPStartContext()
var ruleSetStartGroup task.Group
for i, ruleSet := range r.ruleSets {
ruleSetInPlace := ruleSet
ruleSetStartGroup.Append0(func(ctx context.Context) error {
err := ruleSetInPlace.StartContext(ctx, cacheContext)
if err != nil {
return E.Cause(err, "initialize rule-set[", i, "]")
}
return nil
})
}
ruleSetStartGroup.Concurrency(5)
ruleSetStartGroup.FastFail()
err := ruleSetStartGroup.Run(r.ctx)
monitor.Finish()
if err != nil {
return err
}
}
if cacheContext != nil {
cacheContext.Close()
}
needFindProcess := r.needFindProcess
needWIFIState := r.needWIFIState
for _, ruleSet := range r.ruleSets {
metadata := ruleSet.Metadata()
if metadata.ContainsProcessRule {
needFindProcess = true
}
if metadata.ContainsWIFIRule {
needWIFIState = true
}
}
if C.IsAndroid && r.platformInterface == nil && !r.enforcePackageManager {
if needFindProcess {
monitor.Start("start package manager")
err := r.packageManager.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start package manager")
}
} else {
r.packageManager = nil
}
}
if needFindProcess {
if r.platformInterface != nil {
r.processSearcher = r.platformInterface
} else {
monitor.Start("initialize process searcher")
searcher, err := process.NewSearcher(process.Config{
Logger: r.logger,
PackageManager: r.packageManager,
})
monitor.Finish()
if err != nil {
if err != os.ErrInvalid {
r.logger.Warn(E.Cause(err, "create process searcher"))
}
} else {
r.processSearcher = searcher
}
}
}
if needWIFIState && r.platformInterface != nil {
monitor.Start("initialize WIFI state")
r.needWIFIState = true
r.interfaceMonitor.RegisterCallback(func(_ int) {
r.updateWIFIState()
})
r.updateWIFIState()
monitor.Finish()
}
for i, rule := range r.rules {
monitor.Start("initialize rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize rule[", i, "]")
}
}
for _, ruleSet := range r.ruleSets {
monitor.Start("post start rule_set[", ruleSet.Name(), "]")
err := ruleSet.PostStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]")
}
}
r.started = true
return nil
case adapter.StartStateStarted:
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
}
return nil
}
@ -668,113 +702,6 @@ func (r *Router) Close() error {
return err
}
func (r *Router) PostStart() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout)
var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set")
cacheContext = adapter.NewHTTPStartContext()
var ruleSetStartGroup task.Group
for i, ruleSet := range r.ruleSets {
ruleSetInPlace := ruleSet
ruleSetStartGroup.Append0(func(ctx context.Context) error {
err := ruleSetInPlace.StartContext(ctx, cacheContext)
if err != nil {
return E.Cause(err, "initialize rule-set[", i, "]")
}
return nil
})
}
ruleSetStartGroup.Concurrency(5)
ruleSetStartGroup.FastFail()
err := ruleSetStartGroup.Run(r.ctx)
monitor.Finish()
if err != nil {
return err
}
}
if cacheContext != nil {
cacheContext.Close()
}
needFindProcess := r.needFindProcess
needWIFIState := r.needWIFIState
for _, ruleSet := range r.ruleSets {
metadata := ruleSet.Metadata()
if metadata.ContainsProcessRule {
needFindProcess = true
}
if metadata.ContainsWIFIRule {
needWIFIState = true
}
}
if C.IsAndroid && r.platformInterface == nil && !r.needPackageManager {
if needFindProcess {
monitor.Start("start package manager")
err := r.packageManager.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start package manager")
}
} else {
r.packageManager = nil
}
}
if needFindProcess {
if r.platformInterface != nil {
r.processSearcher = r.platformInterface
} else {
monitor.Start("initialize process searcher")
searcher, err := process.NewSearcher(process.Config{
Logger: r.logger,
PackageManager: r.packageManager,
})
monitor.Finish()
if err != nil {
if err != os.ErrInvalid {
r.logger.Warn(E.Cause(err, "create process searcher"))
}
} else {
r.processSearcher = searcher
}
}
}
if needWIFIState && r.platformInterface != nil {
monitor.Start("initialize WIFI state")
r.needWIFIState = true
r.interfaceMonitor.RegisterCallback(func(_ int) {
r.updateWIFIState()
})
r.updateWIFIState()
monitor.Finish()
}
for i, rule := range r.rules {
monitor.Start("initialize rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize rule[", i, "]")
}
}
for _, ruleSet := range r.ruleSets {
monitor.Start("post start rule_set[", ruleSet.Name(), "]")
err := ruleSet.PostStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]")
}
}
r.started = true
return nil
}
func (r *Router) Cleanup() error {
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
return nil
}
func (r *Router) Outbound(tag string) (adapter.Outbound, bool) {
outbound, loaded := r.outboundByTag[tag]
return outbound, loaded

View file

@ -22,7 +22,7 @@ import (
N "github.com/sagernet/sing/common/network"
)
func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) {
func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) {
switch action.Action {
case "":
return nil, nil
@ -36,7 +36,7 @@ func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action op
UDPConnect: action.RouteOptionsOptions.UDPConnect,
}, nil
case C.RuleActionTypeDirect:
directDialer, err := dialer.New(router, option.DialerOptions(action.DirectOptions))
directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions))
if err != nil {
return nil, err
}

View file

@ -52,7 +52,7 @@ type RuleItem interface {
}
func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) {
action, err := NewRuleAction(router, logger, options.RuleAction)
action, err := NewRuleAction(ctx, logger, options.RuleAction)
if err != nil {
return nil, E.Cause(err, "action")
}
@ -254,7 +254,7 @@ type LogicalRule struct {
}
func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
action, err := NewRuleAction(router, logger, options.RuleAction)
action, err := NewRuleAction(ctx, logger, options.RuleAction)
if err != nil {
return nil, E.Cause(err, "action")
}

View file

@ -33,23 +33,24 @@ import (
var _ adapter.RuleSet = (*RemoteRuleSet)(nil)
type RemoteRuleSet struct {
ctx context.Context
cancel context.CancelFunc
router adapter.Router
logger logger.ContextLogger
options option.RuleSet
metadata adapter.RuleSetMetadata
updateInterval time.Duration
dialer N.Dialer
rules []adapter.HeadlessRule
lastUpdated time.Time
lastEtag string
updateTicker *time.Ticker
cacheFile adapter.CacheFile
pauseManager pause.Manager
callbackAccess sync.Mutex
callbacks list.List[adapter.RuleSetUpdateCallback]
refs atomic.Int32
ctx context.Context
cancel context.CancelFunc
router adapter.Router
outboundManager adapter.OutboundManager
logger logger.ContextLogger
options option.RuleSet
metadata adapter.RuleSetMetadata
updateInterval time.Duration
dialer N.Dialer
rules []adapter.HeadlessRule
lastUpdated time.Time
lastEtag string
updateTicker *time.Ticker
cacheFile adapter.CacheFile
pauseManager pause.Manager
callbackAccess sync.Mutex
callbacks list.List[adapter.RuleSetUpdateCallback]
refs atomic.Int32
}
func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet {
@ -61,13 +62,14 @@ func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.
updateInterval = 24 * time.Hour
}
return &RemoteRuleSet{
ctx: ctx,
cancel: cancel,
router: router,
logger: logger,
options: options,
updateInterval: updateInterval,
pauseManager: service.FromContext[pause.Manager](ctx),
ctx: ctx,
cancel: cancel,
router: router,
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logger,
options: options,
updateInterval: updateInterval,
pauseManager: service.FromContext[pause.Manager](ctx),
}
}
@ -83,17 +85,13 @@ func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter.
s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx)
var dialer N.Dialer
if s.options.RemoteOptions.DownloadDetour != "" {
outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour)
outbound, loaded := s.outboundManager.Outbound(s.options.RemoteOptions.DownloadDetour)
if !loaded {
return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour)
}
dialer = outbound
} else {
outbound, err := s.router.DefaultOutbound(N.NetworkTCP)
if err != nil {
return err
}
dialer = outbound
dialer = s.outboundManager.Default()
}
s.dialer = dialer
if s.cacheFile != nil {