mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-04 04:17:38 +03:00
Refine buffer
This commit is contained in:
parent
603c62165e
commit
63ef20617a
13 changed files with 179 additions and 186 deletions
|
@ -141,6 +141,7 @@ func checkUpdate() {
|
|||
Name: domain,
|
||||
Content: content,
|
||||
Proxied: &overProxy,
|
||||
TTL: 60,
|
||||
}
|
||||
if addr.Is4() {
|
||||
record.Type = "A"
|
||||
|
|
|
@ -192,7 +192,11 @@ func NewLocalClient(f *flags) (*LocalClient, error) {
|
|||
if len(pskList) > 1 {
|
||||
return nil, shadowaead.ErrBadKey
|
||||
}
|
||||
method, err := shadowaead.New(f.Method, pskList[0], []byte(f.Password), rng, false)
|
||||
var key []byte
|
||||
if len(pskList) > 0 {
|
||||
key = pskList[0]
|
||||
}
|
||||
method, err := shadowaead.New(f.Method, key, []byte(f.Password), rng, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -314,27 +318,21 @@ func (c *LocalClient) NewConnection(conn net.Conn, metadata M.Metadata) error {
|
|||
logrus.Info("CONNECT ", conn.RemoteAddr(), " ==> ", metadata.Destination)
|
||||
ctx := context.Background()
|
||||
|
||||
var serverConn net.Conn
|
||||
payload := buf.New()
|
||||
err := task.Run(ctx, func() error {
|
||||
sc, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
|
||||
serverConn = sc
|
||||
if err != nil {
|
||||
return E.Cause(err, "connect to server")
|
||||
}
|
||||
return nil
|
||||
}, func() error {
|
||||
err := conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = payload.ReadFrom(conn)
|
||||
if err != nil && !E.IsTimeout(err) {
|
||||
return E.Cause(err, "read payload")
|
||||
}
|
||||
err = conn.SetReadDeadline(time.Time{})
|
||||
serverConn, err := c.dialer.DialContext(ctx, "tcp", c.server.String())
|
||||
if err != nil {
|
||||
return E.Cause(err, "connect to server")
|
||||
}
|
||||
_payload := buf.StackNew()
|
||||
payload := common.Dup(_payload)
|
||||
err = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
if err != nil {
|
||||
return err
|
||||
})
|
||||
}
|
||||
_, err = payload.ReadFrom(conn)
|
||||
if err != nil && !E.IsTimeout(err) {
|
||||
return E.Cause(err, "read payload")
|
||||
}
|
||||
err = conn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
payload.Release()
|
||||
return err
|
||||
|
|
|
@ -25,31 +25,15 @@ func New() *Buffer {
|
|||
}
|
||||
}
|
||||
|
||||
func NewSize(size int) *Buffer {
|
||||
if size <= 128 || size > BufferSize {
|
||||
return &Buffer{
|
||||
data: make([]byte, size),
|
||||
}
|
||||
}
|
||||
func StackNew() *Buffer {
|
||||
return &Buffer{
|
||||
data: GetBytes(),
|
||||
start: ReversedHeader,
|
||||
end: ReversedHeader,
|
||||
managed: true,
|
||||
data: make([]byte, BufferSize),
|
||||
}
|
||||
}
|
||||
|
||||
func FullNew() *Buffer {
|
||||
func StackNewSize(size int) *Buffer {
|
||||
return &Buffer{
|
||||
data: GetBytes(),
|
||||
managed: true,
|
||||
}
|
||||
}
|
||||
|
||||
func StackNew() Buffer {
|
||||
return Buffer{
|
||||
data: GetBytes(),
|
||||
managed: true,
|
||||
data: Make(size),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,20 +55,6 @@ func As(data []byte) *Buffer {
|
|||
}
|
||||
}
|
||||
|
||||
func Or(data []byte, size int) *Buffer {
|
||||
max := cap(data)
|
||||
if size != max {
|
||||
data = data[:max]
|
||||
}
|
||||
if cap(data) >= size {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
}
|
||||
} else {
|
||||
return NewSize(size)
|
||||
}
|
||||
}
|
||||
|
||||
func With(data []byte) *Buffer {
|
||||
return &Buffer{
|
||||
data: data,
|
||||
|
|
|
@ -10,29 +10,41 @@ const (
|
|||
|
||||
var pool = sync.Pool{
|
||||
New: func() any {
|
||||
var buffer [BufferSize]byte
|
||||
return buffer[:]
|
||||
buffer := make([]byte, BufferSize)
|
||||
return &buffer
|
||||
},
|
||||
}
|
||||
|
||||
func GetBytes() []byte {
|
||||
return pool.Get().([]byte)
|
||||
return *pool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
func PutBytes(buffer []byte) {
|
||||
pool.Put(buffer)
|
||||
pool.Put(&buffer)
|
||||
}
|
||||
|
||||
func Make(size int) []byte {
|
||||
var buffer []byte
|
||||
if size <= 64 {
|
||||
if size <= 16 {
|
||||
buffer = make([]byte, 16)
|
||||
} else if size <= 32 {
|
||||
buffer = make([]byte, 32)
|
||||
} else if size <= 64 {
|
||||
buffer = make([]byte, 64)
|
||||
} else if size <= 128 {
|
||||
buffer = make([]byte, 128)
|
||||
} else if size <= 256 {
|
||||
buffer = make([]byte, 256)
|
||||
} else if size <= 512 {
|
||||
buffer = make([]byte, 512)
|
||||
} else if size <= 1024 {
|
||||
buffer = make([]byte, 1024)
|
||||
} else if size <= 4096 {
|
||||
buffer = make([]byte, 4096)
|
||||
} else if size <= 16384 {
|
||||
buffer = make([]byte, 16384)
|
||||
} else if size <= 4*1024 {
|
||||
buffer = make([]byte, 4*1024)
|
||||
} else if size <= 16*1024 {
|
||||
buffer = make([]byte, 16*1024)
|
||||
} else if size <= 20*1024 {
|
||||
buffer = make([]byte, 20*1024)
|
||||
} else if size <= 65535 {
|
||||
buffer = make([]byte, 65535)
|
||||
} else {
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package lowmem
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
var Enabled = false
|
||||
|
||||
func Free() {
|
||||
if Enabled {
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,8 +57,8 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
|||
|
||||
func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.PacketConn) error {
|
||||
return task.Run(ctx, func() error {
|
||||
buffer := buf.FullNew()
|
||||
defer buffer.Release()
|
||||
_buffer := buf.With(make([]byte, buf.UDPBufferSize))
|
||||
buffer := common.Dup(_buffer)
|
||||
for {
|
||||
n, addr, err := conn.ReadFrom(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
|
@ -72,8 +72,8 @@ func CopyPacketConn(ctx context.Context, conn net.PacketConn, outPacketConn net.
|
|||
buffer.FullReset()
|
||||
}
|
||||
}, func() error {
|
||||
buffer := buf.FullNew()
|
||||
defer buffer.Release()
|
||||
_buffer := buf.With(make([]byte, buf.UDPBufferSize))
|
||||
buffer := common.Dup(_buffer)
|
||||
for {
|
||||
n, addr, err := outPacketConn.ReadFrom(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
|
|
11
common/rw/writev_posix.go
Normal file
11
common/rw/writev_posix.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
//go:build !windows
|
||||
|
||||
package rw
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func WriteV(fd uintptr, data ...[]byte) (int, error) {
|
||||
return unix.Writev(int(fd), data)
|
||||
}
|
16
common/rw/writev_windows.go
Normal file
16
common/rw/writev_windows.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package rw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
func WriteV(fd uintptr, data ...[]byte) (int, error) {
|
||||
var n uint32
|
||||
buffers := make([]*windows.WSABuf, len(data))
|
||||
for i, buf := range data {
|
||||
buffers[i] = &windows.WSABuf{
|
||||
Len: uint32(len(buf)),
|
||||
Buf: &buf[0],
|
||||
}
|
||||
}
|
||||
err := windows.WSASend(windows.Handle(fd), buffers[0], uint32(len(buffers)), &n, 0, nil, nil)
|
||||
return int(n), err
|
||||
}
|
|
@ -40,8 +40,8 @@ func (c *ServerConn) RemoteAddr() net.Addr {
|
|||
}
|
||||
|
||||
func (c *ServerConn) loopInput() {
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
for {
|
||||
destination, err := AddrParser.ReadAddrPort(c.inputReader)
|
||||
if err != nil {
|
||||
|
@ -73,8 +73,8 @@ func (c *ServerConn) loopInput() {
|
|||
}
|
||||
|
||||
func (c *ServerConn) loopOutput() {
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
for {
|
||||
buffer.FullReset()
|
||||
n, addr, err := buffer.ReadPacketFrom(c)
|
||||
|
|
|
@ -29,8 +29,8 @@ func TestServerConn(t *testing.T) {
|
|||
IP: net.IPv4(8, 8, 8, 8),
|
||||
Port: 53,
|
||||
}))
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
common.Must2(buffer.ReadPacketFrom(clientConn))
|
||||
common.Must(message.Unpack(buffer.Bytes()))
|
||||
for _, answer := range message.Answers {
|
||||
|
|
|
@ -81,8 +81,8 @@ func (c *noneConn) Write(b []byte) (n int, err error) {
|
|||
return 0, c.clientHandshake()
|
||||
}
|
||||
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
_buffer := buf.StackNew()
|
||||
buffer := common.Dup(_buffer)
|
||||
|
||||
err = socks.AddressSerializer.WriteAddrPort(buffer, c.destination)
|
||||
if err != nil {
|
||||
|
@ -138,7 +138,8 @@ func (c *nonePacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
|||
|
||||
func (c *nonePacketConn) WritePacket(buffer *buf.Buffer, addrPort *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, addrPort)
|
||||
if err != nil {
|
||||
header.Release()
|
||||
|
|
|
@ -79,9 +79,9 @@ func New(method string, key []byte, password []byte, secureRNG io.Reader, replay
|
|||
}
|
||||
|
||||
func Kdf(key, iv []byte, keyLength int) []byte {
|
||||
subKey := make([]byte, keyLength)
|
||||
subKey := buf.Make(keyLength)
|
||||
kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
|
||||
common.Must1(io.ReadFull(kdf, subKey))
|
||||
common.Must1(io.ReadFull(kdf, common.Dup(subKey)))
|
||||
return subKey
|
||||
}
|
||||
|
||||
|
@ -111,8 +111,8 @@ func (m *Method) KeyLength() int {
|
|||
}
|
||||
|
||||
func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
|
||||
saltBuffer := buf.Make(m.keySaltLength)
|
||||
salt := common.Dup(saltBuffer)
|
||||
_salt := buf.Make(m.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
_, err := io.ReadFull(upstream, salt)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read salt")
|
||||
|
@ -122,18 +122,20 @@ func (m *Method) ReadRequest(upstream io.Reader) (io.Reader, error) {
|
|||
return nil, E.New("salt not unique")
|
||||
}
|
||||
}
|
||||
return NewReader(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil
|
||||
key := Kdf(m.key, salt, m.keySaltLength)
|
||||
return NewReader(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
|
||||
}
|
||||
|
||||
func (m *Method) WriteResponse(upstream io.Writer) (io.Writer, error) {
|
||||
saltBuffer := buf.Make(m.keySaltLength)
|
||||
salt := common.Dup(saltBuffer)
|
||||
_salt := buf.Make(m.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
common.Must1(io.ReadFull(m.secureRNG, salt))
|
||||
_, err := upstream.Write(salt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewWriter(upstream, m.constructor(Kdf(m.key, salt, m.keySaltLength)), MaxPacketSize), nil
|
||||
key := Kdf(m.key, salt, m.keySaltLength)
|
||||
return NewWriter(upstream, m.constructor(common.Dup(key)), MaxPacketSize), nil
|
||||
}
|
||||
|
||||
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
|
@ -154,11 +156,12 @@ func (m *Method) DialEarlyConn(conn net.Conn, destination *M.AddrPort) net.Conn
|
|||
}
|
||||
|
||||
func (m *Method) DialPacketConn(conn net.Conn) socks.PacketConn {
|
||||
return &aeadPacketConn{conn, m}
|
||||
return &clientPacketConn{conn, m}
|
||||
}
|
||||
|
||||
func (m *Method) EncodePacket(buffer *buf.Buffer) error {
|
||||
c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength))
|
||||
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||
c := m.constructor(common.Dup(key))
|
||||
c.Seal(buffer.From(m.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
buffer.Extend(c.Overhead())
|
||||
return nil
|
||||
|
@ -168,7 +171,8 @@ func (m *Method) DecodePacket(buffer *buf.Buffer) error {
|
|||
if buffer.Len() < m.keySaltLength {
|
||||
return E.New("bad packet")
|
||||
}
|
||||
c := m.constructor(Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength))
|
||||
key := Kdf(m.key, buffer.To(m.keySaltLength), m.keySaltLength)
|
||||
c := m.constructor(common.Dup(key))
|
||||
packet, err := c.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -190,8 +194,8 @@ type clientConn struct {
|
|||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
request := buf.New()
|
||||
defer request.Release()
|
||||
_request := buf.StackNew()
|
||||
request := common.Dup(_request)
|
||||
|
||||
common.Must1(request.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
|
||||
|
@ -207,8 +211,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
)
|
||||
|
||||
if len(payload) > 0 {
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
|
@ -240,23 +244,26 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
}
|
||||
|
||||
func (c *clientConn) readResponse() error {
|
||||
if c.reader == nil {
|
||||
salt := make([]byte, c.method.keySaltLength)
|
||||
_, err := io.ReadFull(c.Conn, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.method.replayFilter != nil {
|
||||
if !c.method.replayFilter.Check(salt) {
|
||||
return E.New("salt not unique")
|
||||
}
|
||||
}
|
||||
c.reader = NewReader(
|
||||
c.Conn,
|
||||
c.method.constructor(Kdf(c.method.key, salt, c.method.keySaltLength)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
if c.reader != nil {
|
||||
return nil
|
||||
}
|
||||
_salt := buf.Make(c.method.keySaltLength)
|
||||
salt := common.Dup(_salt)
|
||||
_, err := io.ReadFull(c.Conn, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.method.replayFilter != nil {
|
||||
if !c.method.replayFilter.Check(salt) {
|
||||
return E.New("salt not unique")
|
||||
}
|
||||
}
|
||||
key := Kdf(c.method.key, salt, c.method.keySaltLength)
|
||||
c.reader = NewReader(
|
||||
c.Conn,
|
||||
c.method.constructor(common.Dup(key)),
|
||||
MaxPacketSize,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -300,14 +307,14 @@ func (c *clientConn) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
return c.writer.(io.ReaderFrom).ReadFrom(r)
|
||||
}
|
||||
|
||||
type aeadPacketConn struct {
|
||||
type clientPacketConn struct {
|
||||
net.Conn
|
||||
method *Method
|
||||
}
|
||||
|
||||
func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, c.method.keySaltLength))
|
||||
err := socks.AddressSerializer.WriteAddrPort(header, destination)
|
||||
if err != nil {
|
||||
|
@ -321,7 +328,7 @@ func (c *aeadPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort
|
|||
return common.Error(c.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *aeadPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (*M.AddrPort, error) {
|
||||
n, err := c.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -22,21 +22,18 @@ import (
|
|||
"github.com/sagernet/sing/protocol/shadowsocks"
|
||||
"github.com/sagernet/sing/protocol/shadowsocks/shadowaead"
|
||||
"github.com/sagernet/sing/protocol/socks"
|
||||
"golang.org/x/crypto/chacha20"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
wgReplay "golang.zx2c4.com/wireguard/replay"
|
||||
"lukechampine.com/blake3"
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderTypeClient = 0
|
||||
HeaderTypeServer = 1
|
||||
MaxPaddingLength = 900
|
||||
KeySaltSize = 32
|
||||
PacketNonceSize = 24
|
||||
MinRequestHeaderSize = 1 + 8
|
||||
MinResponseHeaderSize = MinRequestHeaderSize + KeySaltSize
|
||||
MaxPacketSize = 65535 + shadowaead.PacketLengthBufferSize + nonceSize*2
|
||||
HeaderTypeClient = 0
|
||||
HeaderTypeServer = 1
|
||||
MaxPaddingLength = 900
|
||||
KeySaltSize = 32
|
||||
PacketNonceSize = 24
|
||||
MaxPacketSize = 65535
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -106,7 +103,6 @@ func New(method string, pskList [][]byte, secureRNG io.Reader) (shadowsocks.Meth
|
|||
case "2022-blake3-chacha20-poly1305":
|
||||
m.keyLength = 32
|
||||
m.constructor = newChacha20Poly1305
|
||||
m.streamConstructor = newChacha20
|
||||
m.udpCipher = newXChacha20Poly1305(m.psk)
|
||||
}
|
||||
return m, nil
|
||||
|
@ -135,12 +131,6 @@ func newAESGCM(key []byte) cipher.AEAD {
|
|||
return aead
|
||||
}
|
||||
|
||||
func newChacha20(key []byte) cipher.Stream {
|
||||
_nonce := make([]byte, chacha20.NonceSize)
|
||||
stream, _ := chacha20.NewUnauthenticatedCipher(key, common.Dup(_nonce))
|
||||
return stream
|
||||
}
|
||||
|
||||
func newChacha20Poly1305(key []byte) cipher.AEAD {
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
common.Must(err)
|
||||
|
@ -154,18 +144,17 @@ func newXChacha20Poly1305(key []byte) cipher.AEAD {
|
|||
}
|
||||
|
||||
type Method struct {
|
||||
name string
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
blockConstructor func(key []byte) cipher.Block
|
||||
streamConstructor func(key []byte) cipher.Stream
|
||||
udpCipher cipher.AEAD
|
||||
udpBlockCipher cipher.Block
|
||||
psk []byte
|
||||
pskList [][]byte
|
||||
pskHash []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
name string
|
||||
keyLength int
|
||||
constructor func(key []byte) cipher.AEAD
|
||||
blockConstructor func(key []byte) cipher.Block
|
||||
udpCipher cipher.AEAD
|
||||
udpBlockCipher cipher.Block
|
||||
psk []byte
|
||||
pskList [][]byte
|
||||
pskHash []byte
|
||||
secureRNG io.Reader
|
||||
replayFilter replay.Filter
|
||||
}
|
||||
|
||||
func (m *Method) Name() string {
|
||||
|
@ -176,30 +165,6 @@ func (m *Method) KeyLength() int {
|
|||
return m.keyLength
|
||||
}
|
||||
|
||||
func (m *Method) WriteExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
||||
pskLen := len(m.pskList)
|
||||
if pskLen < 2 {
|
||||
return
|
||||
}
|
||||
for i, psk := range m.pskList {
|
||||
keyMaterial := make([]byte, 2*KeySaltSize)
|
||||
copy(keyMaterial, psk)
|
||||
copy(keyMaterial[KeySaltSize:], salt)
|
||||
_identitySubkey := buf.Make(m.keyLength)
|
||||
identitySubkey := common.Dup(_identitySubkey)
|
||||
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
||||
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
||||
if m.blockConstructor != nil {
|
||||
m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash)
|
||||
} else {
|
||||
m.streamConstructor(identitySubkey).XORKeyStream(request.Extend(16), pskHash)
|
||||
}
|
||||
if i == pskLen-2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Method) DialConn(conn net.Conn, destination *M.AddrPort) (net.Conn, error) {
|
||||
shadowsocksConn := &clientConn{
|
||||
Conn: conn,
|
||||
|
@ -236,18 +201,38 @@ type clientConn struct {
|
|||
writer io.Writer
|
||||
}
|
||||
|
||||
func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) {
|
||||
pskLen := len(m.pskList)
|
||||
if pskLen < 2 {
|
||||
return
|
||||
}
|
||||
for i, psk := range m.pskList {
|
||||
keyMaterial := make([]byte, 2*KeySaltSize)
|
||||
copy(keyMaterial, psk)
|
||||
copy(keyMaterial[KeySaltSize:], salt)
|
||||
_identitySubkey := buf.Make(m.keyLength)
|
||||
identitySubkey := common.Dup(_identitySubkey)
|
||||
blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
|
||||
pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
||||
m.blockConstructor(identitySubkey).Encrypt(request.Extend(16), pskHash)
|
||||
if i == pskLen-2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientConn) writeRequest(payload []byte) error {
|
||||
request := buf.New()
|
||||
defer request.Release()
|
||||
_request := buf.StackNew()
|
||||
request := common.Dup(_request)
|
||||
|
||||
salt := make([]byte, KeySaltSize)
|
||||
common.Must1(io.ReadFull(c.method.secureRNG, salt))
|
||||
common.Must1(request.Write(salt))
|
||||
c.method.WriteExtendedIdentityHeaders(request, salt)
|
||||
c.method.writeExtendedIdentityHeaders(request, salt)
|
||||
|
||||
var writer io.Writer = c.Conn
|
||||
var writer io.Writer
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
Writer: c.Conn,
|
||||
Buffer: request,
|
||||
}
|
||||
|
||||
|
@ -258,8 +243,8 @@ func (c *clientConn) writeRequest(payload []byte) error {
|
|||
MaxPacketSize,
|
||||
)
|
||||
|
||||
header := buf.New()
|
||||
defer header.Release()
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
|
||||
writer = &buf.BufferedWriter{
|
||||
Writer: writer,
|
||||
|
@ -362,6 +347,7 @@ func (c *clientConn) readResponse() error {
|
|||
return ErrBadRequestSalt
|
||||
}
|
||||
|
||||
c.requestSalt = nil
|
||||
c.reader = reader
|
||||
return nil
|
||||
}
|
||||
|
@ -417,23 +403,14 @@ type clientPacketConn struct {
|
|||
|
||||
func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination *M.AddrPort) error {
|
||||
defer buffer.Release()
|
||||
header := buf.New()
|
||||
_header := buf.StackNew()
|
||||
header := common.Dup(_header)
|
||||
pskLen := len(c.method.pskList)
|
||||
var dataIndex int
|
||||
if c.method.udpCipher != nil {
|
||||
common.Must1(header.ReadFullFrom(c.method.secureRNG, PacketNonceSize))
|
||||
if pskLen > 1 {
|
||||
for i, psk := range c.method.pskList {
|
||||
pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
|
||||
identityHeader := header.Extend(aes.BlockSize)
|
||||
for textI := 0; textI < aes.BlockSize; textI++ {
|
||||
identityHeader[textI] = pskHash[textI] ^ header.Byte(textI)
|
||||
}
|
||||
c.method.streamConstructor(psk).XORKeyStream(identityHeader, identityHeader)
|
||||
if i == pskLen-2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
panic("unsupported chacha extended header")
|
||||
}
|
||||
dataIndex = buffer.Len()
|
||||
} else {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue