diff --git a/cli/cloudflare-ddns/main.go b/cli/cloudflare-ddns/main.go index 3cb16c0..de9394d 100644 --- a/cli/cloudflare-ddns/main.go +++ b/cli/cloudflare-ddns/main.go @@ -141,6 +141,7 @@ func checkUpdate() { Name: domain, Content: content, Proxied: &overProxy, + TTL: 60, } if addr.Is4() { record.Type = "A" diff --git a/cli/ss-local/main.go b/cli/ss-local/main.go index c10e8e3..880d790 100644 --- a/cli/ss-local/main.go +++ b/cli/ss-local/main.go @@ -192,7 +192,11 @@ func NewLocalClient(f *flags) (*LocalClient, error) { if len(pskList) > 1 { return nil, shadowaead.ErrBadKey } - method, err := shadowaead.New(f.Method, pskList[0], []byte(f.Password), rng, false) + var key []byte + if len(pskList) > 0 { + key = pskList[0] + } + method, err := shadowaead.New(f.Method, key, []byte(f.Password), rng, false) if err != nil { return nil, err } @@ -314,27 +318,21 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error { logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination) ctx := context.Background() - var serverConn net.Conn - payload := buf.New() - err := task.Run(ctx, func() error { - sc, err := c.dialer.DialContext(ctx, "tcp", c.server.String()) - serverConn = sc - if err != nil { - return E.Cause(err, "connect to server") - } - return nil - }, func() error { - err := conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) - if err != nil { - return err - } - _, err = payload.ReadFrom(conn) - if err != nil && !E.IsTimeout(err) { - return E.Cause(err, "read payload") - } - err = conn.SetReadDeadline(time.Time{}) + serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String()) + if err != nil { + return E.Cause(err, "connect to server") + } + _payload := buf.StackNew() + payload := common.Dup(_payload) + err = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + if err != nil { return err - }) + } + _, err = payload.ReadFrom(conn) + if err != nil && !E.IsTimeout(err) { + return E.Cause(err, "read payload") + } + err = conn.SetReadDeadline(time.Time{}) if err != nil { payload.Release() return err diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 493a0be..05b7f83 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -25,31 +25,15 @@ func New() *Buffer { } } -func NewSize(size int) *Buffer { - if size <= 128 || size > BufferSize { - return &Buffer{ - data: make([]byte, size), - } - } +func StackNew() *Buffer { return &Buffer{ - data: GetBytes(), - start: ReversedHeader, - end: ReversedHeader, - managed: true, + data: make([]byte, BufferSize), } } -func FullNew() *Buffer { +func StackNewSize(size int) *Buffer { return &Buffer{ - data: GetBytes(), - managed: true, - } -} - -func StackNew() Buffer { - return Buffer{ - data: GetBytes(), - managed: true, + data: Make(size), } } @@ -71,20 +55,6 @@ func As(data []byte) *Buffer { } } -func Or(data []byte, size int) *Buffer { - max := cap(data) - if size != max { - data = data[:max] - } - if cap(data) >= size { - return &Buffer{ - data: data, - } - } else { - return NewSize(size) - } -} - func With(data []byte) *Buffer { return &Buffer{ data: data, diff --git a/common/buf/pool.go b/common/buf/pool.go index 37b9d7b..bab31fa 100644 --- a/common/buf/pool.go +++ b/common/buf/pool.go @@ -10,29 +10,41 @@ const ( var pool = sync.Pool{ New: func() any { - var buffer [BufferSize]byte - return buffer[:] + buffer := make([]byte, BufferSize) + return &buffer }, } func GetBytes() []byte { - return pool.Get().([]byte) + return *pool.Get().(*[]byte) } func PutBytes(buffer []byte) { - pool.Put(buffer) + pool.Put(&buffer) } func Make(size int) []byte { var buffer []byte - if size <= 64 { + if size <= 16 { + buffer = make([]byte, 16) + } else if size <= 32 { + buffer = make([]byte, 32) + } else if size <= 64 { buffer = make([]byte, 64) + } else if size <= 128 { + buffer = make([]byte, 128) + } else if size <= 256 { + buffer = make([]byte, 256) + } else if size <= 512 { + buffer = make([]byte, 512) } else if size <= 1024 { buffer = make([]byte, 1024) - } else if size <= 4096 { - buffer = make([]byte, 4096) - } else if size <= 16384 { - buffer = make([]byte, 16384) + } else if size <= 4*1024 { + buffer = make([]byte, 4*1024) + } else if size <= 16*1024 { + buffer = make([]byte, 16*1024) + } else if size <= 20*1024 { + buffer = make([]byte, 20*1024) } else if size <= 65535 { buffer = make([]byte, 65535) } else { diff --git a/common/lowmem/free.go b/common/lowmem/free.go index c7683dd..6aa8fd7 100644 --- a/common/lowmem/free.go +++ b/common/lowmem/free.go @@ -1,13 +1,13 @@ package lowmem import ( - "runtime" + "runtime/debug" ) var Enabled = false func Free() { if Enabled { - runtime.GC() + debug.FreeOSMemory() } } diff --git a/common/rw/copy.go b/common/rw/copy.go index ed47675..bcb2b48 100644 --- a/common/rw/copy.go +++ b/common/rw/copy.go @@ -57,8 +57,8 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error { return task.Run(ctx, func() error { - buffer := buf.FullNew() - defer buffer.Release() + _buffer := buf.With(make([]byte, buf.UDPBufferSize)) + buffer := common.Dup(_buffer) for { n, addr, err := conn.ReadFrom(buffer.FreeBytes()) if err != nil { @@ -72,8 +72,8 @@ func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net. buffer.FullReset() } }, func() error { - buffer := buf.FullNew() - defer buffer.Release() + _buffer := buf.With(make([]byte, buf.UDPBufferSize)) + buffer := common.Dup(_buffer) for { n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes()) if err != nil { diff --git a/common/rw/writev_posix.go b/common/rw/writev_posix.go new file mode 100644 index 0000000..cc7a6e8 --- /dev/null +++ b/common/rw/writev_posix.go @@ -0,0 +1,11 @@ +//go:build !windows + +package rw + +import ( + "golang.org/x/sys/unix" +) + +func WriteV(fd uintptr, data ...[]byte) (int, error) { + return unix.Writev(int(fd), data) +} diff --git a/common/rw/writev_windows.go b/common/rw/writev_windows.go new file mode 100644 index 0000000..d053450 --- /dev/null +++ b/common/rw/writev_windows.go @@ -0,0 +1,16 @@ +package rw + +import "golang.org/x/sys/windows" + +func WriteV(fd uintptr, data ...[]byte) (int, error) { + var n uint32 + buffers := make([]*windows.WSABuf, len(data)) + for i, buf := range data { + buffers[i] = &windows.WSABuf{ + Len: uint32(len(buf)), + Buf: &buf[0], + } + } + err := windows.WSASend(windows.Handle(fd), buffers[0], uint32(len(buffers)), &n, 0, nil, nil) + return int(n), err +} diff --git a/common/uot/server.go b/common/uot/server.go index f9b9b46..3eb815f 100644 --- a/common/uot/server.go +++ b/common/uot/server.go @@ -40,8 +40,8 @@ func (c *ServerConn) RemoteAddr() net.Addr { } func (c *ServerConn) loopInput() { - buffer := buf.New() - defer buffer.Release() + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) for { destination, err := AddrParser.ReadAddrPort(c.inputReader) if err != nil { @@ -73,8 +73,8 @@ func (c *ServerConn) loopInput() { } func (c *ServerConn) loopOutput() { - buffer := buf.New() - defer buffer.Release() + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) for { buffer.FullReset() n, addr, err := buffer.ReadPacketFrom(c) diff --git a/common/uot/uot_test.go b/common/uot/uot_test.go index f5c0f89..389862f 100644 --- a/common/uot/uot_test.go +++ b/common/uot/uot_test.go @@ -29,8 +29,8 @@ func TestServerConn(t *testing.T) { IP: net.IPv4(8, 8, 8, 8), Port: 53, })) - buffer := buf.New() - defer buffer.Release() + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) common.Must2(buffer.ReadPacketFrom(clientConn)) common.Must(message.Unpack(buffer.Bytes())) for _, answer := range message.Answers { diff --git a/protocol/shadowsocks/none.go b/protocol/shadowsocks/none.go index e44d48a..6c543f3 100644 --- a/protocol/shadowsocks/none.go +++ b/protocol/shadowsocks/none.go @@ -81,8 +81,8 @@ func (c *noneConn) Write(b []byte) (n int, err error) { return 0, c.clientHandshake() } - buffer := buf.New() - defer buffer.Release() + _buffer := buf.StackNew() + buffer := common.Dup(_buffer) err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination) if err != nil { @@ -138,7 +138,8 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error { defer buffer.Release() - header := buf.New() + _header := buf.StackNew() + header := common.Dup(_header) err := socks.AddressSerializer.WriteAddrPort(header, addrPort) if err != nil { header.Release() diff --git a/protocol/shadowsocks/shadowaead/protocol.go b/protocol/shadowsocks/shadowaead/protocol.go index 6212b33..64e3790 100644 --- a/protocol/shadowsocks/shadowaead/protocol.go +++ b/protocol/shadowsocks/shadowaead/protocol.go @@ -79,9 +79,9 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader, replay } func Kdf(key, iv []byte, keyLength int) []byte { - subKey := make([]byte, keyLength) + subKey := buf.Make(keyLength) kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey")) - common.Must1(io.ReadFull(kdf, subKey)) + common.Must1(io.ReadFull(kdf, common.Dup(subKey))) return subKey } @@ -111,8 +111,8 @@ func (m *Method) KeyLength() int { } func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) { - saltBuffer := buf.Make(m.keySaltLength) - salt := common.Dup(saltBuffer) + _salt := buf.Make(m.keySaltLength) + salt := common.Dup(_salt) _, err := io.ReadFull(upstream, salt) if err != nil { return nil, E.Cause(err, "read salt") @@ -122,18 +122,20 @@ func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) { return nil, E.New("salt not unique") } } - return NewReader(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil + key := Kdf(m.key, salt, m.keySaltLength) + return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil } func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) { - saltBuffer := buf.Make(m.keySaltLength) - salt := common.Dup(saltBuffer) + _salt := buf.Make(m.keySaltLength) + salt := common.Dup(_salt) common.Must1(io.ReadFull(m.secureRNG, salt)) _, err := upstream.Write(salt) if err != nil { return nil, err } - return NewWriter(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil + key := Kdf(m.key, salt, m.keySaltLength) + return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil } func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) { @@ -154,11 +156,12 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn } func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn { - return &aeadPacketConn{conn, m} + return &clientPacketConn{conn, m} } func (m *Method) EncodePacket(buffer *buf.Buffer) error { - c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)) + key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) + c := m.constructor(common.Dup(key)) c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) buffer.Extend(c.Overhead()) return nil @@ -168,7 +171,8 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error { if buffer.Len() < m.keySaltLength { return E.New("bad packet") } - c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)) + key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength) + c := m.constructor(common.Dup(key)) packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) if err != nil { return err @@ -190,8 +194,8 @@ type clientConn struct { } func (c *clientConn) writeRequest(payload []byte) error { - request := buf.New() - defer request.Release() + _request := buf.StackNew() + request := common.Dup(_request) common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength)) @@ -207,8 +211,8 @@ func (c *clientConn) writeRequest(payload []byte) error { ) if len(payload) > 0 { - header := buf.New() - defer header.Release() + _header := buf.StackNew() + header := common.Dup(_header) writer = &buf.BufferedWriter{ Writer: writer, @@ -240,23 +244,26 @@ func (c *clientConn) writeRequest(payload []byte) error { } func (c *clientConn) readResponse() error { - if c.reader == nil { - salt := make([]byte, c.method.keySaltLength) - _, err := io.ReadFull(c.Conn, salt) - if err != nil { - return err - } - if c.method.replayFilter != nil { - if !c.method.replayFilter.Check(salt) { - return E.New("salt not unique") - } - } - c.reader = NewReader( - c.Conn, - c.method.constructor(Kdf(c.method.key, salt, c.method.keySaltLength)), - MaxPacketSize, - ) + if c.reader != nil { + return nil } + _salt := buf.Make(c.method.keySaltLength) + salt := common.Dup(_salt) + _, err := io.ReadFull(c.Conn, salt) + if err != nil { + return err + } + if c.method.replayFilter != nil { + if !c.method.replayFilter.Check(salt) { + return E.New("salt not unique") + } + } + key := Kdf(c.method.key, salt, c.method.keySaltLength) + c.reader = NewReader( + c.Conn, + c.method.constructor(common.Dup(key)), + MaxPacketSize, + ) return nil } @@ -300,14 +307,14 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) { return c.writer.(io.ReaderFrom).ReadFrom(r) } -type aeadPacketConn struct { +type clientPacketConn struct { net.Conn method *Method } -func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { - defer buffer.Release() - header := buf.New() +func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { + _header := buf.StackNew() + header := common.Dup(_header) common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength)) err := socks.AddressSerializer.WriteAddrPort(header, destination) if err != nil { @@ -321,7 +328,7 @@ func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort return common.Error(c.Write(buffer.Bytes())) } -func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { +func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) { n, err := c.Read(buffer.FreeBytes()) if err != nil { return nil, err diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index e11d727..ec85d42 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -22,21 +22,18 @@ import ( "github.com/sagernet/sing/protocol/shadowsocks" "github.com/sagernet/sing/protocol/shadowsocks/shadowaead" "github.com/sagernet/sing/protocol/socks" - "golang.org/x/crypto/chacha20" "golang.org/x/crypto/chacha20poly1305" wgReplay "golang.zx2c4.com/wireguard/replay" "lukechampine.com/blake3" ) const ( - HeaderTypeClient = 0 - HeaderTypeServer = 1 - MaxPaddingLength = 900 - KeySaltSize = 32 - PacketNonceSize = 24 - MinRequestHeaderSize = 1 + 8 - MinResponseHeaderSize = MinRequestHeaderSize + KeySaltSize - MaxPacketSize = 65535 + shadowaead.PacketLengthBufferSize + nonceSize*2 + HeaderTypeClient = 0 + HeaderTypeServer = 1 + MaxPaddingLength = 900 + KeySaltSize = 32 + PacketNonceSize = 24 + MaxPacketSize = 65535 ) const ( @@ -106,7 +103,6 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth case "2022-blake3-chacha20-poly1305": m.keyLength = 32 m.constructor = newChacha20Poly1305 - m.streamConstructor = newChacha20 m.udpCipher = newXChacha20Poly1305(m.psk) } return m, nil @@ -135,12 +131,6 @@ func newAESGCM(key []byte) cipher.AEAD { return aead } -func newChacha20(key []byte) cipher.Stream { - _nonce := make([]byte, chacha20.NonceSize) - stream, _ := chacha20.NewUnauthenticatedCipher(key, common.Dup(_nonce)) - return stream -} - func newChacha20Poly1305(key []byte) cipher.AEAD { cipher, err := chacha20poly1305.New(key) common.Must(err) @@ -154,18 +144,17 @@ func newXChacha20Poly1305(key []byte) cipher.AEAD { } type Method struct { - name string - keyLength int - constructor func(key []byte) cipher.AEAD - blockConstructor func(key []byte) cipher.Block - streamConstructor func(key []byte) cipher.Stream - udpCipher cipher.AEAD - udpBlockCipher cipher.Block - psk []byte - pskList [][]byte - pskHash []byte - secureRNG io.Reader - replayFilter replay.Filter + name string + keyLength int + constructor func(key []byte) cipher.AEAD + blockConstructor func(key []byte) cipher.Block + udpCipher cipher.AEAD + udpBlockCipher cipher.Block + psk []byte + pskList [][]byte + pskHash []byte + secureRNG io.Reader + replayFilter replay.Filter } func (m *Method) Name() string { @@ -176,30 +165,6 @@ func (m *Method) KeyLength() int { return m.keyLength } -func (m *Method) WriteExtendedIdentityHeaders(request *buf.Buffer, salt []byte) { - pskLen := len(m.pskList) - if pskLen < 2 { - return - } - for i, psk := range m.pskList { - keyMaterial := make([]byte, 2*KeySaltSize) - copy(keyMaterial, psk) - copy(keyMaterial[KeySaltSize:], salt) - _identitySubkey := buf.Make(m.keyLength) - identitySubkey := common.Dup(_identitySubkey) - blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) - pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] - if m.blockConstructor != nil { - m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash) - } else { - m.streamConstructor(identitySubkey).XORKeyStream(request.Extend(16), pskHash) - } - if i == pskLen-2 { - break - } - } -} - func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) { shadowsocksConn := &clientConn{ Conn: conn, @@ -236,18 +201,38 @@ type clientConn struct { writer io.Writer } +func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) { + pskLen := len(m.pskList) + if pskLen < 2 { + return + } + for i, psk := range m.pskList { + keyMaterial := make([]byte, 2*KeySaltSize) + copy(keyMaterial, psk) + copy(keyMaterial[KeySaltSize:], salt) + _identitySubkey := buf.Make(m.keyLength) + identitySubkey := common.Dup(_identitySubkey) + blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) + pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] + m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash) + if i == pskLen-2 { + break + } + } +} + func (c *clientConn) writeRequest(payload []byte) error { - request := buf.New() - defer request.Release() + _request := buf.StackNew() + request := common.Dup(_request) salt := make([]byte, KeySaltSize) common.Must1(io.ReadFull(c.method.secureRNG, salt)) common.Must1(request.Write(salt)) - c.method.WriteExtendedIdentityHeaders(request, salt) + c.method.writeExtendedIdentityHeaders(request, salt) - var writer io.Writer = c.Conn + var writer io.Writer writer = &buf.BufferedWriter{ - Writer: writer, + Writer: c.Conn, Buffer: request, } @@ -258,8 +243,8 @@ func (c *clientConn) writeRequest(payload []byte) error { MaxPacketSize, ) - header := buf.New() - defer header.Release() + _header := buf.StackNew() + header := common.Dup(_header) writer = &buf.BufferedWriter{ Writer: writer, @@ -362,6 +347,7 @@ func (c *clientConn) readResponse() error { return ErrBadRequestSalt } + c.requestSalt = nil c.reader = reader return nil } @@ -417,23 +403,14 @@ type clientPacketConn struct { func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error { defer buffer.Release() - header := buf.New() + _header := buf.StackNew() + header := common.Dup(_header) pskLen := len(c.method.pskList) var dataIndex int if c.method.udpCipher != nil { common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize)) if pskLen > 1 { - for i, psk := range c.method.pskList { - pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] - identityHeader := header.Extend(aes.BlockSize) - for textI := 0; textI < aes.BlockSize; textI++ { - identityHeader[textI] = pskHash[textI] ^ header.Byte(textI) - } - c.method.streamConstructor(psk).XORKeyStream(identityHeader, identityHeader) - if i == pskLen-2 { - break - } - } + panic("unsupported chacha extended header") } dataIndex = buffer.Len() } else {