diff --git a/gvisor.go b/gvisor.go index 81e8efd..9e42a7c 100644 --- a/gvisor.go +++ b/gvisor.go @@ -57,7 +57,7 @@ func NewGVisor( return nil, E.New("gVisor stack is unsupported on current platform") } - return &GVisor{ + gStack := &GVisor{ ctx: options.Context, tun: gTun, tunMtu: options.MTU, @@ -66,8 +66,11 @@ func NewGVisor( router: options.Router, handler: options.Handler, logger: options.Logger, - routeMapping: NewRouteMapping(options.UDPTimeout), - }, nil + } + if gStack.router != nil { + gStack.routeMapping = NewRouteMapping(options.Context, options.UDPTimeout) + } + return gStack, nil } func (t *GVisor) Start() error { diff --git a/route_mapping.go b/route_mapping.go index b0ccad8..be89853 100644 --- a/route_mapping.go +++ b/route_mapping.go @@ -1,6 +1,8 @@ package tun import ( + "context" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/cache" ) @@ -9,9 +11,10 @@ type RouteMapping struct { status *cache.LruCache[RouteSession, RouteAction] } -func NewRouteMapping(maxAge int64) *RouteMapping { +func NewRouteMapping(ctx context.Context, maxAge int64) *RouteMapping { return &RouteMapping{ status: cache.New( + cache.WithContext[RouteSession, RouteAction](ctx), cache.WithAge[RouteSession, RouteAction](maxAge), cache.WithUpdateAgeOnGet[RouteSession, RouteAction](), cache.WithEvict[RouteSession, RouteAction](func(key RouteSession, conn RouteAction) { diff --git a/system.go b/system.go index f6e4e04..9705d07 100644 --- a/system.go +++ b/system.go @@ -63,7 +63,9 @@ func NewSystem(options StackOptions) (Stack, error) { inet4Prefixes: options.Inet4Address, inet6Prefixes: options.Inet6Address, underPlatform: options.UnderPlatform, - routeMapping: NewRouteMapping(options.UDPTimeout), + } + if stack.router != nil { + stack.routeMapping = NewRouteMapping(options.Context, options.UDPTimeout) } if len(options.Inet4Address) > 0 { if options.Inet4Address[0].Bits() == 32 { @@ -115,7 +117,7 @@ func (s *System) Start() error { s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port go s.acceptLoop(tcpListener) } - s.tcpNat = NewNat() + s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) s.udpNat = udpnat.New[netip.AddrPort](s.ctx, s.udpTimeout, s.handler) go s.tunLoop() return nil @@ -208,13 +210,14 @@ func (s *System) acceptLoop(listener net.Listener) { } } go func() { - s.handler.NewConnection(s.ctx, conn, M.Metadata{ + _ = s.handler.NewConnection(s.ctx, conn, M.Metadata{ Source: M.SocksaddrFromNetIP(session.Source), Destination: destination, }) - conn.Close() - time.Sleep(time.Second) - s.tcpNat.Revoke(connPort, session) + if tcpConn, isTCPConn := conn.(*net.TCPConn); isTCPConn { + _ = tcpConn.SetLinger(0) + } + _ = conn.Close() }() } } diff --git a/system_nat.go b/system_nat.go index adac1a6..ff80413 100644 --- a/system_nat.go +++ b/system_nat.go @@ -1,8 +1,10 @@ package tun import ( + "context" "net/netip" "sync" + "time" ) type TCPNat struct { @@ -16,20 +18,54 @@ type TCPNat struct { type TCPSession struct { Source netip.AddrPort Destination netip.AddrPort + LastActive time.Time } -func NewNat() *TCPNat { - return &TCPNat{ +func NewNat(ctx context.Context, timeout time.Duration) *TCPNat { + natMap := &TCPNat{ portIndex: 10000, addrMap: make(map[netip.AddrPort]uint16), portMap: make(map[uint16]*TCPSession), } + go natMap.loopCheckTimeout(ctx, timeout) + return natMap +} + +func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) { + ticker := time.NewTicker(timeout) + defer ticker.Stop() + for { + select { + case <-ticker.C: + n.checkTimeout(timeout) + case <-ctx.Done(): + return + } + } +} + +func (n *TCPNat) checkTimeout(timeout time.Duration) { + now := time.Now() + n.portAccess.Lock() + defer n.portAccess.Unlock() + n.addrAccess.Lock() + defer n.addrAccess.Unlock() + for natPort, session := range n.portMap { + if now.Sub(session.LastActive) > timeout { + delete(n.addrMap, session.Source) + delete(n.portMap, natPort) + } + } } func (n *TCPNat) LookupBack(port uint16) *TCPSession { n.portAccess.RLock() - defer n.portAccess.RUnlock() - return n.portMap[port] + session := n.portMap[port] + n.portAccess.RUnlock() + if session != nil { + session.LastActive = time.Now() + } + return session } func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 { @@ -53,16 +89,8 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint1 n.portMap[nextPort] = &TCPSession{ Source: source, Destination: destination, + LastActive: time.Now(), } n.portAccess.Unlock() return nextPort } - -func (n *TCPNat) Revoke(natPort uint16, session *TCPSession) { - n.addrAccess.Lock() - delete(n.addrMap, session.Source) - n.addrAccess.Unlock() - n.portAccess.Lock() - delete(n.portMap, natPort) - n.portAccess.Unlock() -}