move header protection to a separate struct

This commit is contained in:
Marten Seemann 2019-09-06 15:45:41 +07:00
parent 75932b2dcd
commit 63c079e234
6 changed files with 93 additions and 74 deletions

View file

@ -1,32 +1,28 @@
package handshake package handshake
import ( import (
"crypto/aes"
"crypto/cipher" "crypto/cipher"
"encoding/binary" "encoding/binary"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls" "github.com/marten-seemann/qtls"
) )
type sealer struct { type sealer struct {
aead cipher.AEAD aead cipher.AEAD
hpEncrypter cipher.Block headerProtector headerProtector
// use a single slice to avoid allocations // use a single slice to avoid allocations
nonceBuf []byte nonceBuf []byte
hpMask []byte
} }
var _ LongHeaderSealer = &sealer{} var _ LongHeaderSealer = &sealer{}
func newLongHeaderSealer(aead cipher.AEAD, hpEncrypter cipher.Block) LongHeaderSealer { func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
return &sealer{ return &sealer{
aead: aead, aead: aead,
nonceBuf: make([]byte, aead.NonceSize()), headerProtector: headerProtector,
hpEncrypter: hpEncrypter, nonceBuf: make([]byte, aead.NonceSize()),
hpMask: make([]byte, hpEncrypter.BlockSize()),
} }
} }
@ -38,14 +34,7 @@ func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []by
} }
func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != s.hpEncrypter.BlockSize() { s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
panic("invalid sample size")
}
s.hpEncrypter.Encrypt(s.hpMask, sample)
*firstByte ^= s.hpMask[0] & 0xf
for i := range pnBytes {
pnBytes[i] ^= s.hpMask[i+1]
}
} }
func (s *sealer) Overhead() int { func (s *sealer) Overhead() int {
@ -53,22 +42,20 @@ func (s *sealer) Overhead() int {
} }
type longHeaderOpener struct { type longHeaderOpener struct {
aead cipher.AEAD aead cipher.AEAD
pnDecrypter cipher.Block headerProtector headerProtector
// use a single slice to avoid allocations // use a single slice to avoid allocations
nonceBuf []byte nonceBuf []byte
hpMask []byte
} }
var _ LongHeaderOpener = &longHeaderOpener{} var _ LongHeaderOpener = &longHeaderOpener{}
func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener { func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
return &longHeaderOpener{ return &longHeaderOpener{
aead: aead, aead: aead,
nonceBuf: make([]byte, aead.NonceSize()), headerProtector: headerProtector,
pnDecrypter: pnDecrypter, nonceBuf: make([]byte, aead.NonceSize()),
hpMask: make([]byte, pnDecrypter.BlockSize()),
} }
} }
@ -84,14 +71,7 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []
} }
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != o.pnDecrypter.BlockSize() { o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
panic("invalid sample size")
}
o.pnDecrypter.Encrypt(o.hpMask, sample)
*firstByte ^= o.hpMask[0] & 0xf
for i := range pnBytes {
pnBytes[i] ^= o.hpMask[i+1]
}
} }
func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD { func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD {
@ -99,12 +79,3 @@ func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD {
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen()) iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
return suite.AEAD(key, iv) return suite.AEAD(key, iv)
} }
func createHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block {
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
hp, err := aes.NewCipher(hpKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
return hp
}

View file

@ -24,7 +24,8 @@ var _ = Describe("AEAD", func() {
iv := make([]byte, 12) iv := make([]byte, 12)
rand.Read(iv) rand.Read(iv)
return newLongHeaderSealer(aead, hpBlock), newLongHeaderOpener(aead, hpBlock) return newLongHeaderSealer(aead, newAESHeaderProtector(hpBlock, true)),
newLongHeaderOpener(aead, newAESHeaderProtector(hpBlock, true))
} }
Context("message encryption", func() { Context("message encryption", func() {

View file

@ -498,7 +498,7 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
h.readEncLevel = protocol.EncryptionHandshake h.readEncLevel = protocol.EncryptionHandshake
h.handshakeOpener = newLongHeaderOpener( h.handshakeOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret), createAEAD(suite, trafficSecret),
createHeaderProtector(suite, trafficSecret), newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true),
) )
h.logger.Debugf("Installed Handshake Read keys") h.logger.Debugf("Installed Handshake Read keys")
case qtls.EncryptionApplication: case qtls.EncryptionApplication:
@ -520,7 +520,7 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
h.writeEncLevel = protocol.EncryptionHandshake h.writeEncLevel = protocol.EncryptionHandshake
h.handshakeSealer = newLongHeaderSealer( h.handshakeSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret), createAEAD(suite, trafficSecret),
createHeaderProtector(suite, trafficSecret), newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true),
) )
h.logger.Debugf("Installed Handshake Write keys") h.logger.Debugf("Installed Handshake Write keys")
case qtls.EncryptionApplication: case qtls.EncryptionApplication:

