udpnat2: New synced udp nat service

This commit is contained in:
世界 2024-10-21 19:24:36 +08:00
parent 0641c71805
commit 7ec09d6045
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
12 changed files with 307 additions and 40 deletions

View file

@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release() defer buffer.Release()
if destination.IsFqdn() { return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
} }
func (w *ExtendedUDPConn) Upstream() any { func (w *ExtendedUDPConn) Upstream() any {

View file

@ -124,7 +124,7 @@ type UDPHandler interface {
} }
type UDPHandlerEx interface { type UDPHandlerEx interface {
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) error NewPacketEx(buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr)
} }
// Deprecated: Use UDPConnectionHandlerEx instead. // Deprecated: Use UDPConnectionHandlerEx instead.

View file

@ -19,15 +19,27 @@ func (o ReadWaitOptions) NeedHeadroom() bool {
return o.FrontHeadroom > 0 || o.RearHeadroom > 0 return o.FrontHeadroom > 0 || o.RearHeadroom > 0
} }
func (o ReadWaitOptions) Copy(buffer *buf.Buffer) *buf.Buffer {
if o.FrontHeadroom > buffer.Start() ||
o.RearHeadroom > buffer.FreeLen() {
newBuffer := o.newBuffer(buf.UDPBufferSize, false)
newBuffer.Write(buffer.Bytes())
buffer.Release()
return newBuffer
} else {
return buffer
}
}
func (o ReadWaitOptions) NewBuffer() *buf.Buffer { func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
return o.newBuffer(buf.BufferSize) return o.newBuffer(buf.BufferSize, true)
} }
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer { func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
return o.newBuffer(buf.UDPBufferSize) return o.newBuffer(buf.UDPBufferSize, true)
} }
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer { func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buffer {
var bufferSize int var bufferSize int
if o.MTU > 0 { if o.MTU > 0 {
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
@ -38,7 +50,7 @@ func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
if o.FrontHeadroom > 0 { if o.FrontHeadroom > 0 {
buffer.Resize(o.FrontHeadroom, 0) buffer.Resize(o.FrontHeadroom, 0)
} }
if o.RearHeadroom > 0 { if o.RearHeadroom > 0 && reserve {
buffer.Reserve(o.RearHeadroom) buffer.Reserve(o.RearHeadroom)
} }
return buffer return buffer

View file

@ -131,8 +131,6 @@ func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.
s.nat.Delete(key) s.nat.Delete(key)
} }
}() }()
} else {
c.localAddr = source
} }
if common.Done(c.ctx) { if common.Done(c.ctx) {
s.nat.Delete(key) s.nat.Delete(key)
@ -215,10 +213,6 @@ func (c *conn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *conn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *conn) Upstream() any { func (c *conn) Upstream() any {
return c.source return c.source
} }

90
common/udpnat2/conn.go Normal file
View file

@ -0,0 +1,90 @@
package udpnat
import (
"io"
"net"
"os"
"time"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
)
type natConn struct {
writer N.PacketWriter
localAddr M.Socksaddr
packetChan chan *Packet
doneChan chan struct{}
readDeadline pipe.Deadline
readWaitOptions N.ReadWaitOptions
}
func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
select {
case p := <-c.packetChan:
_, err = buffer.ReadOnceFrom(p.Buffer)
destination := p.Destination
p.Buffer.Release()
PutPacket(p)
return destination, err
case <-c.doneChan:
return M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WritePacket(buffer, destination)
}
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer)
destination = packet.Destination
PutPacket(packet)
return
case <-c.doneChan:
return nil, M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (c *natConn) Close() error {
select {
case <-c.doneChan:
default:
close(c.doneChan)
}
return nil
}
func (c *natConn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *natConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}
func (c *natConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *natConn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
return nil
}
func (c *natConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}

28
common/udpnat2/packet.go Normal file
View file

@ -0,0 +1,28 @@
package udpnat
import (
"sync"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
var packetPool = sync.Pool{
New: func() any {
return new(Packet)
},
}
type Packet struct {
Buffer *buf.Buffer
Destination M.Socksaddr
}
func NewPacket() *Packet {
return packetPool.Get().(*Packet)
}
func PutPacket(packet *Packet) {
*packet = Packet{}
packetPool.Put(packet)
}

93
common/udpnat2/service.go Normal file
View file

@ -0,0 +1,93 @@
package udpnat
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
)
type Service struct {
nat *freelru.LRU[netip.AddrPort, *natConn]
handler N.UDPConnectionHandlerEx
prepare PrepareFunc
metrics Metrics
}
type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc)
type Metrics struct {
Creates uint64
Rejects uint64
Inputs uint64
Drops uint64
}
func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service {
nat := common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
nat.SetLifetime(timeout)
nat.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool {
select {
case <-conn.doneChan:
return false
default:
return true
}
})
nat.SetOnEvict(func(_ netip.AddrPort, conn *natConn) {
conn.Close()
})
return &Service{
nat: nat,
handler: handler,
prepare: prepare,
}
}
func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) {
conn, loaded := s.nat.Get(source.AddrPort())
if !loaded {
ok, ctx, writer, onClose := s.prepare(source, destination, userData)
if !ok {
s.metrics.Rejects++
return
}
conn = &natConn{
writer: writer,
localAddr: source,
packetChan: make(chan *Packet, 64),
doneChan: make(chan struct{}),
readDeadline: pipe.MakeDeadline(),
}
s.nat.Add(source.AddrPort(), conn)
s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose)
s.metrics.Creates++
}
packet := NewPacket()
buffer := conn.readWaitOptions.NewPacketBuffer()
for _, bufferSlice := range bufferSlices {
buffer.Write(bufferSlice)
}
*packet = Packet{
Buffer: buffer,
Destination: destination,
}
select {
case conn.packetChan <- packet:
s.metrics.Inputs++
default:
packet.Buffer.Release()
PutPacket(packet)
s.metrics.Drops++
}
}
func (s *Service) Metrics() Metrics {
return s.metrics
}

