Fix udp nat

This commit is contained in:
世界 2022-04-29 17:43:09 +08:00
parent 8d95ae4cff
commit b4b6c838d1
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 144 additions and 331 deletions

View file

@ -36,6 +36,8 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
const udpTimeout = 5 * 60
type flags struct { type flags struct {
Server string `json:"server"` Server string `json:"server"`
ServerPort uint16 `json:"server_port"` ServerPort uint16 `json:"server_port"`
@ -256,7 +258,7 @@ func newClient(f *flags) (*client, error) {
bind = netip.IPv6Unspecified() bind = netip.IPv6Unspecified()
} }
c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, c) c.Listener = mixed.NewListener(netip.AddrPortFrom(bind, f.LocalPort), nil, transproxyMode, udpTimeout, c)
if f.Bypass != "" { if f.Bypass != "" {
err := geoip.LoadMMDB("Country.mmdb") err := geoip.LoadMMDB("Country.mmdb")

View file

@ -31,6 +31,8 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
const udpTimeout = 5 * 60
type flags struct { type flags struct {
Server string `json:"server"` Server string `json:"server"`
ServerPort uint16 `json:"server_port"` ServerPort uint16 `json:"server_port"`
@ -183,15 +185,15 @@ func newServer(f *flags) (*server, error) {
} }
if f.Method == shadowsocks.MethodNone { if f.Method == shadowsocks.MethodNone {
s.service = shadowsocks.NewNoneService(s) s.service = shadowsocks.NewNoneService(udpTimeout, s)
} else if common.Contains(shadowaead.List, f.Method) { } else if common.Contains(shadowaead.List, f.Method) {
service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, s) service, err := shadowaead.NewService(f.Method, key, []byte(f.Password), random.Blake3KeyedHash(), false, udpTimeout, s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.service = service s.service = service
} else if common.Contains(shadowaead_2022.List, f.Method) { } else if common.Contains(shadowaead_2022.List, f.Method) {
service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), s) service, err := shadowaead_2022.NewService(f.Method, key, random.Blake3KeyedHash(), udpTimeout, s)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -72,7 +72,7 @@ func run(cmd *cobra.Command, args []string) {
} }
client := &localClient{upstream: args[1]} client := &localClient{upstream: args[1]}
client.Listener = mixed.NewListener(bind, nil, transproxyMode, client) client.Listener = mixed.NewListener(bind, nil, transproxyMode, 300, client)
err = client.Start() err = client.Start()
if err != nil { if err != nil {

View file

@ -3,108 +3,87 @@ package cache
// Modified by https://github.com/die-net/lrucache // Modified by https://github.com/die-net/lrucache
import ( import (
"container/list"
"sync" "sync"
"time" "time"
"github.com/sagernet/sing/common/list"
) )
// Option is part of Functional Options Pattern type LruCache[K comparable, V any] struct {
type Option func(*LruCache)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback = func(key interface{}, value interface{})
// WithEvict set the evict callback
func WithEvict(cb EvictCallback) Option {
return func(l *LruCache) {
l.onEvict = cb
}
}
// WithUpdateAgeOnGet update expires when Get element
func WithUpdateAgeOnGet() Option {
return func(l *LruCache) {
l.updateAgeOnGet = true
}
}
// WithAge defined element max age (second)
func WithAge(maxAge int64) Option {
return func(l *LruCache) {
l.maxAge = maxAge
}
}
// WithSize defined max length of LruCache
func WithSize(maxSize int) Option {
return func(l *LruCache) {
l.maxSize = maxSize
}
}
// WithStale decide whether Stale return is enabled.
// If this feature is enabled, element will not get Evicted according to `WithAge`.
func WithStale(stale bool) Option {
return func(l *LruCache) {
l.staleReturn = stale
}
}
// LruCache is a thread-safe, in-memory lru-cache that evicts the
// least recently used entries from memory when (if set) the entries are
// older than maxAge (in seconds). Use the New constructor to create one.
type LruCache struct {
maxAge int64 maxAge int64
maxSize int
mu sync.Mutex mu sync.Mutex
cache map[interface{}]*list.Element cache map[K]*list.Element[*entry[K, V]]
lru *list.List // Front is least-recent lru list.List[*entry[K, V]] // Front is least-recent
updateAgeOnGet bool updateAgeOnGet bool
staleReturn bool
onEvict EvictCallback
} }
// NewLRUCache creates an LruCache func NewLRU[K comparable, V any](maxAge int64, updateAgeOnGet bool) LruCache[K, V] {
func NewLRUCache(options ...Option) *LruCache { lc := LruCache[K, V]{
lc := &LruCache{ maxAge: maxAge,
lru: list.New(), updateAgeOnGet: updateAgeOnGet,
cache: make(map[interface{}]*list.Element), cache: make(map[K]*list.Element[*entry[K, V]]),
}
for _, option := range options {
option(lc)
} }
return lc return lc
} }
// Get returns the interface{} representation of a cached response and a bool func (c *LruCache[K, V]) Load(key K) (V, bool) {
// set to true if the key was found.
func (c *LruCache) Get(key interface{}) (interface{}, bool) {
entry := c.get(key) entry := c.get(key)
if entry == nil { if entry == nil {
return nil, false var defaultValue V
return defaultValue, false
} }
value := entry.value value := entry.value
return value, true return value, true
} }
// GetWithExpire returns the interface{} representation of a cached response, func (c *LruCache[K, V]) LoadOrStore(key K, constructor func() V) (V, bool) {
// a time.Time Give expected expires, c.mu.Lock()
// and a bool set to true if the key was found. defer c.mu.Unlock()
// This method will NOT check the maxAge of element and will NOT update the expires.
func (c *LruCache) GetWithExpire(key interface{}) (interface{}, time.Time, bool) { le, ok := c.cache[key]
if ok {
if c.maxAge > 0 && le.Value.expires <= time.Now().Unix() {
c.deleteElement(le)
goto create
}
c.lru.MoveToBack(le)
entry := le.Value
if c.maxAge > 0 && c.updateAgeOnGet {
entry.expires = time.Now().Unix() + c.maxAge
}
return entry.value, true
}
create:
value := constructor()
if le, ok := c.cache[key]; ok {
c.lru.MoveToBack(le)
e := le.Value
e.value = value
e.expires = time.Now().Unix() + c.maxAge
} else {
e := &entry[K, V]{key: key, value: value, expires: time.Now().Unix() + c.maxAge}
c.cache[key] = c.lru.PushBack(e)
}
c.maybeDeleteOldest()
return value, false
}
func (c *LruCache[K, V]) LoadWithExpire(key K) (V, time.Time, bool) {
entry := c.get(key) entry := c.get(key)
if entry == nil { if entry == nil {
return nil, time.Time{}, false var defaultValue V
return defaultValue, time.Time{}, false
} }
return entry.value, time.Unix(entry.expires, 0), true return entry.value, time.Unix(entry.expires, 0), true
} }
// Exist returns if key exist in cache but not put item to the head of linked list func (c *LruCache[K, V]) Exist(key K) bool {
func (c *LruCache) Exist(key interface{}) bool {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -112,58 +91,48 @@ func (c *LruCache) Exist(key interface{}) bool {
return ok return ok
} }
// Set stores the interface{} representation of a response for a given key. func (c *LruCache[K, V]) Store(key K, value V) {
func (c *LruCache) Set(key interface{}, value interface{}) {
expires := int64(0) expires := int64(0)
if c.maxAge > 0 { if c.maxAge > 0 {
expires = time.Now().Unix() + c.maxAge expires = time.Now().Unix() + c.maxAge
} }
c.SetWithExpire(key, value, time.Unix(expires, 0)) c.StoreWithExpire(key, value, time.Unix(expires, 0))
} }
// SetWithExpire stores the interface{} representation of a response for a given key and given expires. func (c *LruCache[K, V]) StoreWithExpire(key K, value V, expires time.Time) {
// The expires time will round to second.
func (c *LruCache) SetWithExpire(key interface{}, value interface{}, expires time.Time) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if le, ok := c.cache[key]; ok { if le, ok := c.cache[key]; ok {
c.lru.MoveToBack(le) c.lru.MoveToBack(le)
e := le.Value.(*entry) e := le.Value
e.value = value e.value = value
e.expires = expires.Unix() e.expires = expires.Unix()
} else { } else {
e := &entry{key: key, value: value, expires: expires.Unix()} e := &entry[K, V]{key: key, value: value, expires: expires.Unix()}
c.cache[key] = c.lru.PushBack(e) c.cache[key] = c.lru.PushBack(e)
if c.maxSize > 0 {
if len := c.lru.Len(); len > c.maxSize {
c.deleteElement(c.lru.Front())
}
}
} }
c.maybeDeleteOldest() c.maybeDeleteOldest()
} }
// CloneTo clone and overwrite elements to another LruCache func (c *LruCache[K, V]) CloneTo(n *LruCache[K, V]) {
func (c *LruCache) CloneTo(n *LruCache) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
n.mu.Lock() n.mu.Lock()
defer n.mu.Unlock() defer n.mu.Unlock()
n.lru = list.New() n.lru = list.List[*entry[K, V]]{}
n.cache = make(map[interface{}]*list.Element) n.cache = make(map[K]*list.Element[*entry[K, V]])
for e := c.lru.Front(); e != nil; e = e.Next() { for e := c.lru.Front(); e != nil; e = e.Next() {
elm := e.Value.(*entry) elm := e.Value
n.cache[elm.key] = n.lru.PushBack(elm) n.cache[elm.key] = n.lru.PushBack(elm)
} }
} }
func (c *LruCache) get(key interface{}) *entry { func (c *LruCache[K, V]) get(key K) *entry[K, V] {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -172,7 +141,7 @@ func (c *LruCache) get(key interface{}) *entry {
return nil return nil
} }
if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { if c.maxAge > 0 && le.Value.expires <= time.Now().Unix() {
c.deleteElement(le) c.deleteElement(le)
c.maybeDeleteOldest() c.maybeDeleteOldest()
@ -180,7 +149,7 @@ func (c *LruCache) get(key interface{}) *entry {
} }
c.lru.MoveToBack(le) c.lru.MoveToBack(le)
entry := le.Value.(*entry) entry := le.Value
if c.maxAge > 0 && c.updateAgeOnGet { if c.maxAge > 0 && c.updateAgeOnGet {
entry.expires = time.Now().Unix() + c.maxAge entry.expires = time.Now().Unix() + c.maxAge
} }
@ -188,7 +157,7 @@ func (c *LruCache) get(key interface{}) *entry {
} }
// Delete removes the value associated with a key. // Delete removes the value associated with a key.
func (c *LruCache) Delete(key interface{}) { func (c *LruCache[K, V]) Delete(key K) {
c.mu.Lock() c.mu.Lock()
if le, ok := c.cache[key]; ok { if le, ok := c.cache[key]; ok {
@ -198,26 +167,21 @@ func (c *LruCache) Delete(key interface{}) {
c.mu.Unlock() c.mu.Unlock()
} }
func (c *LruCache) maybeDeleteOldest() { func (c *LruCache[K, V]) maybeDeleteOldest() {
if !c.staleReturn && c.maxAge > 0 { now := time.Now().Unix()
now := time.Now().Unix() for le := c.lru.Front(); le != nil && le.Value.expires <= now; le = c.lru.Front() {
for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() { c.deleteElement(le)
c.deleteElement(le)
}
} }
} }
func (c *LruCache) deleteElement(le *list.Element) { func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) {
c.lru.Remove(le) c.lru.Remove(le)
e := le.Value.(*entry) e := le.Value
delete(c.cache, e.key) delete(c.cache, e.key)
if c.onEvict != nil {
c.onEvict(e.key, e.value)
}
} }
type entry struct { type entry[K comparable, V any] struct {
key interface{} key K
value interface{} value V
expires int64 expires int64
} }

View file

@ -1,183 +0,0 @@
package cache
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var entries = []struct {
key string
value string
}{
{"1", "one"},
{"2", "two"},
{"3", "three"},
{"4", "four"},
{"5", "five"},
}
func TestLRUCache(t *testing.T) {
c := NewLRUCache()
for _, e := range entries {
c.Set(e.key, e.value)
}
c.Delete("missing")
_, ok := c.Get("missing")
assert.False(t, ok)
for _, e := range entries {
value, ok := c.Get(e.key)
if assert.True(t, ok) {
assert.Equal(t, e.value, value.(string))
}
}
for _, e := range entries {
c.Delete(e.key)
_, ok := c.Get(e.key)
assert.False(t, ok)
}
}
func TestLRUMaxAge(t *testing.T) {
c := NewLRUCache(WithAge(86400))
now := time.Now().Unix()
expected := now + 86400
// Add one expired entry
c.Set("foo", "bar")
c.lru.Back().Value.(*entry).expires = now
// Reset
c.Set("foo", "bar")
e := c.lru.Back().Value.(*entry)
assert.True(t, e.expires >= now)
c.lru.Back().Value.(*entry).expires = now
// Set a few and verify expiration times
for _, s := range entries {
c.Set(s.key, s.value)
e := c.lru.Back().Value.(*entry)
assert.True(t, e.expires >= expected && e.expires <= expected+10)
}
// Make sure we can get them all
for _, s := range entries {
_, ok := c.Get(s.key)
assert.True(t, ok)
}
// Expire all entries
for _, s := range entries {
le, ok := c.cache[s.key]
if assert.True(t, ok) {
le.Value.(*entry).expires = now
}
}
// Get one expired entry, which should clear all expired entries
_, ok := c.Get("3")
assert.False(t, ok)
assert.Equal(t, c.lru.Len(), 0)
}
func TestLRUpdateOnGet(t *testing.T) {
c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet())
now := time.Now().Unix()
expires := now + 86400/2
// Add one expired entry
c.Set("foo", "bar")
c.lru.Back().Value.(*entry).expires = expires
_, ok := c.Get("foo")
assert.True(t, ok)
assert.True(t, c.lru.Back().Value.(*entry).expires > expires)
}
func TestMaxSize(t *testing.T) {
c := NewLRUCache(WithSize(2))
// Add one expired entry
c.Set("foo", "bar")
_, ok := c.Get("foo")
assert.True(t, ok)
c.Set("bar", "foo")
c.Set("baz", "foo")
_, ok = c.Get("foo")
assert.False(t, ok)
}
func TestExist(t *testing.T) {
c := NewLRUCache(WithSize(1))
c.Set(1, 2)
assert.True(t, c.Exist(1))
c.Set(2, 3)
assert.False(t, c.Exist(1))
}
func TestEvict(t *testing.T) {
temp := 0
evict := func(key interface{}, value interface{}) {
temp = key.(int) + value.(int)
}
c := NewLRUCache(WithEvict(evict), WithSize(1))
c.Set(1, 2)
c.Set(2, 3)
assert.Equal(t, temp, 3)
}
func TestSetWithExpire(t *testing.T) {
c := NewLRUCache(WithAge(1))
now := time.Now().Unix()
tenSecBefore := time.Unix(now-10, 0)
c.SetWithExpire(1, 2, tenSecBefore)
// res is expected not to exist, and expires should be empty time.Time
res, expires, exist := c.GetWithExpire(1)
assert.Equal(t, nil, res)
assert.Equal(t, time.Time{}, expires)
assert.Equal(t, false, exist)
}
func TestStale(t *testing.T) {
c := NewLRUCache(WithAge(1), WithStale(true))
now := time.Now().Unix()
tenSecBefore := time.Unix(now-10, 0)
c.SetWithExpire(1, 2, tenSecBefore)
res, expires, exist := c.GetWithExpire(1)
assert.Equal(t, 2, res)
assert.Equal(t, tenSecBefore, expires)
assert.Equal(t, true, exist)
}
func TestCloneTo(t *testing.T) {
o := NewLRUCache(WithSize(10))
o.Set("1", 1)
o.Set("2", 2)
n := NewLRUCache(WithSize(2))
n.Set("3", 3)
n.Set("4", 4)
o.CloneTo(n)
assert.False(t, n.Exist("3"))
assert.True(t, n.Exist("1"))
n.Set("5", 5)
assert.False(t, n.Exist("1"))
}

View file

@ -38,6 +38,10 @@ func (ap AddrPort) UDPAddr() *net.UDPAddr {
} }
} }
func (ap AddrPort) AddrPort() netip.AddrPort {
return netip.AddrPortFrom(ap.Addr.Addr(), ap.Port)
}
func (ap AddrPort) String() string { func (ap AddrPort) String() string {
return net.JoinHostPort(ap.Addr.String(), strconv.Itoa(int(ap.Port))) return net.JoinHostPort(ap.Addr.String(), strconv.Itoa(int(ap.Port)))
} }

