diff --git a/cli/ss-local/main.go b/cli/ss-local/main.go index 33dcd92..78a5d1c 100644 --- a/cli/ss-local/main.go +++ b/cli/ss-local/main.go @@ -36,6 +36,8 @@ import ( "github.com/spf13/cobra" ) +const udpTimeout = 5 * 60 + type flags struct { Server string `json:"server"` ServerPort uint16 `json:"server_port"` @@ -256,7 +258,7 @@ func newClient(f *flags) (*client, error) { bind = netip.IPv6Unspecified() } - c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, c) + c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, udpTimeout, c) if f.Bypass != "" { err := geoip.LoadMMDB("Country.mmdb") diff --git a/cli/ss-server/main.go b/cli/ss-server/main.go index c692c5b..26c3b9b 100644 --- a/cli/ss-server/main.go +++ b/cli/ss-server/main.go @@ -31,6 +31,8 @@ import ( "github.com/spf13/cobra" ) +const udpTimeout = 5 * 60 + type flags struct { Server string `json:"server"` ServerPort uint16 `json:"server_port"` @@ -183,15 +185,15 @@ func newServer(f *flags) (*server, error) { } if f.Method == shadowsocks.MethodNone { - s.service = shadowsocks.NewNoneService(s) + s.service = shadowsocks.NewNoneService(udpTimeout, s) } else if common.Contains(shadowaead.List, f.Method) { - service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, s) + service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, udpTimeout, s) if err != nil { return nil, err } s.service = service } else if common.Contains(shadowaead_2022.List, f.Method) { - service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), s) + service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), udpTimeout, s) if err != nil { return nil, err } diff --git a/cli/uot-local/main.go b/cli/uot-local/main.go index 67e32f5..798000b 100644 --- a/cli/uot-local/main.go +++ b/cli/uot-local/main.go @@ -72,7 +72,7 @@ func run(cmd *cobra.Command, args []string) { } client := &localClient{upstream: args[1]} - client.Listener = mixed.NewListener(bind, nil, transproxyMode, client) + client.Listener = mixed.NewListener(bind, nil, transproxyMode, 300, client) err = client.Start() if err != nil { diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go index 4269d86..454a410 100644 --- a/common/cache/lrucache.go +++ b/common/cache/lrucache.go @@ -3,108 +3,87 @@ package cache // Modified by https://github.com/die-net/lrucache import ( - "container/list" "sync" "time" + + "github.com/sagernet/sing/common/list" ) -// Option is part of Functional Options Pattern -type Option func(*LruCache) - -// EvictCallback is used to get a callback when a cache entry is evicted -type EvictCallback = func(key interface{}, value interface{}) - -// WithEvict set the evict callback -func WithEvict(cb EvictCallback) Option { - return func(l *LruCache) { - l.onEvict = cb - } -} - -// WithUpdateAgeOnGet update expires when Get element -func WithUpdateAgeOnGet() Option { - return func(l *LruCache) { - l.updateAgeOnGet = true - } -} - -// WithAge defined element max age (second) -func WithAge(maxAge int64) Option { - return func(l *LruCache) { - l.maxAge = maxAge - } -} - -// WithSize defined max length of LruCache -func WithSize(maxSize int) Option { - return func(l *LruCache) { - l.maxSize = maxSize - } -} - -// WithStale decide whether Stale return is enabled. -// If this feature is enabled, element will not get Evicted according to `WithAge`. -func WithStale(stale bool) Option { - return func(l *LruCache) { - l.staleReturn = stale - } -} - -// LruCache is a thread-safe, in-memory lru-cache that evicts the -// least recently used entries from memory when (if set) the entries are -// older than maxAge (in seconds). Use the New constructor to create one. -type LruCache struct { +type LruCache[K comparable, V any] struct { maxAge int64 - maxSize int mu sync.Mutex - cache map[interface{}]*list.Element - lru *list.List // Front is least-recent + cache map[K]*list.Element[*entry[K, V]] + lru list.List[*entry[K, V]] // Front is least-recent updateAgeOnGet bool - staleReturn bool - onEvict EvictCallback } -// NewLRUCache creates an LruCache -func NewLRUCache(options ...Option) *LruCache { - lc := &LruCache{ - lru: list.New(), - cache: make(map[interface{}]*list.Element), - } - - for _, option := range options { - option(lc) +func NewLRU[K comparable, V any](maxAge int64, updateAgeOnGet bool) LruCache[K, V] { + lc := LruCache[K, V]{ + maxAge: maxAge, + updateAgeOnGet: updateAgeOnGet, + cache: make(map[K]*list.Element[*entry[K, V]]), } return lc } -// Get returns the interface{} representation of a cached response and a bool -// set to true if the key was found. -func (c *LruCache) Get(key interface{}) (interface{}, bool) { +func (c *LruCache[K, V]) Load(key K) (V, bool) { entry := c.get(key) if entry == nil { - return nil, false + var defaultValue V + return defaultValue, false } value := entry.value return value, true } -// GetWithExpire returns the interface{} representation of a cached response, -// a time.Time Give expected expires, -// and a bool set to true if the key was found. -// This method will NOT check the maxAge of element and will NOT update the expires. -func (c *LruCache) GetWithExpire(key interface{}) (interface{}, time.Time, bool) { +func (c *LruCache[K, V]) LoadOrStore(key K, constructor func() V) (V, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + le, ok := c.cache[key] + if ok { + if c.maxAge > 0 && le.Value.expires <= time.Now().Unix() { + c.deleteElement(le) + goto create + } + + c.lru.MoveToBack(le) + entry := le.Value + if c.maxAge > 0 && c.updateAgeOnGet { + entry.expires = time.Now().Unix() + c.maxAge + } + return entry.value, true + } + +create: + value := constructor() + if le, ok := c.cache[key]; ok { + c.lru.MoveToBack(le) + e := le.Value + e.value = value + e.expires = time.Now().Unix() + c.maxAge + } else { + e := &entry[K, V]{key: key, value: value, expires: time.Now().Unix() + c.maxAge} + c.cache[key] = c.lru.PushBack(e) + } + + c.maybeDeleteOldest() + return value, false +} + +func (c *LruCache[K, V]) LoadWithExpire(key K) (V, time.Time, bool) { entry := c.get(key) if entry == nil { - return nil, time.Time{}, false + var defaultValue V + return defaultValue, time.Time{}, false } return entry.value, time.Unix(entry.expires, 0), true } -// Exist returns if key exist in cache but not put item to the head of linked list -func (c *LruCache) Exist(key interface{}) bool { +func (c *LruCache[K, V]) Exist(key K) bool { c.mu.Lock() defer c.mu.Unlock() @@ -112,58 +91,48 @@ func (c *LruCache) Exist(key interface{}) bool { return ok } -// Set stores the interface{} representation of a response for a given key. -func (c *LruCache) Set(key interface{}, value interface{}) { +func (c *LruCache[K, V]) Store(key K, value V) { expires := int64(0) if c.maxAge > 0 { expires = time.Now().Unix() + c.maxAge } - c.SetWithExpire(key, value, time.Unix(expires, 0)) + c.StoreWithExpire(key, value, time.Unix(expires, 0)) } -// SetWithExpire stores the interface{} representation of a response for a given key and given expires. -// The expires time will round to second. -func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires time.Time) { +func (c *LruCache[K, V]) StoreWithExpire(key K, value V, expires time.Time) { c.mu.Lock() defer c.mu.Unlock() if le, ok := c.cache[key]; ok { c.lru.MoveToBack(le) - e := le.Value.(*entry) + e := le.Value e.value = value e.expires = expires.Unix() } else { - e := &entry{key: key, value: value, expires: expires.Unix()} + e := &entry[K, V]{key: key, value: value, expires: expires.Unix()} c.cache[key] = c.lru.PushBack(e) - - if c.maxSize > 0 { - if len := c.lru.Len(); len > c.maxSize { - c.deleteElement(c.lru.Front()) - } - } } c.maybeDeleteOldest() } -// CloneTo clone and overwrite elements to another LruCache -func (c *LruCache) CloneTo(n *LruCache) { +func (c *LruCache[K, V]) CloneTo(n *LruCache[K, V]) { c.mu.Lock() defer c.mu.Unlock() n.mu.Lock() defer n.mu.Unlock() - n.lru = list.New() - n.cache = make(map[interface{}]*list.Element) + n.lru = list.List[*entry[K, V]]{} + n.cache = make(map[K]*list.Element[*entry[K, V]]) for e := c.lru.Front(); e != nil; e = e.Next() { - elm := e.Value.(*entry) + elm := e.Value n.cache[elm.key] = n.lru.PushBack(elm) } } -func (c *LruCache) get(key interface{}) *entry { +func (c *LruCache[K, V]) get(key K) *entry[K, V] { c.mu.Lock() defer c.mu.Unlock() @@ -172,7 +141,7 @@ func (c *LruCache) get(key interface{}) *entry { return nil } - if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { + if c.maxAge > 0 && le.Value.expires <= time.Now().Unix() { c.deleteElement(le) c.maybeDeleteOldest() @@ -180,7 +149,7 @@ func (c *LruCache) get(key interface{}) *entry { } c.lru.MoveToBack(le) - entry := le.Value.(*entry) + entry := le.Value if c.maxAge > 0 && c.updateAgeOnGet { entry.expires = time.Now().Unix() + c.maxAge } @@ -188,7 +157,7 @@ func (c *LruCache) get(key interface{}) *entry { } // Delete removes the value associated with a key. -func (c *LruCache) Delete(key interface{}) { +func (c *LruCache[K, V]) Delete(key K) { c.mu.Lock() if le, ok := c.cache[key]; ok { @@ -198,26 +167,21 @@ func (c *LruCache) Delete(key interface{}) { c.mu.Unlock() } -func (c *LruCache) maybeDeleteOldest() { - if !c.staleReturn && c.maxAge > 0 { - now := time.Now().Unix() - for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() { - c.deleteElement(le) - } +func (c *LruCache[K, V]) maybeDeleteOldest() { + now := time.Now().Unix() + for le := c.lru.Front(); le != nil && le.Value.expires <= now; le = c.lru.Front() { + c.deleteElement(le) } } -func (c *LruCache) deleteElement(le *list.Element) { +func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) { c.lru.Remove(le) - e := le.Value.(*entry) + e := le.Value delete(c.cache, e.key) - if c.onEvict != nil { - c.onEvict(e.key, e.value) - } } -type entry struct { - key interface{} - value interface{} +type entry[K comparable, V any] struct { + key K + value V expires int64 } diff --git a/common/cache/lrucache_test.go b/common/cache/lrucache_test.go deleted file mode 100644 index 8a04f74..0000000 --- a/common/cache/lrucache_test.go +++ /dev/null @@ -1,183 +0,0 @@ -package cache - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -var entries = []struct { - key string - value string -}{ - {"1", "one"}, - {"2", "two"}, - {"3", "three"}, - {"4", "four"}, - {"5", "five"}, -} - -func TestLRUCache(t *testing.T) { - c := NewLRUCache() - - for _, e := range entries { - c.Set(e.key, e.value) - } - - c.Delete("missing") - _, ok := c.Get("missing") - assert.False(t, ok) - - for _, e := range entries { - value, ok := c.Get(e.key) - if assert.True(t, ok) { - assert.Equal(t, e.value, value.(string)) - } - } - - for _, e := range entries { - c.Delete(e.key) - - _, ok := c.Get(e.key) - assert.False(t, ok) - } -} - -func TestLRUMaxAge(t *testing.T) { - c := NewLRUCache(WithAge(86400)) - - now := time.Now().Unix() - expected := now + 86400 - - // Add one expired entry - c.Set("foo", "bar") - c.lru.Back().Value.(*entry).expires = now - - // Reset - c.Set("foo", "bar") - e := c.lru.Back().Value.(*entry) - assert.True(t, e.expires >= now) - c.lru.Back().Value.(*entry).expires = now - - // Set a few and verify expiration times - for _, s := range entries { - c.Set(s.key, s.value) - e := c.lru.Back().Value.(*entry) - assert.True(t, e.expires >= expected && e.expires <= expected+10) - } - - // Make sure we can get them all - for _, s := range entries { - _, ok := c.Get(s.key) - assert.True(t, ok) - } - - // Expire all entries - for _, s := range entries { - le, ok := c.cache[s.key] - if assert.True(t, ok) { - le.Value.(*entry).expires = now - } - } - - // Get one expired entry, which should clear all expired entries - _, ok := c.Get("3") - assert.False(t, ok) - assert.Equal(t, c.lru.Len(), 0) -} - -func TestLRUpdateOnGet(t *testing.T) { - c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet()) - - now := time.Now().Unix() - expires := now + 86400/2 - - // Add one expired entry - c.Set("foo", "bar") - c.lru.Back().Value.(*entry).expires = expires - - _, ok := c.Get("foo") - assert.True(t, ok) - assert.True(t, c.lru.Back().Value.(*entry).expires > expires) -} - -func TestMaxSize(t *testing.T) { - c := NewLRUCache(WithSize(2)) - // Add one expired entry - c.Set("foo", "bar") - _, ok := c.Get("foo") - assert.True(t, ok) - - c.Set("bar", "foo") - c.Set("baz", "foo") - - _, ok = c.Get("foo") - assert.False(t, ok) -} - -func TestExist(t *testing.T) { - c := NewLRUCache(WithSize(1)) - c.Set(1, 2) - assert.True(t, c.Exist(1)) - c.Set(2, 3) - assert.False(t, c.Exist(1)) -} - -func TestEvict(t *testing.T) { - temp := 0 - evict := func(key interface{}, value interface{}) { - temp = key.(int) + value.(int) - } - - c := NewLRUCache(WithEvict(evict), WithSize(1)) - c.Set(1, 2) - c.Set(2, 3) - - assert.Equal(t, temp, 3) -} - -func TestSetWithExpire(t *testing.T) { - c := NewLRUCache(WithAge(1)) - now := time.Now().Unix() - - tenSecBefore := time.Unix(now-10, 0) - c.SetWithExpire(1, 2, tenSecBefore) - - // res is expected not to exist, and expires should be empty time.Time - res, expires, exist := c.GetWithExpire(1) - assert.Equal(t, nil, res) - assert.Equal(t, time.Time{}, expires) - assert.Equal(t, false, exist) -} - -func TestStale(t *testing.T) { - c := NewLRUCache(WithAge(1), WithStale(true)) - now := time.Now().Unix() - - tenSecBefore := time.Unix(now-10, 0) - c.SetWithExpire(1, 2, tenSecBefore) - - res, expires, exist := c.GetWithExpire(1) - assert.Equal(t, 2, res) - assert.Equal(t, tenSecBefore, expires) - assert.Equal(t, true, exist) -} - -func TestCloneTo(t *testing.T) { - o := NewLRUCache(WithSize(10)) - o.Set("1", 1) - o.Set("2", 2) - - n := NewLRUCache(WithSize(2)) - n.Set("3", 3) - n.Set("4", 4) - - o.CloneTo(n) - - assert.False(t, n.Exist("3")) - assert.True(t, n.Exist("1")) - - n.Set("5", 5) - assert.False(t, n.Exist("1")) -} diff --git a/common/metadata/addr.go b/common/metadata/addr.go index 2ad2e3c..c0575ca 100644 --- a/common/metadata/addr.go +++ b/common/metadata/addr.go @@ -38,6 +38,10 @@ func (ap AddrPort) UDPAddr() *net.UDPAddr { } } +func (ap AddrPort) AddrPort() netip.AddrPort { + return netip.AddrPortFrom(ap.Addr.Addr(), ap.Port) +} + func (ap AddrPort) String() string { return net.JoinHostPort(ap.Addr.String(), strconv.Itoa(int(ap.Port))) } diff --git a/common/udpnat/service.go b/common/udpnat/service.go index 7634c43..b8d784e 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -9,8 +9,8 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/cache" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/gsync" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/protocol/socks" ) @@ -21,17 +21,18 @@ type Handler interface { } type Service[K comparable] struct { - nat gsync.Map[K, *conn] + nat cache.LruCache[K, *conn] handler Handler } -func New[T comparable](handler Handler) *Service[T] { - return &Service[T]{ +func New[K comparable](maxAge int64, handler Handler) Service[K] { + return Service[K]{ + nat: cache.NewLRU[K, *conn](maxAge, true), handler: handler, } } -func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) error { +func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) { c, loaded := s.nat.LoadOrStore(key, func() *conn { c := &conn{ data: make(chan packet), @@ -57,7 +58,6 @@ func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer * } c.data <- p <-ctx.Done() - return nil } type packet struct { diff --git a/protocol/shadowsocks/service.go b/protocol/shadowsocks/service.go index 8c204e7..8a4ebd5 100644 --- a/protocol/shadowsocks/service.go +++ b/protocol/shadowsocks/service.go @@ -3,6 +3,7 @@ package shadowsocks import ( "context" "net" + "net/netip" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" @@ -24,14 +25,14 @@ type Handler interface { type NoneService struct { handler Handler - udp *udpnat.Service[string] + udp udpnat.Service[netip.AddrPort] } -func NewNoneService(handler Handler) Service { +func NewNoneService(udpTimeout int64, handler Handler) Service { s := &NoneService{ handler: handler, } - s.udp = udpnat.New[string](s) + s.udp = udpnat.New[netip.AddrPort](udpTimeout, s) return s } @@ -52,9 +53,10 @@ func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metad } metadata.Protocol = "shadowsocks" metadata.Destination = destination - return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter { + s.udp.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter { return &serverPacketWriter{conn, metadata.Source} }, buffer, metadata) + return nil } type serverPacketWriter struct { diff --git a/protocol/shadowsocks/shadowaead/service.go b/protocol/shadowsocks/shadowaead/service.go index 461302e..bc6baf6 100644 --- a/protocol/shadowsocks/shadowaead/service.go +++ b/protocol/shadowsocks/shadowaead/service.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "io" "net" + "net/netip" "sync" "github.com/sagernet/sing/common" @@ -26,17 +27,17 @@ type Service struct { key []byte secureRNG io.Reader replayFilter replay.Filter - udp *udpnat.Service[string] + udpNat udpnat.Service[netip.AddrPort] handler shadowsocks.Handler } -func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, handler shadowsocks.Handler) (shadowsocks.Service, error) { +func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { s := &Service{ name: method, secureRNG: secureRNG, handler: handler, } - s.udp = udpnat.New[string](s) + s.udpNat = udpnat.New[netip.AddrPort](udpTimeout, s) if replayFilter { s.replayFilter = replay.NewBloomRing() } @@ -190,9 +191,10 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata buffer.Advance(s.keySaltLength) buffer.Truncate(len(packet)) metadata.Protocol = "shadowsocks" - return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter { + s.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter { return &serverPacketWriter{s, conn, metadata.Source} }, buffer, metadata) + return nil } type serverPacketWriter struct { diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index 179e19e..b00ab2c 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -8,14 +8,15 @@ import ( "io" "math" "net" + "net/netip" "sync" "sync/atomic" "time" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/cache" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/gsync" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/rw" @@ -37,17 +38,18 @@ type Service struct { psk []byte replayFilter replay.Filter handler shadowsocks.Handler - udpNat *udpnat.Service[uint64] - sessions gsync.Map[uint64, *serverUDPSession] + udpNat udpnat.Service[uint64] + sessions cache.LruCache[uint64, *serverUDPSession] } -func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) { +func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) { s := &Service{ name: method, psk: psk, secureRNG: secureRNG, replayFilter: replay.NewCuckoo(60), handler: handler, + sessions: cache.NewLRU[uint64, *serverUDPSession](udpTimeout, true), } if len(psk) != KeySaltSize { @@ -71,7 +73,7 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso s.udpCipher = newXChacha20Poly1305(s.psk) } - s.udpNat = udpnat.New[uint64](s) + s.udpNat = udpnat.New[uint64](udpTimeout, s) return s, nil } @@ -239,53 +241,69 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata session.remoteCipher = s.constructor(common.Dup(key)) } } + session.remoteAddr = metadata.Source.AddrPort() + goto process + +returnErr: + if !loaded { + s.sessions.Delete(sessionId) + } + return err + +process: if !session.filter.ValidateCounter(packetId, math.MaxUint64) { - return ErrPacketIdNotUnique + err = ErrPacketIdNotUnique + goto returnErr } if packetHeader != nil { _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) if err != nil { - return E.Cause(err, "decrypt packet") + err = E.Cause(err, "decrypt packet") + goto returnErr } } var headerType byte headerType, err = buffer.ReadByte() if err != nil { - return err + err = E.Cause(err, "decrypt packet") + goto returnErr } - if headerType != HeaderTypeClient { - return ErrBadHeaderType + err = ErrBadHeaderType + goto returnErr } var epoch uint64 err = binary.Read(buffer, binary.BigEndian, &epoch) if err != nil { - return err + goto returnErr } if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 { - return ErrBadTimestamp + err = ErrBadTimestamp + goto returnErr } var paddingLength uint16 err = binary.Read(buffer, binary.BigEndian, &paddingLength) if err != nil { - return E.Cause(err, "read padding length") + err = E.Cause(err, "read padding length") + goto returnErr } buffer.Advance(int(paddingLength)) destination, err := socks.AddressSerializer.ReadAddrPort(buffer) if err != nil { - return err + goto returnErr } metadata.Destination = destination - return s.udpNat.NewPacket(sessionId, func() socks.PacketWriter { + s.udpNat.NewPacket(sessionId, func() socks.PacketWriter { return &serverPacketWriter{s, conn, session, metadata.Source} }, buffer, metadata) + return nil } type serverPacketWriter struct { @@ -343,6 +361,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.Addr type serverUDPSession struct { sessionId uint64 remoteSessionId uint64 + remoteAddr netip.AddrPort packetId uint64 cipher cipher.AEAD remoteCipher cipher.AEAD diff --git a/transport/mixed/listener.go b/transport/mixed/listener.go index 7e5646a..1376c4d 100644 --- a/transport/mixed/listener.go +++ b/transport/mixed/listener.go @@ -32,10 +32,10 @@ type Listener struct { bindAddr netip.Addr handler Handler authenticator auth.Authenticator - udpNat *udpnat.Service[string] + udpNat udpnat.Service[netip.AddrPort] } -func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener { +func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, udpTimeout int64, handler Handler) *Listener { listener := &Listener{ bindAddr: bind.Addr(), handler: handler, @@ -45,7 +45,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy)) if transproxy == redir.ModeTProxy { listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy)) - listener.udpNat = udpnat.New[string](handler) + listener.udpNat = udpnat.New[netip.AddrPort](udpTimeout, handler) } return listener } @@ -97,9 +97,10 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M. } func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { - return l.udpNat.NewPacket(metadata.Source.String(), func() socks.PacketWriter { + l.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter { return &tproxyPacketWriter{metadata.Source.UDPAddr()} }, buffer, metadata) + return nil } type tproxyPacketWriter struct {