udpnat2: Fix concurrency

This commit is contained in:
世界 2024-11-28 13:16:33 +08:00
parent 66034ab8ea
commit b21dbb9f3c
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 16 additions and 34 deletions

View file

@ -3,6 +3,7 @@ package udpnat
import ( import (
"io" "io"
"net" "net"
"net/netip"
"os" "os"
"sync" "sync"
"time" "time"
@ -12,6 +13,7 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe" "github.com/sagernet/sing/common/pipe"
"github.com/sagernet/sing/contrab/freelru"
) )
type Conn interface { type Conn interface {
@ -23,7 +25,7 @@ type Conn interface {
var _ Conn = (*natConn)(nil) var _ Conn = (*natConn)(nil)
type natConn struct { type natConn struct {
service *Service cache freelru.Cache[netip.AddrPort, *natConn]
writer N.PacketWriter writer N.PacketWriter
localAddr M.Socksaddr localAddr M.Socksaddr
handler N.UDPHandlerEx handler N.UDPHandlerEx
@ -93,7 +95,7 @@ fetch:
} }
func (c *natConn) Timeout() time.Duration { func (c *natConn) Timeout() time.Duration {
rawConn, lifetime, loaded := c.service.cache.PeekWithLifetime(c.localAddr.AddrPort()) rawConn, lifetime, loaded := c.cache.PeekWithLifetime(c.localAddr.AddrPort())
if !loaded || rawConn != c { if !loaded || rawConn != c {
return 0 return 0
} }
@ -101,7 +103,7 @@ func (c *natConn) Timeout() time.Duration {
} }
func (c *natConn) SetTimeout(timeout time.Duration) bool { func (c *natConn) SetTimeout(timeout time.Duration) bool {
return c.service.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout) return c.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
} }
func (c *natConn) Close() error { func (c *natConn) Close() error {

View file

@ -17,18 +17,10 @@ type Service struct {
cache freelru.Cache[netip.AddrPort, *natConn] cache freelru.Cache[netip.AddrPort, *natConn]
handler N.UDPConnectionHandlerEx handler N.UDPConnectionHandlerEx
prepare PrepareFunc prepare PrepareFunc
metrics Metrics
} }
type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc)
type Metrics struct {
Creates uint64
Rejects uint64
Inputs uint64
Drops uint64
}
func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service { func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service {
if timeout == 0 { if timeout == 0 {
panic("invalid timeout") panic("invalid timeout")
@ -40,7 +32,6 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur
cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
} }
cache.SetLifetime(timeout) cache.SetLifetime(timeout)
cache.SetUpdateLifetimeOnGet(true)
cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool {
select { select {
case <-conn.doneChan: case <-conn.doneChan:
@ -60,25 +51,26 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur
} }
func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) {
conn, loaded := s.cache.Get(source.AddrPort()) conn, loaded := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) {
if !loaded {
ok, ctx, writer, onClose := s.prepare(source, destination, userData) ok, ctx, writer, onClose := s.prepare(source, destination, userData)
if !ok { if !ok {
s.metrics.Rejects++ return nil, false
return
} }
conn = &natConn{ newConn := &natConn{
service: s, cache: s.cache,
writer: writer, writer: writer,
localAddr: source, localAddr: source,
packetChan: make(chan *N.PacketBuffer, 64), packetChan: make(chan *N.PacketBuffer, 64),
doneChan: make(chan struct{}), doneChan: make(chan struct{}),
readDeadline: pipe.MakeDeadline(), readDeadline: pipe.MakeDeadline(),
} }
s.PurgeExpired() go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose)
s.cache.Add(source.AddrPort(), conn) return newConn, true
go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) })
s.metrics.Creates++ if !loaded {
if conn == nil {
return
}
} }
buffer := conn.readWaitOptions.NewPacketBuffer() buffer := conn.readWaitOptions.NewPacketBuffer()
for _, bufferSlice := range bufferSlices { for _, bufferSlice := range bufferSlices {
@ -95,11 +87,9 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
} }
select { select {
case conn.packetChan <- packet: case conn.packetChan <- packet:
s.metrics.Inputs++
default: default:
packet.Buffer.Release() packet.Buffer.Release()
N.PutPacketBuffer(packet) N.PutPacketBuffer(packet)
s.metrics.Drops++
} }
} }
@ -110,13 +100,3 @@ func (s *Service) Purge() {
func (s *Service) PurgeExpired() { func (s *Service) PurgeExpired() {
s.cache.PurgeExpired() s.cache.PurgeExpired()
} }
func (s *Service) Metrics() Metrics {
return s.metrics
}
func (s *Service) ResetMetrics() Metrics {
metrics := s.metrics
s.metrics = Metrics{}
return metrics
}