View file

@ -0,0 +1,62 @@
package handshake
import (
"crypto/aes"
"crypto/cipher"
"fmt"
"github.com/marten-seemann/qtls"
)
type headerProtector interface {
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
}
func createAESHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block {
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
hp, err := aes.NewCipher(hpKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
return hp
}
type aesHeaderProtector struct {
mask []byte
block cipher.Block
isLongHeader bool
}
var _ headerProtector = &aesHeaderProtector{}
func newAESHeaderProtector(block cipher.Block, isLongHeader bool) headerProtector {
return &aesHeaderProtector{
block: block,
mask: make([]byte, block.BlockSize()),
isLongHeader: isLongHeader,
}
}
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != len(p.mask) {
panic("invalid sample size")
}
p.block.Encrypt(p.mask, sample)
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {
*firstByte ^= p.mask[0] & 0x1f
}
for i := range hdrBytes {
hdrBytes[i] ^= p.mask[i+1]
}
}

View file

@ -25,16 +25,18 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo
otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret) otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret)
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
hpEncrypter, err := aes.NewCipher(myHPKey) encrypterBlock, err := aes.NewCipher(myHPKey)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
hpDecrypter, err := aes.NewCipher(otherHPKey) decrypterBlock, err := aes.NewCipher(otherHPKey)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return newLongHeaderSealer(encrypter, hpEncrypter), newLongHeaderOpener(decrypter, hpDecrypter), nil return newLongHeaderSealer(encrypter, newAESHeaderProtector(encrypterBlock, true)),
newLongHeaderOpener(decrypter, newAESHeaderProtector(decrypterBlock, true)),
nil
} }
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {

View file

@ -67,8 +67,8 @@ type updatableAEAD struct {
nextRcvTrafficSecret []byte nextRcvTrafficSecret []byte
nextSendTrafficSecret []byte nextSendTrafficSecret []byte
hpDecrypter cipher.Block headerDecrypter headerProtector
hpEncrypter cipher.Block headerEncrypter headerProtector
rttStats *congestion.RTTStats rttStats *congestion.RTTStats
@ -76,7 +76,6 @@ type updatableAEAD struct {
// use a single slice to avoid allocations // use a single slice to avoid allocations
nonceBuf []byte nonceBuf []byte
hpMask []byte
} }
var _ ShortHeaderOpener = &updatableAEAD{} var _ ShortHeaderOpener = &updatableAEAD{}
@ -118,10 +117,9 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
// For the server, this function is called after SetWriteKey. // For the server, this function is called after SetWriteKey.
func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) { func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
a.rcvAEAD = createAEAD(suite, trafficSecret) a.rcvAEAD = createAEAD(suite, trafficSecret)
a.hpDecrypter = createHeaderProtector(suite, trafficSecret) a.headerDecrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false)
if a.suite == nil { if a.suite == nil {
a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize()) a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize())
a.hpMask = make([]byte, a.hpDecrypter.BlockSize())
a.aeadOverhead = a.rcvAEAD.Overhead() a.aeadOverhead = a.rcvAEAD.Overhead()
a.suite = suite a.suite = suite
} }
@ -134,10 +132,9 @@ func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
// For the server, this function is called before SetWriteKey. // For the server, this function is called before SetWriteKey.
func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
a.sendAEAD = createAEAD(suite, trafficSecret) a.sendAEAD = createAEAD(suite, trafficSecret)
a.hpEncrypter = createHeaderProtector(suite, trafficSecret) a.headerEncrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false)
if a.suite == nil { if a.suite == nil {
a.nonceBuf = make([]byte, a.sendAEAD.NonceSize()) a.nonceBuf = make([]byte, a.sendAEAD.NonceSize())
a.hpMask = make([]byte, a.hpEncrypter.BlockSize())
a.aeadOverhead = a.sendAEAD.Overhead() a.aeadOverhead = a.sendAEAD.Overhead()
a.suite = suite a.suite = suite
} }
@ -245,24 +242,10 @@ func (a *updatableAEAD) Overhead() int {
return a.aeadOverhead return a.aeadOverhead
} }
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != len(a.hpMask) { a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
panic("invalid sample size")
}
a.hpEncrypter.Encrypt(a.hpMask, sample)
*firstByte ^= a.hpMask[0] & 0x1f
for i := range pnBytes {
pnBytes[i] ^= a.hpMask[i+1]
}
} }
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != len(a.hpMask) { a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
panic("invalid sample size")
}
a.hpDecrypter.Encrypt(a.hpMask, sample)
*firstByte ^= a.hpMask[0] & 0x1f
for i := range pnBytes {
pnBytes[i] ^= a.hpMask[i+1]
}
} }