From a4eb7fa900c79a3b28005016efc2a77bccc7a8bc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= <i@sekai.icu>
Date: Wed, 23 Oct 2024 13:30:48 +0800
Subject: [PATCH] udpnat2: Add SetHandler

---
 common/bufio/cache.go     |  4 +++-
 common/network/conn.go    |  7 +------
 common/network/packet.go  | 35 +++++++++++++++++++++++++++++++
 common/udpnat2/conn.go    | 44 ++++++++++++++++++++++++++-------------
 common/udpnat2/packet.go  | 28 -------------------------
 common/udpnat2/service.go | 25 +++++++++++++---------
 6 files changed, 84 insertions(+), 59 deletions(-)
 create mode 100644 common/network/packet.go
 delete mode 100644 common/udpnat2/packet.go

diff --git a/common/bufio/cache.go b/common/bufio/cache.go
index ace7259..ce62d4d 100644
--- a/common/bufio/cache.go
+++ b/common/bufio/cache.go
@@ -184,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
 	if buffer != nil {
 		buffer.DecRef()
 	}
-	return &N.PacketBuffer{
+	packet := N.NewPacketBuffer()
+	*packet = N.PacketBuffer{
 		Buffer:      buffer,
 		Destination: c.destination,
 	}
+	return packet
 }
 
 func (c *CachedPacketConn) Upstream() any {
diff --git a/common/network/conn.go b/common/network/conn.go
index c795a19..c289bf6 100644
--- a/common/network/conn.go
+++ b/common/network/conn.go
@@ -124,7 +124,7 @@ type UDPHandler interface {
 }
 
 type UDPHandlerEx interface {
-	NewPacketEx(buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr)
+	NewPacketEx(buffer *buf.Buffer, source M.Socksaddr)
 }
 
 // Deprecated: Use UDPConnectionHandlerEx instead.
@@ -146,11 +146,6 @@ type CachedPacketReader interface {
 	ReadCachedPacket() *PacketBuffer
 }
 
-type PacketBuffer struct {
-	Buffer      *buf.Buffer
-	Destination M.Socksaddr
-}
-
 type WithUpstreamReader interface {
 	UpstreamReader() any
 }
diff --git a/common/network/packet.go b/common/network/packet.go
new file mode 100644
index 0000000..5b85214
--- /dev/null
+++ b/common/network/packet.go
@@ -0,0 +1,35 @@
+package network
+
+import (
+	"sync"
+
+	"github.com/sagernet/sing/common/buf"
+	M "github.com/sagernet/sing/common/metadata"
+)
+
+type PacketBuffer struct {
+	Buffer      *buf.Buffer
+	Destination M.Socksaddr
+}
+
+var packetPool = sync.Pool{
+	New: func() any {
+		return new(PacketBuffer)
+	},
+}
+
+func NewPacketBuffer() *PacketBuffer {
+	return packetPool.Get().(*PacketBuffer)
+}
+
+func PutPacketBuffer(packet *PacketBuffer) {
+	*packet = PacketBuffer{}
+	packetPool.Put(packet)
+}
+
+func ReleaseMultiPacketBuffer(packetBuffers []*PacketBuffer) {
+	for _, packet := range packetBuffers {
+		packet.Buffer.Release()
+		PutPacketBuffer(packet)
+	}
+}
diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go
index a5ca8ac..a96f4c8 100644
--- a/common/udpnat2/conn.go
+++ b/common/udpnat2/conn.go
@@ -12,22 +12,23 @@ import (
 	"github.com/sagernet/sing/common/pipe"
 )
 
-type natConn struct {
+type Conn struct {
 	writer          N.PacketWriter
 	localAddr       M.Socksaddr
-	packetChan      chan *Packet
+	handler         N.UDPHandlerEx
+	packetChan      chan *N.PacketBuffer
 	doneChan        chan struct{}
 	readDeadline    pipe.Deadline
 	readWaitOptions N.ReadWaitOptions
 }
 
-func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
+func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
 	select {
 	case p := <-c.packetChan:
 		_, err = buffer.ReadOnceFrom(p.Buffer)
 		destination := p.Destination
 		p.Buffer.Release()
-		PutPacket(p)
+		N.PutPacketBuffer(p)
 		return destination, err
 	case <-c.doneChan:
 		return M.Socksaddr{}, io.ErrClosedPipe
@@ -36,21 +37,36 @@ func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
 	}
 }
 
-func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
 	return c.writer.WritePacket(buffer, destination)
 }
 
-func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
+func (c *Conn) SetHandler(handler N.UDPHandlerEx) {
+	c.handler = handler
+fetch:
+	for {
+		select {
+		case packet := <-c.packetChan:
+			c.handler.NewPacketEx(packet.Buffer, packet.Destination)
+			N.PutPacketBuffer(packet)
+			continue fetch
+		default:
+			break fetch
+		}
+	}
+}
+
+func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
 	c.readWaitOptions = options
 	return false
 }
 
