Refine buffer

This commit is contained in:
世界 2022-04-22 17:11:24 +08:00
parent 603c62165e
commit 63ef20617a
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
13 changed files with 179 additions and 186 deletions

View file

@ -81,8 +81,8 @@ func (c *noneConn) Write(b []byte) (n int, err error) {
return 0, c.clientHandshake()
}
buffer := buf.New()
defer buffer.Release()
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
@ -138,7 +138,8 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
_header := buf.StackNew()
header := common.Dup(_header)
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
if err != nil {
header.Release()

View file

@ -79,9 +79,9 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader, replay
}
func Kdf(key, iv []byte, keyLength int) []byte {
subKey := make([]byte, keyLength)
subKey := buf.Make(keyLength)
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
common.Must1(io.ReadFull(kdf, subKey))
common.Must1(io.ReadFull(kdf, common.Dup(subKey)))
return subKey
}
@ -111,8 +111,8 @@ func (m *Method) KeyLength() int {
}
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
saltBuffer := buf.Make(m.keySaltLength)
salt := common.Dup(saltBuffer)
_salt := buf.Make(m.keySaltLength)
salt := common.Dup(_salt)
_, err := io.ReadFull(upstream, salt)
if err != nil {
return nil, E.Cause(err, "read salt")
@ -122,18 +122,20 @@ func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
return nil, E.New("salt not unique")
}
}
return NewReader(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil
key := Kdf(m.key, salt, m.keySaltLength)
return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
}
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
saltBuffer := buf.Make(m.keySaltLength)
salt := common.Dup(saltBuffer)
_salt := buf.Make(m.keySaltLength)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(m.secureRNG, salt))
_, err := upstream.Write(salt)
if err != nil {
return nil, err
}
return NewWriter(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil
key := Kdf(m.key, salt, m.keySaltLength)
return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
}
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
@ -154,11 +156,12 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
return &aeadPacketConn{conn, m}
return &clientPacketConn{conn, m}
}
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength))
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
buffer.Extend(c.Overhead())
return nil
@ -168,7 +171,8 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error {
if buffer.Len() < m.keySaltLength {
return E.New("bad packet")
}
c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength))
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
if err != nil {
return err
@ -190,8 +194,8 @@ type clientConn struct {
}
func (c *clientConn) writeRequest(payload []byte) error {
request := buf.New()
defer request.Release()
_request := buf.StackNew()
request := common.Dup(_request)
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
@ -207,8 +211,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
)
if len(payload) > 0 {
header := buf.New()
defer header.Release()
_header := buf.StackNew()
header := common.Dup(_header)
writer = &buf.BufferedWriter{
Writer: writer,
@ -240,23 +244,26 @@ func (c *clientConn) writeRequest(payload []byte) error {
}
func (c *clientConn) readResponse() error {
if c.reader == nil {
salt := make([]byte, c.method.keySaltLength)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
if c.method.replayFilter != nil {
if !c.method.replayFilter.Check(salt) {
return E.New("salt not unique")
}
}
c.reader = NewReader(
c.Conn,
c.method.constructor(Kdf(c.method.key, salt, c.method.keySaltLength)),
MaxPacketSize,
)
if c.reader != nil {
return nil
}
_salt := buf.Make(c.method.keySaltLength)
salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
if c.method.replayFilter != nil {
if !c.method.replayFilter.Check(salt) {
return E.New("salt not unique")
}
}
key := Kdf(c.method.key, salt, c.method.keySaltLength)
c.reader = NewReader(
c.Conn,
c.method.constructor(common.Dup(key)),
MaxPacketSize,
)
return nil
}
@ -300,14 +307,14 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
return c.writer.(io.ReaderFrom).ReadFrom(r)
}
type aeadPacketConn struct {
type clientPacketConn struct {
net.Conn
method *Method
}
func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
_header := buf.StackNew()
header := common.Dup(_header)
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
@ -321,7 +328,7 @@ func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort
return common.Error(c.Write(buffer.Bytes()))
}
func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, err

View file

@ -22,21 +22,18 @@ import (
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"golang.org/x/crypto/chacha20"
"golang.org/x/crypto/chacha20poly1305"
wgReplay "golang.zx2c4.com/wireguard/replay"
"lukechampine.com/blake3"
)
const (
HeaderTypeClient = 0
HeaderTypeServer = 1
MaxPaddingLength = 900
KeySaltSize = 32
PacketNonceSize = 24
MinRequestHeaderSize = 1 + 8
MinResponseHeaderSize = MinRequestHeaderSize + KeySaltSize
MaxPacketSize = 65535 + shadowaead.PacketLengthBufferSize + nonceSize*2
HeaderTypeClient = 0
HeaderTypeServer = 1
MaxPaddingLength = 900
KeySaltSize = 32
PacketNonceSize = 24
MaxPacketSize = 65535
)
const (
@ -106,7 +103,6 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
case "2022-blake3-chacha20-poly1305":
m.keyLength = 32
m.constructor = newChacha20Poly1305
m.streamConstructor = newChacha20
m.udpCipher = newXChacha20Poly1305(m.psk)
}
return m, nil
@ -135,12 +131,6 @@ func newAESGCM(key []byte) cipher.AEAD {
return aead
}
func newChacha20(key []byte) cipher.Stream {
_nonce := make([]byte, chacha20.NonceSize)
stream, _ := chacha20.NewUnauthenticatedCipher(key, common.Dup(_nonce))
return stream
}
func newChacha20Poly1305(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.New(key)
common.Must(err)
@ -154,18 +144,17 @@ func newXChacha20Poly1305(key []byte) cipher.AEAD {
}
type Method struct {
name string
keyLength int
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
streamConstructor func(key []byte) cipher.Stream
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
psk []byte
pskList [][]byte
pskHash []byte
secureRNG io.Reader
replayFilter replay.Filter
name string
keyLength int
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
psk []byte
pskList [][]byte
pskHash []byte
secureRNG io.Reader
replayFilter replay.Filter
}
func (m *Method) Name() string {
@ -176,30 +165,6 @@ func (m *Method) KeyLength() int {
return m.keyLength
}
func (m *Method) WriteExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
pskLen := len(m.pskList)
if pskLen < 2 {
return
}
for i, psk := range m.pskList {
keyMaterial := make([]byte, 2*KeySaltSize)
copy(keyMaterial, psk)
copy(keyMaterial[KeySaltSize:], salt)
_identitySubkey := buf.Make(m.keyLength)
identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
if m.blockConstructor != nil {
m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash)
} else {
m.streamConstructor(identitySubkey).XORKeyStream(request.Extend(16), pskHash)
}
if i == pskLen-2 {
break
}
}
}
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
@ -236,18 +201,38 @@ type clientConn struct {
writer io.Writer
}
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
pskLen := len(m.pskList)
if pskLen < 2 {
return
}
for i, psk := range m.pskList {
keyMaterial := make([]byte, 2*KeySaltSize)
copy(keyMaterial, psk)
copy(keyMaterial[KeySaltSize:], salt)
_identitySubkey := buf.Make(m.keyLength)
identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash)
if i == pskLen-2 {
break
}
}
}
func (c *clientConn) writeRequest(payload []byte) error {
request := buf.New()
defer request.Release()
_request := buf.StackNew()
request := common.Dup(_request)
salt := make([]byte, KeySaltSize)
common.Must1(io.ReadFull(c.method.secureRNG, salt))
common.Must1(request.Write(salt))
c.method.WriteExtendedIdentityHeaders(request, salt)
c.method.writeExtendedIdentityHeaders(request, salt)
var writer io.Writer = c.Conn
var writer io.Writer
writer = &buf.BufferedWriter{
Writer: writer,
Writer: c.Conn,
Buffer: request,
}
@ -258,8 +243,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
MaxPacketSize,
)
header := buf.New()
defer header.Release()
_header := buf.StackNew()
header := common.Dup(_header)
writer = &buf.BufferedWriter{
Writer: writer,
@ -362,6 +347,7 @@ func (c *clientConn) readResponse() error {
return ErrBadRequestSalt
}
c.requestSalt = nil
c.reader = reader
return nil
}
@ -417,23 +403,14 @@ type clientPacketConn struct {
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
_header := buf.StackNew()
header := common.Dup(_header)
pskLen := len(c.method.pskList)
var dataIndex int
if c.method.udpCipher != nil {
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
if pskLen > 1 {
for i, psk := range c.method.pskList {
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
identityHeader := header.Extend(aes.BlockSize)
for textI := 0; textI < aes.BlockSize; textI++ {
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
}
c.method.streamConstructor(psk).XORKeyStream(identityHeader, identityHeader)
if i == pskLen-2 {
break
}
}
panic("unsupported chacha extended header")
}
dataIndex = buffer.Len()
} else {