udpnat2: Add SetHandler

This commit is contained in:
世界 2024-10-23 13:30:48 +08:00
parent 7ec09d6045
commit a4eb7fa900
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
6 changed files with 84 additions and 59 deletions

View file

@ -184,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
if buffer != nil { if buffer != nil {
buffer.DecRef() buffer.DecRef()
} }
return &N.PacketBuffer{ packet := N.NewPacketBuffer()
*packet = N.PacketBuffer{
Buffer: buffer, Buffer: buffer,
Destination: c.destination, Destination: c.destination,
} }
return packet
} }
func (c *CachedPacketConn) Upstream() any { func (c *CachedPacketConn) Upstream() any {

View file

@ -124,7 +124,7 @@ type UDPHandler interface {
} }
type UDPHandlerEx interface { type UDPHandlerEx interface {
NewPacketEx(buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr)
} }
// Deprecated: Use UDPConnectionHandlerEx instead. // Deprecated: Use UDPConnectionHandlerEx instead.
@ -146,11 +146,6 @@ type CachedPacketReader interface {
ReadCachedPacket() *PacketBuffer ReadCachedPacket() *PacketBuffer
} }
type PacketBuffer struct {
Buffer *buf.Buffer
Destination M.Socksaddr
}
type WithUpstreamReader interface { type WithUpstreamReader interface {
UpstreamReader() any UpstreamReader() any
} }

35
common/network/packet.go Normal file
View file

@ -0,0 +1,35 @@
package network
import (
"sync"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type PacketBuffer struct {
Buffer *buf.Buffer
Destination M.Socksaddr
}
var packetPool = sync.Pool{
New: func() any {
return new(PacketBuffer)
},
}
func NewPacketBuffer() *PacketBuffer {
return packetPool.Get().(*PacketBuffer)
}
func PutPacketBuffer(packet *PacketBuffer) {
*packet = PacketBuffer{}
packetPool.Put(packet)
}
func ReleaseMultiPacketBuffer(packetBuffers []*PacketBuffer) {
for _, packet := range packetBuffers {
packet.Buffer.Release()
PutPacketBuffer(packet)
}
}

View file

@ -12,22 +12,23 @@ import (
"github.com/sagernet/sing/common/pipe" "github.com/sagernet/sing/common/pipe"
) )
type natConn struct { type Conn struct {
writer N.PacketWriter writer N.PacketWriter
localAddr M.Socksaddr localAddr M.Socksaddr
packetChan chan *Packet handler N.UDPHandlerEx
packetChan chan *N.PacketBuffer
doneChan chan struct{} doneChan chan struct{}
readDeadline pipe.Deadline readDeadline pipe.Deadline
readWaitOptions N.ReadWaitOptions readWaitOptions N.ReadWaitOptions
} }
func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
select { select {
case p := <-c.packetChan: case p := <-c.packetChan:
_, err = buffer.ReadOnceFrom(p.Buffer) _, err = buffer.ReadOnceFrom(p.Buffer)
destination := p.Destination destination := p.Destination
p.Buffer.Release() p.Buffer.Release()
PutPacket(p) N.PutPacketBuffer(p)
return destination, err return destination, err
case <-c.doneChan: case <-c.doneChan:
return M.Socksaddr{}, io.ErrClosedPipe return M.Socksaddr{}, io.ErrClosedPipe
@ -36,21 +37,36 @@ func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
} }
} }
func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WritePacket(buffer, destination) return c.writer.WritePacket(buffer, destination)
} }
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { func (c *Conn) SetHandler(handler N.UDPHandlerEx) {
c.handler = handler
fetch:
for {
select {
case packet := <-c.packetChan:
c.handler.NewPacketEx(packet.Buffer, packet.Destination)
N.PutPacketBuffer(packet)
continue fetch
default:
break fetch
}
}
}
func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options c.readWaitOptions = options
return false return false
} }
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select { select {
case packet := <-c.packetChan: case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer) buffer = c.readWaitOptions.Copy(packet.Buffer)
destination = packet.Destination destination = packet.Destination
PutPacket(packet) N.PutPacketBuffer(packet)
return return
case <-c.doneChan: case <-c.doneChan:
return nil, M.Socksaddr{}, io.ErrClosedPipe return nil, M.Socksaddr{}, io.ErrClosedPipe
@ -59,7 +75,7 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr,
} }
} }
func (c *natConn) Close() error { func (c *Conn) Close() error {
select { select {
case <-c.doneChan: case <-c.doneChan:
default: default:
@ -68,23 +84,23 @@ func (c *natConn) Close() error {
return nil return nil
} }
func (c *natConn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {
return c.localAddr return c.localAddr
} }
func (c *natConn) RemoteAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr {
return M.Socksaddr{} return M.Socksaddr{}
} }
func (c *natConn) SetDeadline(t time.Time) error { func (c *Conn) SetDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *natConn) SetReadDeadline(t time.Time) error { func (c *Conn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t) c.readDeadline.Set(t)
return nil return nil
} }
func (c *natConn) SetWriteDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }

View file

@ -1,28 +0,0 @@
package udpnat
import (
"sync"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
var packetPool = sync.Pool{
New: func() any {
return new(Packet)
},
}
type Packet struct {
Buffer *buf.Buffer
Destination M.Socksaddr
}
func NewPacket() *Packet {
return packetPool.Get().(*Packet)
}
func PutPacket(packet *Packet) {
*packet = Packet{}
packetPool.Put(packet)
}

View file

@ -14,7 +14,7 @@ import (
) )
type Service struct { type Service struct {
nat *freelru.LRU[netip.AddrPort, *natConn] nat *freelru.LRU[netip.AddrPort, *Conn]
handler N.UDPConnectionHandlerEx handler N.UDPConnectionHandlerEx
prepare PrepareFunc prepare PrepareFunc
metrics Metrics metrics Metrics
@ -30,9 +30,9 @@ type Metrics struct {
} }
func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service { func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service {
nat := common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) nat := common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
nat.SetLifetime(timeout) nat.SetLifetime(timeout)
nat.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { nat.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool {
select { select {
case <-conn.doneChan: case <-conn.doneChan:
return false return false
@ -40,7 +40,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur
return true return true
} }
}) })
nat.SetOnEvict(func(_ netip.AddrPort, conn *natConn) { nat.SetOnEvict(func(_ netip.AddrPort, conn *Conn) {
conn.Close() conn.Close()
}) })
return &Service{ return &Service{
@ -55,26 +55,31 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
if !loaded { if !loaded {
ok, ctx, writer, onClose := s.prepare(source, destination, userData) ok, ctx, writer, onClose := s.prepare(source, destination, userData)
if !ok { if !ok {
println(2)
s.metrics.Rejects++ s.metrics.Rejects++
return return
} }
conn = &natConn{ conn = &Conn{
writer: writer, writer: writer,
localAddr: source, localAddr: source,
packetChan: make(chan *Packet, 64), packetChan: make(chan *N.PacketBuffer, 64),
doneChan: make(chan struct{}), doneChan: make(chan struct{}),
readDeadline: pipe.MakeDeadline(), readDeadline: pipe.MakeDeadline(),
} }
s.nat.Add(source.AddrPort(), conn) s.nat.Add(source.AddrPort(), conn)
s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose)
s.metrics.Creates++ s.metrics.Creates++
} }
packet := NewPacket()
buffer := conn.readWaitOptions.NewPacketBuffer() buffer := conn.readWaitOptions.NewPacketBuffer()
for _, bufferSlice := range bufferSlices { for _, bufferSlice := range bufferSlices {
buffer.Write(bufferSlice) buffer.Write(bufferSlice)
} }
*packet = Packet{ if conn.handler != nil {
conn.handler.NewPacketEx(buffer, destination)
return
}
packet := N.NewPacketBuffer()
*packet = N.PacketBuffer{
Buffer: buffer, Buffer: buffer,
Destination: destination, Destination: destination,
} }
@ -83,7 +88,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati
s.metrics.Inputs++ s.metrics.Inputs++
default: default:
packet.Buffer.Release() packet.Buffer.Release()
PutPacket(packet) N.PutPacketBuffer(packet)
s.metrics.Drops++ s.metrics.Drops++
} }
} }