From fa7811b49c4140bf11d52f2e97c2dc8ed21d6bcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 27 May 2022 14:45:41 +0800 Subject: [PATCH] Minimal dependencies & Fixes --- .github/update_dependencies.sh | 6 ++ go.mod | 6 +- go.sum | 14 +--- none.go | 55 +++++-------- shadowaead/aead.go | 12 +-- shadowaead/protocol.go | 3 +- shadowaead_2022/protocol.go | 117 ++++++++++++++++------------ shadowaead_2022/relay.go | 29 ++++--- shadowaead_2022/service.go | 48 ++++++++---- shadowaead_2022/service_multi.go | 21 ++++- shadowaead_2022/wg_replay/replay.go | 62 +++++++++++++++ shadowimpl/fetcher.go | 2 +- shadowstream/protocol.go | 53 +------------ 13 files changed, 237 insertions(+), 191 deletions(-) create mode 100755 .github/update_dependencies.sh create mode 100644 shadowaead_2022/wg_replay/replay.go diff --git a/.github/update_dependencies.sh b/.github/update_dependencies.sh new file mode 100755 index 0000000..2ec62da --- /dev/null +++ b/.github/update_dependencies.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +PROJECTS=$(dirname "$0")/../.. + +go get -v github.com/sagernet/sing@$(git -C $PROJECTS/sing rev-parse HEAD) +go mod tidy diff --git a/go.mod b/go.mod index a3c06a3..c7ea058 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,8 @@ module github.com/sagernet/sing-shadowsocks go 1.18 require ( - github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d - github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34 - golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898 - golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d + github.com/sagernet/sing v0.0.0-20220527064225-7abc98fdea29 + golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e lukechampine.com/blake3 v1.1.7 ) diff --git a/go.sum b/go.sum index 45dd320..2ed76f9 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,11 @@ -github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d h1:CPqTNIigGweVPT4CYb+OO2E6XyRKFOmvTHwWRLgCAlE= -github.com/dgryski/go-camellia v0.0.0-20191119043421-69a8a13fb23d/go.mod h1:QX5ZVULjAfZJux/W62Y91HvCh9hyW6enAwcrrv/sLj0= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.12 h1:p9dKCg8i4gmOxtv35DvrYoWqYzQrvEVdjQ762Y0OqZE= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= -github.com/sagernet/sing v0.0.0-20220525062603-53c607b13ff2 h1:x7E53uloX7pU3rWOzb81IBCAmwMtE2u9x4ZJvJXaCnM= -github.com/sagernet/sing v0.0.0-20220525062603-53c607b13ff2/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg= -github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34 h1:/FfHfteLZo5mOtZbYOx/9ymDEYxlwBuM5iHO9reVe/E= -github.com/sagernet/sing v0.0.0-20220525123305-5b49366f6e34/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg= -golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898 h1:SLP7Q4Di66FONjDJbCYrCRrh97focO6sLogHO7/g8F0= -golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +github.com/sagernet/sing v0.0.0-20220527064225-7abc98fdea29 h1:dOXoZvzUbNir2yxCP0Uj1+o7KY/5wtW/83ri8QAIGoQ= +github.com/sagernet/sing v0.0.0-20220527064225-7abc98fdea29/go.mod h1:w2HnJzXKHpD6F5Z/9XlSD4qbcpHY2RSZuQnFzqgELMg= +golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= +golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d h1:q4JksJ2n0fmbXC0Aj0eOs6E0AcPqnKglxWXWFqGD6x0= -golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d/go.mod h1:bVQfyl2sCM/QIIGHpWbFGfHPuDvqnCNkT6MQLTCjO/U= lukechampine.com/blake3 v1.1.7 h1:GgRMhmdsuK8+ii6UZFDL8Nb+VyMwadAgcJyfYHxG6n0= lukechampine.com/blake3 v1.1.7/go.mod h1:tkKEOtDkNtklkXtLNEOGNq5tcV90tJiA1vAA12R78LA= diff --git a/none.go b/none.go index e286f1d..335d178 100644 --- a/none.go +++ b/none.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" @@ -61,6 +62,9 @@ type noneConn struct { } func (c *noneConn) clientHandshake() error { + c.access.Lock() + defer c.access.Unlock() + err := M.SocksaddrSerializer.WriteAddrPort(c.Conn, c.destination) if err != nil { return err @@ -71,59 +75,38 @@ func (c *noneConn) clientHandshake() error { func (c *noneConn) Write(b []byte) (n int, err error) { if c.handshake { - goto direct + return c.Conn.Write(b) } c.access.Lock() - defer c.access.Unlock() - if c.handshake { - goto direct + c.access.Unlock() + return c.Conn.Write(b) } - { - if len(b) == 0 { - return 0, c.clientHandshake() - } - - _buffer := buf.StackNew() - buffer := common.Dup(_buffer) - - err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination) - if err != nil { - return - } - - bufN, _ := buffer.Write(b) - _, err = c.Conn.Write(buffer.Bytes()) - runtime.KeepAlive(_buffer) - if err != nil { - return - } - - if bufN < len(b) { - _, err = c.Conn.Write(b[bufN:]) - if err != nil { - return - } - } - - n = len(b) + err = M.SocksaddrSerializer.WriteAddrPort(c.Conn, c.destination) + if err != nil { + return } - -direct: + c.handshake = true + c.access.Unlock() return c.Conn.Write(b) } func (c *noneConn) ReadFrom(r io.Reader) (n int64, err error) { if !c.handshake { - return rw.ReadFrom0(c, r) + c.access.Lock() + if !c.handshake { + c.access.Unlock() + return rw.ReadFrom0(c, r) + } + c.access.Unlock() } return c.Conn.(io.ReaderFrom).ReadFrom(r) } func (c *noneConn) WriteTo(w io.Writer) (n int64, err error) { - return io.Copy(w, c.Conn) + return bufio.Copy(w, c.Conn) } func (c *noneConn) RemoteAddr() net.Addr { diff --git a/shadowaead/aead.go b/shadowaead/aead.go index 658cf61..d9a0ce6 100644 --- a/shadowaead/aead.go +++ b/shadowaead/aead.go @@ -15,11 +15,6 @@ const ( ) const ( - // NonceSize - // crypto/cipher.gcmStandardNonceSize - // golang.org/x/crypto/chacha20poly1305.NonceSize - NonceSize = 12 - // Overhead // crypto/cipher.gcmTagSize // golang.org/x/crypto/chacha20poly1305.Overhead @@ -40,7 +35,7 @@ func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reade upstream: upstream, cipher: cipher, buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2), - nonce: make([]byte, NonceSize), + nonce: make([]byte, cipher.NonceSize()), } } @@ -233,7 +228,7 @@ func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error { } func (r *Reader) ReadWithLength(length uint16) error { - end := length + Overhead + end := int(length) + Overhead _, err := io.ReadFull(r.upstream, r.buffer[:end]) if err != nil { return err @@ -321,9 +316,10 @@ func (w *Writer) Write(p []byte) (n int, err error) { if pLen > w.maxPacketSize { data = p[:w.maxPacketSize] p = p[w.maxPacketSize:] + pLen -= w.maxPacketSize } else { data = p - p = nil + pLen = 0 } binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data))) w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) diff --git a/shadowaead/protocol.go b/shadowaead/protocol.go index 1271952..815bc5d 100644 --- a/shadowaead/protocol.go +++ b/shadowaead/protocol.go @@ -319,7 +319,7 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) if err != nil { return } - b := buf.With(p[:n]) + b := buf.As(p[:n]) err = c.DecodePacket(b) if err != nil { return @@ -337,6 +337,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { _buffer := buf.StackNew() defer runtime.KeepAlive(_buffer) buffer := common.Dup(_buffer) + buffer.WriteRandom(c.keySaltLength) err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr)) if err != nil { return diff --git a/shadowaead_2022/protocol.go b/shadowaead_2022/protocol.go index f990d0b..ebb639e 100644 --- a/shadowaead_2022/protocol.go +++ b/shadowaead_2022/protocol.go @@ -20,6 +20,7 @@ import ( "github.com/sagernet/sing-shadowsocks" "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing-shadowsocks/shadowaead_2022/wg_replay" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" @@ -28,7 +29,6 @@ import ( "github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/rw" "golang.org/x/crypto/chacha20poly1305" - wgReplay "golang.zx2c4.com/wireguard/replay" "lukechampine.com/blake3" ) @@ -83,18 +83,18 @@ func New(method string, pskList [][]byte) (shadowsocks.Method, error) { switch method { case "2022-blake3-aes-128-gcm": m.keySaltLength = 16 - m.constructor = newAESGCM - m.blockConstructor = newAES + m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) + m.blockConstructor = aes.NewCipher case "2022-blake3-aes-256-gcm": m.keySaltLength = 32 - m.constructor = newAESGCM - m.blockConstructor = newAES + m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) + m.blockConstructor = aes.NewCipher case "2022-blake3-chacha20-poly1305": if len(pskList) > 1 { return nil, os.ErrInvalid } m.keySaltLength = 32 - m.constructor = newChacha20Poly1305 + m.constructor = chacha20poly1305.New } if len(pskList) == 0 { @@ -121,11 +121,15 @@ func New(method string, pskList [][]byte) (shadowsocks.Method, error) { m.pskHash = pskHash } + var err error switch method { case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": - m.udpBlockCipher = newAES(pskList[0]) + m.udpBlockCipher, err = aes.NewCipher(pskList[0]) case "2022-blake3-chacha20-poly1305": - m.udpCipher = newXChacha20Poly1305(pskList[0]) + m.udpCipher, err = chacha20poly1305.NewX(pskList[0]) + } + if err != nil { + return nil, err } m.pskList = pskList @@ -146,37 +150,21 @@ func SessionKey(psk []byte, salt []byte, keyLength int) []byte { return outKey } -func newAES(key []byte) cipher.Block { - block, err := aes.NewCipher(key) - common.Must(err) - return block -} - -func newAESGCM(key []byte) cipher.AEAD { - block, err := aes.NewCipher(key) - common.Must(err) - aead, err := cipher.NewGCM(block) - common.Must(err) - return aead -} - -func newChacha20Poly1305(key []byte) cipher.AEAD { - cipher, err := chacha20poly1305.New(key) - common.Must(err) - return cipher -} - -func newXChacha20Poly1305(key []byte) cipher.AEAD { - cipher, err := chacha20poly1305.NewX(key) - common.Must(err) - return cipher +func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) { + return func(key []byte) (cipher.AEAD, error) { + b, err := block(key) + if err != nil { + return nil, err + } + return aead(b) + } } type Method struct { name string keySaltLength int - constructor func(key []byte) cipher.AEAD - blockConstructor func(key []byte) cipher.Block + constructor func(key []byte) (cipher.AEAD, error) + blockConstructor func(key []byte) (cipher.Block, error) udpCipher cipher.AEAD udpBlockCipher cipher.Block pskList [][]byte @@ -222,10 +210,10 @@ type clientConn struct { writer *shadowaead.Writer } -func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) { +func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error { pskLen := len(m.pskList) if pskLen < 2 { - return + return nil } for i, psk := range m.pskList { keyMaterial := buf.Make(m.keySaltLength * 2) @@ -238,12 +226,17 @@ func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] header := request.Extend(16) - m.blockConstructor(identitySubkey).Encrypt(header, pskHash) + b, err := m.blockConstructor(identitySubkey) + if err != nil { + return err + } + b.Encrypt(header, pskHash) runtime.KeepAlive(_identitySubkey) if i == pskLen-2 { break } } + return nil } func (c *clientConn) writeRequest(payload []byte) error { @@ -251,16 +244,24 @@ func (c *clientConn) writeRequest(payload []byte) error { common.Must1(io.ReadFull(rand.Reader, salt)) key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) + writeCipher, err := c.constructor(common.Dup(key)) + if err != nil { + return err + } writer := shadowaead.NewWriter( c.Conn, - c.constructor(common.Dup(key)), + writeCipher, MaxPacketSize, ) runtime.KeepAlive(key) header := writer.Buffer() header.Write(salt) - c.writeExtendedIdentityHeaders(header, salt) + + err = c.writeExtendedIdentityHeaders(header, salt) + if err != nil { + return err + } var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte fixedLengthBuffer := buf.With(common.Dup(_fixedLengthBuffer[:])) @@ -287,7 +288,7 @@ func (c *clientConn) writeRequest(payload []byte) error { writer.WriteChunk(header, variableLengthBuffer.Slice()) runtime.KeepAlive(_variableLengthBuffer) - err := writer.BufferedWriter(header.Len()).Flush() + err = writer.BufferedWriter(header.Len()).Flush() if err != nil { return E.Cause(err, "client handshake") } @@ -315,9 +316,13 @@ func (c *clientConn) readResponse() error { key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) runtime.KeepAlive(_salt) + readCipher, err := c.constructor(common.Dup(key)) + if err != nil { + return err + } reader := shadowaead.NewReader( c.Conn, - c.constructor(common.Dup(key)), + readCipher, MaxPacketSize, ) runtime.KeepAlive(key) @@ -458,7 +463,11 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad for textI := 0; textI < aes.BlockSize; textI++ { identityHeader[textI] = pskHash[textI] ^ header.Byte(textI) } - c.blockConstructor(psk).Encrypt(identityHeader, identityHeader) + b, err := c.blockConstructor(psk) + if err != nil { + return err + } + b.Encrypt(identityHeader, identityHeader) if i == pskLen-2 { break @@ -524,7 +533,10 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { remoteCipher = c.session.lastRemoteCipher } else { key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength) - remoteCipher = c.constructor(common.Dup(key)) + remoteCipher, err = c.constructor(common.Dup(key)) + if err != nil { + return M.Socksaddr{}, err + } runtime.KeepAlive(key) } _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) @@ -573,7 +585,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { c.session.lastFilter = c.session.filter c.session.lastRemoteSeen = time.Now().Unix() c.session.lastRemoteCipher = c.session.remoteCipher - c.session.filter = wgReplay.Filter{} + c.session.filter = wg_replay.Filter{} } } c.session.remoteSessionId = sessionId @@ -663,7 +675,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { for textI := 0; textI < aes.BlockSize; textI++ { identityHeader[textI] = pskHash[textI] ^ buffer.Byte(textI) } - c.blockConstructor(psk).Encrypt(identityHeader, identityHeader) + b, err := c.blockConstructor(psk) + if err != nil { + return 0, err + } + b.Encrypt(identityHeader, identityHeader) if i == pskLen-2 { break @@ -679,6 +695,7 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if err != nil { return } + common.Must1(buffer.Write(p)) if c.udpCipher != nil { c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) buffer.Extend(shadowaead.Overhead) @@ -705,8 +722,8 @@ type udpSession struct { cipher cipher.AEAD remoteCipher cipher.AEAD lastRemoteCipher cipher.AEAD - filter wgReplay.Filter - lastFilter wgReplay.Filter + filter wg_replay.Filter + lastFilter wg_replay.Filter rng io.Reader } @@ -727,7 +744,11 @@ func (m *Method) newUDPSession() *udpSession { sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength) - session.cipher = m.constructor(common.Dup(key)) + var err error + session.cipher, err = m.constructor(common.Dup(key)) + if err != nil { + return nil + } runtime.KeepAlive(key) } return session diff --git a/shadowaead_2022/relay.go b/shadowaead_2022/relay.go index 61da950..ca859bb 100644 --- a/shadowaead_2022/relay.go +++ b/shadowaead_2022/relay.go @@ -29,8 +29,8 @@ type Relay[U comparable] struct { keySaltLength int handler shadowsocks.Handler - constructor func(key []byte) cipher.AEAD - blockConstructor func(key []byte) cipher.Block + constructor func(key []byte) (cipher.AEAD, error) + blockConstructor func(key []byte) (cipher.Block, error) udpBlockCipher cipher.Block iPSK []byte @@ -60,9 +60,9 @@ func (s *Relay[U]) AddUser(user U, key []byte, destination M.Socksaddr) error { s.uPSKHash[user] = uPSKHash s.uPSKHashR[uPSKHash] = user s.uDestination[user] = destination - s.uCipher[user] = s.blockConstructor(key) - - return nil + var err error + s.uCipher[user], err = s.blockConstructor(key) + return err } func (s *Relay[U]) RemoveUser(user U) { @@ -94,12 +94,12 @@ func NewRelay[U comparable](method string, psk []byte, secureRNG io.Reader, udpT switch method { case "2022-blake3-aes-128-gcm": s.keySaltLength = 16 - s.constructor = newAESGCM - s.blockConstructor = newAES + s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) + s.blockConstructor = aes.NewCipher case "2022-blake3-aes-256-gcm": s.keySaltLength = 32 - s.constructor = newAESGCM - s.blockConstructor = newAES + s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) + s.blockConstructor = aes.NewCipher default: return nil, os.ErrInvalid } @@ -110,8 +110,9 @@ func NewRelay[U comparable](method string, psk []byte, secureRNG io.Reader, udpT psk = Key(psk, s.keySaltLength) } } - s.udpBlockCipher = s.blockConstructor(psk) - return s, nil + var err error + s.udpBlockCipher, err = s.blockConstructor(psk) + return s, err } func (s *Relay[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { @@ -143,7 +144,11 @@ func (s *Relay[U]) newConnection(ctx context.Context, conn net.Conn, metadata M. _identitySubkey := buf.Make(s.keySaltLength) identitySubkey := common.Dup(_identitySubkey) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) - s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) + b, err := s.blockConstructor(identitySubkey) + if err != nil { + return err + } + b.Decrypt(eiHeader, eiHeader) runtime.KeepAlive(_identitySubkey) var user U diff --git a/shadowaead_2022/service.go b/shadowaead_2022/service.go index 6a4c750..a0f43cb 100644 --- a/shadowaead_2022/service.go +++ b/shadowaead_2022/service.go @@ -18,6 +18,7 @@ import ( "github.com/sagernet/sing-shadowsocks" "github.com/sagernet/sing-shadowsocks/shadowaead" + "github.com/sagernet/sing-shadowsocks/shadowaead_2022/wg_replay" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/cache" @@ -27,7 +28,7 @@ import ( "github.com/sagernet/sing/common/replay" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/udpnat" - wgReplay "golang.zx2c4.com/wireguard/replay" + "golang.org/x/crypto/chacha20poly1305" ) var ( @@ -41,8 +42,8 @@ type Service struct { keySaltLength int handler shadowsocks.Handler - constructor func(key []byte) cipher.AEAD - blockConstructor func(key []byte) cipher.Block + constructor func(key []byte) (cipher.AEAD, error) + blockConstructor func(key []byte) (cipher.Block, error) udpCipher cipher.AEAD udpBlockCipher cipher.Block psk []byte @@ -79,15 +80,15 @@ func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks switch method { case "2022-blake3-aes-128-gcm": s.keySaltLength = 16 - s.constructor = newAESGCM - s.blockConstructor = newAES + s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) + s.blockConstructor = aes.NewCipher case "2022-blake3-aes-256-gcm": s.keySaltLength = 32 - s.constructor = newAESGCM - s.blockConstructor = newAES + s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) + s.blockConstructor = aes.NewCipher case "2022-blake3-chacha20-poly1305": s.keySaltLength = 32 - s.constructor = newChacha20Poly1305 + s.constructor = chacha20poly1305.New default: return nil, os.ErrInvalid } @@ -102,11 +103,15 @@ func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks } } + var err error switch method { case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": - s.udpBlockCipher = newAES(psk) + s.udpBlockCipher, err = aes.NewCipher(psk) case "2022-blake3-chacha20-poly1305": - s.udpCipher = newXChacha20Poly1305(psk) + s.udpCipher, err = chacha20poly1305.NewX(psk) + } + if err != nil { + return nil, err } s.psk = psk @@ -138,9 +143,13 @@ func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.M } requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength) + readCipher, err := s.constructor(common.Dup(requestKey)) + if err != nil { + return err + } reader := shadowaead.NewReader( conn, - s.constructor(common.Dup(requestKey)), + readCipher, MaxPacketSize, ) runtime.KeepAlive(requestKey) @@ -232,9 +241,13 @@ func (c *serverConn) writeResponse(payload []byte) (n int, err error) { common.Must1(io.ReadFull(rand.Reader, salt)) key := SessionKey(c.uPSK, salt, c.keySaltLength) runtime.KeepAlive(_salt) + writeCipher, err := c.constructor(common.Dup(key)) + if err != nil { + return + } writer := shadowaead.NewWriter( c.Conn, - c.constructor(common.Dup(key)), + writeCipher, MaxPacketSize, ) runtime.KeepAlive(key) @@ -331,7 +344,10 @@ func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf. session.remoteSessionId = sessionId if packetHeader != nil { key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength) - session.remoteCipher = s.constructor(common.Dup(key)) + session.remoteCipher, err = s.constructor(common.Dup(key)) + if err != nil { + return err + } runtime.KeepAlive(key) } } @@ -461,7 +477,7 @@ type serverUDPSession struct { packetId uint64 cipher cipher.AEAD remoteCipher cipher.AEAD - filter wgReplay.Filter + filter wg_replay.Filter rng io.Reader } @@ -482,7 +498,9 @@ func (m *Service) newUDPSession() *serverUDPSession { sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) key := SessionKey(m.psk, sessionId, m.keySaltLength) - session.cipher = m.constructor(common.Dup(key)) + var err error + session.cipher, err = m.constructor(common.Dup(key)) + common.Must(err) runtime.KeepAlive(key) } return session diff --git a/shadowaead_2022/service_multi.go b/shadowaead_2022/service_multi.go index c98c6b9..9d7ba32 100644 --- a/shadowaead_2022/service_multi.go +++ b/shadowaead_2022/service_multi.go @@ -137,7 +137,11 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta _identitySubkey := buf.Make(s.keySaltLength) identitySubkey := common.Dup(_identitySubkey) blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) - s.blockConstructor(identitySubkey).Decrypt(eiHeader, eiHeader) + b, err := s.blockConstructor(identitySubkey) + if err != nil { + return err + } + b.Decrypt(eiHeader, eiHeader) runtime.KeepAlive(_identitySubkey) var user U @@ -151,9 +155,13 @@ func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, meta runtime.KeepAlive(_eiHeader) requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength) + readCipher, err := s.constructor(common.Dup(requestKey)) + if err != nil { + return err + } reader := shadowaead.NewReader( conn, - s.constructor(common.Dup(requestKey)), + readCipher, MaxPacketSize, ) @@ -273,7 +281,10 @@ func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buff if !loaded { session.remoteSessionId = sessionId key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength) - session.remoteCipher = s.constructor(common.Dup(key)) + session.remoteCipher, err = s.constructor(common.Dup(key)) + if err != nil { + return err + } runtime.KeepAlive(key) } @@ -359,7 +370,9 @@ func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession { sessionId := make([]byte, 8) binary.BigEndian.PutUint64(sessionId, session.sessionId) key := SessionKey(uPSK, sessionId, s.keySaltLength) - session.cipher = s.constructor(common.Dup(key)) + var err error + session.cipher, err = s.constructor(common.Dup(key)) + common.Must(err) runtime.KeepAlive(key) return session } diff --git a/shadowaead_2022/wg_replay/replay.go b/shadowaead_2022/wg_replay/replay.go new file mode 100644 index 0000000..19e93ce --- /dev/null +++ b/shadowaead_2022/wg_replay/replay.go @@ -0,0 +1,62 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +// Package wg_replay implements an efficient anti-replay algorithm as specified in RFC 6479. +package wg_replay + +type block uint64 + +const ( + blockBitLog = 6 // 1<<6 == 64 bits + blockBits = 1 << blockBitLog // must be power of 2 + ringBlocks = 1 << 7 // must be power of 2 + windowSize = (ringBlocks - 1) * blockBits + blockMask = ringBlocks - 1 + bitMask = blockBits - 1 +) + +// A Filter rejects replayed messages by checking if message counter value is +// within a sliding window of previously received messages. +// The zero value for Filter is an empty filter ready to use. +// Filters are unsafe for concurrent use. +type Filter struct { + last uint64 + ring [ringBlocks]block +} + +// Reset resets the filter to empty state. +func (f *Filter) Reset() { + f.last = 0 + f.ring[0] = 0 +} + +// ValidateCounter checks if the counter should be accepted. +// Overlimit counters (>= limit) are always rejected. +func (f *Filter) ValidateCounter(counter, limit uint64) bool { + if counter >= limit { + return false + } + indexBlock := counter >> blockBitLog + if counter > f.last { // move window forward + current := f.last >> blockBitLog + diff := indexBlock - current + if diff > ringBlocks { + diff = ringBlocks // cap diff to clear the whole ring + } + for i := current + 1; i <= current+diff; i++ { + f.ring[i&blockMask] = 0 + } + f.last = counter + } else if f.last-counter > windowSize { // behind current window + return false + } + // check and set bit + indexBlock &= blockMask + indexBit := counter & bitMask + old := f.ring[indexBlock] + new := old | 1<