View file

@ -31,6 +31,8 @@ type OnEvictCallback[K comparable, V any] func(K, V)
// HashKeyCallback is the function that creates a hash from the passed key. // HashKeyCallback is the function that creates a hash from the passed key.
type HashKeyCallback[K comparable] func(K) uint32 type HashKeyCallback[K comparable] func(K) uint32
type HealthCheckCallback[K comparable, V any] func(K, V) bool
type element[K comparable, V any] struct { type element[K comparable, V any] struct {
key K key K
value V value V
@ -61,12 +63,13 @@ const emptyBucket = math.MaxUint32
// LRU implements a non-thread safe fixed size LRU cache. // LRU implements a non-thread safe fixed size LRU cache.
type LRU[K comparable, V any] struct { type LRU[K comparable, V any] struct {
buckets []uint32 // contains positions of bucket lists or 'emptyBucket' buckets []uint32 // contains positions of bucket lists or 'emptyBucket'
elements []element[K, V] elements []element[K, V]
onEvict OnEvictCallback[K, V] onEvict OnEvictCallback[K, V]
hash HashKeyCallback[K] hash HashKeyCallback[K]
lifetime time.Duration healthCheck HealthCheckCallback[K, V]
metrics Metrics lifetime time.Duration
metrics Metrics
// used for element clearing after removal or expiration // used for element clearing after removal or expiration
emptyKey K emptyKey K
@ -108,6 +111,10 @@ func (lru *LRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) {
lru.onEvict = onEvict lru.onEvict = onEvict
} }
func (lru *LRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) {
lru.healthCheck = healthCheck
}
// New constructs an LRU with the given capacity of elements. // New constructs an LRU with the given capacity of elements.
// The hash function calculates a hash value from the keys. // The hash function calculates a hash value from the keys.
func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) { func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) {
@ -120,7 +127,8 @@ func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K,
// by reducing the chance of collisions. // by reducing the chance of collisions.
// Size must not be lower than the capacity. // Size must not be lower than the capacity.
func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallback[K]) ( func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallback[K]) (
*LRU[K, V], error) { *LRU[K, V], error,
) {
if capacity == 0 { if capacity == 0 {
return nil, errors.New("capacity must be positive") return nil, errors.New("capacity must be positive")
} }
@ -144,7 +152,8 @@ func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallbac
} }
func initLRU[K comparable, V any](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K], func initLRU[K comparable, V any](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K],
buckets []uint32, elements []element[K, V]) { buckets []uint32, elements []element[K, V],
) {
lru.cap = capacity lru.cap = capacity
lru.size = size lru.size = size
lru.hash = hash lru.hash = hash
@ -294,7 +303,7 @@ func (lru *LRU[K, V]) clearKeyAndValue(pos uint32) {
lru.elements[pos].value = lru.emptyValue lru.elements[pos].value = lru.emptyValue
} }
func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) { func (lru *LRU[K, V]) findKey(hash uint32, key K, updateLifetimeOnGet bool) (uint32, bool) {
_, startPos := lru.hashToPos(hash) _, startPos := lru.hashToPos(hash)
if startPos == emptyBucket { if startPos == emptyBucket {
return emptyBucket, false return emptyBucket, false
@ -303,10 +312,14 @@ func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) {
pos := startPos pos := startPos
for { for {
if key == lru.elements[pos].key { if key == lru.elements[pos].key {
if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= now() { elem := lru.elements[pos]
if (elem.expire != 0 && elem.expire <= now()) || (lru.healthCheck != nil && !lru.healthCheck(key, elem.value)) {
lru.removeAt(pos) lru.removeAt(pos)
return emptyBucket, false return emptyBucket, false
} }
if updateLifetimeOnGet {
lru.elements[pos].expire = expire(lru.lifetime)
}
return pos, true return pos, true
} }
@ -330,7 +343,8 @@ func (lru *LRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (e
} }
func (lru *LRU[K, V]) addWithLifetime(hash uint32, key K, value V, func (lru *LRU[K, V]) addWithLifetime(hash uint32, key K, value V,
lifetime time.Duration) (evicted bool) { lifetime time.Duration,
) (evicted bool) {
bucketPos, startPos := lru.hashToPos(hash) bucketPos, startPos := lru.hashToPos(hash)
if startPos == emptyBucket { if startPos == emptyBucket {
pos := lru.len pos := lru.len
@ -425,11 +439,11 @@ func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) {
// If the found cache item is already expired, the evict function is called // If the found cache item is already expired, the evict function is called
// and the return value indicates that the key was not found. // and the return value indicates that the key was not found.
func (lru *LRU[K, V]) Get(key K) (value V, ok bool) { func (lru *LRU[K, V]) Get(key K) (value V, ok bool) {
return lru.get(lru.hash(key), key) return lru.get(lru.hash(key), key, true)
} }
func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) { func (lru *LRU[K, V]) get(hash uint32, key K, updateLifetime bool) (value V, ok bool) {
if pos, ok := lru.findKey(hash, key); ok { if pos, ok := lru.findKey(hash, key, updateLifetime); ok {
if pos != lru.head { if pos != lru.head {
lru.unlinkElement(pos) lru.unlinkElement(pos)
lru.setHead(pos) lru.setHead(pos)
@ -449,7 +463,7 @@ func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) {
} }
func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) { func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) {
if pos, ok := lru.findKey(hash, key); ok { if pos, ok := lru.findKey(hash, key, false); ok {
return lru.elements[pos].value, ok return lru.elements[pos].value, ok
} }
@ -476,7 +490,7 @@ func (lru *LRU[K, V]) Remove(key K) (removed bool) {
} }
func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) { func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) {
if pos, ok := lru.findKey(hash, key); ok { if pos, ok := lru.findKey(hash, key, false); ok {
lru.removeAt(pos) lru.removeAt(pos)
return ok return ok
} }

View file

@ -0,0 +1,35 @@
package freelru_test
import (
"testing"
"time"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
"github.com/stretchr/testify/require"
)
func TestMyChange0(t *testing.T) {
t.Parallel()
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
require.NoError(t, err)
lru.AddWithLifetime("hello", "world", 2*time.Second)
time.Sleep(time.Second)
lru.Get("hello")
time.Sleep(time.Second + time.Millisecond*100)
_, ok := lru.Get("hello")
require.True(t, ok)
}
func TestMyChange1(t *testing.T) {
t.Parallel()
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
require.NoError(t, err)
lru.AddWithLifetime("hello", "world", 2*time.Second)
time.Sleep(time.Second)
lru.Peek("hello")
time.Sleep(time.Second + time.Millisecond*100)
_, ok := lru.Get("hello")
require.False(t, ok)
}

View file

@ -46,3 +46,8 @@ func (h Hasher[K]) Hash(key K) uint64 {
p := noescape(unsafe.Pointer(&key)) p := noescape(unsafe.Pointer(&key))
return uint64(h.hash(p, h.seed)) return uint64(h.hash(p, h.seed))
} }
func (h Hasher[K]) Hash32(key K) uint32 {
p := noescape(unsafe.Pointer(&key))
return uint32(h.hash(p, h.seed))
}

View file

@ -52,6 +52,7 @@ func newHashSeed() uintptr {
//go:nocheckptr //go:nocheckptr
func noescape(p unsafe.Pointer) unsafe.Pointer { func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p) x := uintptr(p)
//nolint:staticcheck
return unsafe.Pointer(x ^ 0) return unsafe.Pointer(x ^ 0)
} }
@ -91,9 +92,11 @@ type hmap struct {
} }
// go/src/runtime/type.go // go/src/runtime/type.go
type tflag uint8 type (
type nameOff int32 tflag uint8
type typeOff int32 nameOff int32
typeOff int32
)
// go/src/runtime/type.go // go/src/runtime/type.go
type _type struct { type _type struct {

View file

@ -37,7 +37,7 @@ func (c *LazyConn) ConnHandshakeSuccess(conn net.Conn) error {
Destination: M.SocksaddrFromNet(conn.LocalAddr()), Destination: M.SocksaddrFromNet(conn.LocalAddr()),
}) })
case socks5.Version: case socks5.Version:
return socks5.WriteResponse(conn, socks5.Response{ return socks5.WriteResponse(c.Conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess, ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(conn.LocalAddr()), Bind: M.SocksaddrFromNet(conn.LocalAddr()),
}) })
@ -211,5 +211,5 @@ func (c *LazyAssociatePacketConn) WriterReplaceable() bool {
} }
func (c *LazyAssociatePacketConn) Upstream() any { func (c *LazyAssociatePacketConn) Upstream() any {
return c.underlying return &c.AssociatePacketConn
} }