Fix buffer usage

This commit is contained in:
世界 2022-05-07 17:08:57 +08:00
parent 6089c358c2
commit f1b87be6e4
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
21 changed files with 121 additions and 24 deletions

View file

@ -4,9 +4,6 @@ import (
"archive/tar" "archive/tar"
_ "embed" _ "embed"
"encoding/hex" "encoding/hex"
"github.com/sagernet/sing"
"github.com/sagernet/sing/common/log"
"github.com/spf13/cobra"
"io" "io"
"os" "os"
"os/exec" "os/exec"
@ -15,17 +12,22 @@ import (
"strings" "strings"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/sagernet/sing"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/log"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/u-root/u-root/pkg/ldd" "github.com/u-root/u-root/pkg/ldd"
) )
var logger = log.NewLogger("libpack") var logger = log.NewLogger("libpack")
var packageName string var (
var executablePath string packageName string
var outputPath string executablePath string
outputPath string
)
func main() { func main() {
command := &cobra.Command{ command := &cobra.Command{

View file

@ -6,6 +6,7 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"runtime"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -129,6 +130,7 @@ func testSocksUDP(server M.Socksaddr) error {
Port: 53, Port: 53,
})) }))
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
common.Must2(buffer.ReadPacketFrom(assConn)) common.Must2(buffer.ReadPacketFrom(assConn))
common.Must(message.Unpack(buffer.Bytes())) common.Must(message.Unpack(buffer.Bytes()))

View file

@ -10,6 +10,7 @@ import (
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
"runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
"syscall" "syscall"
@ -350,7 +351,7 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
if err != nil { if err != nil {
return E.Cause(err, "client handshake") return E.Cause(err, "client handshake")
} }
runtime.KeepAlive(_payload)
return rw.CopyConn(ctx, serverConn, conn) return rw.CopyConn(ctx, serverConn, conn)
} }

View file

@ -9,6 +9,7 @@ import (
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
"runtime"
"syscall" "syscall"
"time" "time"
@ -315,7 +316,7 @@ func (c *client) NewConnection(ctx context.Context, conn net.Conn, metadata M.Me
if err != nil { if err != nil {
return E.Cause(err, "client handshake") return E.Cause(err, "client handshake")
} }
runtime.KeepAlive(_request)
return rw.CopyConn(ctx, clientConn, conn) return rw.CopyConn(ctx, clientConn, conn)
} }

View file

