Improve udpnat write back

This commit is contained in:
世界 2022-06-01 15:52:42 +08:00
parent ee9be8af94
commit e4d76a44eb
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 40 additions and 64 deletions

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/sagernet/sing-shadowsocks
go 1.18 go 1.18
require ( require (
github.com/sagernet/sing v0.0.0-20220601033944-4e04bbd3d84d github.com/sagernet/sing v0.0.0-20220601075130-066830bbec3b
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
lukechampine.com/blake3 v1.1.7 lukechampine.com/blake3 v1.1.7
) )

4
go.sum
View file

@ -1,8 +1,8 @@
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE= github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE=
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/sagernet/sing v0.0.0-20220601033944-4e04bbd3d84d h1:BNhKTknI2tBPUOPDV3lcgwOX6iZimL7K3TPdTdp5hiA= github.com/sagernet/sing v0.0.0-20220601075130-066830bbec3b h1:N/26MV/2ijp9wb5FzrybKevxIyGRoFlqbJhFlyTr9G0=
github.com/sagernet/sing v0.0.0-20220601033944-4e04bbd3d84d/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg= github.com/sagernet/sing v0.0.0-20220601075130-066830bbec3b/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k=

22
none.go
View file

@ -162,14 +162,14 @@ func (c *nonePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
type NoneService struct { type NoneService struct {
handler Handler handler Handler
udp *udpnat.Service[netip.AddrPort] udpNat *udpnat.Service[netip.AddrPort]
} }
func NewNoneService(udpTimeout int64, handler Handler) Service { func NewNoneService(udpTimeout int64, handler Handler) Service {
s := &NoneService{ s := &NoneService{
handler: handler, handler: handler,
} }
s.udp = udpnat.New[netip.AddrPort](udpTimeout, s) s.udpNat = udpnat.New[netip.AddrPort](udpTimeout, handler)
return s return s
} }
@ -193,29 +193,29 @@ func (s *NoneService) NewPacket(ctx context.Context, conn N.PacketConn, buffer *
} }
metadata.Protocol = "shadowsocks" metadata.Protocol = "shadowsocks"
metadata.Destination = destination metadata.Destination = destination
s.udp.NewPacket(ctx, metadata.Source.AddrPort(), func() N.PacketWriter { s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), buffer, metadata, func(natConn N.PacketConn) N.PacketWriter {
return &nonePacketWriter{conn, metadata.Source} return &nonePacketWriter{conn, natConn}
}, buffer, metadata) })
return nil return nil
} }
type nonePacketWriter struct { type nonePacketWriter struct {
N.PacketConn source N.PacketConn
sourceAddr M.Socksaddr nat N.PacketConn
} }
func (s *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (w *nonePacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))) header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination)))
err := M.SocksaddrSerializer.WriteAddrPort(header, destination) err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
if err != nil { if err != nil {
buffer.Release() buffer.Release()
return err return err
} }
return s.PacketConn.WritePacket(buffer, s.sourceAddr) return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr()))
} }
func (s *NoneService) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { func (w *nonePacketWriter) Upstream() any {
return s.handler.NewPacketConnection(ctx, conn, metadata) return w.source
} }
func (s *NoneService) HandleError(err error) { func (s *NoneService) HandleError(err error) {

View file

@ -225,16 +225,16 @@ func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.
metadata.Protocol = "shadowsocks" metadata.Protocol = "shadowsocks"
metadata.Destination = destination metadata.Destination = destination
s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), func() N.PacketWriter { s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), buffer, metadata, func(natConn N.PacketConn) N.PacketWriter {
return &serverPacketWriter{s, conn, metadata.Source} return &serverPacketWriter{s, conn, natConn}
}, buffer, metadata) })
return nil return nil
} }
type serverPacketWriter struct { type serverPacketWriter struct {
*Service *Service
N.PacketConn source N.PacketConn
source M.Socksaddr nat N.PacketConn
} }
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
@ -250,5 +250,9 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
common.KeepAlive(key) common.KeepAlive(key)
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil) c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
buffer.Extend(Overhead) buffer.Extend(Overhead)
return w.PacketConn.WritePacket(buffer, w.source) return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr()))
}
func (w *serverPacketWriter) Upstream() any {
return w.source
} }

View file

