refactor: Extract services form router

This commit is contained in:
世界 2024-11-10 16:46:59 +08:00
parent a1be455202
commit 9afe75586a
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
27 changed files with 314 additions and 464 deletions

View file

@ -91,16 +91,12 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if err != nil {
return err
}
var (
// selectedOutbound adapter.Outbound
selectedDialer N.Dialer
selectedTag string
selectedDescription string
)
var selectedOutbound adapter.Outbound
if selectedRule != nil {
switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute:
selectedOutbound, loaded := r.outboundManager.Outbound(action.Outbound)
var loaded bool
selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound)
if !loaded {
buf.ReleaseMulti(buffers)
return E.New("outbound not found: ", action.Outbound)
@ -109,12 +105,6 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
buf.ReleaseMulti(buffers)
return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag())
}
selectedDialer = selectedOutbound
selectedTag = selectedOutbound.Tag()
selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
case *rule.RuleActionDirect:
selectedDialer = action.Dialer
selectedDescription = action.String()
case *rule.RuleActionReject:
buf.ReleaseMulti(buffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@ -133,25 +123,16 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
buf.ReleaseMulti(buffers)
return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag())
}
selectedDialer = defaultOutbound
selectedTag = defaultOutbound.Tag()
selectedDescription = F.ToString("outbound/", defaultOutbound.Type(), "[", defaultOutbound.Tag(), "]")
selectedOutbound = defaultOutbound
}
for _, buffer := range buffers {
conn = bufio.NewCachedConn(conn, buffer)
}
if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, selectedRule)
defer tracker.Leave()
conn = trackerConn
if r.tracker != nil {
conn = r.tracker.RoutedConnection(ctx, conn, metadata, selectedRule, selectedOutbound)
}
if r.v2rayServer != nil {
if statsService := r.v2rayServer.StatsService(); statsService != nil {
conn = statsService.RoutedConnection(metadata.Inbound, selectedTag, metadata.User, conn)
}
}
legacyOutbound, isLegacy := selectedDialer.(adapter.ConnectionHandler)
legacyOutbound, isLegacy := selectedOutbound.(adapter.ConnectionHandler)
if isLegacy {
err = legacyOutbound.NewConnection(ctx, conn, metadata)
if err != nil {
@ -159,7 +140,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
if onClose != nil {
onClose(err)
}
return E.Cause(err, selectedDescription)
return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
} else {
if onClose != nil {
onClose(nil)
@ -168,13 +149,13 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad
return nil
}
// TODO
err = outbound.NewConnection(ctx, selectedDialer, conn, metadata)
err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata)
if err != nil {
conn.Close()
if onClose != nil {
onClose(err)
}
return E.Cause(err, selectedDescription)
return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
} else {
if onClose != nil {
onClose(nil)
@ -246,16 +227,13 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
if err != nil {
return err
}
var (
selectedDialer N.Dialer
selectedTag string
selectedDescription string
)
var selectedOutbound adapter.Outbound
var selectReturn bool
if selectedRule != nil {
switch action := selectedRule.Action().(type) {
case *rule.RuleActionRoute:
selectedOutbound, loaded := r.outboundManager.Outbound(action.Outbound)
var loaded bool
selectedOutbound, loaded = r.outboundManager.Outbound(action.Outbound)
if !loaded {
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("outbound not found: ", action.Outbound)
@ -264,12 +242,6 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag())
}
selectedDialer = selectedOutbound
selectedTag = selectedOutbound.Tag()
selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]")
case *rule.RuleActionDirect:
selectedDialer = action.Dialer
selectedDescription = action.String()
case *rule.RuleActionReject:
N.ReleaseMultiPacketBuffer(packetBuffers)
N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx))
@ -285,41 +257,32 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m
N.ReleaseMultiPacketBuffer(packetBuffers)
return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag())
}
selectedDialer = defaultOutbound
selectedTag = defaultOutbound.Tag()
selectedDescription = F.ToString("outbound/", defaultOutbound.Type(), "[", defaultOutbound.Tag(), "]")
selectedOutbound = defaultOutbound
}
for _, buffer := range packetBuffers {
conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination)
N.PutPacketBuffer(buffer)
}
if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedPacketConnection(ctx, conn, metadata, selectedRule)
defer tracker.Leave()
conn = trackerConn
}
if r.v2rayServer != nil {
if statsService := r.v2rayServer.StatsService(); statsService != nil {
conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedTag, metadata.User, conn)
}
if r.tracker != nil {
conn = r.tracker.RoutedPacketConnection(ctx, conn, metadata, selectedRule, selectedOutbound)
}
if metadata.FakeIP {
conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
}
legacyOutbound, isLegacy := selectedDialer.(adapter.PacketConnectionHandler)
legacyOutbound, isLegacy := selectedOutbound.(adapter.PacketConnectionHandler)
if isLegacy {
err = legacyOutbound.NewPacketConnection(ctx, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil {
return E.Cause(err, selectedDescription)
return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
}
return nil
}
// TODO
err = outbound.NewPacketConnection(ctx, selectedDialer, conn, metadata)
err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata)
N.CloseOnHandshakeFailure(conn, onClose, err)
if err != nil {
return E.Cause(err, selectedDescription)
return E.Cause(err, F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]"))
}
return nil
}

View file

@ -27,7 +27,6 @@ import (
F "github.com/sagernet/sing/common/format"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
@ -63,16 +62,14 @@ type Router struct {
dnsReverseMapping *DNSReverseMapping
fakeIPStore adapter.FakeIPStore
processSearcher process.Searcher
timeService *ntp.Service
pauseManager pause.Manager
clashServer adapter.ClashServer
v2rayServer adapter.V2RayServer
tracker adapter.ConnectionTracker
platformInterface platform.Interface
needWIFIState bool
started bool
}
func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions, ntpOptions option.NTPOptions) (*Router, error) {
func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions) (*Router, error) {
router := &Router{
ctx: ctx,
logger: logFactory.NewLogger("router"),
@ -94,7 +91,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
platformInterface: service.FromContext[platform.Interface](ctx),
needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule),
}
ctx = service.ContextWith[adapter.Router](ctx, router)
service.MustRegister[adapter.Router](ctx, router)
router.dnsClient = dns.NewClient(dns.ClientOptions{
DisableCache: dnsOptions.DNSClientOptions.DisableCache,
DisableExpire: dnsOptions.DNSClientOptions.DisableExpire,
@ -290,23 +287,6 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route
}
router.fakeIPStore = fakeip.NewStore(ctx, router.logger, inet4Range, inet6Range)
}
if ntpOptions.Enabled {
ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions)
if err != nil {
return nil, E.Cause(err, "create NTP service")
}
timeService := ntp.NewService(ntp.Options{
Context: ctx,
Dialer: ntpDialer,
Logger: logFactory.NewLogger("ntp"),
Server: ntpOptions.ServerOptions.Build(),
Interval: time.Duration(ntpOptions.Interval),
WriteToSystem: ntpOptions.WriteToSystem,
})
service.MustRegister[ntp.TimeService](ctx, timeService)
router.timeService = timeService
}
return router, nil
}
@ -380,14 +360,6 @@ func (r *Router) Start(stage adapter.StartStage) error {
return E.Cause(err, "initialize DNS server[", i, "]")
}
}
if r.timeService != nil {
monitor.Start("initialize time service")
err := r.timeService.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize time service")
}
}
case adapter.StartStatePostStart:
var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 {
@ -502,13 +474,6 @@ func (r *Router) Close() error {
})
monitor.Finish()
}
if r.timeService != nil {
monitor.Start("close time service")
err = E.Append(err, r.timeService.Close(), func(err error) error {
return E.Cause(err, "close time service")
})
monitor.Finish()
}
if r.fakeIPStore != nil {
monitor.Start("close fakeip store")
err = E.Append(err, r.fakeIPStore.Close(), func(err error) error {
@ -536,29 +501,8 @@ func (r *Router) Rules() []adapter.Rule {
return r.rules
}
func (r *Router) ClashServer() adapter.ClashServer {
return r.clashServer
}
func (r *Router) SetClashServer(server adapter.ClashServer) {
r.clashServer = server
}
func (r *Router) V2RayServer() adapter.V2RayServer {
return r.v2rayServer
}
func (r *Router) SetV2RayServer(server adapter.V2RayServer) {
r.v2rayServer = server
}
func (r *Router) NewError(ctx context.Context, err error) {
common.Close(err)
if E.IsClosedOrCanceled(err) {
r.logger.DebugContext(ctx, "connection closed: ", err)
return
}
r.logger.ErrorContext(ctx, err)
func (r *Router) SetTracker(tracker adapter.ConnectionTracker) {
r.tracker = tracker
}
func (r *Router) ResetNetwork() {

View file

@ -219,7 +219,7 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio
rule.allItems = append(rule.allItems, item)
}
if options.ClashMode != "" {
item := NewClashModeItem(router, options.ClashMode)
item := NewClashModeItem(ctx, options.ClashMode)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}

View file

@ -216,7 +216,7 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op
rule.allItems = append(rule.allItems, item)
}
if options.ClashMode != "" {
item := NewClashModeItem(router, options.ClashMode)
item := NewClashModeItem(ctx, options.ClashMode)
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}

View file

@ -1,31 +1,38 @@
package rule
import (
"context"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/service"
)
var _ RuleItem = (*ClashModeItem)(nil)
type ClashModeItem struct {
router adapter.Router
mode string
ctx context.Context
clashServer adapter.ClashServer
mode string
}
func NewClashModeItem(router adapter.Router, mode string) *ClashModeItem {
func NewClashModeItem(ctx context.Context, mode string) *ClashModeItem {
return &ClashModeItem{
router: router,
mode: mode,
ctx: ctx,
mode: mode,
}
}
func (r *ClashModeItem) Start() error {
r.clashServer = service.FromContext[adapter.ClashServer](r.ctx)
return nil
}
func (r *ClashModeItem) Match(metadata *adapter.InboundContext) bool {
clashServer := r.router.ClashServer()
if clashServer == nil {
if r.clashServer == nil {
return false
}
return strings.EqualFold(clashServer.Mode(), r.mode)
return strings.EqualFold(r.clashServer.Mode(), r.mode)
}
func (r *ClashModeItem) String() string {