-func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
+func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
 	select {
 	case packet := <-c.packetChan:
 		buffer = c.readWaitOptions.Copy(packet.Buffer)
 		destination = packet.Destination
-		PutPacket(packet)
+		N.PutPacketBuffer(packet)
 		return
 	case <-c.doneChan:
 		return nil, M.Socksaddr{}, io.ErrClosedPipe
@@ -59,7 +75,7 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr,
 	}
 }
 
-func (c *natConn) Close() error {
+func (c *Conn) Close() error {
 	select {
 	case <-c.doneChan:
 	default:
@@ -68,23 +84,23 @@ func (c *natConn) Close() error {
 	return nil
 }
 
-func (c *natConn) LocalAddr() net.Addr {
+func (c *Conn) LocalAddr() net.Addr {
 	return c.localAddr
 }
 
-func (c *natConn) RemoteAddr() net.Addr {
+func (c *Conn) RemoteAddr() net.Addr {
 	return M.Socksaddr{}
 }
 
-func (c *natConn) SetDeadline(t time.Time) error {
+func (c *Conn) SetDeadline(t time.Time) error {
 	return os.ErrInvalid
 }
 
-func (c *natConn) SetReadDeadline(t time.Time) error {
+func (c *Conn) SetReadDeadline(t time.Time) error {
 	c.readDeadline.Set(t)
 	return nil
 }
 
-func (c *natConn) SetWriteDeadline(t time.Time) error {
+func (c *Conn) SetWriteDeadline(t time.Time) error {
 	return os.ErrInvalid
 }
diff --git a/common/udpnat2/packet.go b/common/udpnat2/packet.go
deleted file mode 100644
index 1d56ff4..0000000
--- a/common/udpnat2/packet.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package udpnat
-
-import (
-	"sync"
-
-	"github.com/sagernet/sing/common/buf"
-	M "github.com/sagernet/sing/common/metadata"
-)
-
-var packetPool = sync.Pool{
-	New: func() any {
-		return new(Packet)
-	},
-}
-
-type Packet struct {
-	Buffer      *buf.Buffer
-	Destination M.Socksaddr
-}
-
-func NewPacket() *Packet {
-	return packetPool.Get().(*Packet)
-}
-
-func PutPacket(packet *Packet) {
-	*packet = Packet{}
-	packetPool.Put(packet)
-}
diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go
index 85b3641..8c8afc9 100644
--- a/common/udpnat2/service.go
+++ b/common/udpnat2/service.go
@@ -14,7 +14,7 @@ import (
 )
 
 type Service struct {
-	nat     *freelru.LRU[netip.AddrPort, *natConn]
+	nat     *freelru.LRU[netip.AddrPort, *Conn]
 	handler N.UDPConnectionHandlerEx
 	prepare PrepareFunc
 	metrics Metrics
@@ -30,9 +30,9 @@ type Metrics struct {
 }
 
 func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service {
-	nat := common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
+	nat := common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
 	nat.SetLifetime(timeout)
-	nat.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool {
+	nat.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool {
 		select {
 		case <-conn.doneChan:
 			return false
@@ -40,7 +40,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur
 			return true
 		}
 	})
-	nat.SetOnEvict(func(_ netip.AddrPort, conn *natConn) {
+	nat.SetOnEvict(func(_ netip.AddrPort, conn *Conn) {
 		conn.Close()
 	})
 	return &Service{
@@ -55,26 +55,31 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
 	if !loaded {
 		ok, ctx, writer, onClose := s.prepare(source, destination, userData)
 		if !ok {
+			println(2)
 			s.metrics.Rejects++
 			return
 		}
-		conn = &natConn{
+		conn = &Conn{
 			writer:       writer,
 			localAddr:    source,
-			packetChan:   make(chan *Packet, 64),
+			packetChan:   make(chan *N.PacketBuffer, 64),
 			doneChan:     make(chan struct{}),
 			readDeadline: pipe.MakeDeadline(),
 		}
 		s.nat.Add(source.AddrPort(), conn)
-		s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose)
+		go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose)
 		s.metrics.Creates++
 	}
-	packet := NewPacket()
 	buffer := conn.readWaitOptions.NewPacketBuffer()
 	for _, bufferSlice := range bufferSlices {
 		buffer.Write(bufferSlice)
 	}
-	*packet = Packet{
+	if conn.handler != nil {
+		conn.handler.NewPacketEx(buffer, destination)
+		return
+	}
+	packet := N.NewPacketBuffer()
+	*packet = N.PacketBuffer{
 		Buffer:      buffer,
 		Destination: destination,
 	}
@@ -83,7 +88,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
 		s.metrics.Inputs++
 	default:
 		packet.Buffer.Release()
-		PutPacket(packet)
+		N.PutPacketBuffer(packet)
 		s.metrics.Drops++
 	}
 }