mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
move header protection to a separate struct
This commit is contained in:
parent
75932b2dcd
commit
63c079e234
6 changed files with 93 additions and 74 deletions
|
@ -1,32 +1,28 @@
|
||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/marten-seemann/qtls"
|
"github.com/marten-seemann/qtls"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sealer struct {
|
type sealer struct {
|
||||||
aead cipher.AEAD
|
aead cipher.AEAD
|
||||||
hpEncrypter cipher.Block
|
headerProtector headerProtector
|
||||||
|
|
||||||
// use a single slice to avoid allocations
|
// use a single slice to avoid allocations
|
||||||
nonceBuf []byte
|
nonceBuf []byte
|
||||||
hpMask []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ LongHeaderSealer = &sealer{}
|
var _ LongHeaderSealer = &sealer{}
|
||||||
|
|
||||||
func newLongHeaderSealer(aead cipher.AEAD, hpEncrypter cipher.Block) LongHeaderSealer {
|
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
|
||||||
return &sealer{
|
return &sealer{
|
||||||
aead: aead,
|
aead: aead,
|
||||||
nonceBuf: make([]byte, aead.NonceSize()),
|
headerProtector: headerProtector,
|
||||||
hpEncrypter: hpEncrypter,
|
nonceBuf: make([]byte, aead.NonceSize()),
|
||||||
hpMask: make([]byte, hpEncrypter.BlockSize()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,14 +34,7 @@ func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []by
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||||
if len(sample) != s.hpEncrypter.BlockSize() {
|
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
|
||||||
panic("invalid sample size")
|
|
||||||
}
|
|
||||||
s.hpEncrypter.Encrypt(s.hpMask, sample)
|
|
||||||
*firstByte ^= s.hpMask[0] & 0xf
|
|
||||||
for i := range pnBytes {
|
|
||||||
pnBytes[i] ^= s.hpMask[i+1]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sealer) Overhead() int {
|
func (s *sealer) Overhead() int {
|
||||||
|
@ -53,22 +42,20 @@ func (s *sealer) Overhead() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
type longHeaderOpener struct {
|
type longHeaderOpener struct {
|
||||||
aead cipher.AEAD
|
aead cipher.AEAD
|
||||||
pnDecrypter cipher.Block
|
headerProtector headerProtector
|
||||||
|
|
||||||
// use a single slice to avoid allocations
|
// use a single slice to avoid allocations
|
||||||
nonceBuf []byte
|
nonceBuf []byte
|
||||||
hpMask []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ LongHeaderOpener = &longHeaderOpener{}
|
var _ LongHeaderOpener = &longHeaderOpener{}
|
||||||
|
|
||||||
func newLongHeaderOpener(aead cipher.AEAD, pnDecrypter cipher.Block) LongHeaderOpener {
|
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
|
||||||
return &longHeaderOpener{
|
return &longHeaderOpener{
|
||||||
aead: aead,
|
aead: aead,
|
||||||
nonceBuf: make([]byte, aead.NonceSize()),
|
headerProtector: headerProtector,
|
||||||
pnDecrypter: pnDecrypter,
|
nonceBuf: make([]byte, aead.NonceSize()),
|
||||||
hpMask: make([]byte, pnDecrypter.BlockSize()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,14 +71,7 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||||
if len(sample) != o.pnDecrypter.BlockSize() {
|
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
|
||||||
panic("invalid sample size")
|
|
||||||
}
|
|
||||||
o.pnDecrypter.Encrypt(o.hpMask, sample)
|
|
||||||
*firstByte ^= o.hpMask[0] & 0xf
|
|
||||||
for i := range pnBytes {
|
|
||||||
pnBytes[i] ^= o.hpMask[i+1]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD {
|
func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD {
|
||||||
|
@ -99,12 +79,3 @@ func createAEAD(suite cipherSuite, trafficSecret []byte) cipher.AEAD {
|
||||||
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
|
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
|
||||||
return suite.AEAD(key, iv)
|
return suite.AEAD(key, iv)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block {
|
|
||||||
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
|
|
||||||
hp, err := aes.NewCipher(hpKey)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
|
||||||
}
|
|
||||||
return hp
|
|
||||||
}
|
|
||||||
|
|
|
@ -24,7 +24,8 @@ var _ = Describe("AEAD", func() {
|
||||||
|
|
||||||
iv := make([]byte, 12)
|
iv := make([]byte, 12)
|
||||||
rand.Read(iv)
|
rand.Read(iv)
|
||||||
return newLongHeaderSealer(aead, hpBlock), newLongHeaderOpener(aead, hpBlock)
|
return newLongHeaderSealer(aead, newAESHeaderProtector(hpBlock, true)),
|
||||||
|
newLongHeaderOpener(aead, newAESHeaderProtector(hpBlock, true))
|
||||||
}
|
}
|
||||||
|
|
||||||
Context("message encryption", func() {
|
Context("message encryption", func() {
|
||||||
|
|
|
@ -498,7 +498,7 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
|
||||||
h.readEncLevel = protocol.EncryptionHandshake
|
h.readEncLevel = protocol.EncryptionHandshake
|
||||||
h.handshakeOpener = newLongHeaderOpener(
|
h.handshakeOpener = newLongHeaderOpener(
|
||||||
createAEAD(suite, trafficSecret),
|
createAEAD(suite, trafficSecret),
|
||||||
createHeaderProtector(suite, trafficSecret),
|
newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true),
|
||||||
)
|
)
|
||||||
h.logger.Debugf("Installed Handshake Read keys")
|
h.logger.Debugf("Installed Handshake Read keys")
|
||||||
case qtls.EncryptionApplication:
|
case qtls.EncryptionApplication:
|
||||||
|
@ -520,7 +520,7 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
|
||||||
h.writeEncLevel = protocol.EncryptionHandshake
|
h.writeEncLevel = protocol.EncryptionHandshake
|
||||||
h.handshakeSealer = newLongHeaderSealer(
|
h.handshakeSealer = newLongHeaderSealer(
|
||||||
createAEAD(suite, trafficSecret),
|
createAEAD(suite, trafficSecret),
|
||||||
createHeaderProtector(suite, trafficSecret),
|
newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), true),
|
||||||
)
|
)
|
||||||
h.logger.Debugf("Installed Handshake Write keys")
|
h.logger.Debugf("Installed Handshake Write keys")
|
||||||
case qtls.EncryptionApplication:
|
case qtls.EncryptionApplication:
|
||||||
|
|
62
internal/handshake/header_protector.go
Normal file
62
internal/handshake/header_protector.go
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/marten-seemann/qtls"
|
||||||
|
)
|
||||||
|
|
||||||
|
type headerProtector interface {
|
||||||
|
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||||
|
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAESHeaderProtector(suite cipherSuite, trafficSecret []byte) cipher.Block {
|
||||||
|
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
|
||||||
|
hp, err := aes.NewCipher(hpKey)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||||
|
}
|
||||||
|
return hp
|
||||||
|
}
|
||||||
|
|
||||||
|
type aesHeaderProtector struct {
|
||||||
|
mask []byte
|
||||||
|
block cipher.Block
|
||||||
|
isLongHeader bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ headerProtector = &aesHeaderProtector{}
|
||||||
|
|
||||||
|
func newAESHeaderProtector(block cipher.Block, isLongHeader bool) headerProtector {
|
||||||
|
return &aesHeaderProtector{
|
||||||
|
block: block,
|
||||||
|
mask: make([]byte, block.BlockSize()),
|
||||||
|
isLongHeader: isLongHeader,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||||
|
p.apply(sample, firstByte, hdrBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||||
|
p.apply(sample, firstByte, hdrBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||||
|
if len(sample) != len(p.mask) {
|
||||||
|
panic("invalid sample size")
|
||||||
|
}
|
||||||
|
p.block.Encrypt(p.mask, sample)
|
||||||
|
if p.isLongHeader {
|
||||||
|
*firstByte ^= p.mask[0] & 0xf
|
||||||
|
} else {
|
||||||
|
*firstByte ^= p.mask[0] & 0x1f
|
||||||
|
}
|
||||||
|
for i := range hdrBytes {
|
||||||
|
hdrBytes[i] ^= p.mask[i+1]
|
||||||
|
}
|
||||||
|
}
|
|
@ -25,16 +25,18 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo
|
||||||
otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret)
|
otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret)
|
||||||
|
|
||||||
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
|
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
|
||||||
hpEncrypter, err := aes.NewCipher(myHPKey)
|
encrypterBlock, err := aes.NewCipher(myHPKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
|
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
|
||||||
hpDecrypter, err := aes.NewCipher(otherHPKey)
|
decrypterBlock, err := aes.NewCipher(otherHPKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return newLongHeaderSealer(encrypter, hpEncrypter), newLongHeaderOpener(decrypter, hpDecrypter), nil
|
return newLongHeaderSealer(encrypter, newAESHeaderProtector(encrypterBlock, true)),
|
||||||
|
newLongHeaderOpener(decrypter, newAESHeaderProtector(decrypterBlock, true)),
|
||||||
|
nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||||
|
|
|
@ -67,8 +67,8 @@ type updatableAEAD struct {
|
||||||
nextRcvTrafficSecret []byte
|
nextRcvTrafficSecret []byte
|
||||||
nextSendTrafficSecret []byte
|
nextSendTrafficSecret []byte
|
||||||
|
|
||||||
hpDecrypter cipher.Block
|
headerDecrypter headerProtector
|
||||||
hpEncrypter cipher.Block
|
headerEncrypter headerProtector
|
||||||
|
|
||||||
rttStats *congestion.RTTStats
|
rttStats *congestion.RTTStats
|
||||||
|
|
||||||
|
@ -76,7 +76,6 @@ type updatableAEAD struct {
|
||||||
|
|
||||||
// use a single slice to avoid allocations
|
// use a single slice to avoid allocations
|
||||||
nonceBuf []byte
|
nonceBuf []byte
|
||||||
hpMask []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ShortHeaderOpener = &updatableAEAD{}
|
var _ ShortHeaderOpener = &updatableAEAD{}
|
||||||
|
@ -118,10 +117,9 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
|
||||||
// For the server, this function is called after SetWriteKey.
|
// For the server, this function is called after SetWriteKey.
|
||||||
func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
|
func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
|
||||||
a.rcvAEAD = createAEAD(suite, trafficSecret)
|
a.rcvAEAD = createAEAD(suite, trafficSecret)
|
||||||
a.hpDecrypter = createHeaderProtector(suite, trafficSecret)
|
a.headerDecrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false)
|
||||||
if a.suite == nil {
|
if a.suite == nil {
|
||||||
a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize())
|
a.nonceBuf = make([]byte, a.rcvAEAD.NonceSize())
|
||||||
a.hpMask = make([]byte, a.hpDecrypter.BlockSize())
|
|
||||||
a.aeadOverhead = a.rcvAEAD.Overhead()
|
a.aeadOverhead = a.rcvAEAD.Overhead()
|
||||||
a.suite = suite
|
a.suite = suite
|
||||||
}
|
}
|
||||||
|
@ -134,10 +132,9 @@ func (a *updatableAEAD) SetReadKey(suite cipherSuite, trafficSecret []byte) {
|
||||||
// For the server, this function is called before SetWriteKey.
|
// For the server, this function is called before SetWriteKey.
|
||||||
func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
|
func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
|
||||||
a.sendAEAD = createAEAD(suite, trafficSecret)
|
a.sendAEAD = createAEAD(suite, trafficSecret)
|
||||||
a.hpEncrypter = createHeaderProtector(suite, trafficSecret)
|
a.headerEncrypter = newAESHeaderProtector(createAESHeaderProtector(suite, trafficSecret), false)
|
||||||
if a.suite == nil {
|
if a.suite == nil {
|
||||||
a.nonceBuf = make([]byte, a.sendAEAD.NonceSize())
|
a.nonceBuf = make([]byte, a.sendAEAD.NonceSize())
|
||||||
a.hpMask = make([]byte, a.hpEncrypter.BlockSize())
|
|
||||||
a.aeadOverhead = a.sendAEAD.Overhead()
|
a.aeadOverhead = a.sendAEAD.Overhead()
|
||||||
a.suite = suite
|
a.suite = suite
|
||||||
}
|
}
|
||||||
|
@ -245,24 +242,10 @@ func (a *updatableAEAD) Overhead() int {
|
||||||
return a.aeadOverhead
|
return a.aeadOverhead
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||||
if len(sample) != len(a.hpMask) {
|
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
|
||||||
panic("invalid sample size")
|
|
||||||
}
|
|
||||||
a.hpEncrypter.Encrypt(a.hpMask, sample)
|
|
||||||
*firstByte ^= a.hpMask[0] & 0x1f
|
|
||||||
for i := range pnBytes {
|
|
||||||
pnBytes[i] ^= a.hpMask[i+1]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||||
if len(sample) != len(a.hpMask) {
|
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
|
||||||
panic("invalid sample size")
|
|
||||||
}
|
|
||||||
a.hpDecrypter.Encrypt(a.hpMask, sample)
|
|
||||||
*firstByte ^= a.hpMask[0] & 0x1f
|
|
||||||
for i := range pnBytes {
|
|
||||||
pnBytes[i] ^= a.hpMask[i+1]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue