mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-02 03:17:37 +03:00
udpnat2: Fix concurrency
This commit is contained in:
parent
6edd2ce0ea
commit
39040e06dc
2 changed files with 16 additions and 34 deletions
|
@ -3,6 +3,7 @@ package udpnat
|
|||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -12,6 +13,7 @@ import (
|
|||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/pipe"
|
||||
"github.com/sagernet/sing/contrab/freelru"
|
||||
)
|
||||
|
||||
type Conn interface {
|
||||
|
@ -23,7 +25,7 @@ type Conn interface {
|
|||
var _ Conn = (*natConn)(nil)
|
||||
|
||||
type natConn struct {
|
||||
service *Service
|
||||
cache freelru.Cache[netip.AddrPort, *natConn]
|
||||
writer N.PacketWriter
|
||||
localAddr M.Socksaddr
|
||||
handler N.UDPHandlerEx
|
||||
|
@ -93,7 +95,7 @@ fetch:
|
|||
}
|
||||
|
||||
func (c *natConn) Timeout() time.Duration {
|
||||
rawConn, lifetime, loaded := c.service.cache.PeekWithLifetime(c.localAddr.AddrPort())
|
||||
rawConn, lifetime, loaded := c.cache.PeekWithLifetime(c.localAddr.AddrPort())
|
||||
if !loaded || rawConn != c {
|
||||
return 0
|
||||
}
|
||||
|
@ -101,7 +103,7 @@ func (c *natConn) Timeout() time.Duration {
|
|||
}
|
||||
|
||||
func (c *natConn) SetTimeout(timeout time.Duration) bool {
|
||||
return c.service.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
|
||||
return c.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
|
||||
}
|
||||
|
||||
func (c *natConn) Close() error {
|
||||
|
|
|
@ -17,18 +17,10 @@ type Service struct {
|
|||
cache freelru.Cache[netip.AddrPort, *natConn]
|
||||
handler N.UDPConnectionHandlerEx
|
||||
prepare PrepareFunc
|
||||
metrics Metrics
|
||||
}
|
||||
|
||||
type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc)
|
||||
|
||||
type Metrics struct {
|
||||
Creates uint64
|
||||
Rejects uint64
|
||||
Inputs uint64
|
||||
Drops uint64
|
||||
}
|
||||
|
||||
func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service {
|
||||
if timeout == 0 {
|
||||
panic("invalid timeout")
|
||||
|
@ -40,7 +32,6 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur
|
|||
cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
|
||||
}
|
||||
cache.SetLifetime(timeout)
|
||||
cache.SetUpdateLifetimeOnGet(true)
|
||||
cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool {
|
||||
select {
|
||||
case <-conn.doneChan:
|
||||
|
@ -60,25 +51,26 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur
|
|||
}
|
||||
|
||||
func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) {
|
||||
conn, loaded := s.cache.Get(source.AddrPort())
|
||||
if !loaded {
|
||||
conn, loaded := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) {
|
||||
ok, ctx, writer, onClose := s.prepare(source, destination, userData)
|
||||
if !ok {
|
||||
s.metrics.Rejects++
|
||||
return
|
||||
return nil, false
|
||||
}
|
||||
conn = &natConn{
|
||||
service: s,
|
||||
newConn := &natConn{
|
||||
cache: s.cache,
|
||||
writer: writer,
|
||||
localAddr: source,
|
||||
packetChan: make(chan *N.PacketBuffer, 64),
|
||||
doneChan: make(chan struct{}),
|
||||
readDeadline: pipe.MakeDeadline(),
|
||||
}
|
||||
s.PurgeExpired()
|
||||
s.cache.Add(source.AddrPort(), conn)
|
||||
go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose)
|
||||
s.metrics.Creates++
|
||||
go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose)
|
||||
return newConn, true
|
||||
})
|
||||
if !loaded {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
buffer := conn.readWaitOptions.NewPacketBuffer()
|
||||
for _, bufferSlice := range bufferSlices {
|
||||
|
@ -95,11 +87,9 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
|
|||
}
|
||||
select {
|
||||
case conn.packetChan <- packet:
|
||||
s.metrics.Inputs++
|
||||
default:
|
||||
packet.Buffer.Release()
|
||||
N.PutPacketBuffer(packet)
|
||||
s.metrics.Drops++
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,13 +100,3 @@ func (s *Service) Purge() {
|
|||
func (s *Service) PurgeExpired() {
|
||||
s.cache.PurgeExpired()
|
||||
}
|
||||
|
||||
func (s *Service) Metrics() Metrics {
|
||||
return s.metrics
|
||||
}
|
||||
|
||||
func (s *Service) ResetMetrics() Metrics {
|
||||
metrics := s.metrics
|
||||
s.metrics = Metrics{}
|
||||
return metrics
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue