From f196b4303e31f8b5ceb076b3d19b23b9c7277c89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 18 Apr 2023 09:19:09 +0800 Subject: [PATCH] Fix udpnat timeout --- common/cache/lrucache.go | 53 ++++++++++++++++++++++++++++------------ common/udpnat/service.go | 10 +++++--- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go index 37fa26b..675f08a 100644 --- a/common/cache/lrucache.go +++ b/common/cache/lrucache.go @@ -3,6 +3,9 @@ package cache // Modified by https://github.com/die-net/lrucache import ( + "context" + "log" + "runtime/debug" "sync" "time" @@ -38,13 +41,14 @@ func WithSize[K comparable, V any](maxSize int) Option[K, V] { } } -func WithStale[K comparable, V any](stale bool) Option[K, V] { +func WithContext[K comparable, V any](ctx context.Context) Option[K, V] { return func(l *LruCache[K, V]) { - l.staleReturn = stale + l.ctx = ctx } } type LruCache[K comparable, V any] struct { + ctx context.Context maxAge int64 maxSize int mu sync.Mutex @@ -64,6 +68,14 @@ func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] { option(lc) } + if lc.maxAge > 0 { + if lc.ctx == nil { + lc.ctx = context.Background() + log.Println("your lru cache is going to leak") + debug.PrintStack() + } + go lc.loopCheckTimeout() + } return lc } @@ -107,8 +119,6 @@ create: e := &entry[K, V]{key: key, value: value, expires: time.Now().Unix() + c.maxAge} c.cache[key] = c.lru.PushBack(e) } - - c.maybeDeleteOldest() return value, false } @@ -146,8 +156,6 @@ create: e := &entry[K, V]{key: key, value: value, expires: time.Now().Unix() + c.maxAge} c.cache[key] = c.lru.PushBack(e) } - - c.maybeDeleteOldest() return value, false } @@ -195,8 +203,6 @@ func (c *LruCache[K, V]) StoreWithExpire(key K, value V, expires time.Time) { } } } - - c.maybeDeleteOldest() } func (c *LruCache[K, V]) CloneTo(n *LruCache[K, V]) { @@ -234,8 +240,6 @@ func (c *LruCache[K, V]) get(key K) *entry[K, V] { if !c.staleReturn && c.maxAge > 0 && le.Value.expires <= time.Now().Unix() { c.deleteElement(le) - c.maybeDeleteOldest() - return nil } @@ -258,15 +262,34 @@ func (c *LruCache[K, V]) Delete(key K) { c.mu.Unlock() } -func (c *LruCache[K, V]) maybeDeleteOldest() { - if !c.staleReturn && c.maxAge > 0 { - now := time.Now().Unix() - for le := c.lru.Front(); le != nil && le.Value.expires <= now; le = c.lru.Front() { - c.deleteElement(le) +func (c *LruCache[K, V]) loopCheckTimeout() { + ticker := time.NewTicker(time.Second * time.Duration(c.maxAge)) + defer ticker.Stop() + for { + select { + case <-ticker.C: + c.checkTimeout() + case <-c.ctx.Done(): + return } } } +func (c *LruCache[K, V]) checkTimeout() { + c.mu.Lock() + defer c.mu.Unlock() + now := time.Now().Unix() + var toDelete []*list.Element[*entry[K, V]] + for it := c.lru.Front(); it != nil; it = it.Next() { + if it.Value.expires <= now { + toDelete = append(toDelete, it) + } + } + for _, it := range toDelete { + c.deleteElement(it) + } +} + func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) { c.lru.Remove(le) e := le.Value diff --git a/common/udpnat/service.go b/common/udpnat/service.go index be06424..4b978c1 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -25,9 +25,10 @@ type Service[K comparable] struct { handler Handler } -func New[K comparable](maxAge int64, handler Handler) *Service[K] { +func New[K comparable](ctx context.Context, maxAge int64, handler Handler) *Service[K] { return &Service[K]{ nat: cache.New( + cache.WithContext[K, *conn](ctx), cache.WithAge[K, *conn](maxAge), cache.WithUpdateAgeOnGet[K, *conn](), cache.WithEvict[K, *conn](func(key K, conn *conn) { @@ -102,6 +103,10 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu } } +func (s *Service[T]) Close() error { + return common.Close(common.PtrOrNil(s.nat)) +} + type packet struct { data *buf.Buffer destination M.Socksaddr @@ -159,10 +164,9 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *conn) Close() error { select { case <-c.ctx.Done(): - return os.ErrClosed default: + c.cancel(net.ErrClosed) } - c.cancel(net.ErrClosed) if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser { return sourceCloser.Close() }