Fix udp nat

This commit is contained in:
世界 2022-04-29 17:43:09 +08:00
parent 8d95ae4cff
commit b4b6c838d1
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 144 additions and 331 deletions

View file

@ -36,6 +36,8 @@ import (
"github.com/spf13/cobra"
)
const udpTimeout = 5 * 60
type flags struct {
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
@ -256,7 +258,7 @@ func newClient(f *flags) (*client, error) {
bind = netip.IPv6Unspecified()
}
c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, c)
c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, udpTimeout, c)
if f.Bypass != "" {
err := geoip.LoadMMDB("Country.mmdb")

View file

@ -31,6 +31,8 @@ import (
"github.com/spf13/cobra"
)
const udpTimeout = 5 * 60
type flags struct {
Server string `json:"server"`
ServerPort uint16 `json:"server_port"`
@ -183,15 +185,15 @@ func newServer(f *flags) (*server, error) {
}
if f.Method == shadowsocks.MethodNone {
s.service = shadowsocks.NewNoneService(s)
s.service = shadowsocks.NewNoneService(udpTimeout, s)
} else if common.Contains(shadowaead.List, f.Method) {
service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, s)
service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, udpTimeout, s)
if err != nil {
return nil, err
}
s.service = service
} else if common.Contains(shadowaead_2022.List, f.Method) {
service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), s)
service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), udpTimeout, s)
if err != nil {
return nil, err
}

View file