View file

@ -9,8 +9,8 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/socks"
) )
@ -21,17 +21,18 @@ type Handler interface {
} }
type Service[K comparable] struct { type Service[K comparable] struct {
nat gsync.Map[K, *conn] nat cache.LruCache[K, *conn]
handler Handler handler Handler
} }
func New[T comparable](handler Handler) *Service[T] { func New[K comparable](maxAge int64, handler Handler) Service[K] {
return &Service[T]{ return Service[K]{
nat: cache.NewLRU[K, *conn](maxAge, true),
handler: handler, handler: handler,
} }
} }
func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) error { func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *buf.Buffer, metadata M.Metadata) {
c, loaded := s.nat.LoadOrStore(key, func() *conn { c, loaded := s.nat.LoadOrStore(key, func() *conn {
c := &conn{ c := &conn{
data: make(chan packet), data: make(chan packet),
@ -57,7 +58,6 @@ func (s *Service[T]) NewPacket(key T, writer func() socks.PacketWriter, buffer *
} }
c.data <- p c.data <- p
<-ctx.Done() <-ctx.Done()
return nil
} }
type packet struct { type packet struct {

View file

@ -3,6 +3,7 @@ package shadowsocks
import ( import (
"context" "context"
"net" "net"
"net/netip"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@ -24,14 +25,14 @@ type Handler interface {
type NoneService struct { type NoneService struct {
handler Handler handler Handler
udp *udpnat.Service[string] udp udpnat.Service[netip.AddrPort]
} }
func NewNoneService(handler Handler) Service { func NewNoneService(udpTimeout int64, handler Handler) Service {
s := &NoneService{ s := &NoneService{
handler: handler, handler: handler,
} }
s.udp = udpnat.New[string](s) s.udp = udpnat.New[netip.AddrPort](udpTimeout, s)
return s return s
} }
@ -52,9 +53,10 @@ func (s *NoneService) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metad
} }
metadata.Protocol = "shadowsocks" metadata.Protocol = "shadowsocks"
metadata.Destination = destination metadata.Destination = destination
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter { s.udp.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
return &serverPacketWriter{conn, metadata.Source} return &serverPacketWriter{conn, metadata.Source}
}, buffer, metadata) }, buffer, metadata)
return nil
} }
type serverPacketWriter struct { type serverPacketWriter struct {

View file

@ -5,6 +5,7 @@ import (
"crypto/cipher" "crypto/cipher"
"io" "io"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -26,17 +27,17 @@ type Service struct {
key []byte key []byte
secureRNG io.Reader secureRNG io.Reader
replayFilter replay.Filter replayFilter replay.Filter
udp *udpnat.Service[string] udpNat udpnat.Service[netip.AddrPort]
handler shadowsocks.Handler handler shadowsocks.Handler
} }
func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, handler shadowsocks.Handler) (shadowsocks.Service, error) { func NewService(method string, key []byte, password []byte, secureRNG io.Reader, replayFilter bool, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{ s := &Service{
name: method, name: method,
secureRNG: secureRNG, secureRNG: secureRNG,
handler: handler, handler: handler,
} }
s.udp = udpnat.New[string](s) s.udpNat = udpnat.New[netip.AddrPort](udpTimeout, s)
if replayFilter { if replayFilter {
s.replayFilter = replay.NewBloomRing() s.replayFilter = replay.NewBloomRing()
} }
@ -190,9 +191,10 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
buffer.Advance(s.keySaltLength) buffer.Advance(s.keySaltLength)
buffer.Truncate(len(packet)) buffer.Truncate(len(packet))
metadata.Protocol = "shadowsocks" metadata.Protocol = "shadowsocks"
return s.udp.NewPacket(metadata.Source.String(), func() socks.PacketWriter { s.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source} return &serverPacketWriter{s, conn, metadata.Source}
}, buffer, metadata) }, buffer, metadata)
return nil
} }
type serverPacketWriter struct { type serverPacketWriter struct {

View file

@ -8,14 +8,15 @@ import (
"io" "io"
"math" "math"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/gsync"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/replay"
"github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/rw"
@ -37,17 +38,18 @@ type Service struct {
psk []byte psk []byte
replayFilter replay.Filter replayFilter replay.Filter
handler shadowsocks.Handler handler shadowsocks.Handler
udpNat *udpnat.Service[uint64] udpNat udpnat.Service[uint64]
sessions gsync.Map[uint64, *serverUDPSession] sessions cache.LruCache[uint64, *serverUDPSession]
} }
func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowsocks.Handler) (shadowsocks.Service, error) { func NewService(method string, psk []byte, secureRNG io.Reader, udpTimeout int64, handler shadowsocks.Handler) (shadowsocks.Service, error) {
s := &Service{ s := &Service{
name: method, name: method,
psk: psk, psk: psk,
secureRNG: secureRNG, secureRNG: secureRNG,
replayFilter: replay.NewCuckoo(60), replayFilter: replay.NewCuckoo(60),
handler: handler, handler: handler,
sessions: cache.NewLRU[uint64, *serverUDPSession](udpTimeout, true),
} }
if len(psk) != KeySaltSize { if len(psk) != KeySaltSize {
@ -71,7 +73,7 @@ func NewService(method string, psk []byte, secureRNG io.Reader, handler shadowso
s.udpCipher = newXChacha20Poly1305(s.psk) s.udpCipher = newXChacha20Poly1305(s.psk)
} }
s.udpNat = udpnat.New[uint64](s) s.udpNat = udpnat.New[uint64](udpTimeout, s)
return s, nil return s, nil
} }
@ -239,53 +241,69 @@ func (s *Service) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata
session.remoteCipher = s.constructor(common.Dup(key)) session.remoteCipher = s.constructor(common.Dup(key))
} }
} }
session.remoteAddr = metadata.Source.AddrPort()
goto process
returnErr:
if !loaded {
s.sessions.Delete(sessionId)
}
return err
process:
if !session.filter.ValidateCounter(packetId, math.MaxUint64) { if !session.filter.ValidateCounter(packetId, math.MaxUint64) {
return ErrPacketIdNotUnique err = ErrPacketIdNotUnique
goto returnErr
} }
if packetHeader != nil { if packetHeader != nil {
_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil { if err != nil {
return E.Cause(err, "decrypt packet") err = E.Cause(err, "decrypt packet")
goto returnErr
} }
} }
var headerType byte var headerType byte
headerType, err = buffer.ReadByte() headerType, err = buffer.ReadByte()
if err != nil { if err != nil {
return err err = E.Cause(err, "decrypt packet")
goto returnErr
} }
if headerType != HeaderTypeClient { if headerType != HeaderTypeClient {
return ErrBadHeaderType err = ErrBadHeaderType
goto returnErr
} }
var epoch uint64 var epoch uint64
err = binary.Read(buffer, binary.BigEndian, &epoch) err = binary.Read(buffer, binary.BigEndian, &epoch)
if err != nil { if err != nil {
return err goto returnErr
} }
if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 { if math.Abs(float64(uint64(time.Now().Unix())-epoch)) > 30 {
return ErrBadTimestamp err = ErrBadTimestamp
goto returnErr
} }
var paddingLength uint16 var paddingLength uint16
err = binary.Read(buffer, binary.BigEndian, &paddingLength) err = binary.Read(buffer, binary.BigEndian, &paddingLength)
if err != nil { if err != nil {
return E.Cause(err, "read padding length") err = E.Cause(err, "read padding length")
goto returnErr
} }
buffer.Advance(int(paddingLength)) buffer.Advance(int(paddingLength))
destination, err := socks.AddressSerializer.ReadAddrPort(buffer) destination, err := socks.AddressSerializer.ReadAddrPort(buffer)
if err != nil { if err != nil {
return err goto returnErr
} }
metadata.Destination = destination metadata.Destination = destination
return s.udpNat.NewPacket(sessionId, func() socks.PacketWriter { s.udpNat.NewPacket(sessionId, func() socks.PacketWriter {
return &serverPacketWriter{s, conn, session, metadata.Source} return &serverPacketWriter{s, conn, session, metadata.Source}
}, buffer, metadata) }, buffer, metadata)
return nil
} }
type serverPacketWriter struct { type serverPacketWriter struct {
@ -343,6 +361,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination *M.Addr
type serverUDPSession struct { type serverUDPSession struct {
sessionId uint64 sessionId uint64
remoteSessionId uint64 remoteSessionId uint64
remoteAddr netip.AddrPort
packetId uint64 packetId uint64
cipher cipher.AEAD cipher cipher.AEAD
remoteCipher cipher.AEAD remoteCipher cipher.AEAD

View file

@ -32,10 +32,10 @@ type Listener struct {
bindAddr netip.Addr bindAddr netip.Addr
handler Handler handler Handler
authenticator auth.Authenticator authenticator auth.Authenticator
udpNat *udpnat.Service[string] udpNat udpnat.Service[netip.AddrPort]
} }
func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, handler Handler) *Listener { func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transproxy redir.TransproxyMode, udpTimeout int64, handler Handler) *Listener {
listener := &Listener{ listener := &Listener{
bindAddr: bind.Addr(), bindAddr: bind.Addr(),
handler: handler, handler: handler,
@ -45,7 +45,7 @@ func NewListener(bind netip.AddrPort, authenticator auth.Authenticator, transpro
listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy)) listener.TCPListener = tcp.NewTCPListener(bind, listener, tcp.WithTransproxyMode(transproxy))
if transproxy == redir.ModeTProxy { if transproxy == redir.ModeTProxy {
listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy)) listener.UDPListener = udp.NewUDPListener(bind, listener, udp.WithTransproxyMode(transproxy))
listener.udpNat = udpnat.New[string](handler) listener.udpNat = udpnat.New[netip.AddrPort](udpTimeout, handler)
} }
return listener return listener
} }
@ -97,9 +97,10 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
} }
func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { func (l *Listener) NewPacket(conn socks.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
return l.udpNat.NewPacket(metadata.Source.String(), func() socks.PacketWriter { l.udpNat.NewPacket(metadata.Source.AddrPort(), func() socks.PacketWriter {
return &tproxyPacketWriter{metadata.Source.UDPAddr()} return &tproxyPacketWriter{metadata.Source.UDPAddr()}
}, buffer, metadata) }, buffer, metadata)
return nil
} }
type tproxyPacketWriter struct { type tproxyPacketWriter struct {