diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index f5485af..492a04e 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -3,6 +3,7 @@ package udpnat import ( "context" "net/netip" + "sync" "time" "github.com/sagernet/sing/common" @@ -18,6 +19,10 @@ type Service struct { handler N.UDPConnectionHandlerEx prepare PrepareFunc metrics Metrics + + timeout time.Duration + closeOnce sync.Once + doneChan chan struct{} } type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) @@ -50,12 +55,38 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur conn.Close() }) return &Service{ - cache: cache, - handler: handler, - prepare: prepare, + cache: cache, + handler: handler, + prepare: prepare, + timeout: timeout, + doneChan: make(chan struct{}), } } +func (s *Service) Start() error { + ticker := time.NewTicker(s.timeout) + go func() { + defer ticker.Stop() + for { + select { + case <-ticker.C: + s.PurgeExpired() + case <-s.doneChan: + s.Purge() + return + } + } + }() + return nil +} + +func (s *Service) Close() error { + s.closeOnce.Do(func() { + close(s.doneChan) + }) + return nil +} + func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { conn, loaded := s.cache.Get(source.AddrPort()) if !loaded {