Add support for safe buffer

This commit is contained in:
世界 2022-05-27 19:41:26 +08:00
parent 6fa09df37d
commit c8403614f5
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 72 additions and 64 deletions

View file

@ -7,7 +7,6 @@ import (
"crypto/sha1"
"io"
"net"
"runtime"
"github.com/sagernet/sing-shadowsocks"
"github.com/sagernet/sing/common"
@ -73,7 +72,7 @@ func Kdf(key, iv []byte, keyLength int) []byte {
info := []byte("ss-subkey")
subKey := buf.Make(keyLength)
kdf := hkdf.New(sha1.New, key, iv, common.Dup(info))
runtime.KeepAlive(info)
common.KeepAlive(info)
common.Must1(io.ReadFull(kdf, common.Dup(subKey)))
return subKey
}
@ -103,20 +102,20 @@ func (m *Method) KeyLength() int {
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
_salt := buf.Make(m.keySaltLength)
defer runtime.KeepAlive(_salt)
defer common.KeepAlive(_salt)
salt := common.Dup(_salt)
_, err := io.ReadFull(upstream, salt)
if err != nil {
return nil, E.Cause(err, "read salt")
}
key := Kdf(m.key, salt, m.keySaltLength)
defer runtime.KeepAlive(key)
defer common.KeepAlive(key)
return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
}
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
_salt := buf.Make(m.keySaltLength)
defer runtime.KeepAlive(_salt)
defer common.KeepAlive(_salt)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(rand.Reader, salt))
_, err := upstream.Write(salt)
@ -151,7 +150,7 @@ func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
runtime.KeepAlive(key)
common.KeepAlive(key)
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
buffer.Extend(Overhead)
return nil
@ -163,7 +162,7 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error {
}
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
runtime.KeepAlive(key)
common.KeepAlive(key)
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
if err != nil {
return err
@ -182,18 +181,18 @@ type clientConn struct {
}
func (c *clientConn) writeRequest(payload []byte) error {
_salt := make([]byte, c.method.keySaltLength)
_salt := buf.Make(c.method.keySaltLength)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(rand.Reader, salt))
key := Kdf(c.method.key, salt, c.method.keySaltLength)
runtime.KeepAlive(_salt)
common.KeepAlive(_salt)
writer := NewWriter(
c.Conn,
c.method.constructor(common.Dup(key)),
MaxPacketSize,
)
runtime.KeepAlive(key)
common.KeepAlive(key)
header := writer.Buffer()
header.Write(salt)
bufferedWriter := writer.BufferedWriter(header.Len())
@ -229,14 +228,14 @@ func (c *clientConn) readResponse() error {
return nil
}
_salt := buf.Make(c.method.keySaltLength)
defer runtime.KeepAlive(_salt)
defer common.KeepAlive(_salt)
salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
key := Kdf(c.method.key, salt, c.method.keySaltLength)
defer runtime.KeepAlive(key)
defer common.KeepAlive(key)
c.reader = NewReader(
c.Conn,
c.method.constructor(common.Dup(key)),
@ -288,6 +287,7 @@ type clientPacketConn struct {
}
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
header := buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))
common.Must1(io.ReadFull(rand.Reader, header[:c.keySaltLength]))
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
@ -334,9 +334,10 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
}
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
_buffer := buf.StackNew()
defer runtime.KeepAlive(_buffer)
_buffer := buf.StackNewPacket()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
buffer.WriteRandom(c.keySaltLength)
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
if err != nil {

View file

@ -7,7 +7,6 @@ import (
"io"
"net"
"net/netip"
"runtime"
"sync"
"github.com/sagernet/sing-shadowsocks"
@ -85,7 +84,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
_header := buf.Make(s.keySaltLength + PacketLengthBufferSize + Overhead)
defer runtime.KeepAlive(_header)
defer common.KeepAlive(_header)
header := common.Dup(_header)
n, err := conn.Read(header)
@ -132,14 +131,14 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
common.Must1(io.ReadFull(rand.Reader, salt))
key := Kdf(c.key, salt, c.keySaltLength)
runtime.KeepAlive(_salt)
common.KeepAlive(_salt)
writer := NewWriter(
c.Conn,
c.constructor(common.Dup(key)),
MaxPacketSize,
)
runtime.KeepAlive(key)
common.KeepAlive(key)
header := writer.Buffer()
header.Write(salt)
@ -203,7 +202,7 @@ func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.
}
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
c := s.constructor(common.Dup(key))
runtime.KeepAlive(key)
common.KeepAlive(key)
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
if err != nil {
return err
@ -235,11 +234,12 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
common.Must1(io.ReadFull(rand.Reader, header[:w.keySaltLength]))
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
if err != nil {
buffer.Release()
return err
}
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
c := w.constructor(common.Dup(key))
runtime.KeepAlive(key)
common.KeepAlive(key)
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
buffer.Extend(Overhead)
return w.PacketConn.WritePacket(buffer, w.source)