@ -14,7 +14,6 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -38,7 +37,6 @@ type Relay[U comparable] struct {
uDestination map[U]M.Socksaddr uDestination map[U]M.Socksaddr
uCipher map[U]cipher.Block uCipher map[U]cipher.Block
udpNat *udpnat.Service[uint64] udpNat *udpnat.Service[uint64]
udpSessions *cache.LruCache[uint64, *relayUDPSession]
} }
func (s *Relay[U]) AddUser(user U, key []byte, destination M.Socksaddr) error { func (s *Relay[U]) AddUser(user U, key []byte, destination M.Socksaddr) error {
@ -84,10 +82,6 @@ func NewRelay[U comparable](method string, psk []byte, secureRNG io.Reader, udpT
uCipher: make(map[U]cipher.Block), uCipher: make(map[U]cipher.Block),
udpNat: udpnat.New[uint64](udpTimeout, handler), udpNat: udpnat.New[uint64](udpTimeout, handler),
udpSessions: cache.New(
cache.WithAge[uint64, *relayUDPSession](udpTimeout),
cache.WithUpdateAgeOnGet[uint64, *relayUDPSession](),
),
} }
switch method { switch method {
@ -207,35 +201,17 @@ func (s *Relay[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf
return E.New("invalid request") return E.New("invalid request")
} }
session, _ := s.udpSessions.LoadOrStore(sessionId, func() *relayUDPSession {
return new(relayUDPSession)
})
session.sourceAddr = metadata.Source
s.uCipher[user].Encrypt(packetHeader, packetHeader) s.uCipher[user].Encrypt(packetHeader, packetHeader)
copy(buffer.Range(aes.BlockSize, 2*aes.BlockSize), packetHeader) copy(buffer.Range(aes.BlockSize, 2*aes.BlockSize), packetHeader)
buffer.Advance(aes.BlockSize) buffer.Advance(aes.BlockSize)
metadata.Protocol = "shadowsocks-relay" metadata.Protocol = "shadowsocks-relay"
metadata.Destination = s.uDestination[user] metadata.Destination = s.uDestination[user]
s.udpNat.NewContextPacket(ctx, sessionId, func() (context.Context, N.PacketWriter) { s.udpNat.NewContextPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return &shadowsocks.UserContext[U]{ return &shadowsocks.UserContext[U]{
ctx, ctx,
user, user,
}, &relayPacketWriter[U]{conn, session} }, &udpnat.DirectBackWriter{Source: conn, Nat: natConn}
}, buffer, metadata) })
return nil return nil
} }
type relayUDPSession struct {
sourceAddr M.Socksaddr
}
type relayPacketWriter[U comparable] struct {
N.PacketConn
session *relayUDPSession
}
func (w *relayPacketWriter[U]) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
return w.PacketConn.WritePacket(buffer, w.session.sourceAddr)
}

View file

@ -410,11 +410,9 @@ process:
goto returnErr goto returnErr
} }
metadata.Destination = destination metadata.Destination = destination
s.udpNat.NewPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) N.PacketWriter {
session.remoteAddr = metadata.Source return &serverPacketWriter{s, conn, natConn, session}
s.udpNat.NewPacket(ctx, sessionId, func() N.PacketWriter { })
return &serverPacketWriter{s, conn, session}
}, buffer, metadata)
return nil return nil
} }
@ -424,14 +422,11 @@ func (s *Service) HandleError(err error) {
type serverPacketWriter struct { type serverPacketWriter struct {
*Service *Service
N.PacketConn source N.PacketConn
nat N.PacketConn
session *serverUDPSession session *serverUDPSession
} }
func (w *serverPacketWriter) Upstream() any {
return w.PacketConn
}
func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
var hdrLen int var hdrLen int
if w.udpCipher != nil { if w.udpCipher != nil {
@ -477,13 +472,16 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
buffer.Extend(shadowaead.Overhead) buffer.Extend(shadowaead.Overhead)
w.udpBlockCipher.Encrypt(packetHeader, packetHeader) w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
} }
return w.PacketConn.WritePacket(buffer, w.session.remoteAddr) return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr()))
}
func (w *serverPacketWriter) Upstream() any {
return w.source
} }
type serverUDPSession struct { type serverUDPSession struct {
sessionId uint64 sessionId uint64
remoteSessionId uint64 remoteSessionId uint64
remoteAddr M.Socksaddr
packetId uint64 packetId uint64
cipher cipher.AEAD cipher cipher.AEAD
remoteCipher cipher.AEAD remoteCipher cipher.AEAD

View file

@ -349,14 +349,12 @@ process:
} }
metadata.Destination = destination metadata.Destination = destination
session.remoteAddr = metadata.Source s.udpNat.NewContextPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
s.udpNat.NewContextPacket(ctx, sessionId, func() (context.Context, N.PacketWriter) {
return &shadowsocks.UserContext[U]{ return &shadowsocks.UserContext[U]{
ctx, ctx,
user, user,
}, &serverPacketWriter{s.Service, conn, session} }, &serverPacketWriter{s.Service, conn, natConn, session}
}, buffer, metadata) })
return nil return nil
} }