From f1b87be6e43325dd15650d52f733856f85bc5d95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 7 May 2022 17:08:57 +0800 Subject: [PATCH] Fix buffer usage --- cli/libpack/main_linux.go | 14 ++++++---- cli/socks-chk/main.go | 2 ++ cli/ss-local/main.go | 3 +- cli/trojan-local/main.go | 3 +- common/buf/buffer.go | 4 +-- common/network/conn.go | 3 ++ common/rw/copy.go | 4 +++ common/rw/read.go | 2 ++ common/tun/system/tun.go | 2 ++ common/uot/server.go | 3 ++ common/uot/uot_test.go | 2 ++ protocol/shadowsocks/none.go | 28 +++++++++++++------ protocol/shadowsocks/shadowaead/protocol.go | 14 +++++++++- protocol/shadowsocks/shadowaead/service.go | 7 +++++ .../shadowsocks/shadowaead_2022/protocol.go | 27 ++++++++++++++---- .../shadowsocks/shadowaead_2022/service.go | 7 +++++ .../shadowaead_2022/service_multi.go | 5 ++++ protocol/socks5/conn.go | 5 ++++ protocol/trojan/protocol.go | 5 ++++ transport/mixed/listener.go | 2 ++ transport/udp/udp.go | 3 ++ 21 files changed, 121 insertions(+), 24 deletions(-) diff --git a/cli/libpack/main_linux.go b/cli/libpack/main_linux.go index 1dae856..db9a38e 100644 --- a/cli/libpack/main_linux.go +++ b/cli/libpack/main_linux.go @@ -4,9 +4,6 @@ import ( "archive/tar" _ "embed" "encoding/hex" - "github.com/sagernet/sing" - "github.com/sagernet/sing/common/log" - "github.com/spf13/cobra" "io" "os" "os/exec" @@ -15,17 +12,22 @@ import ( "strings" "github.com/klauspost/compress/zstd" + "github.com/sagernet/sing" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/log" "github.com/sirupsen/logrus" + "github.com/spf13/cobra" "github.com/u-root/u-root/pkg/ldd" ) var logger = log.NewLogger("libpack") -var packageName string -var executablePath string -var outputPath string +var ( + packageName string + executablePath string + outputPath string +) func main() { command := &cobra.Command{ diff --git a/cli/socks-chk/main.go b/cli/socks-chk/main.go index 7ee6c53..c43ee88 100644 --- a/cli/socks-chk/main.go +++ b/cli/socks-chk/main.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/netip" + "runtime" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -129,6 +130,7 @@ func testSocksUDP(server M.Socksaddr) error { Port: 53, })) _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) common.Must2(buffer.ReadPacketFrom(assConn)) common.Must(message.Unpack(buffer.Bytes())) diff --git a/cli/ss-local/main.go b/cli/ss-local/main.go index ed3d389..4a709af 100644 --- a/cli/ss-local/main.go +++ b/cli/ss-local/main.go @@ -10,6 +10,7 @@ import ( "net/netip" "os" "os/signal" + "runtime" "runtime/debug" "strings" "syscall" @@ -350,7 +351,7 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me if err != nil { return E.Cause(err, "client handshake") } - + runtime.KeepAlive(_payload) return rw.CopyConn(ctx, serverConn, conn) } diff --git a/cli/trojan-local/main.go b/cli/trojan-local/main.go index 928f550..2ccd3e5 100644 --- a/cli/trojan-local/main.go +++ b/cli/trojan-local/main.go @@ -9,6 +9,7 @@ import ( "net/netip" "os" "os/signal" + "runtime" "syscall" "time" @@ -315,7 +316,7 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me if err != nil { return E.Cause(err, "client handshake") } - + runtime.KeepAlive(_request) return rw.CopyConn(ctx, clientConn, conn) } diff --git a/common/buf/buffer.go b/common/buf/buffer.go index b178b97..5128d04 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -119,7 +119,7 @@ func (b *Buffer) ExtendHeader(size int) []byte { } } -func (b *Buffer) WriteBufferAtFirst(buffer *Buffer) *Buffer { +func (b *Buffer) _WriteBufferAtFirst(buffer *Buffer) *Buffer { size := buffer.Len() if b.start >= size { n := copy(b.data[b.start-size:b.start], buffer.Bytes()) @@ -140,7 +140,7 @@ func (b *Buffer) WriteBufferAtFirst(buffer *Buffer) *Buffer { } } -func (b *Buffer) WriteAtFirst(data []byte) (n int, err error) { +func (b *Buffer) _WriteAtFirst(data []byte) (n int, err error) { size := len(data) if b.start >= size { n = copy(b.data[b.start-size:b.start], data) diff --git a/common/network/conn.go b/common/network/conn.go index 01c6946..c27133a 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -4,6 +4,7 @@ import ( "context" "net" "os" + "runtime" "time" "github.com/sagernet/sing/common" @@ -63,6 +64,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error defer rw.CloseRead(conn) defer rw.CloseWrite(dest) _buffer := buf.StackNewMax() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader) for { @@ -81,6 +83,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error defer rw.CloseRead(dest) defer rw.CloseWrite(conn) _buffer := buf.StackNewMax() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader) for { diff --git a/common/rw/copy.go b/common/rw/copy.go index 1fb3df1..2c7daa2 100644 --- a/common/rw/copy.go +++ b/common/rw/copy.go @@ -5,6 +5,7 @@ import ( "io" "net" "os" + "runtime" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -69,6 +70,7 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) { return rt.ReadFrom(src) } _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) for { buffer.FullReset() @@ -89,6 +91,7 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) { func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error { return task.Run(ctx, func() error { _buffer := buf.With(make([]byte, buf.UDPBufferSize)) + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) for { n, addr, err := conn.ReadFrom(buffer.FreeBytes()) @@ -104,6 +107,7 @@ func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net. } }, func() error { _buffer := buf.With(make([]byte, buf.UDPBufferSize)) + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) for { n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes()) diff --git a/common/rw/read.go b/common/rw/read.go index 931ae72..b730242 100644 --- a/common/rw/read.go +++ b/common/rw/read.go @@ -2,6 +2,7 @@ package rw import ( "io" + "runtime" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -63,6 +64,7 @@ func ReadFrom0(readerFrom ReaderFromWriter, reader io.Reader) (n int64, err erro func CopyOnce(dest io.Writer, src io.Reader) (n int64, err error) { _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) n, err = buffer.ReadFrom(src) if err != nil { diff --git a/common/tun/system/tun.go b/common/tun/system/tun.go index 04d5c13..535b3cc 100644 --- a/common/tun/system/tun.go +++ b/common/tun/system/tun.go @@ -5,6 +5,7 @@ import ( "net" "net/netip" "os" + "runtime" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -88,6 +89,7 @@ func (t *Stack) Close() error { func (t *Stack) tunLoop() { _buffer := buf.Make(t.tunMtu) + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) for { n, err := t.tunFile.Read(buffer) diff --git a/common/uot/server.go b/common/uot/server.go index 9dbd6d7..8a1f47e 100644 --- a/common/uot/server.go +++ b/common/uot/server.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" "net" + "runtime" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -41,6 +42,7 @@ func (c *ServerConn) RemoteAddr() net.Addr { func (c *ServerConn) loopInput() { _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) for { destination, err := AddrParser.ReadAddrPort(c.inputReader) @@ -74,6 +76,7 @@ func (c *ServerConn) loopInput() { func (c *ServerConn) loopOutput() { _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) for { buffer.FullReset() diff --git a/common/uot/uot_test.go b/common/uot/uot_test.go index 7eb960f..30e276d 100644 --- a/common/uot/uot_test.go +++ b/common/uot/uot_test.go @@ -2,6 +2,7 @@ package uot import ( "net" + "runtime" "testing" "github.com/sagernet/sing/common" @@ -30,6 +31,7 @@ func TestServerConn(t *testing.T) { Port: 53, })) _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) common.Must2(buffer.ReadPacketFrom(clientConn)) common.Must(message.Unpack(buffer.Bytes())) diff --git a/protocol/shadowsocks/none.go b/protocol/shadowsocks/none.go index 4910b4a..a663a53 100644 --- a/protocol/shadowsocks/none.go +++ b/protocol/shadowsocks/none.go @@ -5,6 +5,7 @@ import ( "io" "net" "net/netip" + "runtime" "sync" "github.com/sagernet/sing/common" @@ -95,6 +96,7 @@ func (c *noneConn) Write(b []byte) (n int, err error) { bufN, _ := buffer.Write(b) _, err = c.Conn.Write(buffer.Bytes()) + runtime.KeepAlive(_buffer) if err != nil { return } @@ -141,17 +143,27 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { return socks5.AddressSerializer.ReadAddrPort(buffer) } -func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort M.Socksaddr) error { - defer buffer.Release() - _header := buf.StackNewMax() - header := common.Dup(_header) - err := socks5.AddressSerializer.WriteAddrPort(header, addrPort) +func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + headerLen := socks5.AddressSerializer.AddrPortLen(destination) + var header *buf.Buffer + var writeHeader bool + if buffer.Start() >= headerLen { + header = buf.With(buffer.ExtendHeader(headerLen)) + } else { + _buffer := buf.StackNewSize(buffer.Len() + headerLen) + defer runtime.KeepAlive(_buffer) + header = common.Dup(_buffer) + writeHeader = true + } + err := socks5.AddressSerializer.WriteAddrPort(header, destination) if err != nil { - header.Release() return err } - buffer = buffer.WriteBufferAtFirst(header) - return common.Error(buffer.WriteTo(c)) + if writeHeader { + return common.Error(header.WriteTo(c)) + } else { + return common.Error(buffer.WriteTo(c)) + } } type NoneService struct { diff --git a/protocol/shadowsocks/shadowaead/protocol.go b/protocol/shadowsocks/shadowaead/protocol.go index 79b480e..178654b 100644 --- a/protocol/shadowsocks/shadowaead/protocol.go +++ b/protocol/shadowsocks/shadowaead/protocol.go @@ -6,6 +6,7 @@ import ( "crypto/sha1" "io" "net" + "runtime" "sync" "github.com/sagernet/sing/common" @@ -80,8 +81,10 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader, replay } func Kdf(key, iv []byte, keyLength int) []byte { + info := []byte("ss-subkey") subKey := buf.Make(keyLength) - kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey")) + kdf := hkdf.New(sha1.New, key, iv, common.Dup(info)) + runtime.KeepAlive(info) common.Must1(io.ReadFull(kdf, common.Dup(subKey))) return subKey } @@ -113,6 +116,7 @@ func (m *Method) KeyLength() int { func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) { _salt := buf.Make(m.keySaltLength) + defer runtime.KeepAlive(_salt) salt := common.Dup(_salt) _, err := io.ReadFull(upstream, salt) if err != nil { @@ -124,11 +128,13 @@ func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) { } } key := Kdf(m.key, salt, m.keySaltLength) + defer runtime.KeepAlive(key) return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil } func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) { _salt := buf.Make(m.keySaltLength) + defer runtime.KeepAlive(_salt) salt := common.Dup(_salt) common.Must1(io.ReadFull(m.secureRNG, salt)) _, err := upstream.Write(salt) @@ -163,6 +169,7 @@ func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn { func (m *Method) EncodePacket(buffer *buf.Buffer) error { key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) c := m.constructor(common.Dup(key)) + runtime.KeepAlive(key) c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) buffer.Extend(c.Overhead()) return nil @@ -174,6 +181,7 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error { } key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) c := m.constructor(common.Dup(key)) + runtime.KeepAlive(key) packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) if err != nil { return err @@ -200,11 +208,13 @@ func (c *clientConn) writeRequest(payload []byte) error { common.Must1(io.ReadFull(c.method.secureRNG, salt)) key := Kdf(c.method.key, salt, c.method.keySaltLength) + runtime.KeepAlive(_salt) writer := NewWriter( c.Conn, c.method.constructor(common.Dup(key)), MaxPacketSize, ) + runtime.KeepAlive(key) header := writer.Buffer() header.Write(salt) bufferedWriter := writer.BufferedWriter(header.Len()) @@ -240,6 +250,7 @@ func (c *clientConn) readResponse() error { return nil } _salt := buf.Make(c.method.keySaltLength) + defer runtime.KeepAlive(_salt) salt := common.Dup(_salt) _, err := io.ReadFull(c.Conn, salt) if err != nil { @@ -251,6 +262,7 @@ func (c *clientConn) readResponse() error { } } key := Kdf(c.method.key, salt, c.method.keySaltLength) + defer runtime.KeepAlive(key) c.reader = NewReader( c.Conn, c.method.constructor(common.Dup(key)), diff --git a/protocol/shadowsocks/shadowaead/service.go b/protocol/shadowsocks/shadowaead/service.go index a1ccd2e..0a508e2 100644 --- a/protocol/shadowsocks/shadowaead/service.go +++ b/protocol/shadowsocks/shadowaead/service.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/netip" + "runtime" "sync" "github.com/sagernet/sing/common" @@ -89,6 +90,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { _salt := buf.Make(s.keySaltLength) + defer runtime.KeepAlive(_salt) salt := common.Dup(_salt) _, err := io.ReadFull(conn, salt) @@ -127,11 +129,14 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { common.Must1(io.ReadFull(c.secureRNG, salt)) key := Kdf(c.key, salt, c.keySaltLength) + runtime.KeepAlive(_salt) + writer := NewWriter( c.Conn, c.constructor(common.Dup(key)), MaxPacketSize, ) + runtime.KeepAlive(key) header := writer.Buffer() header.Write(salt) @@ -213,6 +218,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me } key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength) c := s.constructor(common.Dup(key)) + runtime.KeepAlive(key) packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil) if err != nil { return err @@ -241,6 +247,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks } key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength) c := w.constructor(common.Dup(key)) + runtime.KeepAlive(key) c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil) buffer.Extend(c.Overhead()) return w.PacketConn.WritePacket(buffer, w.source) diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index d29c97d..4d3ab87 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -9,6 +9,7 @@ import ( "math" "math/rand" "net" + "runtime" "sync" "sync/atomic" "time" @@ -231,6 +232,7 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) header := request.Extend(16) m.blockConstructor(identitySubkey).Encrypt(header, pskHash) + runtime.KeepAlive(_identitySubkey) if debug.Enabled { logger.Trace("encoded ", buf.EncodeHexString(header)) } @@ -252,12 +254,12 @@ func (c *clientConn) writeRequest(payload []byte) error { common.Must1(io.ReadFull(c.method.secureRNG, salt)) key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength) - writer := shadowaead.NewWriter( c.Conn, c.method.constructor(common.Dup(key)), MaxPacketSize, ) + runtime.KeepAlive(key) header := writer.Buffer() header.Write(salt) @@ -344,11 +346,13 @@ func (c *clientConn) readResponse() error { } key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength) + runtime.KeepAlive(_salt) reader := shadowaead.NewReader( c.Conn, c.method.constructor(common.Dup(key)), MaxPacketSize, ) + runtime.KeepAlive(key) headerType, err := rw.ReadByte(reader) if err != nil { @@ -385,6 +389,7 @@ func (c *clientConn) readResponse() error { } return ErrBadRequestSalt } + runtime.KeepAlive(_requestSalt) c.requestSalt = nil c.reader = reader @@ -472,10 +477,21 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad logger.Trace("begin client packet") } - defer buffer.Release() - _header := buf.StackNew() - header := common.Dup(_header) + var hdrLen int + if c.method.udpCipher != nil { + hdrLen = PacketNonceSize + } + hdrLen += 16 // packet header pskLen := len(c.method.pskList) + if c.method.udpCipher == nil && pskLen > 1 { + hdrLen += (pskLen - 1) * aes.BlockSize + } + hdrLen += 1 // header type + hdrLen += 8 // timestamp + hdrLen += 1 // padding length + hdrLen += socks5.AddressSerializer.AddrPortLen(destination) + header := buf.With(buffer.ExtendHeader(hdrLen)) + var dataIndex int if c.method.udpCipher != nil { common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize)) @@ -540,7 +556,6 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad if err != nil { return err } - buffer = buffer.WriteBufferAtFirst(header) if err != nil { return err } @@ -606,6 +621,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { } else { key := Blake3DeriveKey(c.method.psk[:], packetHeader[:8], c.method.keyLength) remoteCipher = c.method.constructor(common.Dup(key)) + runtime.KeepAlive(key) } _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) if err != nil { @@ -717,6 +733,7 @@ func (m *Method) newUDPSession() *udpSession { binary.BigEndian.PutUint64(sessionId, session.sessionId) key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength) session.cipher = m.constructor(common.Dup(key)) + runtime.KeepAlive(key) } return session } diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index 0626720..cf0a702 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -8,6 +8,7 @@ import ( "io" "math" "net" + "runtime" "sync" "sync/atomic" "time" @@ -110,6 +111,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M s.constructor(common.Dup(requestKey)), MaxPacketSize, ) + runtime.KeepAlive(requestKey) headerType, err := rw.ReadByte(reader) if err != nil { @@ -192,11 +194,13 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { salt := common.Dup(_salt[:]) common.Must1(io.ReadFull(c.secureRNG, salt)) key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength) + runtime.KeepAlive(_salt) writer := shadowaead.NewWriter( c.Conn, c.constructor(common.Dup(key)), MaxPacketSize, ) + runtime.KeepAlive(key) header := writer.Buffer() header.Write(salt) bufferedWriter := writer.BufferedWriter(header.Len()) @@ -306,6 +310,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me if packetHeader != nil { key := Blake3DeriveKey(s.psk[:], packetHeader[:8], s.keyLength) session.remoteCipher = s.constructor(common.Dup(key)) + runtime.KeepAlive(key) } } goto process @@ -382,6 +387,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks defer buffer.Release() _header := buf.StackNew() + defer runtime.KeepAlive(_header) header := common.Dup(_header) var dataIndex int @@ -446,6 +452,7 @@ func (m *Service) newUDPSession() *serverUDPSession { binary.BigEndian.PutUint64(sessionId, session.sessionId) key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength) session.cipher = m.constructor(common.Dup(key)) + runtime.KeepAlive(key) } return session } diff --git a/protocol/shadowsocks/shadowaead_2022/service_multi.go b/protocol/shadowsocks/shadowaead_2022/service_multi.go index 7898d09..61626b2 100644 --- a/protocol/shadowsocks/shadowaead_2022/service_multi.go +++ b/protocol/shadowsocks/shadowaead_2022/service_multi.go @@ -7,6 +7,7 @@ import ( "io" "math" "net" + "runtime" "time" "github.com/sagernet/sing/common" @@ -106,6 +107,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta identitySubkey := common.Dup(_identitySubkey) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) + runtime.KeepAlive(_identitySubkey) var user U var uPSK [KeySaltSize]byte @@ -122,6 +124,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta s.constructor(common.Dup(requestKey)), MaxPacketSize, ) + runtime.KeepAlive(requestSalt) headerType, err := rw.ReadByte(reader) if err != nil { @@ -220,6 +223,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad session.remoteSessionId = sessionId key := Blake3DeriveKey(uPSK[:], packetHeader[:8], s.keyLength) session.remoteCipher = s.constructor(common.Dup(key)) + runtime.KeepAlive(key) } goto process @@ -299,5 +303,6 @@ func (m *MultiService[U]) newUDPSession(uPSK [KeySaltSize]byte) *serverUDPSessio binary.BigEndian.PutUint64(sessionId, session.sessionId) key := Blake3DeriveKey(uPSK[:], sessionId, m.keyLength) session.cipher = m.constructor(common.Dup(key)) + runtime.KeepAlive(key) return session } diff --git a/protocol/socks5/conn.go b/protocol/socks5/conn.go index 72b87ca..9a8b91a 100644 --- a/protocol/socks5/conn.go +++ b/protocol/socks5/conn.go @@ -2,6 +2,7 @@ package socks5 import ( "net" + "runtime" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -44,6 +45,7 @@ func (c *AssociateConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) common.Must(buffer.WriteZeroN(3)) err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) @@ -66,6 +68,7 @@ func (c *AssociateConn) Read(b []byte) (n int, err error) { func (c *AssociateConn) Write(b []byte) (n int, err error) { _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) common.Must(buffer.WriteZeroN(3)) err = AddressSerializer.WriteAddrPort(buffer, c.dest) @@ -134,6 +137,7 @@ func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err erro func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) common.Must(buffer.WriteZeroN(3)) @@ -156,6 +160,7 @@ func (c *AssociatePacketConn) Read(b []byte) (n int, err error) { func (c *AssociatePacketConn) Write(b []byte) (n int, err error) { _buffer := buf.StackNew() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) common.Must(buffer.WriteZeroN(3)) diff --git a/protocol/trojan/protocol.go b/protocol/trojan/protocol.go index 21adb2a..f8c37a2 100644 --- a/protocol/trojan/protocol.go +++ b/protocol/trojan/protocol.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "io" "net" + "runtime" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -149,9 +150,11 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr var writeHeader bool if len(payload) > 0 && headerLen+len(payload) < 65535 { buffer := buf.Make(headerLen + len(payload)) + defer runtime.KeepAlive(buffer) header = buf.With(common.Dup(buffer)) } else { buffer := buf.Make(headerLen) + defer runtime.KeepAlive(buffer) header = buf.With(common.Dup(buffer)) writeHeader = true } @@ -185,6 +188,7 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Soc header = buf.With(payload.ExtendHeader(headerLen)) } else { buffer := buf.Make(headerLen) + defer runtime.KeepAlive(buffer) header = buf.With(common.Dup(buffer)) writeHeader = true } @@ -246,6 +250,7 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) err } else { writeHeader = true _buffer := buf.Make(headerOverload) + defer runtime.KeepAlive(_buffer) header = buf.With(common.Dup(_buffer)) } common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination)) diff --git a/transport/mixed/listener.go b/transport/mixed/listener.go index faf4149..1ba16c6 100644 --- a/transport/mixed/listener.go +++ b/transport/mixed/listener.go @@ -7,6 +7,7 @@ import ( "net" netHttp "net/http" "net/netip" + "runtime" "strings" "github.com/sagernet/sing" @@ -99,6 +100,7 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M. if reader.Buffered() > 0 { _buffer := buf.StackNewSize(reader.Buffered()) + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) _, err = buffer.ReadFullFrom(reader, reader.Buffered()) if err != nil { diff --git a/transport/udp/udp.go b/transport/udp/udp.go index d26f94e..73b3e47 100644 --- a/transport/udp/udp.go +++ b/transport/udp/udp.go @@ -3,6 +3,7 @@ package udp import ( "net" "net/netip" + "runtime" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -91,6 +92,7 @@ func (l *Listener) Close() error { func (l *Listener) loop() { _buffer := buf.StackNewMax() + defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader).Slice() if !l.tproxy { @@ -111,6 +113,7 @@ func (l *Listener) loop() { } } else { _oob := make([]byte, 1024) + defer runtime.KeepAlive(_oob) oob := common.Dup(_oob) for { n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(data, oob)