mirror of
https://github.com/SagerNet/sing-shadowsocks.git
synced 2025-04-03 20:07:40 +03:00
Add support for safe buffer
This commit is contained in:
parent
6fa09df37d
commit
c8403614f5
10 changed files with 72 additions and 64 deletions
|
@ -7,7 +7,6 @@ import (
|
|||
"crypto/sha1"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
|
||||
"github.com/sagernet/sing-shadowsocks"
|
||||
"github.com/sagernet/sing/common"
|
||||
|
@ -73,7 +72,7 @@ func Kdf(key, iv []byte, keyLength int) []byte {
|
|||
info := []byte("ss-subkey")
|
||||
subKey := buf.Make(keyLength)
|
||||
kdf := hkdf.New(sha1.New, key, iv, common.Dup(info))
|
||||
runtime.KeepAlive(info)
|
||||
common.KeepAlive(info)
|
||||
common.Must1(io.ReadFull(kdf, common.Dup(subKey)))
|
||||
return subKey
|
||||
}
|
||||
|
@ -103,20 +102,20 @@ func (m *Method) KeyLength() int {
|
|||
|
||||
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
|
||||
_salt := buf.Make(m.keySaltLength)
|
||||
defer runtime.KeepAlive(_salt)
|
||||
defer common.KeepAlive(_salt)
|
||||
salt := common.Dup(_salt)
|
||||
_, err := io.ReadFull(upstream, salt)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read salt")
|
||||
}
|
||||
key := Kdf(m.key, salt, m.keySaltLength)
|
||||
defer runtime.KeepAlive(key)
|
||||
defer common.KeepAlive(key)
|
||||
return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
|
||||
}
|
||||
|
||||
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
|
||||
_salt := buf.Make(m.keySaltLength)
|
||||
defer runtime.KeepAlive(_salt)
|
||||
defer common.KeepAlive(_salt)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(rand.Reader, salt))
|
||||
_, err := upstream.Write(salt)
|
||||
|
@ -151,7 +150,7 @@ func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
|
|||
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
||||
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||
c := m.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
common.KeepAlive(key)
|
||||
c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
buffer.Extend(Overhead)
|
||||
return nil
|
||||
|
@ -163,7 +162,7 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error {
|
|||
}
|
||||
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||
c := m.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
common.KeepAlive(key)
|
||||
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -182,18 +181,18 @@ type clientConn struct {
|
|||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
_salt := make([]byte, c.method.keySaltLength)
|
||||
_salt := buf.Make(c.method.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(rand.Reader, salt))
|
||||
|
||||
key := Kdf(c.method.key, salt, c.method.keySaltLength)
|
||||
runtime.KeepAlive(_salt)
|
||||
common.KeepAlive(_salt)
|
||||
writer := NewWriter(
|
||||
c.Conn,
|
||||
c.method.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
runtime.KeepAlive(key)
|
||||
common.KeepAlive(key)
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
bufferedWriter := writer.BufferedWriter(header.Len())
|
||||
|
@ -229,14 +228,14 @@ func (c *clientConn) readResponse() error {
|
|||
return nil
|
||||
}
|
||||
_salt := buf.Make(c.method.keySaltLength)
|
||||
defer runtime.KeepAlive(_salt)
|
||||
defer common.KeepAlive(_salt)
|
||||
salt := common.Dup(_salt)
|
||||
_, err := io.ReadFull(c.Conn, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := Kdf(c.method.key, salt, c.method.keySaltLength)
|
||||
defer runtime.KeepAlive(key)
|
||||
defer common.KeepAlive(key)
|
||||
c.reader = NewReader(
|
||||
c.Conn,
|
||||
c.method.constructor(common.Dup(key)),
|
||||
|
@ -288,6 +287,7 @@ type clientPacketConn struct {
|
|||
}
|
||||
|
||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
defer buffer.Release()
|
||||
header := buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))
|
||||
common.Must1(io.ReadFull(rand.Reader, header[:c.keySaltLength]))
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[c.keySaltLength:]), destination)
|
||||
|
@ -334,9 +334,10 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
|
|||
}
|
||||
|
||||
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
_buffer := buf.StackNew()
|
||||
defer runtime.KeepAlive(_buffer)
|
||||
_buffer := buf.StackNewPacket()
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
buffer.WriteRandom(c.keySaltLength)
|
||||
err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
|
||||
if err != nil {
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-shadowsocks"
|
||||
|
@ -85,7 +84,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M
|
|||
|
||||
func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
_header := buf.Make(s.keySaltLength + PacketLengthBufferSize + Overhead)
|
||||
defer runtime.KeepAlive(_header)
|
||||
defer common.KeepAlive(_header)
|
||||
header := common.Dup(_header)
|
||||
|
||||
n, err := conn.Read(header)
|
||||
|
@ -132,14 +131,14 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
|
|||
common.Must1(io.ReadFull(rand.Reader, salt))
|
||||
|
||||
key := Kdf(c.key, salt, c.keySaltLength)
|
||||
runtime.KeepAlive(_salt)
|
||||
common.KeepAlive(_salt)
|
||||
|
||||
writer := NewWriter(
|
||||
c.Conn,
|
||||
c.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
runtime.KeepAlive(key)
|
||||
common.KeepAlive(key)
|
||||
|
||||
header := writer.Buffer()
|
||||
header.Write(salt)
|
||||
|
@ -203,7 +202,7 @@ func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.
|
|||
}
|
||||
key := Kdf(s.key, buffer.To(s.keySaltLength), s.keySaltLength)
|
||||
c := s.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
common.KeepAlive(key)
|
||||
packet, err := c.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(s.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -235,11 +234,12 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks
|
|||
common.Must1(io.ReadFull(rand.Reader, header[:w.keySaltLength]))
|
||||
err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
return err
|
||||
}
|
||||
key := Kdf(w.key, buffer.To(w.keySaltLength), w.keySaltLength)
|
||||
c := w.constructor(common.Dup(key))
|
||||
runtime.KeepAlive(key)
|
||||
common.KeepAlive(key)
|
||||
c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil)
|
||||
buffer.Extend(Overhead)
|
||||
return w.PacketConn.WritePacket(buffer, w.source)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue