diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index d5d71c9b..00043113 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -1,32 +1,28 @@ package handshake import ( - "crypto/aes" "crypto/cipher" "encoding/binary" - "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/marten-seemann/qtls" ) type sealer struct { - aead cipher.AEAD - hpEncrypter cipher.Block + aead cipher.AEAD + headerProtector headerProtector // use a single slice to avoid allocations nonceBuf []byte - hpMask []byte } var _ LongHeaderSealer = &sealer{} -func newLongHeaderSealer(aead cipher.AEAD, hpEncrypter cipher.Block) LongHeaderSealer { +func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { return &sealer{ - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), - hpEncrypter: hpEncrypter, - hpMask: make([]byte, hpEncrypter.BlockSize()), + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), } } @@ -38,14 +34,7 @@ func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []by } func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != s.hpEncrypter.BlockSize() { - panic("invalid sample size") - } - s.hpEncrypter.Encrypt(s.hpMask, sample) - *firstByte ^= s.hpMask[0] & 0xf - for i := range pnBytes { - pnBytes[i] ^= s.hpMask[i+1] - } + s.headerProtector.EncryptHeader(sample, firstByte, pnBytes) } func (s *sealer) Overhead() int { @@ -53,22 +42,20 @@ func (s *sealer) Overhead() int { } type longHeaderOpener struct { - aead cipher.AEAD - pnDecrypter cipher.Block + aead cipher.AEAD + headerProtector headerProtector // use a single slice to avoid allocations nonceBuf []byte - hpMask []byte } var _ LongHeaderOpener = &longHeaderOpener{} -func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener { +func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { return &longHeaderOpener{ - aead: aead, - nonceBuf: make([]byte, aead.NonceSize()), - pnDecrypter: pnDecrypter, - hpMask: make([]byte, pnDecrypter.BlockSize()), + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), } } @@ -84,14 +71,7 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad [] } func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != o.pnDecrypter.BlockSize() { - panic("invalid sample size") - } - o.pnDecrypter.Encrypt(o.hpMask, sample) - *firstByte ^= o.hpMask[0] & 0xf - for i := range pnBytes { - pnBytes[i] ^= o.hpMask[i+1] - } + o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) } func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD { @@ -99,12 +79,3 @@ func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD { iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen()) return suite.AEAD(key, iv) } - -func createHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block { - hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen()) - hp, err := aes.NewCipher(hpKey) - if err != nil { - panic(fmt.Sprintf("error creating new AES cipher: %s", err)) - } - return hp -} diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 7498fb5a..15391472 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -24,7 +24,8 @@ var _ = Describe("AEAD", func() { iv := make([]byte, 12) rand.Read(iv) - return newLongHeaderSealer(aead, hpBlock), newLongHeaderOpener(aead, hpBlock) + return newLongHeaderSealer(aead, newAESHeaderProtector(hpBlock, true)), + newLongHeaderOpener(aead, newAESHeaderProtector(hpBlock, true)) } Context("message encryption", func() { diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index c0882381..85b6e1f1 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -498,7 +498,7 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph h.readEncLevel = protocol.EncryptionHandshake h.handshakeOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret), - createHeaderProtector(suite, trafficSecret), + newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true), ) h.logger.Debugf("Installed Handshake Read keys") case qtls.EncryptionApplication: @@ -520,7 +520,7 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip h.writeEncLevel = protocol.EncryptionHandshake h.handshakeSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret), - createHeaderProtector(suite, trafficSecret), + newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true), ) h.logger.Debugf("Installed Handshake Write keys") case qtls.EncryptionApplication: diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go new file mode 100644 index 00000000..b6a7b7ea --- /dev/null +++ b/internal/handshake/header_protector.go @@ -0,0 +1,62 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "fmt" + + "github.com/marten-seemann/qtls" +) + +type headerProtector interface { + EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) + DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) +} + +func createAESHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block { + hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen()) + hp, err := aes.NewCipher(hpKey) + if err != nil { + panic(fmt.Sprintf("error creating new AES cipher: %s", err)) + } + return hp +} + +type aesHeaderProtector struct { + mask []byte + block cipher.Block + isLongHeader bool +} + +var _ headerProtector = &aesHeaderProtector{} + +func newAESHeaderProtector(block cipher.Block, isLongHeader bool) headerProtector { + return &aesHeaderProtector{ + block: block, + mask: make([]byte, block.BlockSize()), + isLongHeader: isLongHeader, + } +} + +func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) != len(p.mask) { + panic("invalid sample size") + } + p.block.Encrypt(p.mask, sample) + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index da794c5c..eb90122b 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -25,16 +25,18 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret) encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) - hpEncrypter, err := aes.NewCipher(myHPKey) + encrypterBlock, err := aes.NewCipher(myHPKey) if err != nil { return nil, nil, err } decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) - hpDecrypter, err := aes.NewCipher(otherHPKey) + decrypterBlock, err := aes.NewCipher(otherHPKey) if err != nil { return nil, nil, err } - return newLongHeaderSealer(encrypter, hpEncrypter), newLongHeaderOpener(decrypter, hpDecrypter), nil + return newLongHeaderSealer(encrypter, newAESHeaderProtector(encrypterBlock, true)), + newLongHeaderOpener(decrypter, newAESHeaderProtector(decrypterBlock, true)), + nil } func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index ff87aeeb..172c3e79 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -67,8 +67,8 @@ type updatableAEAD struct { nextRcvTrafficSecret []byte nextSendTrafficSecret []byte - hpDecrypter cipher.Block - hpEncrypter cipher.Block + headerDecrypter headerProtector + headerEncrypter headerProtector rttStats *congestion.RTTStats @@ -76,7 +76,6 @@ type updatableAEAD struct { // use a single slice to avoid allocations nonceBuf []byte - hpMask []byte } var _ ShortHeaderOpener = &updatableAEAD{} @@ -118,10 +117,9 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte // For the server, this function is called after SetWriteKey. func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { a.rcvAEAD = createAEAD(suite, trafficSecret) - a.hpDecrypter = createHeaderProtector(suite, trafficSecret) + a.headerDecrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false) if a.suite == nil { a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize()) - a.hpMask = make([]byte, a.hpDecrypter.BlockSize()) a.aeadOverhead = a.rcvAEAD.Overhead() a.suite = suite } @@ -134,10 +132,9 @@ func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { // For the server, this function is called before SetWriteKey. func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { a.sendAEAD = createAEAD(suite, trafficSecret) - a.hpEncrypter = createHeaderProtector(suite, trafficSecret) + a.headerEncrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false) if a.suite == nil { a.nonceBuf = make([]byte, a.sendAEAD.NonceSize()) - a.hpMask = make([]byte, a.hpEncrypter.BlockSize()) a.aeadOverhead = a.sendAEAD.Overhead() a.suite = suite } @@ -245,24 +242,10 @@ func (a *updatableAEAD) Overhead() int { return a.aeadOverhead } -func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != len(a.hpMask) { - panic("invalid sample size") - } - a.hpEncrypter.Encrypt(a.hpMask, sample) - *firstByte ^= a.hpMask[0] & 0x1f - for i := range pnBytes { - pnBytes[i] ^= a.hpMask[i+1] - } +func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) } -func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { - if len(sample) != len(a.hpMask) { - panic("invalid sample size") - } - a.hpDecrypter.Encrypt(a.hpMask, sample) - *firstByte ^= a.hpMask[0] & 0x1f - for i := range pnBytes { - pnBytes[i] ^= a.hpMask[i+1] - } +func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) }