@ -72,7 +72,7 @@ func run(cmd *cobra.Command, args []string) {
}
client := &localClient{upstream: args[1]}
client.Listener = mixed.NewListener(bind, nil, transproxyMode, client)
client.Listener = mixed.NewListener(bind, nil, transproxyMode, 300, client)
err = client.Start()
if err != nil {

View file

@ -3,108 +3,87 @@ package cache
// Modified by https://github.com/die-net/lrucache
import (
"container/list"
"sync"
"time"
"github.com/sagernet/sing/common/list"
)
// Option is part of Functional Options Pattern
type Option func(*LruCache)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback = func(key interface{}, value interface{})
// WithEvict set the evict callback
func WithEvict(cb EvictCallback) Option {
return func(l *LruCache) {
l.onEvict = cb
}
}
// WithUpdateAgeOnGet update expires when Get element
func WithUpdateAgeOnGet() Option {
return func(l *LruCache) {
l.updateAgeOnGet = true
}
}
// WithAge defined element max age (second)
func WithAge(maxAge int64) Option {
return func(l *LruCache) {
l.maxAge = maxAge
}
}
// WithSize defined max length of LruCache
func WithSize(maxSize int) Option {
return func(l *LruCache) {
l.maxSize = maxSize
}
}
// WithStale decide whether Stale return is enabled.
// If this feature is enabled, element will not get Evicted according to `WithAge`.
func WithStale(stale bool) Option {
return func(l *LruCache) {
l.staleReturn = stale
}
}
// LruCache is a thread-safe, in-memory lru-cache that evicts the
// least recently used entries from memory when (if set) the entries are
// older than maxAge (in seconds). Use the New constructor to create one.
type LruCache struct {
type LruCache[K comparable, V any] struct {
maxAge int64
maxSize int
mu sync.Mutex
cache map[interface{}]*list.Element
lru *list.List // Front is least-recent
cache map[K]*list.Element[*entry[K, V]]
lru list.List[*entry[K, V]] // Front is least-recent
updateAgeOnGet bool
staleReturn bool
onEvict EvictCallback
}
// NewLRUCache creates an LruCache
func NewLRUCache(options ...Option) *LruCache {
lc := &LruCache{
lru: list.New(),
cache: make(map[interface{}]*list.Element),
}
for _, option := range options {
option(lc)
func NewLRU[K comparable, V any](maxAge int64, updateAgeOnGet bool) LruCache[K, V] {
lc := LruCache[K, V]{
maxAge: maxAge,
updateAgeOnGet: updateAgeOnGet,
cache: make(map[K]*list.Element[*entry[K, V]]),
}
return lc
}
// Get returns the interface{} representation of a cached response and a bool
// set to true if the key was found.
func (c *LruCache) Get(key interface{}) (interface{}, bool) {
func (c *LruCache[K, V]) Load(key K) (V, bool) {
entry := c.get(key)
if entry == nil {
return nil, false
var defaultValue V
return defaultValue, false
}
value := entry.value
return value, true
}
// GetWithExpire returns the interface{} representation of a cached response,
// a time.Time Give expected expires,
// and a bool set to true if the key was found.
// This method will NOT check the maxAge of element and will NOT update the expires.
func (c *LruCache) GetWithExpire(key interface{}) (interface{}, time.Time, bool) {
func (c *LruCache[K, V]) LoadOrStore(key K, constructor func() V) (V, bool) {
c.mu.Lock()
defer c.mu.Unlock()
le, ok := c.cache[key]
if ok {
if c.maxAge > 0 && le.Value.expires <= time.Now().Unix() {
c.deleteElement(le)
goto create
}
c.lru.MoveToBack(le)
entry := le.Value
if c.maxAge > 0 && c.updateAgeOnGet {
entry.expires = time.Now().Unix() + c.maxAge
}
return entry.value, true
}
create:
value := constructor()
if le, ok := c.cache[key]; ok {
c.lru.MoveToBack(le)
e := le.Value
e.value = value
e.expires = time.Now().Unix() + c.maxAge
} else {
e := &entry[K, V]{key: key, value: value, expires: time.Now().Unix() + c.maxAge}
c.cache[key] = c.lru.PushBack(e)
}
c.maybeDeleteOldest()
return value, false
}
func (c *LruCache[K, V]) LoadWithExpire(key K) (V, time.Time, bool) {
entry := c.get(key)
if entry == nil {
return nil, time.Time{}, false
var defaultValue V
return defaultValue, time.Time{}, false
}
return entry.value, time.Unix(entry.expires, 0), true
}
// Exist returns if key exist in cache but not put item to the head of linked list
func (c *LruCache) Exist(key interface{}) bool {
func (c *LruCache[K, V]) Exist(key K) bool {
c.mu.Lock()
defer c.mu.Unlock()
@ -112,58 +91,48 @@ func (c *LruCache) Exist(key interface{}) bool {
return ok
}
// Set stores the interface{} representation of a response for a given key.
func (c *LruCache) Set(key interface{}, value interface{}) {
func (c *LruCache[K, V]) Store(key K, value V) {
expires := int64(0)
if c.maxAge > 0 {
expires = time.Now().Unix() + c.maxAge
}
c.SetWithExpire(key, value, time.Unix(expires, 0))
c.StoreWithExpire(key, value, time.Unix(expires, 0))
}
// SetWithExpire stores the interface{} representation of a response for a given key and given expires.
// The expires time will round to second.
func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires time.Time) {
func (c *LruCache[K, V]) StoreWithExpire(key K, value V, expires time.Time) {
c.mu.Lock()
defer c.mu.Unlock()
if le, ok := c.cache[key]; ok {
c.lru.MoveToBack(le)
e := le.Value.(*entry)
e := le.Value
e.value = value
e.expires = expires.Unix()
} else {
e := &entry{key: key, value: value, expires: expires.Unix()}
e := &entry[K, V]{key: key, value: value, expires: expires.Unix()}
c.cache[key] = c.lru.PushBack(e)
if c.maxSize > 0 {
if len := c.lru.Len(); len > c.maxSize {
c.deleteElement(c.lru.Front())
}
}
}
c.maybeDeleteOldest()
}
// CloneTo clone and overwrite elements to another LruCache
func (c *LruCache) CloneTo(n *LruCache) {
func (c *LruCache[K, V]) CloneTo(n *LruCache[K, V]) {
c.mu.Lock()
defer c.mu.Unlock()
n.mu.Lock()
defer n.mu.Unlock()
n.lru = list.New()
n.cache = make(map[interface{}]*list.Element)
n.lru = list.List[*entry[K, V]]{}
n.cache = make(map[K]*list.Element[*entry[K, V]])
for e := c.lru.Front(); e != nil; e = e.Next() {
elm := e.Value.(*entry)
elm := e.Value
n.cache[elm.key] = n.lru.PushBack(elm)
}
}
func (c *LruCache) get(key interface{}) *entry {
func (c *LruCache[K, V]) get(key K) *entry[K, V] {
c.mu.Lock()
defer c.mu.Unlock()
@ -172,7 +141,7 @@ func (c *LruCache) get(key interface{}) *entry {
return nil
}
if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() {
if c.maxAge > 0 && le.Value.expires <= time.Now().Unix() {
c.deleteElement(le)
c.maybeDeleteOldest()
@ -180,7 +149,7 @@ func (c *LruCache) get(key interface{}) *entry {
}
c.lru.MoveToBack(le)
entry := le.Value.(*entry)
entry := le.Value
if c.maxAge > 0 && c.updateAgeOnGet {
entry.expires = time.Now().Unix() + c.maxAge
}
@ -188,7 +157,7 @@ func (c *LruCache) get(key interface{}) *entry {
}
// Delete removes the value associated with a key.
func (c *LruCache) Delete(key interface{}) {
func (c *LruCache[K, V]) Delete(key K) {
c.mu.Lock()
if le, ok := c.cache[key]; ok {
@ -198,26 +167,21 @@ func (c *LruCache) Delete(key interface{}) {
c.mu.Unlock()
}
func (c *LruCache) maybeDeleteOldest() {
if !c.staleReturn && c.maxAge > 0 {
now := time.Now().Unix()
for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() {
c.deleteElement(le)
}
func (c *LruCache[K, V]) maybeDeleteOldest() {
now := time.Now().Unix()
for le := c.lru.Front(); le != nil && le.Value.expires <= now; le = c.lru.Front() {
c.deleteElement(le)
}
}
func (c *LruCache) deleteElement(le *list.Element) {
func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) {
c.lru.Remove(le)
e := le.Value.(*entry)
e := le.Value
delete(c.cache, e.key)
if c.onEvict != nil {
c.onEvict(e.key, e.value)
}
}
type entry struct {
key interface{}
value interface{}
type entry[K comparable, V any] struct {
key K
value V
expires int64
}

View file

@ -1,183 +0,0 @@
package cache
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var entries = []struct {
key string
value string
}{
{"1", "one"},
{"2", "two"},
{"3", "three"},
{"4", "four"},
{"5", "five"},
}
func TestLRUCache(t *testing.T) {
c := NewLRUCache()
for _, e := range entries {
c.Set(e.key, e.value)
}
c.Delete("missing")
_, ok := c.Get("missing")
assert.False(t, ok)
for _, e := range entries {
value, ok := c.Get(e.key)
if assert.True(t, ok) {
assert.Equal(t, e.value, value.(string))
}
}
for _, e := range entries {
c.Delete(e.key)
_, ok := c.Get(e.key)
assert.False(t, ok)
}
}
func TestLRUMaxAge(t *testing.T) {
c := NewLRUCache(WithAge(86400))
now := time.Now().Unix()
expected := now + 86400
// Add one expired entry
c.Set("foo", "bar")
c.lru.Back().Value.(*entry).expires = now
// Reset
c.Set("foo", "bar")
e := c.lru.Back().Value.(*entry)
assert.True(t, e.expires >= now)
c.lru.Back().Value.(*entry).expires = now
// Set a few and verify expiration times
for _, s := range entries {
c.Set(s.key, s.value)
e := c.lru.Back().Value.(*entry)
assert.True(t, e.expires >= expected && e.expires <= expected+10)
}
// Make sure we can get them all
for _, s := range entries {
_, ok := c.Get(s.key)
assert.True(t, ok)
}
// Expire all entries
for _, s := range entries {
le, ok := c.cache[s.key]
if assert.True(t, ok) {
le.Value.(*entry).expires = now
}
}
// Get one expired entry, which should clear all expired entries
_, ok := c.Get("3")
assert.False(t, ok)
assert.Equal(t, c.lru.Len(), 0)
}
func TestLRUpdateOnGet(t *testing.T) {
c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet())
now := time.Now().Unix()
expires := now + 86400/2
// Add one expired entry
c.Set("foo", "bar")
c.lru.Back().Value.(*entry).expires = expires
_, ok := c.Get("foo")
assert.True(t, ok)
assert.True(t, c.lru.Back().Value.(*entry).expires > expires)
}
func TestMaxSize(t *testing.T) {
c := NewLRUCache(WithSize(2))
// Add one expired entry
c.Set("foo", "bar")
_, ok := c.Get("foo")
assert.True(t, ok)
c.Set("bar", "foo")
c.Set("baz", "foo")
_, ok = c.Get("foo")
assert.False(t, ok)
}
func TestExist(t *testing.T) {
c := NewLRUCache(WithSize(1))
c.Set(1, 2)
assert.True(t, c.Exist(1))
c.Set(2, 3)
assert.False(t, c.Exist(1))
}
func TestEvict(t *testing.T) {
temp := 0
evict := func(key interface{}, value interface{}) {
temp = key.(int) + value.(int)
}
c := NewLRUCache(WithEvict(evict), WithSize(1))
c.Set(1, 2)
c.Set(2, 3)
assert.Equal(t, temp, 3)
}
func TestSetWithExpire(t *testing.T) {
c := NewLRUCache(WithAge(1))
now := time.Now().Unix()
tenSecBefore := time.Unix(now-10, 0)
c.SetWithExpire(1, 2, tenSecBefore)
// res is expected not to exist, and expires should be empty time.Time
res, expires, exist := c.GetWithExpire(1)
assert.Equal(t, nil, res)
assert.Equal(t, time.Time{}, expires)
assert.Equal(t, false, exist)
}
func TestStale(t *testing.T) {
c := NewLRUCache(WithAge(1), WithStale(true))
now := time.Now().Unix()
tenSecBefore := time.Unix(now-10, 0)
c.SetWithExpire(1, 2, tenSecBefore)
res, expires, exist := c.GetWithExpire(1)
assert.Equal(t, 2, res)
assert.Equal(t, tenSecBefore, expires)
assert.Equal(t, true, exist)
}
func TestCloneTo(t *testing.T) {
o := NewLRUCache(WithSize(10))
o.Set("1", 1)
o.Set("2", 2)
n := NewLRUCache(WithSize(2))
n.Set("3", 3)
n.Set("4", 4)
o.CloneTo(n)
assert.False(t, n.Exist("3"))
assert.True(t, n.Exist("1"))
n.Set("5", 5)
assert.False(t, n.Exist("1"))
}

View file

@ -38,6 +38,10 @@ func (ap AddrPort) UDPAddr() *net.UDPAddr {
}
}
func (ap AddrPort) AddrPort() netip.AddrPort {
return netip.AddrPortFrom(ap.Addr.Addr(), ap.Port)
}
func (ap AddrPort) String() string {
return net.JoinHostPort(ap.Addr.String(), strconv.Itoa(int(ap.Port)))
}

View file

@ -9,8 +9,8 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks"
)
@ -21,17 +21,18 @@ type Handler interface {
}
type Service[K comparable] struct {
nat gsync.Map[K, *conn]
nat cache.LruCache[K, *conn]
handler Handler
}
func New[T comparable](handler Handler) *Service[T] {
return &Service[T]{
func New[K comparable](maxAge int64, handler Handler) Service[K] {
return Service[K]{
nat: cache.NewLRU[K, *conn](maxAge, true),
handler: handler,
}
}
func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) error {
func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) {
c, loaded := s.nat.LoadOrStore(key, func() *conn {
c := &conn{
data: make(chan packet),
@ -57,7 +58,6 @@ func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *
}
c.data <- p
<-ctx.Done()
return nil
}
type packet struct {

View file

@ -3,6 +3,7 @@ package shadowsocks
import (
"context"
"net"
"net/netip"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
@ -24,14 +25,14 @@ type Handler interface {
type NoneService struct {
handler Handler
udp *udpnat.Service[string]
udp udpnat.Service[netip.AddrPort]
}
func NewNoneService(handler Handler) Service {
func NewNoneService(udpTimeout int64, handler Handler) Service {
s := &NoneService{
handler: handler,
}
s.udp = udpnat.New[string](s)
s.udp = udpnat.New[netip.AddrPort](udpTimeout, s)
return s
}
@ -52,9 +53,10 @@ func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metad
}
metadata.Protocol = "shadowsocks"
metadata.Destination = destination
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
s.udp.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
return &serverPacketWriter{conn, metadata.Source}
}, buffer, metadata)
return nil
}
type serverPacketWriter struct {

View file

@ -5,6 +5,7 @@ import (
"crypto/cipher"
"io"
"net"
"net/netip"
"sync"
"github.com/sagernet/sing/common"
@ -26,17 +27,17 @@ type Service struct {
key []byte
secureRNG io.Reader
replayFilter replay.Filter
udp *udpnat.Service[string]
udpNat udpnat.Service[netip.AddrPort]
handler shadowsocks.Handler
}
func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, handler shadowsocks.Handler) (shadowsocks.Service, error) {
func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{
name: method,
secureRNG: secureRNG,
handler: handler,
}
s.udp = udpnat.New[string](s)
s.udpNat = udpnat.New[netip.AddrPort](udpTimeout, s)
if replayFilter {
s.replayFilter = replay.NewBloomRing()
}
@ -190,9 +191,10 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
buffer.Advance(s.keySaltLength)
buffer.Truncate(len(packet))
metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
s.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source}
}, buffer, metadata)
return nil
}
type serverPacketWriter struct {

View file

@ -8,14 +8,15 @@ import (
"io"
"math"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw"
@ -37,17 +38,18 @@ type Service struct {
psk []byte
replayFilter replay.Filter
handler shadowsocks.Handler
udpNat *udpnat.Service[uint64]
sessions gsync.Map[uint64, *serverUDPSession]
udpNat udpnat.Service[uint64]
sessions cache.LruCache[uint64, *serverUDPSession]
}
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) {
func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{
name: method,
psk: psk,
secureRNG: secureRNG,
replayFilter: replay.NewCuckoo(60),
handler: handler,
sessions: cache.NewLRU[uint64, *serverUDPSession](udpTimeout, true),
}
if len(psk) != KeySaltSize {
@ -71,7 +73,7 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso
s.udpCipher = newXChacha20Poly1305(s.psk)
}
s.udpNat = udpnat.New[uint64](s)
s.udpNat = udpnat.New[uint64](udpTimeout, s)
return s, nil
}
@ -239,53 +241,69 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
session.remoteCipher = s.constructor(common.Dup(key))
}
}
session.remoteAddr = metadata.Source.AddrPort()
goto process
returnErr:
if !loaded {
s.sessions.Delete(sessionId)
}
return err
process:
if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
return ErrPacketIdNotUnique
err = ErrPacketIdNotUnique
goto returnErr
}
if packetHeader != nil {
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil {
return E.Cause(err, "decrypt packet")
err = E.Cause(err, "decrypt packet")
goto returnErr
}
}
var headerType byte
headerType, err = buffer.ReadByte()
if err != nil {
return err
err = E.Cause(err, "decrypt packet")
goto returnErr
}
if headerType != HeaderTypeClient {
return ErrBadHeaderType
err = ErrBadHeaderType
goto returnErr
}
var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil {
return err
goto returnErr
}
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
return ErrBadTimestamp
err = ErrBadTimestamp
goto returnErr
}
var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil {
return E.Cause(err, "read padding length")
err = E.Cause(err, "read padding length")
goto returnErr
}
buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
if err != nil {
return err
goto returnErr
}
metadata.Destination = destination
return s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
return &serverPacketWriter{s, conn, session, metadata.Source}
}, buffer, metadata)
return nil
}
type serverPacketWriter struct {
@ -343,6 +361,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.Addr
type serverUDPSession struct {
sessionId uint64
remoteSessionId uint64
remoteAddr netip.AddrPort
packetId uint64
cipher cipher.AEAD
remoteCipher cipher.AEAD

View file

@ -32,10 +32,10 @@ type Listener struct {
bindAddr netip.Addr
handler Handler
authenticator auth.Authenticator
udpNat *udpnat.Service[string]
udpNat udpnat.Service[netip.AddrPort]
}
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener {
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, udpTimeout int64, handler Handler) *Listener {
listener := &Listener{
bindAddr: bind.Addr(),
handler: handler,
@ -45,7 +45,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro
listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy))
if transproxy == redir.ModeTProxy {
listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy))
listener.udpNat = udpnat.New[string](handler)
listener.udpNat = udpnat.New[netip.AddrPort](udpTimeout, handler)
}
return listener
}
@ -97,9 +97,10 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
}
func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
return l.udpNat.NewPacket(metadata.Source.String(), func() socks.PacketWriter {
l.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
return &tproxyPacketWriter{metadata.Source.UDPAddr()}
}, buffer, metadata)
return nil
}
type tproxyPacketWriter struct {