refactor: Extract services form router

This commit is contained in:
世界 2024-11-10 16:46:59 +08:00
parent a1be455202
commit 9afe75586a
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
27 changed files with 314 additions and 464 deletions

View file

@ -93,7 +93,18 @@ func New(ctx context.Context, options option.CacheFileOptions) *CacheFile {
}
}
func (c *CacheFile) start() error {
func (c *CacheFile) Name() string {
return "cache-file"
}
func (c *CacheFile) Dependencies() []string {
return nil
}
func (c *CacheFile) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateInitialize {
return nil
}
const fileMode = 0o666
options := bbolt.Options{Timeout: time.Second}
var (
@ -151,14 +162,6 @@ func (c *CacheFile) start() error {
return nil
}
func (c *CacheFile) PreStart() error {
return c.start()
}
func (c *CacheFile) Start() error {
return nil
}
func (c *CacheFile) Close() error {
if c.DB == nil {
return nil

View file

@ -133,45 +133,50 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op
return s, nil
}
func (s *Server) PreStart() error {
cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
if cacheFile != nil {
mode := cacheFile.LoadMode()
if common.Any(s.modeList, func(it string) bool {
return strings.EqualFold(it, mode)
}) {
s.mode = mode
}
}
return nil
func (s *Server) Name() string {
return "clash server"
}
func (s *Server) Start() error {
if s.externalController {
s.checkAndDownloadExternalUI()
var (
listener net.Listener
err error
)
for i := 0; i < 3; i++ {
listener, err = net.Listen("tcp", s.httpServer.Addr)
if runtime.GOOS == "android" && errors.Is(err, syscall.EADDRINUSE) {
time.Sleep(100 * time.Millisecond)
continue
func (s *Server) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateStart:
cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
if cacheFile != nil {
mode := cacheFile.LoadMode()
if common.Any(s.modeList, func(it string) bool {
return strings.EqualFold(it, mode)
}) {
s.mode = mode
}
break
}
if err != nil {
return E.Cause(err, "external controller listen error")
}
s.logger.Info("restful api listening at ", listener.Addr())
go func() {
err = s.httpServer.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
s.logger.Error("external controller serve error: ", err)
case adapter.StartStateStarted:
if s.externalController {
s.checkAndDownloadExternalUI()
var (
listener net.Listener
err error
)
for i := 0; i < 3; i++ {
listener, err = net.Listen("tcp", s.httpServer.Addr)
if runtime.GOOS == "android" && errors.Is(err, syscall.EADDRINUSE) {
time.Sleep(100 * time.Millisecond)
continue
}
break
}
}()
if err != nil {
return E.Cause(err, "external controller listen error")
}
s.logger.Info("restful api listening at ", listener.Addr())
go func() {
err = s.httpServer.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
s.logger.Error("external controller serve error: ", err)
}
}()
}
}
return nil
}
@ -233,14 +238,12 @@ func (s *Server) TrafficManager() *trafficontrol.Manager {
return s.trafficManager
}
func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) {
tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule)
return tracker, tracker
func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn {
return trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule, matchOutbound)
}
func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) {
tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule)
return tracker, tracker
func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn {
return trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule, matchOutbound)
}
func authentication(serverSecret string) func(next http.Handler) http.Handler {

View file

@ -5,7 +5,6 @@ import (
"time"
"github.com/sagernet/sing-box/adapter"
R "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio"
@ -88,7 +87,6 @@ func (t TrackerMetadata) MarshalJSON() ([]byte, error) {
}
type Tracker interface {
adapter.Tracker
Metadata() TrackerMetadata
Close() error
}
@ -108,10 +106,6 @@ func (tt *TCPConn) Close() error {
return tt.ExtendedConn.Close()
}
func (tt *TCPConn) Leave() {
tt.manager.Leave(tt)
}
func (tt *TCPConn) Upstream() any {
return tt.ExtendedConn
}
@ -124,7 +118,7 @@ func (tt *TCPConn) WriterReplaceable() bool {
return true
}
func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *TCPConn {
func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, matchRule adapter.Rule, matchOutbound adapter.Outbound) *TCPConn {
id, _ := uuid.NewV4()
var (
chain []string
@ -132,12 +126,8 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundCont
outbound string
outboundType string
)
var action adapter.RuleAction
if rule != nil {
action = rule.Action()
}
if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction {
next = routeAction.Outbound
if matchOutbound != nil {
next = matchOutbound.Tag()
} else {
next = outboundManager.Default().Tag()
}
@ -172,7 +162,7 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundCont
Upload: upload,
Download: download,
Chain: common.Reverse(chain),
Rule: rule,
Rule: matchRule,
Outbound: outbound,
OutboundType: outboundType,
},
@ -197,10 +187,6 @@ func (ut *UDPConn) Close() error {
return ut.PacketConn.Close()
}
func (ut *UDPConn) Leave() {
ut.manager.Leave(ut)
}
func (ut *UDPConn) Upstream() any {
return ut.PacketConn
}
@ -213,7 +199,7 @@ func (ut *UDPConn) WriterReplaceable() bool {
return true
}
func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *UDPConn {
func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, matchRule adapter.Rule, matchOutbound adapter.Outbound) *UDPConn {
id, _ := uuid.NewV4()
var (
chain []string
@ -221,12 +207,8 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.Inbound
outbound string
outboundType string
)
var action adapter.RuleAction
if rule != nil {
action = rule.Action()
}
if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction {
next = routeAction.Outbound
if matchOutbound != nil {
next = matchOutbound.Tag()
} else {
next = outboundManager.Default().Tag()
}
@ -261,7 +243,7 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.Inbound
Upload: upload,
Download: download,
Chain: common.Reverse(chain),
Rule: rule,
Rule: matchRule,
Outbound: outbound,
OutboundType: outboundType,
},

View file

@ -38,11 +38,7 @@ func (s *CommandServer) handleSetClashMode(conn net.Conn) error {
if service == nil {
return writeError(conn, E.New("service not ready"))
}
clashServer := service.instance.Router().ClashServer()
if clashServer == nil {
return writeError(conn, E.New("Clash API disabled"))
}
clashServer.(*clashapi.Server).SetMode(newMode)
service.clashServer.(*clashapi.Server).SetMode(newMode)
return writeError(conn, nil)
}
@ -69,18 +65,14 @@ func (s *CommandServer) handleModeConn(conn net.Conn) error {
return ctx.Err()
}
}
clashServer := s.service.instance.Router().ClashServer()
if clashServer == nil {
return binary.Write(conn, binary.BigEndian, uint16(0))
}
err := writeClashModeList(conn, clashServer)
err := writeClashModeList(conn, s.service.clashServer)
if err != nil {
return err
}
for {
select {
case <-s.modeUpdate:
err = varbin.Write(conn, binary.BigEndian, clashServer.Mode())
err = varbin.Write(conn, binary.BigEndian, s.service.clashServer.Mode())
if err != nil {
return err
}

View file

@ -45,11 +45,7 @@ func (s *CommandServer) handleCloseConnection(conn net.Conn) error {
if service == nil {
return writeError(conn, E.New("service not ready"))
}
clashServer := service.instance.Router().ClashServer()
if clashServer == nil {
return writeError(conn, E.New("Clash API disabled"))
}
targetConn := clashServer.(*clashapi.Server).TrafficManager().Connection(uuid.FromStringOrNil(connId))
targetConn := service.clashServer.(*clashapi.Server).TrafficManager().Connection(uuid.FromStringOrNil(connId))
if targetConn == nil {
return writeError(conn, E.New("connection already closed"))
}

View file

@ -49,11 +49,7 @@ func (s *CommandServer) handleConnectionsConn(conn net.Conn) error {
for {
service := s.service
if service != nil {
clashServer := service.instance.Router().ClashServer()
if clashServer == nil {
return E.New("Clash API disabled")
}
trafficManager = clashServer.(*clashapi.Server).TrafficManager()
trafficManager = service.clashServer.(*clashapi.Server).TrafficManager()
break
}
select {

View file

@ -60,7 +60,7 @@ func NewCommandServer(handler CommandServerHandler, maxLines int32) *CommandServ
func (s *CommandServer) SetService(newService *BoxService) {
if newService != nil {
service.PtrFromContext[urltest.HistoryStorage](newService.ctx).SetHook(s.urlTestUpdate)
newService.instance.Router().ClashServer().(*clashapi.Server).SetModeUpdateHook(s.modeUpdate)
newService.clashServer.(*clashapi.Server).SetModeUpdateHook(s.modeUpdate)
}
s.service = newService
s.notifyURLTestUpdate()

View file

@ -31,12 +31,10 @@ func (s *CommandServer) readStatus() StatusMessage {
message.ConnectionsOut = int32(conntrack.Count())
if s.service != nil {
if clashServer := s.service.instance.Router().ClashServer(); clashServer != nil {
message.TrafficAvailable = true
trafficManager := clashServer.(*clashapi.Server).TrafficManager()
message.UplinkTotal, message.DownlinkTotal = trafficManager.Total()
message.ConnectionsIn = int32(trafficManager.ConnectionsLen())
}
message.TrafficAvailable = true
trafficManager := s.service.clashServer.(*clashapi.Server).TrafficManager()
message.UplinkTotal, message.DownlinkTotal = trafficManager.Total()
message.ConnectionsIn = int32(trafficManager.ConnectionsLen())
}
return message

View file

@ -34,17 +34,18 @@ import (
type BoxService struct {
ctx context.Context
cancel context.CancelFunc
instance *box.Box
pauseManager pause.Manager
urlTestHistoryStorage *urltest.HistoryStorage
instance *box.Box
clashServer adapter.ClashServer
pauseManager pause.Manager
servicePauseFields
}
func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) {
ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry())
ctx = service.ContextWith[deprecated.Manager](ctx, new(deprecatedManager))
ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID)
service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager))
options, err := parseConfig(ctx, configContent)
if err != nil {
return nil, err
@ -54,7 +55,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box
urlTestHistoryStorage := urltest.NewHistoryStorage()
ctx = service.ContextWithPtr(ctx, urlTestHistoryStorage)
platformWrapper := &platformInterfaceWrapper{iif: platformInterface, useProcFS: platformInterface.UseProcFS()}
ctx = service.ContextWith[platform.Interface](ctx, platformWrapper)
service.MustRegister[platform.Interface](ctx, platformWrapper)
instance, err := box.New(box.Options{
Context: ctx,
Options: options,
@ -71,6 +72,7 @@ func NewService(configContent string, platformInterface PlatformInterface) (*Box
instance: instance,
urlTestHistoryStorage: urlTestHistoryStorage,
pauseManager: service.FromContext[pause.Manager](ctx),
clashServer: service.FromContext[adapter.ClashServer](ctx),
}, nil
}

View file

@ -44,7 +44,14 @@ func NewServer(logger log.Logger, options option.V2RayAPIOptions) (adapter.V2Ray
return server, nil
}
func (s *Server) Start() error {
func (s *Server) Name() string {
return "v2ray server"
}
func (s *Server) Start(stage adapter.StartStage) error {
if stage != adapter.StartStatePostStart {
return nil
}
listener, err := net.Listen("tcp", s.listen)
if err != nil {
return err
@ -70,6 +77,6 @@ func (s *Server) Close() error {
)
}
func (s *Server) StatsService() adapter.V2RayStatsService {
func (s *Server) StatsService() adapter.ConnectionTracker {
return s.statsService
}

View file

@ -22,7 +22,7 @@ func init() {
}
var (
_ adapter.V2RayStatsService = (*StatsService)(nil)
_ adapter.ConnectionTracker = (*StatsService)(nil)
_ StatsServiceServer = (*StatsService)(nil)
)
@ -60,7 +60,10 @@ func NewStatsService(options option.V2RayStatsServiceOptions) *StatsService {
}
}
func (s *StatsService) RoutedConnection(inbound string, outbound string, user string, conn net.Conn) net.Conn {
func (s *StatsService) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn {
inbound := metadata.Inbound
user := metadata.User
outbound := matchOutbound.Tag()
var readCounter []*atomic.Int64
var writeCounter []*atomic.Int64
countInbound := inbound != "" && s.inbounds[inbound]
@ -86,7 +89,10 @@ func (s *StatsService) RoutedConnection(inbound string, outbound string, user st
return bufio.NewInt64CounterConn(conn, readCounter, writeCounter)
}
func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, user string, conn N.PacketConn) N.PacketConn {
func (s *StatsService) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn {
inbound := metadata.Inbound
user := metadata.User
outbound := matchOutbound.Tag()
var readCounter []*atomic.Int64
var writeCounter []*atomic.Int64
countInbound := inbound != "" && s.inbounds[inbound]