@ -119,7 +119,7 @@ func (b *Buffer) ExtendHeader(size int) []byte {
} }
} }
func (b *Buffer) WriteBufferAtFirst(buffer *Buffer) *Buffer { func (b *Buffer) _WriteBufferAtFirst(buffer *Buffer) *Buffer {
size := buffer.Len() size := buffer.Len()
if b.start >= size { if b.start >= size {
n := copy(b.data[b.start-size:b.start], buffer.Bytes()) n := copy(b.data[b.start-size:b.start], buffer.Bytes())
@ -140,7 +140,7 @@ func (b *Buffer) WriteBufferAtFirst(buffer *Buffer) *Buffer {
} }
} }
func (b *Buffer) WriteAtFirst(data []byte) (n int, err error) { func (b *Buffer) _WriteAtFirst(data []byte) (n int, err error) {
size := len(data) size := len(data)
if b.start >= size { if b.start >= size {
n = copy(b.data[b.start-size:b.start], data) n = copy(b.data[b.start-size:b.start], data)

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"net" "net"
"os" "os"
"runtime"
"time" "time"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -63,6 +64,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
defer rw.CloseRead(conn) defer rw.CloseRead(conn)
defer rw.CloseWrite(dest) defer rw.CloseWrite(dest)
_buffer := buf.StackNewMax() _buffer := buf.StackNewMax()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader) data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
for { for {
@ -81,6 +83,7 @@ func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error
defer rw.CloseRead(dest) defer rw.CloseRead(dest)
defer rw.CloseWrite(conn) defer rw.CloseWrite(conn)
_buffer := buf.StackNewMax() _buffer := buf.StackNewMax()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader) data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader)
for { for {

View file

@ -5,6 +5,7 @@ import (
"io" "io"
"net" "net"
"os" "os"
"runtime"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -69,6 +70,7 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
return rt.ReadFrom(src) return rt.ReadFrom(src)
} }
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
for { for {
buffer.FullReset() buffer.FullReset()
@ -89,6 +91,7 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error { func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
return task.Run(ctx, func() error { return task.Run(ctx, func() error {
_buffer := buf.With(make([]byte, buf.UDPBufferSize)) _buffer := buf.With(make([]byte, buf.UDPBufferSize))
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
for { for {
n, addr, err := conn.ReadFrom(buffer.FreeBytes()) n, addr, err := conn.ReadFrom(buffer.FreeBytes())
@ -104,6 +107,7 @@ func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.
} }
}, func() error { }, func() error {
_buffer := buf.With(make([]byte, buf.UDPBufferSize)) _buffer := buf.With(make([]byte, buf.UDPBufferSize))
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
for { for {
n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes()) n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes())

View file

@ -2,6 +2,7 @@ package rw
import ( import (
"io" "io"
"runtime"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -63,6 +64,7 @@ func ReadFrom0(readerFrom ReaderFromWriter, reader io.Reader) (n int64, err erro
func CopyOnce(dest io.Writer, src io.Reader) (n int64, err error) { func CopyOnce(dest io.Writer, src io.Reader) (n int64, err error) {
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
n, err = buffer.ReadFrom(src) n, err = buffer.ReadFrom(src)
if err != nil { if err != nil {

View file

@ -5,6 +5,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"runtime"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -88,6 +89,7 @@ func (t *Stack) Close() error {
func (t *Stack) tunLoop() { func (t *Stack) tunLoop() {
_buffer := buf.Make(t.tunMtu) _buffer := buf.Make(t.tunMtu)
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
for { for {
n, err := t.tunFile.Read(buffer) n, err := t.tunFile.Read(buffer)

View file

@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
"runtime"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -41,6 +42,7 @@ func (c *ServerConn) RemoteAddr() net.Addr {
func (c *ServerConn) loopInput() { func (c *ServerConn) loopInput() {
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
for { for {
destination, err := AddrParser.ReadAddrPort(c.inputReader) destination, err := AddrParser.ReadAddrPort(c.inputReader)
@ -74,6 +76,7 @@ func (c *ServerConn) loopInput() {
func (c *ServerConn) loopOutput() { func (c *ServerConn) loopOutput() {
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
for { for {
buffer.FullReset() buffer.FullReset()

View file

@ -2,6 +2,7 @@ package uot
import ( import (
"net" "net"
"runtime"
"testing" "testing"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -30,6 +31,7 @@ func TestServerConn(t *testing.T) {
Port: 53, Port: 53,
})) }))
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
common.Must2(buffer.ReadPacketFrom(clientConn)) common.Must2(buffer.ReadPacketFrom(clientConn))
common.Must(message.Unpack(buffer.Bytes())) common.Must(message.Unpack(buffer.Bytes()))

View file

@ -5,6 +5,7 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -95,6 +96,7 @@ func (c *noneConn) Write(b []byte) (n int, err error) {
bufN, _ := buffer.Write(b) bufN, _ := buffer.Write(b)
_, err = c.Conn.Write(buffer.Bytes()) _, err = c.Conn.Write(buffer.Bytes())
runtime.KeepAlive(_buffer)
if err != nil { if err != nil {
return return
} }
@ -141,17 +143,27 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return socks5.AddressSerializer.ReadAddrPort(buffer) return socks5.AddressSerializer.ReadAddrPort(buffer)
} }
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort M.Socksaddr) error { func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release() headerLen := socks5.AddressSerializer.AddrPortLen(destination)
_header := buf.StackNewMax() var header *buf.Buffer
header := common.Dup(_header) var writeHeader bool
err := socks5.AddressSerializer.WriteAddrPort(header, addrPort) if buffer.Start() >= headerLen {
header = buf.With(buffer.ExtendHeader(headerLen))
} else {
_buffer := buf.StackNewSize(buffer.Len() + headerLen)
defer runtime.KeepAlive(_buffer)
header = common.Dup(_buffer)
writeHeader = true
}
err := socks5.AddressSerializer.WriteAddrPort(header, destination)
if err != nil { if err != nil {
header.Release()
return err return err
} }
buffer = buffer.WriteBufferAtFirst(header) if writeHeader {
return common.Error(buffer.WriteTo(c)) return common.Error(header.WriteTo(c))
} else {
return common.Error(buffer.WriteTo(c))
}
} }
type NoneService struct { type NoneService struct {

View file

@ -6,6 +6,7 @@ import (
"crypto/sha1" "crypto/sha1"
"io" "io"
"net" "net"
"runtime"
"sync" "sync"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -80,8 +81,10 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader, replay
} }
func Kdf(key, iv []byte, keyLength int) []byte { func Kdf(key, iv []byte, keyLength int) []byte {
info := []byte("ss-subkey")
subKey := buf.Make(keyLength) subKey := buf.Make(keyLength)
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey")) kdf := hkdf.New(sha1.New, key, iv, common.Dup(info))
runtime.KeepAlive(info)
common.Must1(io.ReadFull(kdf, common.Dup(subKey))) common.Must1(io.ReadFull(kdf, common.Dup(subKey)))
return subKey return subKey
} }
@ -113,6 +116,7 @@ func (m *Method) KeyLength() int {
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) { func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
_salt := buf.Make(m.keySaltLength) _salt := buf.Make(m.keySaltLength)
defer runtime.KeepAlive(_salt)
salt := common.Dup(_salt) salt := common.Dup(_salt)
_, err := io.ReadFull(upstream, salt) _, err := io.ReadFull(upstream, salt)
if err != nil { if err != nil {
@ -124,11 +128,13 @@ func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
} }
} }
key := Kdf(m.key, salt, m.keySaltLength) key := Kdf(m.key, salt, m.keySaltLength)
defer runtime.KeepAlive(key)
return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
} }
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) { func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
_salt := buf.Make(m.keySaltLength) _salt := buf.Make(m.keySaltLength)
defer runtime.KeepAlive(_salt)
salt := common.Dup(_salt) salt := common.Dup(_salt)
common.Must1(io.ReadFull(m.secureRNG, salt)) common.Must1(io.ReadFull(m.secureRNG, salt))
_, err := upstream.Write(salt) _, err := upstream.Write(salt)
@ -163,6 +169,7 @@ func (m *Method) DialPacketConn(conn net.Conn) N.PacketConn {
func (m *Method) EncodePacket(buffer *buf.Buffer) error { func (m *Method) EncodePacket(buffer *buf.Buffer) error {
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key)) c := m.constructor(common.Dup(key))
runtime.KeepAlive(key)
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
buffer.Extend(c.Overhead()) buffer.Extend(c.Overhead())
return nil return nil
@ -174,6 +181,7 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error {
} }
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key)) c := m.constructor(common.Dup(key))
runtime.KeepAlive(key)
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
if err != nil { if err != nil {
return err return err
@ -200,11 +208,13 @@ func (c *clientConn) writeRequest(payload []byte) error {
common.Must1(io.ReadFull(c.method.secureRNG, salt)) common.Must1(io.ReadFull(c.method.secureRNG, salt))
key := Kdf(c.method.key, salt, c.method.keySaltLength) key := Kdf(c.method.key, salt, c.method.keySaltLength)
runtime.KeepAlive(_salt)
writer := NewWriter( writer := NewWriter(
c.Conn, c.Conn,
c.method.constructor(common.Dup(key)), c.method.constructor(common.Dup(key)),
MaxPacketSize, MaxPacketSize,
) )
runtime.KeepAlive(key)
header := writer.Buffer() header := writer.Buffer()
header.Write(salt) header.Write(salt)
bufferedWriter := writer.BufferedWriter(header.Len()) bufferedWriter := writer.BufferedWriter(header.Len())
@ -240,6 +250,7 @@ func (c *clientConn) readResponse() error {
return nil return nil
} }
_salt := buf.Make(c.method.keySaltLength) _salt := buf.Make(c.method.keySaltLength)
defer runtime.KeepAlive(_salt)
salt := common.Dup(_salt) salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt) _, err := io.ReadFull(c.Conn, salt)
if err != nil { if err != nil {
@ -251,6 +262,7 @@ func (c *clientConn) readResponse() error {
} }
} }
key := Kdf(c.method.key, salt, c.method.keySaltLength) key := Kdf(c.method.key, salt, c.method.keySaltLength)
defer runtime.KeepAlive(key)
c.reader = NewReader( c.reader = NewReader(
c.Conn, c.Conn,
c.method.constructor(common.Dup(key)), c.method.constructor(common.Dup(key)),

View file

@ -6,6 +6,7 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -89,6 +90,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
_salt := buf.Make(s.keySaltLength) _salt := buf.Make(s.keySaltLength)
defer runtime.KeepAlive(_salt)
salt := common.Dup(_salt) salt := common.Dup(_salt)
_, err := io.ReadFull(conn, salt) _, err := io.ReadFull(conn, salt)
@ -127,11 +129,14 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
common.Must1(io.ReadFull(c.secureRNG, salt)) common.Must1(io.ReadFull(c.secureRNG, salt))
key := Kdf(c.key, salt, c.keySaltLength) key := Kdf(c.key, salt, c.keySaltLength)
runtime.KeepAlive(_salt)
writer := NewWriter( writer := NewWriter(
c.Conn, c.Conn,
c.constructor(common.Dup(key)), c.constructor(common.Dup(key)),
MaxPacketSize, MaxPacketSize,
) )
runtime.KeepAlive(key)
header := writer.Buffer() header := writer.Buffer()
header.Write(salt) header.Write(salt)
@ -213,6 +218,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me
} }
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength) key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
c := s.constructor(common.Dup(key)) c := s.constructor(common.Dup(key))
runtime.KeepAlive(key)
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil) packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil { if err != nil {
return err return err
@ -241,6 +247,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
} }
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength) key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
c := w.constructor(common.Dup(key)) c := w.constructor(common.Dup(key))
runtime.KeepAlive(key)
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil) c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
buffer.Extend(c.Overhead()) buffer.Extend(c.Overhead())
return w.PacketConn.WritePacket(buffer, w.source) return w.PacketConn.WritePacket(buffer, w.source)

View file

@ -9,6 +9,7 @@ import (
"math" "math"
"math/rand" "math/rand"
"net" "net"
"runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -231,6 +232,7 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte)
header := request.Extend(16) header := request.Extend(16)
m.blockConstructor(identitySubkey).Encrypt(header, pskHash) m.blockConstructor(identitySubkey).Encrypt(header, pskHash)
runtime.KeepAlive(_identitySubkey)
if debug.Enabled { if debug.Enabled {
logger.Trace("encoded ", buf.EncodeHexString(header)) logger.Trace("encoded ", buf.EncodeHexString(header))
} }
@ -252,12 +254,12 @@ func (c *clientConn) writeRequest(payload []byte) error {
common.Must1(io.ReadFull(c.method.secureRNG, salt)) common.Must1(io.ReadFull(c.method.secureRNG, salt))
key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength) key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength)
writer := shadowaead.NewWriter( writer := shadowaead.NewWriter(
c.Conn, c.Conn,
c.method.constructor(common.Dup(key)), c.method.constructor(common.Dup(key)),
MaxPacketSize, MaxPacketSize,
) )
runtime.KeepAlive(key)
header := writer.Buffer() header := writer.Buffer()
header.Write(salt) header.Write(salt)
@ -344,11 +346,13 @@ func (c *clientConn) readResponse() error {
} }
key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength) key := Blake3DeriveKey(c.method.psk[:], salt, c.method.keyLength)
runtime.KeepAlive(_salt)
reader := shadowaead.NewReader( reader := shadowaead.NewReader(
c.Conn, c.Conn,
c.method.constructor(common.Dup(key)), c.method.constructor(common.Dup(key)),
MaxPacketSize, MaxPacketSize,
) )
runtime.KeepAlive(key)
headerType, err := rw.ReadByte(reader) headerType, err := rw.ReadByte(reader)
if err != nil { if err != nil {
@ -385,6 +389,7 @@ func (c *clientConn) readResponse() error {
} }
return ErrBadRequestSalt return ErrBadRequestSalt
} }
runtime.KeepAlive(_requestSalt)
c.requestSalt = nil c.requestSalt = nil
c.reader = reader c.reader = reader
@ -472,10 +477,21 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
logger.Trace("begin client packet") logger.Trace("begin client packet")
} }
defer buffer.Release() var hdrLen int
_header := buf.StackNew() if c.method.udpCipher != nil {
header := common.Dup(_header) hdrLen = PacketNonceSize
}
hdrLen += 16 // packet header
pskLen := len(c.method.pskList) pskLen := len(c.method.pskList)
if c.method.udpCipher == nil && pskLen > 1 {
hdrLen += (pskLen - 1) * aes.BlockSize
}
hdrLen += 1 // header type
hdrLen += 8 // timestamp
hdrLen += 1 // padding length
hdrLen += socks5.AddressSerializer.AddrPortLen(destination)
header := buf.With(buffer.ExtendHeader(hdrLen))
var dataIndex int var dataIndex int
if c.method.udpCipher != nil { if c.method.udpCipher != nil {
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize)) common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
@ -540,7 +556,6 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
if err != nil { if err != nil {
return err return err
} }
buffer = buffer.WriteBufferAtFirst(header)
if err != nil { if err != nil {
return err return err
} }
@ -606,6 +621,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
} else { } else {
key := Blake3DeriveKey(c.method.psk[:], packetHeader[:8], c.method.keyLength) key := Blake3DeriveKey(c.method.psk[:], packetHeader[:8], c.method.keyLength)
remoteCipher = c.method.constructor(common.Dup(key)) remoteCipher = c.method.constructor(common.Dup(key))
runtime.KeepAlive(key)
} }
_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
if err != nil { if err != nil {
@ -717,6 +733,7 @@ func (m *Method) newUDPSession() *udpSession {
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength) key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key)
} }
return session return session
} }

View file

@ -8,6 +8,7 @@ import (
"io" "io"
"math" "math"
"net" "net"
"runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -110,6 +111,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M
s.constructor(common.Dup(requestKey)), s.constructor(common.Dup(requestKey)),
MaxPacketSize, MaxPacketSize,
) )
runtime.KeepAlive(requestKey)
headerType, err := rw.ReadByte(reader) headerType, err := rw.ReadByte(reader)
if err != nil { if err != nil {
@ -192,11 +194,13 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
salt := common.Dup(_salt[:]) salt := common.Dup(_salt[:])
common.Must1(io.ReadFull(c.secureRNG, salt)) common.Must1(io.ReadFull(c.secureRNG, salt))
key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength) key := Blake3DeriveKey(c.uPSK[:], salt, c.keyLength)
runtime.KeepAlive(_salt)
writer := shadowaead.NewWriter( writer := shadowaead.NewWriter(
c.Conn, c.Conn,
c.constructor(common.Dup(key)), c.constructor(common.Dup(key)),
MaxPacketSize, MaxPacketSize,
) )
runtime.KeepAlive(key)
header := writer.Buffer() header := writer.Buffer()
header.Write(salt) header.Write(salt)
bufferedWriter := writer.BufferedWriter(header.Len()) bufferedWriter := writer.BufferedWriter(header.Len())
@ -306,6 +310,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me
if packetHeader != nil { if packetHeader != nil {
key := Blake3DeriveKey(s.psk[:], packetHeader[:8], s.keyLength) key := Blake3DeriveKey(s.psk[:], packetHeader[:8], s.keyLength)
session.remoteCipher = s.constructor(common.Dup(key)) session.remoteCipher = s.constructor(common.Dup(key))
runtime.KeepAlive(key)
} }
} }
goto process goto process
@ -382,6 +387,7 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
defer buffer.Release() defer buffer.Release()
_header := buf.StackNew() _header := buf.StackNew()
defer runtime.KeepAlive(_header)
header := common.Dup(_header) header := common.Dup(_header)
var dataIndex int var dataIndex int
@ -446,6 +452,7 @@ func (m *Service) newUDPSession() *serverUDPSession {
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength) key := Blake3DeriveKey(m.psk[:], sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key)
} }
return session return session
} }

View file

@ -7,6 +7,7 @@ import (
"io" "io"
"math" "math"
"net" "net"
"runtime"
"time" "time"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
@ -106,6 +107,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
identitySubkey := common.Dup(_identitySubkey) identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader)
runtime.KeepAlive(_identitySubkey)
var user U var user U
var uPSK [KeySaltSize]byte var uPSK [KeySaltSize]byte
@ -122,6 +124,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta
s.constructor(common.Dup(requestKey)), s.constructor(common.Dup(requestKey)),
MaxPacketSize, MaxPacketSize,
) )
runtime.KeepAlive(requestSalt)
headerType, err := rw.ReadByte(reader) headerType, err := rw.ReadByte(reader)
if err != nil { if err != nil {
@ -220,6 +223,7 @@ func (s *MultiService[U]) newPacket(conn N.PacketConn, buffer *buf.Buffer, metad
session.remoteSessionId = sessionId session.remoteSessionId = sessionId
key := Blake3DeriveKey(uPSK[:], packetHeader[:8], s.keyLength) key := Blake3DeriveKey(uPSK[:], packetHeader[:8], s.keyLength)
session.remoteCipher = s.constructor(common.Dup(key)) session.remoteCipher = s.constructor(common.Dup(key))
runtime.KeepAlive(key)
} }
goto process goto process
@ -299,5 +303,6 @@ func (m *MultiService[U]) newUDPSession(uPSK [KeySaltSize]byte) *serverUDPSessio
binary.BigEndian.PutUint64(sessionId, session.sessionId) binary.BigEndian.PutUint64(sessionId, session.sessionId)
key := Blake3DeriveKey(uPSK[:], sessionId, m.keyLength) key := Blake3DeriveKey(uPSK[:], sessionId, m.keyLength)
session.cipher = m.constructor(common.Dup(key)) session.cipher = m.constructor(common.Dup(key))
runtime.KeepAlive(key)
return session return session
} }

View file

@ -2,6 +2,7 @@ package socks5
import ( import (
"net" "net"
"runtime"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -44,6 +45,7 @@ func (c *AssociateConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *AssociateConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3)) common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) err = AddressSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
@ -66,6 +68,7 @@ func (c *AssociateConn) Read(b []byte) (n int, err error) {
func (c *AssociateConn) Write(b []byte) (n int, err error) { func (c *AssociateConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3)) common.Must(buffer.WriteZeroN(3))
err = AddressSerializer.WriteAddrPort(buffer, c.dest) err = AddressSerializer.WriteAddrPort(buffer, c.dest)
@ -134,6 +137,7 @@ func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err erro
func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3)) common.Must(buffer.WriteZeroN(3))
@ -156,6 +160,7 @@ func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) { func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
_buffer := buf.StackNew() _buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
common.Must(buffer.WriteZeroN(3)) common.Must(buffer.WriteZeroN(3))

View file

@ -6,6 +6,7 @@ import (
"encoding/hex" "encoding/hex"
"io" "io"
"net" "net"
"runtime"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -149,9 +150,11 @@ func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr
var writeHeader bool var writeHeader bool
if len(payload) > 0 && headerLen+len(payload) < 65535 { if len(payload) > 0 && headerLen+len(payload) < 65535 {
buffer := buf.Make(headerLen + len(payload)) buffer := buf.Make(headerLen + len(payload))
defer runtime.KeepAlive(buffer)
header = buf.With(common.Dup(buffer)) header = buf.With(common.Dup(buffer))
} else { } else {
buffer := buf.Make(headerLen) buffer := buf.Make(headerLen)
defer runtime.KeepAlive(buffer)
header = buf.With(common.Dup(buffer)) header = buf.With(common.Dup(buffer))
writeHeader = true writeHeader = true
} }
@ -185,6 +188,7 @@ func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Soc
header = buf.With(payload.ExtendHeader(headerLen)) header = buf.With(payload.ExtendHeader(headerLen))
} else { } else {
buffer := buf.Make(headerLen) buffer := buf.Make(headerLen)
defer runtime.KeepAlive(buffer)
header = buf.With(common.Dup(buffer)) header = buf.With(common.Dup(buffer))
writeHeader = true writeHeader = true
} }
@ -246,6 +250,7 @@ func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) err
} else { } else {
writeHeader = true writeHeader = true
_buffer := buf.Make(headerOverload) _buffer := buf.Make(headerOverload)
defer runtime.KeepAlive(_buffer)
header = buf.With(common.Dup(_buffer)) header = buf.With(common.Dup(_buffer))
} }
common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination)) common.Must(socks5.AddressSerializer.WriteAddrPort(header, destination))

View file

@ -7,6 +7,7 @@ import (
"net" "net"
netHttp "net/http" netHttp "net/http"
"net/netip" "net/netip"
"runtime"
"strings" "strings"
"github.com/sagernet/sing" "github.com/sagernet/sing"
@ -99,6 +100,7 @@ func (l *Listener) NewConnection(ctx context.Context, conn net.Conn, metadata M.
if reader.Buffered() > 0 { if reader.Buffered() > 0 {
_buffer := buf.StackNewSize(reader.Buffered()) _buffer := buf.StackNewSize(reader.Buffered())
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
_, err = buffer.ReadFullFrom(reader, reader.Buffered()) _, err = buffer.ReadFullFrom(reader, reader.Buffered())
if err != nil { if err != nil {

View file

@ -3,6 +3,7 @@ package udp
import ( import (
"net" "net"
"net/netip" "net/netip"
"runtime"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -91,6 +92,7 @@ func (l *Listener) Close() error {
func (l *Listener) loop() { func (l *Listener) loop() {
_buffer := buf.StackNewMax() _buffer := buf.StackNewMax()
defer runtime.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader).Slice() data := buffer.Cut(buf.ReversedHeader, buf.ReversedHeader).Slice()
if !l.tproxy { if !l.tproxy {
@ -111,6 +113,7 @@ func (l *Listener) loop() {
} }
} else { } else {
_oob := make([]byte, 1024) _oob := make([]byte, 1024)
defer runtime.KeepAlive(_oob)
oob := common.Dup(_oob) oob := common.Dup(_oob)
for { for {
n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(data, oob) n, oobN, _, addr, err := l.ReadMsgUDPAddrPort(data, oob)