Fix system nat mapping

This commit is contained in:
世界 2023-04-18 21:21:59 +08:00
parent bf7110b1ab
commit 53f50347e0
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 60 additions and 23 deletions

View file

@ -57,7 +57,7 @@ func NewGVisor(
return nil, E.New("gVisor stack is unsupported on current platform") return nil, E.New("gVisor stack is unsupported on current platform")
} }
return &GVisor{ gStack := &GVisor{
ctx: options.Context, ctx: options.Context,
tun: gTun, tun: gTun,
tunMtu: options.MTU, tunMtu: options.MTU,
@ -66,8 +66,11 @@ func NewGVisor(
router: options.Router, router: options.Router,
handler: options.Handler, handler: options.Handler,
logger: options.Logger, logger: options.Logger,
routeMapping: NewRouteMapping(options.UDPTimeout), }
}, nil if gStack.router != nil {
gStack.routeMapping = NewRouteMapping(options.Context, options.UDPTimeout)
}
return gStack, nil
} }
func (t *GVisor) Start() error { func (t *GVisor) Start() error {

View file

@ -1,6 +1,8 @@
package tun package tun
import ( import (
"context"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/cache" "github.com/sagernet/sing/common/cache"
) )
@ -9,9 +11,10 @@ type RouteMapping struct {
status *cache.LruCache[RouteSession, RouteAction] status *cache.LruCache[RouteSession, RouteAction]
} }
func NewRouteMapping(maxAge int64) *RouteMapping { func NewRouteMapping(ctx context.Context, maxAge int64) *RouteMapping {
return &RouteMapping{ return &RouteMapping{
status: cache.New( status: cache.New(
cache.WithContext[RouteSession, RouteAction](ctx),
cache.WithAge[RouteSession, RouteAction](maxAge), cache.WithAge[RouteSession, RouteAction](maxAge),
cache.WithUpdateAgeOnGet[RouteSession, RouteAction](), cache.WithUpdateAgeOnGet[RouteSession, RouteAction](),
cache.WithEvict[RouteSession, RouteAction](func(key RouteSession, conn RouteAction) { cache.WithEvict[RouteSession, RouteAction](func(key RouteSession, conn RouteAction) {

View file

@ -63,7 +63,9 @@ func NewSystem(options StackOptions) (Stack, error) {
inet4Prefixes: options.Inet4Address, inet4Prefixes: options.Inet4Address,
inet6Prefixes: options.Inet6Address, inet6Prefixes: options.Inet6Address,
underPlatform: options.UnderPlatform, underPlatform: options.UnderPlatform,
routeMapping: NewRouteMapping(options.UDPTimeout), }
if stack.router != nil {
stack.routeMapping = NewRouteMapping(options.Context, options.UDPTimeout)
} }
if len(options.Inet4Address) > 0 { if len(options.Inet4Address) > 0 {
if options.Inet4Address[0].Bits() == 32 { if options.Inet4Address[0].Bits() == 32 {
@ -115,7 +117,7 @@ func (s *System) Start() error {
s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
go s.acceptLoop(tcpListener) go s.acceptLoop(tcpListener)
} }
s.tcpNat = NewNat() s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout))
s.udpNat = udpnat.New[netip.AddrPort](s.ctx, s.udpTimeout, s.handler) s.udpNat = udpnat.New[netip.AddrPort](s.ctx, s.udpTimeout, s.handler)
go s.tunLoop() go s.tunLoop()
return nil return nil
@ -208,13 +210,14 @@ func (s *System) acceptLoop(listener net.Listener) {
} }
} }
go func() { go func() {
s.handler.NewConnection(s.ctx, conn, M.Metadata{ _ = s.handler.NewConnection(s.ctx, conn, M.Metadata{
Source: M.SocksaddrFromNetIP(session.Source), Source: M.SocksaddrFromNetIP(session.Source),
Destination: destination, Destination: destination,
}) })
conn.Close() if tcpConn, isTCPConn := conn.(*net.TCPConn); isTCPConn {
time.Sleep(time.Second) _ = tcpConn.SetLinger(0)
s.tcpNat.Revoke(connPort, session) }
_ = conn.Close()
}() }()
} }
} }

View file

@ -1,8 +1,10 @@
package tun package tun
import ( import (
"context"
"net/netip" "net/netip"
"sync" "sync"
"time"
) )
type TCPNat struct { type TCPNat struct {
@ -16,20 +18,54 @@ type TCPNat struct {
type TCPSession struct { type TCPSession struct {
Source netip.AddrPort Source netip.AddrPort
Destination netip.AddrPort Destination netip.AddrPort
LastActive time.Time
} }
func NewNat() *TCPNat { func NewNat(ctx context.Context, timeout time.Duration) *TCPNat {
return &TCPNat{ natMap := &TCPNat{
portIndex: 10000, portIndex: 10000,
addrMap: make(map[netip.AddrPort]uint16), addrMap: make(map[netip.AddrPort]uint16),
portMap: make(map[uint16]*TCPSession), portMap: make(map[uint16]*TCPSession),
} }
go natMap.loopCheckTimeout(ctx, timeout)
return natMap
}
func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) {
ticker := time.NewTicker(timeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
n.checkTimeout(timeout)
case <-ctx.Done():
return
}
}
}
func (n *TCPNat) checkTimeout(timeout time.Duration) {
now := time.Now()
n.portAccess.Lock()
defer n.portAccess.Unlock()
n.addrAccess.Lock()
defer n.addrAccess.Unlock()
for natPort, session := range n.portMap {
if now.Sub(session.LastActive) > timeout {
delete(n.addrMap, session.Source)
delete(n.portMap, natPort)
}
}
} }
func (n *TCPNat) LookupBack(port uint16) *TCPSession { func (n *TCPNat) LookupBack(port uint16) *TCPSession {
n.portAccess.RLock() n.portAccess.RLock()
defer n.portAccess.RUnlock() session := n.portMap[port]
return n.portMap[port] n.portAccess.RUnlock()
if session != nil {
session.LastActive = time.Now()
}
return session
} }
func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 { func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 {
@ -53,16 +89,8 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint1
n.portMap[nextPort] = &TCPSession{ n.portMap[nextPort] = &TCPSession{
Source: source, Source: source,
Destination: destination, Destination: destination,
LastActive: time.Now(),
} }
n.portAccess.Unlock() n.portAccess.Unlock()
return nextPort return nextPort
} }
func (n *TCPNat) Revoke(natPort uint16, session *TCPSession) {
n.addrAccess.Lock()
delete(n.addrMap, session.Source)
n.addrAccess.Unlock()
n.portAccess.Lock()
delete(n.portMap, natPort)
n.portAccess.Unlock()
}