From 04cd76f0aab1be7ed7f8fff29853b880469815d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 15 May 2022 07:20:22 +0800 Subject: [PATCH] Shadowsocks 2022 protocol change --- common/buf/buffer.go | 7 +- common/cond.go | 7 +- protocol/shadowsocks/service.go | 7 ++ protocol/shadowsocks/shadowaead/aead.go | 97 ++++++++++++++--- protocol/shadowsocks/shadowaead/protocol.go | 2 +- protocol/shadowsocks/shadowaead/service.go | 24 +++-- .../shadowsocks/shadowaead_2022/protocol.go | 102 +++++++++--------- .../shadowsocks/shadowaead_2022/service.go | 72 +++++++++---- .../shadowaead_2022/service_multi.go | 38 +++++-- .../shadowaead_2022/service_multi_test.go | 4 +- .../shadowaead_2022/service_test.go | 49 +++++++++ 11 files changed, 294 insertions(+), 115 deletions(-) create mode 100644 protocol/shadowsocks/shadowaead_2022/service_test.go diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 5128d04..5091c0b 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -50,14 +50,9 @@ func From(data []byte) *Buffer { } func As(data []byte) *Buffer { - size := len(data) - max := cap(data) - if size != max { - data = data[:max] - } return &Buffer{ data: data, - end: size, + end: len(data), } } diff --git a/common/cond.go b/common/cond.go index 0dc6c2b..deec778 100644 --- a/common/cond.go +++ b/common/cond.go @@ -3,7 +3,6 @@ package common import ( "context" "io" - "log" "strings" "unsafe" ) @@ -117,20 +116,20 @@ func Error(_ any, err error) error { func Must(errs ...error) { for _, err := range errs { if err != nil { - log.Fatalln(err) + panic(err) } } } func Must1(_ any, err error) { if err != nil { - log.Fatalln(err) + panic(err) } } func Must2(_ any, _ any, err error) { if err != nil { - log.Fatalln(err) + panic(err) } } diff --git a/protocol/shadowsocks/service.go b/protocol/shadowsocks/service.go index 63d4076..20b3da2 100644 --- a/protocol/shadowsocks/service.go +++ b/protocol/shadowsocks/service.go @@ -32,6 +32,13 @@ type ServerConnError struct { Cause error } +func (e *ServerConnError) Close() error { + if tcpConn, ok := e.Conn.(*net.TCPConn); ok { + tcpConn.SetLinger(0) + } + return e.Conn.Close() +} + func (e *ServerConnError) Unwrap() error { return e.Cause } diff --git a/protocol/shadowsocks/shadowaead/aead.go b/protocol/shadowsocks/shadowaead/aead.go index 936c98a..eed70b9 100644 --- a/protocol/shadowsocks/shadowaead/aead.go +++ b/protocol/shadowsocks/shadowaead/aead.go @@ -14,6 +14,18 @@ const ( PacketLengthBufferSize = 2 ) +const ( + // NonceSize + // crypto/cipher.gcmStandardNonceSize + // golang.org/x/crypto/chacha20poly1305.NonceSize + NonceSize = 12 + + // Overhead + // crypto/cipher.gcmTagSize + // golang.org/x/crypto/chacha20poly1305.Overhead + Overhead = 16 +) + type Reader struct { upstream io.Reader cipher cipher.AEAD @@ -27,8 +39,8 @@ func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reade return &Reader{ upstream: upstream, cipher: cipher, - buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2), - nonce: make([]byte, cipher.NonceSize()), + buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2), + nonce: make([]byte, NonceSize), } } @@ -62,7 +74,7 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) { n += int64(writeN) } for { - start := PacketLengthBufferSize + r.cipher.Overhead() + start := PacketLengthBufferSize + Overhead _, err = io.ReadFull(r.upstream, r.buffer[:start]) if err != nil { return @@ -73,7 +85,7 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) { } increaseNonce(r.nonce) length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) - end := length + r.cipher.Overhead() + end := length + Overhead _, err = io.ReadFull(r.upstream, r.buffer[:end]) if err != nil { return @@ -92,7 +104,7 @@ func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) { } func (r *Reader) readInternal() (err error) { - start := PacketLengthBufferSize + r.cipher.Overhead() + start := PacketLengthBufferSize + Overhead _, err = io.ReadFull(r.upstream, r.buffer[:start]) if err != nil { return err @@ -103,7 +115,7 @@ func (r *Reader) readInternal() (err error) { } increaseNonce(r.nonce) length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) - end := length + r.cipher.Overhead() + end := length + Overhead _, err = io.ReadFull(r.upstream, r.buffer[:end]) if err != nil { return err @@ -138,7 +150,7 @@ func (r *Reader) Read(b []byte) (n int, err error) { r.index += n return } - start := PacketLengthBufferSize + r.cipher.Overhead() + start := PacketLengthBufferSize + Overhead _, err = io.ReadFull(r.upstream, r.buffer[:start]) if err != nil { return 0, err @@ -149,7 +161,7 @@ func (r *Reader) Read(b []byte) (n int, err error) { } increaseNonce(r.nonce) length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) - end := length + r.cipher.Overhead() + end := length + Overhead if len(b) >= end { data := b[:end] @@ -202,6 +214,59 @@ func (r *Reader) Cached() int { return r.cached } +func (r *Reader) CachedSlice() []byte { + return r.buffer[r.index : r.index+r.cached] +} + +func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error { + _, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) + end := length + Overhead + _, err = io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = length + r.index = 0 + return nil +} + +func (r *Reader) ReadWithLength(length uint16) error { + end := length + Overhead + _, err := io.ReadFull(r.upstream, r.buffer[:end]) + if err != nil { + return err + } + _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = int(length) + r.index = 0 + return nil +} + +func (r *Reader) ReadChunk(chunk []byte) error { + bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil) + if err != nil { + return err + } + increaseNonce(r.nonce) + r.cached = len(bb) + r.index = 0 + return nil +} + type Writer struct { upstream io.Writer cipher cipher.AEAD @@ -214,7 +279,7 @@ func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Write return &Writer{ upstream: upstream, cipher: cipher, - buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+cipher.Overhead()*2), + buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2), nonce: make([]byte, cipher.NonceSize()), maxPacketSize: maxPacketSize, } @@ -244,7 +309,7 @@ func (w *Writer) SetWriter(writer io.Writer) { func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) { for { - offset := w.cipher.Overhead() + PacketLengthBufferSize + offset := Overhead + PacketLengthBufferSize readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize]) if readErr != nil { return 0, readErr @@ -271,7 +336,7 @@ func (w *Writer) Write(p []byte) (n int, err error) { binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data))) w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) increaseNonce(w.nonce) - offset := w.cipher.Overhead() + PacketLengthBufferSize + offset := Overhead + PacketLengthBufferSize packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil) increaseNonce(w.nonce) _, err = w.upstream.Write(w.buffer[:offset+len(packet)]) @@ -288,11 +353,17 @@ func (w *Writer) Buffer() *buf.Buffer { return buf.With(w.buffer) } +func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) { + bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil) + buffer.Extend(len(bb)) + increaseNonce(w.nonce) +} + func (w *Writer) BufferedWriter(reversed int) *BufferedWriter { return &BufferedWriter{ upstream: w, reversed: reversed, - data: w.buffer[PacketLengthBufferSize+w.cipher.Overhead() : len(w.buffer)-w.cipher.Overhead()], + data: w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead], } } @@ -340,7 +411,7 @@ func (w *BufferedWriter) Flush() error { binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index)) w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil) increaseNonce(w.upstream.nonce) - offset := w.upstream.cipher.Overhead() + PacketLengthBufferSize + offset := Overhead + PacketLengthBufferSize packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil) increaseNonce(w.upstream.nonce) _, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)]) diff --git a/protocol/shadowsocks/shadowaead/protocol.go b/protocol/shadowsocks/shadowaead/protocol.go index 8458f2a..ea8696c 100644 --- a/protocol/shadowsocks/shadowaead/protocol.go +++ b/protocol/shadowsocks/shadowaead/protocol.go @@ -154,7 +154,7 @@ func (m *Method) EncodePacket(buffer *buf.Buffer) error { c := m.constructor(common.Dup(key)) runtime.KeepAlive(key) c.Seal(buffer.Index(m.keySaltLength), rw.ZeroBytes[:c.NonceSize()], buffer.From(m.keySaltLength), nil) - buffer.Extend(c.Overhead()) + buffer.Extend(Overhead) return nil } diff --git a/protocol/shadowsocks/shadowaead/service.go b/protocol/shadowsocks/shadowaead/service.go index 8db96b3..7cc254b 100644 --- a/protocol/shadowsocks/shadowaead/service.go +++ b/protocol/shadowsocks/shadowaead/service.go @@ -20,6 +20,8 @@ import ( "golang.org/x/crypto/chacha20poly1305" ) +var ErrBadHeader = E.New("bad header") + type Service struct { name string keySaltLength int @@ -83,17 +85,25 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M } func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - _salt := buf.Make(s.keySaltLength) - defer runtime.KeepAlive(_salt) - salt := common.Dup(_salt) + _header := buf.Make(s.keySaltLength + PacketLengthBufferSize + Overhead) + defer runtime.KeepAlive(_header) + header := common.Dup(_header) - _, err := io.ReadFull(conn, salt) + n, err := conn.Read(header) if err != nil { - return E.Cause(err, "read salt") + return E.Cause(err, "read header") + } else if n < len(header) { + return ErrBadHeader } - key := Kdf(s.key, salt, s.keySaltLength) + key := Kdf(s.key, header[:s.keySaltLength], s.keySaltLength) reader := NewReader(conn, s.constructor(common.Dup(key)), MaxPacketSize) + + err = reader.ReadWithLengthChunk(header[s.keySaltLength:]) + if err != nil { + return err + } + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) if err != nil { return err @@ -250,6 +260,6 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks c := w.constructor(common.Dup(key)) runtime.KeepAlive(key) c.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:c.NonceSize()], buffer.From(w.keySaltLength), nil) - buffer.Extend(c.Overhead()) + buffer.Extend(Overhead) return w.PacketConn.WritePacket(buffer, w.source) } diff --git a/protocol/shadowsocks/shadowaead_2022/protocol.go b/protocol/shadowsocks/shadowaead_2022/protocol.go index 19b418c..7fc6a11 100644 --- a/protocol/shadowsocks/shadowaead_2022/protocol.go +++ b/protocol/shadowsocks/shadowaead_2022/protocol.go @@ -30,22 +30,12 @@ import ( ) const ( - HeaderTypeClient = 0 - HeaderTypeServer = 1 - MaxPaddingLength = 900 - PacketNonceSize = 24 - MaxPacketSize = 65535 -) - -const ( - // crypto/cipher.gcmStandardNonceSize - // golang.org/x/crypto/chacha20poly1305.NonceSize - nonceSize = 12 - - // Overhead - // crypto/cipher.gcmTagSize - // golang.org/x/crypto/chacha20poly1305.Overhead - overhead = 16 + HeaderTypeClient = 0 + HeaderTypeServer = 1 + MaxPaddingLength = 900 + PacketNonceSize = 24 + MaxPacketSize = 65535 + RequestHeaderFixedChunkLength = 1 + 8 + 2 ) var ( @@ -259,38 +249,32 @@ func (c *clientConn) writeRequest(payload []byte) error { header.Write(salt) c.writeExtendedIdentityHeaders(header, salt) - bufferedWriter := writer.BufferedWriter(header.Len()) - - common.Must(rw.WriteByte(bufferedWriter, HeaderTypeClient)) - common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix()))) - - err := M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination) - if err != nil { - return E.Cause(err, "write destination") + var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte + fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:])) + common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient)) + common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(time.Now().Unix()))) + var paddingLen int + if len(payload) == 0 { + paddingLen = rand.Intn(MaxPaddingLength + 1) } + variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen + len(payload) + common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen))) + writer.WriteChunk(header, fixedLengthBuffer.Slice()) + runtime.KeepAlive(_fixedLengthBuffer) - if len(payload) > 0 { - err = binary.Write(bufferedWriter, binary.BigEndian, uint16(0)) - if err != nil { - return E.Cause(err, "write padding length") - } - _, err = bufferedWriter.Write(payload) - if err != nil { - return E.Cause(err, "write payload") - } + _variableLengthBuffer := buf.Make(variableLengthHeaderLen) + variableLengthBuffer := buf.With(common.Dup(_variableLengthBuffer)) + common.Must(M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination)) + common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen))) + if paddingLen > 0 { + common.Must1(io.CopyN(variableLengthBuffer, c.secureRNG, int64(paddingLen))) } else { - pLen := rand.Intn(MaxPaddingLength + 1) - err = binary.Write(bufferedWriter, binary.BigEndian, uint16(pLen)) - if err != nil { - return E.Cause(err, "write padding length") - } - _, err = io.CopyN(bufferedWriter, c.secureRNG, int64(pLen)) - if err != nil { - return E.Cause(err, "write padding") - } + common.Must1(variableLengthBuffer.Write(payload)) } + writer.WriteChunk(header, variableLengthBuffer.Slice()) + runtime.KeepAlive(_variableLengthBuffer) - err = bufferedWriter.Flush() + err := writer.BufferedWriter(header.Len()).Flush() if err != nil { return E.Cause(err, "client handshake") } @@ -325,6 +309,11 @@ func (c *clientConn) readResponse() error { ) runtime.KeepAlive(key) + err = reader.ReadWithLength(uint16(1 + 8 + c.keySaltLength + 2)) + if err != nil { + return E.Cause(err, "read response fixed length chunk") + } + headerType, err := rw.ReadByte(reader) if err != nil { return err @@ -356,6 +345,17 @@ func (c *clientConn) readResponse() error { } runtime.KeepAlive(_requestSalt) + var length uint16 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return err + } + + err = reader.ReadWithLength(length) + if err != nil { + return err + } + c.requestSalt = nil c.reader = reader @@ -481,11 +481,11 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad } if c.udpCipher != nil { c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) - buffer.Extend(c.udpCipher.Overhead()) + buffer.Extend(shadowaead.Overhead) } else { packetHeader := buffer.To(aes.BlockSize) c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) - buffer.Extend(c.session.cipher.Overhead()) + buffer.Extend(shadowaead.Overhead) c.udpBlockCipher.Encrypt(packetHeader, packetHeader) } return common.Error(c.Write(buffer.Bytes())) @@ -505,7 +505,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { return M.Socksaddr{}, E.Cause(err, "decrypt packet") } buffer.Advance(PacketNonceSize) - buffer.Truncate(buffer.Len() - c.udpCipher.Overhead()) + buffer.Truncate(buffer.Len() - shadowaead.Overhead) } else { packetHeader = buffer.To(aes.BlockSize) c.udpBlockCipher.Decrypt(packetHeader, packetHeader) @@ -536,7 +536,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { if err != nil { return M.Socksaddr{}, E.Cause(err, "decrypt packet") } - buffer.Truncate(buffer.Len() - remoteCipher.Overhead()) + buffer.Truncate(buffer.Len() - shadowaead.Overhead) } var headerType byte @@ -625,9 +625,9 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { destination := M.SocksaddrFromNet(addr) var overHead int if c.udpCipher != nil { - overHead = PacketNonceSize + c.udpCipher.Overhead() + overHead = PacketNonceSize + shadowaead.Overhead } else { - overHead = c.session.cipher.Overhead() + overHead = shadowaead.Overhead } overHead += 16 // packet header pskLen := len(c.pskList) @@ -686,11 +686,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { } if c.udpCipher != nil { c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) - buffer.Extend(c.udpCipher.Overhead()) + buffer.Extend(shadowaead.Overhead) } else { packetHeader := buffer.To(aes.BlockSize) c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) - buffer.Extend(c.session.cipher.Overhead()) + buffer.Extend(shadowaead.Overhead) c.udpBlockCipher.Encrypt(packetHeader, packetHeader) } err = common.Error(c.Write(buffer.Bytes())) diff --git a/protocol/shadowsocks/shadowaead_2022/service.go b/protocol/shadowsocks/shadowaead_2022/service.go index 6fdf2c2..4217335 100644 --- a/protocol/shadowsocks/shadowaead_2022/service.go +++ b/protocol/shadowsocks/shadowaead_2022/service.go @@ -28,7 +28,10 @@ import ( wgReplay "golang.zx2c4.com/wireguard/replay" ) -var ErrNoPadding = E.New("bad request: missing payload or padding") +var ( + ErrNoPadding = E.New("bad request: missing payload or padding") + ErrBadPadding = E.New("bad request: damaged padding") +) type Service struct { name string @@ -108,12 +111,17 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M } func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - requestSalt := buf.Make(s.keySaltLength) - _, err := io.ReadFull(conn, requestSalt) + header := buf.Make(s.keySaltLength + shadowaead.Overhead + RequestHeaderFixedChunkLength) + + n, err := conn.Read(header) if err != nil { - return E.Cause(err, "read request salt") + return E.Cause(err, "read header") + } else if n < len(header) { + return shadowaead.ErrBadHeader } + requestSalt := header[:s.keySaltLength] + if !s.replayFilter.Check(requestSalt) { return E.New("salt not unique") } @@ -126,7 +134,12 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M ) runtime.KeepAlive(requestKey) - headerType, err := rw.ReadByte(reader) + err = reader.ReadChunk(header[s.keySaltLength:]) + if err != nil { + return E.Cause(err, "read request fixed length chunk") + } + + headerType, err := reader.ReadByte() if err != nil { return E.Cause(err, "read header") } @@ -138,7 +151,7 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M var epoch uint64 err = binary.Read(reader, binary.BigEndian, &epoch) if err != nil { - return E.Cause(err, "read timestamp") + return err } diff := int(math.Abs(float64(time.Now().Unix() - int64(epoch)))) @@ -146,15 +159,30 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M return ErrBadTimestamp } + var length uint16 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return err + } + + err = reader.ReadWithLength(length) + if err != nil { + return err + } + destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) if err != nil { - return E.Cause(err, "read destination") + return err } var paddingLen uint16 err = binary.Read(reader, binary.BigEndian, &paddingLen) if err != nil { - return E.Cause(err, "read padding length") + return err + } + + if uint16(reader.Cached()) < paddingLen { + return ErrNoPadding } if paddingLen > 0 { @@ -201,21 +229,23 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { runtime.KeepAlive(key) header := writer.Buffer() header.Write(salt) - bufferedWriter := writer.BufferedWriter(header.Len()) - common.Must(rw.WriteByte(bufferedWriter, HeaderTypeServer)) - common.Must(binary.Write(bufferedWriter, binary.BigEndian, uint64(time.Now().Unix()))) - common.Must1(bufferedWriter.Write(c.requestSalt[:])) + _headerFixedChunk := buf.Make(1 + 8 + c.keySaltLength + 2) + headerFixedChunk := buf.With(common.Dup(_headerFixedChunk)) + common.Must(headerFixedChunk.WriteByte(HeaderTypeServer)) + common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(time.Now().Unix()))) + common.Must1(headerFixedChunk.Write(c.requestSalt)) + common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(len(payload)))) + + writer.WriteChunk(header, headerFixedChunk.Slice()) + runtime.KeepAlive(_headerFixedChunk) c.requestSalt = nil if len(payload) > 0 { - _, err = bufferedWriter.Write(payload) - if err != nil { - return - } + writer.WriteChunk(header, payload) } - err = bufferedWriter.Flush() + err = writer.BufferedWriter(header.Len()).Flush() if err != nil { return } @@ -287,7 +317,7 @@ func (s *Service) newPacket(conn N.PacketConn, buffer *buf.Buffer, metadata M.Me return E.Cause(err, "decrypt packet header") } buffer.Advance(PacketNonceSize) - buffer.Truncate(buffer.Len() - s.udpCipher.Overhead()) + buffer.Truncate(buffer.Len() - shadowaead.Overhead) } else { packetHeader = buffer.To(aes.BlockSize) s.udpBlockCipher.Decrypt(packetHeader, packetHeader) @@ -332,7 +362,7 @@ process: err = E.Cause(err, "decrypt packet") goto returnErr } - buffer.Truncate(buffer.Len() - session.remoteCipher.Overhead()) + buffer.Truncate(buffer.Len() - shadowaead.Overhead) } var headerType byte @@ -421,11 +451,11 @@ func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socks if w.udpCipher != nil { w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) - buffer.Extend(w.udpCipher.Overhead()) + buffer.Extend(shadowaead.Overhead) } else { packetHeader := buffer.To(aes.BlockSize) w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) - buffer.Extend(w.session.cipher.Overhead()) + buffer.Extend(shadowaead.Overhead) w.udpBlockCipher.Encrypt(packetHeader, packetHeader) } return w.PacketConn.WritePacket(buffer, w.session.remoteAddr) diff --git a/protocol/shadowsocks/shadowaead_2022/service_multi.go b/protocol/shadowsocks/shadowaead_2022/service_multi.go index 6269379..5816935 100644 --- a/protocol/shadowsocks/shadowaead_2022/service_multi.go +++ b/protocol/shadowsocks/shadowaead_2022/service_multi.go @@ -91,22 +91,21 @@ func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, meta } func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - requestSalt := make([]byte, s.keySaltLength) - _, err := io.ReadFull(conn, requestSalt) + requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength) + n, err := conn.Read(requestHeader) if err != nil { - return E.Cause(err, "read request salt") + return err + } else if n < len(requestHeader) { + return shadowaead.ErrBadHeader } - + requestSalt := requestHeader[:s.keySaltLength] if !s.replayFilter.Check(requestSalt) { return E.New("salt not unique") } var _eiHeader [aes.BlockSize]byte eiHeader := common.Dup(_eiHeader[:]) - _, err = io.ReadFull(conn, eiHeader) - if err != nil { - return E.Cause(err, "read extended identity header") - } + copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize]) keyMaterial := buf.Make(s.keySaltLength * 2) copy(keyMaterial, s.psk) @@ -125,6 +124,7 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta } else { return E.New("invalid request") } + runtime.KeepAlive(_eiHeader) requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength) reader := shadowaead.NewReader( @@ -132,7 +132,11 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta s.constructor(common.Dup(requestKey)), MaxPacketSize, ) - runtime.KeepAlive(requestSalt) + + err = reader.ReadChunk(requestHeader[s.keySaltLength+aes.BlockSize:]) + if err != nil { + return err + } headerType, err := rw.ReadByte(reader) if err != nil { @@ -152,6 +156,16 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta if diff > 30 { return ErrBadTimestamp } + var length uint16 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return E.Cause(err, "read length") + } + + err = reader.ReadWithLength(length) + if err != nil { + return err + } destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) if err != nil { @@ -164,7 +178,9 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta return E.Cause(err, "read padding length") } - if paddingLen > 0 { + if reader.Cached() < int(paddingLen) { + return ErrBadPadding + } else if paddingLen > 0 { err = reader.Discard(int(paddingLen)) if err != nil { return E.Cause(err, "discard padding") @@ -257,7 +273,7 @@ process: err = E.Cause(err, "decrypt packet") goto returnErr } - buffer.Truncate(buffer.Len() - session.remoteCipher.Overhead()) + buffer.Truncate(buffer.Len() - shadowaead.Overhead) } var headerType byte diff --git a/protocol/shadowsocks/shadowaead_2022/service_multi_test.go b/protocol/shadowsocks/shadowaead_2022/service_multi_test.go index fb04dba..8be86ae 100644 --- a/protocol/shadowsocks/shadowaead_2022/service_multi_test.go +++ b/protocol/shadowsocks/shadowaead_2022/service_multi_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/random" @@ -40,7 +41,8 @@ func TestMultiService(t *testing.T) { go func() { err := multiService.NewConnection(context.Background(), serverConn, M.Metadata{}) if err != nil { - t.Error(err) + serverConn.Close() + t.Error(E.Cause(err, "server")) return } }() diff --git a/protocol/shadowsocks/shadowaead_2022/service_test.go b/protocol/shadowsocks/shadowaead_2022/service_test.go new file mode 100644 index 0000000..6817152 --- /dev/null +++ b/protocol/shadowsocks/shadowaead_2022/service_test.go @@ -0,0 +1,49 @@ +package shadowaead_2022_test + +import ( + "context" + "net" + "sync" + "testing" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/random" + "github.com/sagernet/sing/protocol/shadowsocks/shadowaead_2022" +) + +func TestService(t *testing.T) { + method := "2022-blake3-aes-128-gcm" + var psk [16]byte + random.Default.Read(psk[:]) + + var wg sync.WaitGroup + + service, err := shadowaead_2022.NewService(method, psk[:], "", random.Default, 500, &multiHandler{t, &wg}) + if err != nil { + t.Fatal(err) + } + + client, err := shadowaead_2022.New(method, [][]byte{psk[:]}, "", random.Default) + if err != nil { + t.Fatal(err) + } + wg.Add(1) + + serverConn, clientConn := net.Pipe() + defer common.Close(serverConn, clientConn) + go func() { + err := service.NewConnection(context.Background(), serverConn, M.Metadata{}) + if err != nil { + serverConn.Close() + t.Error(E.Cause(err, "server")) + return + } + }() + _, err = client.DialConn(clientConn, M.ParseSocksaddr("test.com:443")) + if err != nil { + t.Fatal(err) + } + wg.Wait() +}