mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
move header protection to a separate struct
This commit is contained in:
parent
75932b2dcd
commit
63c079e234
6 changed files with 93 additions and 74 deletions
|
@ -1,32 +1,28 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type sealer struct {
|
||||
aead cipher.AEAD
|
||||
hpEncrypter cipher.Block
|
||||
aead cipher.AEAD
|
||||
headerProtector headerProtector
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
hpMask []byte
|
||||
}
|
||||
|
||||
var _ LongHeaderSealer = &sealer{}
|
||||
|
||||
func newLongHeaderSealer(aead cipher.AEAD, hpEncrypter cipher.Block) LongHeaderSealer {
|
||||
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
|
||||
return &sealer{
|
||||
aead: aead,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
hpEncrypter: hpEncrypter,
|
||||
hpMask: make([]byte, hpEncrypter.BlockSize()),
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,14 +34,7 @@ func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []by
|
|||
}
|
||||
|
||||
func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != s.hpEncrypter.BlockSize() {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
s.hpEncrypter.Encrypt(s.hpMask, sample)
|
||||
*firstByte ^= s.hpMask[0] & 0xf
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= s.hpMask[i+1]
|
||||
}
|
||||
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
|
||||
func (s *sealer) Overhead() int {
|
||||
|
@ -53,22 +42,20 @@ func (s *sealer) Overhead() int {
|
|||
}
|
||||
|
||||
type longHeaderOpener struct {
|
||||
aead cipher.AEAD
|
||||
pnDecrypter cipher.Block
|
||||
aead cipher.AEAD
|
||||
headerProtector headerProtector
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
hpMask []byte
|
||||
}
|
||||
|
||||
var _ LongHeaderOpener = &longHeaderOpener{}
|
||||
|
||||
func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener {
|
||||
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
|
||||
return &longHeaderOpener{
|
||||
aead: aead,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
pnDecrypter: pnDecrypter,
|
||||
hpMask: make([]byte, pnDecrypter.BlockSize()),
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,14 +71,7 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []
|
|||
}
|
||||
|
||||
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != o.pnDecrypter.BlockSize() {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
o.pnDecrypter.Encrypt(o.hpMask, sample)
|
||||
*firstByte ^= o.hpMask[0] & 0xf
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= o.hpMask[i+1]
|
||||
}
|
||||
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
|
||||
func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD {
|
||||
|
@ -99,12 +79,3 @@ func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD {
|
|||
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
|
||||
return suite.AEAD(key, iv)
|
||||
}
|
||||
|
||||
func createHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block {
|
||||
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
|
||||
hp, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
return hp
|
||||
}
|
||||
|
|
|
@ -24,7 +24,8 @@ var _ = Describe("AEAD", func() {
|
|||
|
||||
iv := make([]byte, 12)
|
||||
rand.Read(iv)
|
||||
return newLongHeaderSealer(aead, hpBlock), newLongHeaderOpener(aead, hpBlock)
|
||||
return newLongHeaderSealer(aead, newAESHeaderProtector(hpBlock, true)),
|
||||
newLongHeaderOpener(aead, newAESHeaderProtector(hpBlock, true))
|
||||
}
|
||||
|
||||
Context("message encryption", func() {
|
||||
|
|
|
@ -498,7 +498,7 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
|
|||
h.readEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeOpener = newLongHeaderOpener(
|
||||
createAEAD(suite, trafficSecret),
|
||||
createHeaderProtector(suite, trafficSecret),
|
||||
newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true),
|
||||
)
|
||||
h.logger.Debugf("Installed Handshake Read keys")
|
||||
case qtls.EncryptionApplication:
|
||||
|
@ -520,7 +520,7 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
|
|||
h.writeEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeSealer = newLongHeaderSealer(
|
||||
createAEAD(suite, trafficSecret),
|
||||
createHeaderProtector(suite, trafficSecret),
|
||||
newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true),
|
||||
)
|
||||
h.logger.Debugf("Installed Handshake Write keys")
|
||||
case qtls.EncryptionApplication:
|
||||
|
|
62
internal/handshake/header_protector.go
Normal file
62
internal/handshake/header_protector.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type headerProtector interface {
|
||||
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
}
|
||||
|
||||
func createAESHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block {
|
||||
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
|
||||
hp, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
return hp
|
||||
}
|
||||
|
||||
type aesHeaderProtector struct {
|
||||
mask []byte
|
||||
block cipher.Block
|
||||
isLongHeader bool
|
||||
}
|
||||
|
||||
var _ headerProtector = &aesHeaderProtector{}
|
||||
|
||||
func newAESHeaderProtector(block cipher.Block, isLongHeader bool) headerProtector {
|
||||
return &aesHeaderProtector{
|
||||
block: block,
|
||||
mask: make([]byte, block.BlockSize()),
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
if len(sample) != len(p.mask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
p.block.Encrypt(p.mask, sample)
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
*firstByte ^= p.mask[0] & 0x1f
|
||||
}
|
||||
for i := range hdrBytes {
|
||||
hdrBytes[i] ^= p.mask[i+1]
|
||||
}
|
||||
}
|
|
@ -25,16 +25,18 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo
|
|||
otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret)
|
||||
|
||||
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
|
||||
hpEncrypter, err := aes.NewCipher(myHPKey)
|
||||
encrypterBlock, err := aes.NewCipher(myHPKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
|
||||
hpDecrypter, err := aes.NewCipher(otherHPKey)
|
||||
decrypterBlock, err := aes.NewCipher(otherHPKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return newLongHeaderSealer(encrypter, hpEncrypter), newLongHeaderOpener(decrypter, hpDecrypter), nil
|
||||
return newLongHeaderSealer(encrypter, newAESHeaderProtector(encrypterBlock, true)),
|
||||
newLongHeaderOpener(decrypter, newAESHeaderProtector(decrypterBlock, true)),
|
||||
nil
|
||||
}
|
||||
|
||||
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
|
|
|
@ -67,8 +67,8 @@ type updatableAEAD struct {
|
|||
nextRcvTrafficSecret []byte
|
||||
nextSendTrafficSecret []byte
|
||||
|
||||
hpDecrypter cipher.Block
|
||||
hpEncrypter cipher.Block
|
||||
headerDecrypter headerProtector
|
||||
headerEncrypter headerProtector
|
||||
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
|
@ -76,7 +76,6 @@ type updatableAEAD struct {
|
|||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
hpMask []byte
|
||||
}
|
||||
|
||||
var _ ShortHeaderOpener = &updatableAEAD{}
|
||||
|
@ -118,10 +117,9 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
|
|||
// For the server, this function is called after SetWriteKey.
|
||||
func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
|
||||
a.rcvAEAD = createAEAD(suite, trafficSecret)
|
||||
a.hpDecrypter = createHeaderProtector(suite, trafficSecret)
|
||||
a.headerDecrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false)
|
||||
if a.suite == nil {
|
||||
a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize())
|
||||
a.hpMask = make([]byte, a.hpDecrypter.BlockSize())
|
||||
a.aeadOverhead = a.rcvAEAD.Overhead()
|
||||
a.suite = suite
|
||||
}
|
||||
|
@ -134,10 +132,9 @@ func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
|
|||
// For the server, this function is called before SetWriteKey.
|
||||
func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
|
||||
a.sendAEAD = createAEAD(suite, trafficSecret)
|
||||
a.hpEncrypter = createHeaderProtector(suite, trafficSecret)
|
||||
a.headerEncrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false)
|
||||
if a.suite == nil {
|
||||
a.nonceBuf = make([]byte, a.sendAEAD.NonceSize())
|
||||
a.hpMask = make([]byte, a.hpEncrypter.BlockSize())
|
||||
a.aeadOverhead = a.sendAEAD.Overhead()
|
||||
a.suite = suite
|
||||
}
|
||||
|
@ -245,24 +242,10 @@ func (a *updatableAEAD) Overhead() int {
|
|||
return a.aeadOverhead
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != len(a.hpMask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
a.hpEncrypter.Encrypt(a.hpMask, sample)
|
||||
*firstByte ^= a.hpMask[0] & 0x1f
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= a.hpMask[i+1]
|
||||
}
|
||||
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
if len(sample) != len(a.hpMask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
a.hpDecrypter.Encrypt(a.hpMask, sample)
|
||||
*firstByte ^= a.hpMask[0] & 0x1f
|
||||
for i := range pnBytes {
|
||||
pnBytes[i] ^= a.hpMask[i+1]
|
||||
}
|
||||
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue