Refine buffer

This commit is contained in:
世界 2022-04-22 17:11:24 +08:00
parent 603c62165e
commit 63ef20617a
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
13 changed files with 179 additions and 186 deletions

View file

@ -141,6 +141,7 @@ func checkUpdate() {
Name: domain,
Content: content,
Proxied: &overProxy,
TTL: 60,
}
if addr.Is4() {
record.Type = "A"

View file

@ -192,7 +192,11 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
if len(pskList) > 1 {
return nil, shadowaead.ErrBadKey
}
method, err := shadowaead.New(f.Method, pskList[0], []byte(f.Password), rng, false)
var key []byte
if len(pskList) > 0 {
key = pskList[0]
}
method, err := shadowaead.New(f.Method, key, []byte(f.Password), rng, false)
if err != nil {
return nil, err
}
@ -314,27 +318,21 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
ctx := context.Background()
var serverConn net.Conn
payload := buf.New()
err := task.Run(ctx, func() error {
sc, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
serverConn = sc
if err != nil {
return E.Cause(err, "connect to server")
}
return nil
}, func() error {
err := conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
if err != nil {
return err
}
_, err = payload.ReadFrom(conn)
if err != nil && !E.IsTimeout(err) {
return E.Cause(err, "read payload")
}
err = conn.SetReadDeadline(time.Time{})
serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
if err != nil {
return E.Cause(err, "connect to server")
}
_payload := buf.StackNew()
payload := common.Dup(_payload)
err = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
if err != nil {
return err
})
}
_, err = payload.ReadFrom(conn)
if err != nil && !E.IsTimeout(err) {
return E.Cause(err, "read payload")
}
err = conn.SetReadDeadline(time.Time{})
if err != nil {
payload.Release()
return err

View file

@ -25,31 +25,15 @@ func New() *Buffer {
}
}
func NewSize(size int) *Buffer {
if size <= 128 || size > BufferSize {
return &Buffer{
data: make([]byte, size),
}
}
func StackNew() *Buffer {
return &Buffer{
data: GetBytes(),
start: ReversedHeader,
end: ReversedHeader,
managed: true,
data: make([]byte, BufferSize),
}
}
func FullNew() *Buffer {
func StackNewSize(size int) *Buffer {
return &Buffer{
data: GetBytes(),
managed: true,
}
}
func StackNew() Buffer {
return Buffer{
data: GetBytes(),
managed: true,
data: Make(size),
}
}
@ -71,20 +55,6 @@ func As(data []byte) *Buffer {
}
}
func Or(data []byte, size int) *Buffer {
max := cap(data)
if size != max {
data = data[:max]
}
if cap(data) >= size {
return &Buffer{
data: data,
}
} else {
return NewSize(size)
}
}
func With(data []byte) *Buffer {
return &Buffer{
data: data,

View file

@ -10,29 +10,41 @@ const (
var pool = sync.Pool{
New: func() any {
var buffer [BufferSize]byte
return buffer[:]
buffer := make([]byte, BufferSize)
return &buffer
},
}
func GetBytes() []byte {
return pool.Get().([]byte)
return *pool.Get().(*[]byte)
}
func PutBytes(buffer []byte) {
pool.Put(buffer)
pool.Put(&buffer)
}
func Make(size int) []byte {
var buffer []byte
if size <= 64 {
if size <= 16 {
buffer = make([]byte, 16)
} else if size <= 32 {
buffer = make([]byte, 32)
} else if size <= 64 {
buffer = make([]byte, 64)
} else if size <= 128 {
buffer = make([]byte, 128)
} else if size <= 256 {
buffer = make([]byte, 256)
} else if size <= 512 {
buffer = make([]byte, 512)
} else if size <= 1024 {
buffer = make([]byte, 1024)
} else if size <= 4096 {
buffer = make([]byte, 4096)
} else if size <= 16384 {
buffer = make([]byte, 16384)
} else if size <= 4*1024 {
buffer = make([]byte, 4*1024)
} else if size <= 16*1024 {
buffer = make([]byte, 16*1024)
} else if size <= 20*1024 {
buffer = make([]byte, 20*1024)
} else if size <= 65535 {
buffer = make([]byte, 65535)
} else {

View file

@ -1,13 +1,13 @@
package lowmem
import (
"runtime"
"runtime/debug"
)
var Enabled = false
func Free() {
if Enabled {
runtime.GC()
debug.FreeOSMemory()
}
}

View file

@ -57,8 +57,8 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
return task.Run(ctx, func() error {
buffer := buf.FullNew()
defer buffer.Release()
_buffer := buf.With(make([]byte, buf.UDPBufferSize))
buffer := common.Dup(_buffer)
for {
n, addr, err := conn.ReadFrom(buffer.FreeBytes())
if err != nil {
@ -72,8 +72,8 @@ func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.
buffer.FullReset()
}
}, func() error {
buffer := buf.FullNew()
defer buffer.Release()
_buffer := buf.With(make([]byte, buf.UDPBufferSize))
buffer := common.Dup(_buffer)
for {
n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes())
if err != nil {

11
common/rw/writev_posix.go Normal file
View file

@ -0,0 +1,11 @@
//go:build !windows
package rw
import (
"golang.org/x/sys/unix"
)
func WriteV(fd uintptr, data ...[]byte) (int, error) {
return unix.Writev(int(fd), data)
}

View file

@ -0,0 +1,16 @@
package rw
import "golang.org/x/sys/windows"
func WriteV(fd uintptr, data ...[]byte) (int, error) {
var n uint32
buffers := make([]*windows.WSABuf, len(data))
for i, buf := range data {
buffers[i] = &windows.WSABuf{
Len: uint32(len(buf)),
Buf: &buf[0],
}
}
err := windows.WSASend(windows.Handle(fd), buffers[0], uint32(len(buffers)), &n, 0, nil, nil)
return int(n), err
}

View file

@ -40,8 +40,8 @@ func (c *ServerConn) RemoteAddr() net.Addr {
}
func (c *ServerConn) loopInput() {
buffer := buf.New()
defer buffer.Release()
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
for {
destination, err := AddrParser.ReadAddrPort(c.inputReader)
if err != nil {
@ -73,8 +73,8 @@ func (c *ServerConn) loopInput() {
}
func (c *ServerConn) loopOutput() {
buffer := buf.New()
defer buffer.Release()
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
for {
buffer.FullReset()
n, addr, err := buffer.ReadPacketFrom(c)

View file

@ -29,8 +29,8 @@ func TestServerConn(t *testing.T) {
IP: net.IPv4(8, 8, 8, 8),
Port: 53,
}))
buffer := buf.New()
defer buffer.Release()
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
common.Must2(buffer.ReadPacketFrom(clientConn))
common.Must(message.Unpack(buffer.Bytes()))
for _, answer := range message.Answers {

View file

@ -81,8 +81,8 @@ func (c *noneConn) Write(b []byte) (n int, err error) {
return 0, c.clientHandshake()
}
buffer := buf.New()
defer buffer.Release()
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
if err != nil {
@ -138,7 +138,8 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
_header := buf.StackNew()
header := common.Dup(_header)
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
if err != nil {
header.Release()

View file

@ -79,9 +79,9 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader, replay
}
func Kdf(key, iv []byte, keyLength int) []byte {
subKey := make([]byte, keyLength)
subKey := buf.Make(keyLength)
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
common.Must1(io.ReadFull(kdf, subKey))
common.Must1(io.ReadFull(kdf, common.Dup(subKey)))
return subKey
}
@ -111,8 +111,8 @@ func (m *Method) KeyLength() int {
}
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
saltBuffer := buf.Make(m.keySaltLength)
salt := common.Dup(saltBuffer)
_salt := buf.Make(m.keySaltLength)
salt := common.Dup(_salt)
_, err := io.ReadFull(upstream, salt)
if err != nil {
return nil, E.Cause(err, "read salt")
@ -122,18 +122,20 @@ func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
return nil, E.New("salt not unique")
}
}
return NewReader(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil
key := Kdf(m.key, salt, m.keySaltLength)
return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
}
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
saltBuffer := buf.Make(m.keySaltLength)
salt := common.Dup(saltBuffer)
_salt := buf.Make(m.keySaltLength)
salt := common.Dup(_salt)
common.Must1(io.ReadFull(m.secureRNG, salt))
_, err := upstream.Write(salt)
if err != nil {
return nil, err
}
return NewWriter(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil
key := Kdf(m.key, salt, m.keySaltLength)
return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
}
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
@ -154,11 +156,12 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
}
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
return &aeadPacketConn{conn, m}
return &clientPacketConn{conn, m}
}
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength))
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
buffer.Extend(c.Overhead())
return nil
@ -168,7 +171,8 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error {
if buffer.Len() < m.keySaltLength {
return E.New("bad packet")
}
c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength))
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
c := m.constructor(common.Dup(key))
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
if err != nil {
return err
@ -190,8 +194,8 @@ type clientConn struct {
}
func (c *clientConn) writeRequest(payload []byte) error {
request := buf.New()
defer request.Release()
_request := buf.StackNew()
request := common.Dup(_request)
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
@ -207,8 +211,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
)
if len(payload) > 0 {
header := buf.New()
defer header.Release()
_header := buf.StackNew()
header := common.Dup(_header)
writer = &buf.BufferedWriter{
Writer: writer,
@ -240,23 +244,26 @@ func (c *clientConn) writeRequest(payload []byte) error {
}
func (c *clientConn) readResponse() error {
if c.reader == nil {
salt := make([]byte, c.method.keySaltLength)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
if c.method.replayFilter != nil {
if !c.method.replayFilter.Check(salt) {
return E.New("salt not unique")
}
}
c.reader = NewReader(
c.Conn,
c.method.constructor(Kdf(c.method.key, salt, c.method.keySaltLength)),
MaxPacketSize,
)
if c.reader != nil {
return nil
}
_salt := buf.Make(c.method.keySaltLength)
salt := common.Dup(_salt)
_, err := io.ReadFull(c.Conn, salt)
if err != nil {
return err
}
if c.method.replayFilter != nil {
if !c.method.replayFilter.Check(salt) {
return E.New("salt not unique")
}
}
key := Kdf(c.method.key, salt, c.method.keySaltLength)
c.reader = NewReader(
c.Conn,
c.method.constructor(common.Dup(key)),
MaxPacketSize,
)
return nil
}
@ -300,14 +307,14 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
return c.writer.(io.ReaderFrom).ReadFrom(r)
}
type aeadPacketConn struct {
type clientPacketConn struct {
net.Conn
method *Method
}
func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
_header := buf.StackNew()
header := common.Dup(_header)
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
err := socks.AddressSerializer.WriteAddrPort(header, destination)
if err != nil {
@ -321,7 +328,7 @@ func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort
return common.Error(c.Write(buffer.Bytes()))
}
func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
n, err := c.Read(buffer.FreeBytes())
if err != nil {
return nil, err

View file

@ -22,21 +22,18 @@ import (
"github.com/sagernet/sing/protocol/shadowsocks"
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
"github.com/sagernet/sing/protocol/socks"
"golang.org/x/crypto/chacha20"
"golang.org/x/crypto/chacha20poly1305"
wgReplay "golang.zx2c4.com/wireguard/replay"
"lukechampine.com/blake3"
)
const (
HeaderTypeClient = 0
HeaderTypeServer = 1
MaxPaddingLength = 900
KeySaltSize = 32
PacketNonceSize = 24
MinRequestHeaderSize = 1 + 8
MinResponseHeaderSize = MinRequestHeaderSize + KeySaltSize
MaxPacketSize = 65535 + shadowaead.PacketLengthBufferSize + nonceSize*2
HeaderTypeClient = 0
HeaderTypeServer = 1
MaxPaddingLength = 900
KeySaltSize = 32
PacketNonceSize = 24
MaxPacketSize = 65535
)
const (
@ -106,7 +103,6 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
case "2022-blake3-chacha20-poly1305":
m.keyLength = 32
m.constructor = newChacha20Poly1305
m.streamConstructor = newChacha20
m.udpCipher = newXChacha20Poly1305(m.psk)
}
return m, nil
@ -135,12 +131,6 @@ func newAESGCM(key []byte) cipher.AEAD {
return aead
}
func newChacha20(key []byte) cipher.Stream {
_nonce := make([]byte, chacha20.NonceSize)
stream, _ := chacha20.NewUnauthenticatedCipher(key, common.Dup(_nonce))
return stream
}
func newChacha20Poly1305(key []byte) cipher.AEAD {
cipher, err := chacha20poly1305.New(key)
common.Must(err)
@ -154,18 +144,17 @@ func newXChacha20Poly1305(key []byte) cipher.AEAD {
}
type Method struct {
name string
keyLength int
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
streamConstructor func(key []byte) cipher.Stream
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
psk []byte
pskList [][]byte
pskHash []byte
secureRNG io.Reader
replayFilter replay.Filter
name string
keyLength int
constructor func(key []byte) cipher.AEAD
blockConstructor func(key []byte) cipher.Block
udpCipher cipher.AEAD
udpBlockCipher cipher.Block
psk []byte
pskList [][]byte
pskHash []byte
secureRNG io.Reader
replayFilter replay.Filter
}
func (m *Method) Name() string {
@ -176,30 +165,6 @@ func (m *Method) KeyLength() int {
return m.keyLength
}
func (m *Method) WriteExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
pskLen := len(m.pskList)
if pskLen < 2 {
return
}
for i, psk := range m.pskList {
keyMaterial := make([]byte, 2*KeySaltSize)
copy(keyMaterial, psk)
copy(keyMaterial[KeySaltSize:], salt)
_identitySubkey := buf.Make(m.keyLength)
identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
if m.blockConstructor != nil {
m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash)
} else {
m.streamConstructor(identitySubkey).XORKeyStream(request.Extend(16), pskHash)
}
if i == pskLen-2 {
break
}
}
}
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
shadowsocksConn := &clientConn{
Conn: conn,
@ -236,18 +201,38 @@ type clientConn struct {
writer io.Writer
}
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
pskLen := len(m.pskList)
if pskLen < 2 {
return
}
for i, psk := range m.pskList {
keyMaterial := make([]byte, 2*KeySaltSize)
copy(keyMaterial, psk)
copy(keyMaterial[KeySaltSize:], salt)
_identitySubkey := buf.Make(m.keyLength)
identitySubkey := common.Dup(_identitySubkey)
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash)
if i == pskLen-2 {
break
}
}
}
func (c *clientConn) writeRequest(payload []byte) error {
request := buf.New()
defer request.Release()
_request := buf.StackNew()
request := common.Dup(_request)
salt := make([]byte, KeySaltSize)
common.Must1(io.ReadFull(c.method.secureRNG, salt))
common.Must1(request.Write(salt))
c.method.WriteExtendedIdentityHeaders(request, salt)
c.method.writeExtendedIdentityHeaders(request, salt)
var writer io.Writer = c.Conn
var writer io.Writer
writer = &buf.BufferedWriter{
Writer: writer,
Writer: c.Conn,
Buffer: request,
}
@ -258,8 +243,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
MaxPacketSize,
)
header := buf.New()
defer header.Release()
_header := buf.StackNew()
header := common.Dup(_header)
writer = &buf.BufferedWriter{
Writer: writer,
@ -362,6 +347,7 @@ func (c *clientConn) readResponse() error {
return ErrBadRequestSalt
}
c.requestSalt = nil
c.reader = reader
return nil
}
@ -417,23 +403,14 @@ type clientPacketConn struct {
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
defer buffer.Release()
header := buf.New()
_header := buf.StackNew()
header := common.Dup(_header)
pskLen := len(c.method.pskList)
var dataIndex int
if c.method.udpCipher != nil {
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
if pskLen > 1 {
for i, psk := range c.method.pskList {
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
identityHeader := header.Extend(aes.BlockSize)
for textI := 0; textI < aes.BlockSize; textI++ {
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
}
c.method.streamConstructor(psk).XORKeyStream(identityHeader, identityHeader)
if i == pskLen-2 {
break
}
}
panic("unsupported chacha extended header")
}
dataIndex = buffer.Len()
} else {