implement header encrytion and decryption for sealers and openers

This commit is contained in:
Marten Seemann 2018-12-14 16:42:45 +06:30
parent 67f923c736
commit 5a68ba0a02
8 changed files with 239 additions and 44 deletions

View file

@ -8,20 +8,28 @@ import (
) )
type sealer struct { type sealer struct {
iv []byte iv []byte
aead cipher.AEAD aead cipher.AEAD
pnEncrypter cipher.Block
// use a single slice to avoid allocations // use a single slice to avoid allocations
nonceBuf []byte nonceBuf []byte
pnMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
} }
var _ Sealer = &sealer{} var _ Sealer = &sealer{}
func newSealer(aead cipher.AEAD, iv []byte) Sealer { func newSealer(aead cipher.AEAD, iv []byte, pnEncrypter cipher.Block, is1RTT bool) Sealer {
return &sealer{ return &sealer{
iv: iv, iv: iv,
aead: aead, aead: aead,
nonceBuf: make([]byte, aead.NonceSize()), nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
pnEncrypter: pnEncrypter,
pnMask: make([]byte, pnEncrypter.BlockSize()),
} }
} }
@ -30,25 +38,48 @@ func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []by
return s.aead.Seal(dst, s.nonceBuf, src, ad) return s.aead.Seal(dst, s.nonceBuf, src, ad)
} }
func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != s.pnEncrypter.BlockSize() {
panic("invalid sample size")
}
s.pnEncrypter.Encrypt(s.pnMask, sample)
if s.is1RTT {
*firstByte ^= s.pnMask[0] & 0x1f
} else {
*firstByte ^= s.pnMask[0] & 0xf
}
for i := range pnBytes {
pnBytes[i] ^= s.pnMask[i+1]
}
}
func (s *sealer) Overhead() int { func (s *sealer) Overhead() int {
return s.aead.Overhead() return s.aead.Overhead()
} }
type opener struct { type opener struct {
iv []byte iv []byte
aead cipher.AEAD aead cipher.AEAD
pnDecrypter cipher.Block
// use a single slice to avoid allocations // use a single slice to avoid allocations
nonceBuf []byte nonceBuf []byte
pnMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
} }
var _ Opener = &opener{} var _ Opener = &opener{}
func newOpener(aead cipher.AEAD, iv []byte) Opener { func newOpener(aead cipher.AEAD, iv []byte, pnDecrypter cipher.Block, is1RTT bool) Opener {
return &opener{ return &opener{
iv: iv, iv: iv,
aead: aead, aead: aead,
nonceBuf: make([]byte, aead.NonceSize()), nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
pnDecrypter: pnDecrypter,
pnMask: make([]byte, pnDecrypter.BlockSize()),
} }
} }
@ -56,3 +87,18 @@ func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]b
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
return o.aead.Open(dst, o.nonceBuf, src, ad) return o.aead.Open(dst, o.nonceBuf, src, ad)
} }
func (o *opener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != o.pnDecrypter.BlockSize() {
panic("invalid sample size")
}
o.pnDecrypter.Encrypt(o.pnMask, sample)
if o.is1RTT {
*firstByte ^= o.pnMask[0] & 0x1f
} else {
*firstByte ^= o.pnMask[0] & 0xf
}
for i := range pnBytes {
pnBytes[i] ^= o.pnMask[i+1]
}
}

View file

@ -10,42 +10,106 @@ import (
) )
var _ = Describe("AEAD", func() { var _ = Describe("AEAD", func() {
var sealer Sealer getSealerAndOpener := func(is1RTT bool) (Sealer, Opener) {
var opener Opener
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
ad := []byte("Donec in velit neque.")
BeforeEach(func() {
key := make([]byte, 16) key := make([]byte, 16)
pnKey := make([]byte, 16)
rand.Read(key) rand.Read(key)
rand.Read(pnKey)
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
aead, err := cipher.NewGCM(block) aead, err := cipher.NewGCM(block)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
pnBlock, err := aes.NewCipher(pnKey)
Expect(err).ToNot(HaveOccurred())
iv := make([]byte, 12) iv := make([]byte, 12)
rand.Read(iv) rand.Read(iv)
sealer = newSealer(aead, iv) return newSealer(aead, iv, pnBlock, is1RTT), newOpener(aead, iv, pnBlock, is1RTT)
opener = newOpener(aead, iv) }
Context("message encryption", func() {
var (
sealer Sealer
opener Opener
)
BeforeEach(func() {
sealer, opener = getSealerAndOpener(false)
})
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
ad := []byte("Donec in velit neque.")
It("encrypts and decrypts a message", func() {
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
opened, err := opener.Open(nil, encrypted, 0x1337, ad)
Expect(err).ToNot(HaveOccurred())
Expect(opened).To(Equal(msg))
})
It("fails to open a message if the associated data is not the same", func() {
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad"))
Expect(err).To(MatchError("cipher: message authentication failed"))
})
It("fails to open a message if the packet number is not the same", func() {
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted, 0x42, ad)
Expect(err).To(MatchError("cipher: message authentication failed"))
})
}) })
It("encrypts and decrypts a message", func() { Context("header encryption", func() {
encrypted := sealer.Seal(nil, msg, 0x1337, ad) It("encrypts and encrypts the header, for long headers", func() {
opened, err := opener.Open(nil, encrypted, 0x1337, ad) sealer, opener := getSealerAndOpener(false)
Expect(err).ToNot(HaveOccurred()) var lastFourBitsDifferent int
Expect(opened).To(Equal(msg)) for i := 0; i < 100; i++ {
}) sample := make([]byte, 16)
rand.Read(sample)
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
sealer.EncryptHeader(sample, &header[0], header[9:13])
if header[0]&0xf != 0xb5&0xf {
lastFourBitsDifferent++
}
Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0)))
Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
opener.DecryptHeader(sample, &header[0], header[9:13])
Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
}
Expect(lastFourBitsDifferent).To(BeNumerically(">", 75))
})
It("fails to open a message if the associated data is not the same", func() { It("encrypts and encrypts the header, for short headers", func() {
encrypted := sealer.Seal(nil, msg, 0x1337, ad) sealer, opener := getSealerAndOpener(true)
_, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) var lastFiveBitsDifferent int
Expect(err).To(MatchError("cipher: message authentication failed")) for i := 0; i < 100; i++ {
}) sample := make([]byte, 16)
rand.Read(sample)
header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
sealer.EncryptHeader(sample, &header[0], header[9:13])
if header[0]&0x1f != 0xb5&0x1f {
lastFiveBitsDifferent++
}
Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0)))
Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
opener.DecryptHeader(sample, &header[0], header[9:13])
Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
}
Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75))
})
It("fails to open a message if the packet number is not the same", func() { It("fails to decrypt the header when using a different sample", func() {
encrypted := sealer.Seal(nil, msg, 0x1337, ad) sealer, opener := getSealerAndOpener(true)
_, err := opener.Open(nil, encrypted, 0x42, ad) header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
Expect(err).To(MatchError("cipher: message authentication failed")) sample := make([]byte, 16)
rand.Read(sample)
sealer.EncryptHeader(sample, &header[0], header[9:13])
rand.Read(sample) // use a different sample
opener.DecryptHeader(sample, &header[0], header[9:13])
Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
})
}) })
}) })

View file

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"crypto/aes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -407,7 +408,17 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) { func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen()) key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen()) iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
opener := newOpener(suite.AEAD(key, iv), iv) pnKey := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "pn", suite.KeyLen())
pnDecrypter, err := aes.NewCipher(pnKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
opener := newOpener(
suite.AEAD(key, iv),
iv,
pnDecrypter,
h.readEncLevel == protocol.Encryption1RTT,
)
switch h.readEncLevel { switch h.readEncLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
@ -427,7 +438,17 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) { func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen()) key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen()) iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
sealer := newSealer(suite.AEAD(key, iv), iv) pnKey := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "pn", suite.KeyLen())
pnEncrypter, err := aes.NewCipher(pnKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
sealer := newSealer(
suite.AEAD(key, iv),
iv,
pnEncrypter,
h.writeEncLevel == protocol.Encryption1RTT,
)
switch h.writeEncLevel { switch h.writeEncLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:

View file

@ -21,8 +21,8 @@ func newInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Se
mySecret = serverSecret mySecret = serverSecret
otherSecret = clientSecret otherSecret = clientSecret
} }
myKey, _, myIV := computeInitialKeyAndIV(mySecret) myKey, myPNKey, myIV := computeInitialKeyAndIV(mySecret)
otherKey, _, otherIV := computeInitialKeyAndIV(otherSecret) otherKey, otherPNKey, otherIV := computeInitialKeyAndIV(otherSecret)
encrypterCipher, err := aes.NewCipher(myKey) encrypterCipher, err := aes.NewCipher(myKey)
if err != nil { if err != nil {
@ -32,6 +32,10 @@ func newInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Se
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
pnEncrypter, err := aes.NewCipher(myPNKey)
if err != nil {
return nil, nil, err
}
decrypterCipher, err := aes.NewCipher(otherKey) decrypterCipher, err := aes.NewCipher(otherKey)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -40,7 +44,11 @@ func newInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Se
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return newSealer(encrypter, myIV), newOpener(decrypter, otherIV), nil pnDecrypter, err := aes.NewCipher(otherPNKey)
if err != nil {
return nil, nil, err
}
return newSealer(encrypter, myIV, pnEncrypter, false), newOpener(decrypter, otherIV, pnDecrypter, false), nil
} }
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {

View file

@ -1,6 +1,8 @@
package handshake package handshake
import ( import (
"math/rand"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -64,7 +66,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
}) })
It("seals and opens", func() { It("seals and opens", func() {
connectionID := protocol.ConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}) connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}
clientSealer, clientOpener, err := newInitialAEAD(connectionID, protocol.PerspectiveClient) clientSealer, clientOpener, err := newInitialAEAD(connectionID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serverSealer, serverOpener, err := newInitialAEAD(connectionID, protocol.PerspectiveServer) serverSealer, serverOpener, err := newInitialAEAD(connectionID, protocol.PerspectiveServer)
@ -81,8 +83,8 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
}) })
It("doesn't work if initialized with different connection IDs", func() { It("doesn't work if initialized with different connection IDs", func() {
c1 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1}) c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1}
c2 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2}) c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2}
clientSealer, _, err := newInitialAEAD(c1, protocol.PerspectiveClient) clientSealer, _, err := newInitialAEAD(c1, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, serverOpener, err := newInitialAEAD(c2, protocol.PerspectiveServer) _, serverOpener, err := newInitialAEAD(c2, protocol.PerspectiveServer)
@ -92,4 +94,36 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
_, err = serverOpener.Open(nil, clientMessage, 42, []byte("aad")) _, err = serverOpener.Open(nil, clientMessage, 42, []byte("aad"))
Expect(err).To(MatchError("cipher: message authentication failed")) Expect(err).To(MatchError("cipher: message authentication failed"))
}) })
It("encrypts und decrypts the header", func() {
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
clientSealer, clientOpener, err := newInitialAEAD(connID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
serverSealer, serverOpener, err := newInitialAEAD(connID, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
// the first byte and the last 4 bytes should be encrypted
header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef}
sample := make([]byte, 16)
rand.Read(sample)
clientSealer.EncryptHeader(sample, &header[0], header[6:10])
// only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified
Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0)))
Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4}))
Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
serverOpener.DecryptHeader(sample, &header[0], header[6:10])
Expect(header[0]).To(Equal(byte(0x5e)))
Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4}))
Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
serverSealer.EncryptHeader(sample, &header[0], header[6:10])
// only the last 4 bits of the first byte are encrypted. Check that the first 4 bits are unmodified
Expect(header[0] & 0xf0).To(Equal(byte(0x5e & 0xf0)))
Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4}))
Expect(header[6:10]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
clientOpener.DecryptHeader(sample, &header[0], header[6:10])
Expect(header[0]).To(Equal(byte(0x5e)))
Expect(header[1:6]).To(Equal([]byte{0, 1, 2, 3, 4}))
Expect(header[6:10]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
})
}) })

View file

@ -11,11 +11,13 @@ import (
// Opener opens a packet // Opener opens a packet
type Opener interface { type Opener interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
} }
// Sealer seals a packet // Sealer seals a packet
type Sealer interface { type Sealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
Overhead() int Overhead() int
} }

View file

@ -34,6 +34,16 @@ func (m *MockOpener) EXPECT() *MockOpenerMockRecorder {
return m.recorder return m.recorder
} }
// DecryptHeader mocks base method
func (m *MockOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
}
// DecryptHeader indicates an expected call of DecryptHeader
func (mr *MockOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockOpener)(nil).DecryptHeader), arg0, arg1, arg2)
}
// Open mocks base method // Open mocks base method
func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)

View file

@ -34,6 +34,16 @@ func (m *MockSealer) EXPECT() *MockSealerMockRecorder {
return m.recorder return m.recorder
} }
// EncryptHeader mocks base method
func (m *MockSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2)
}
// EncryptHeader indicates an expected call of EncryptHeader
func (mr *MockSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockSealer)(nil).EncryptHeader), arg0, arg1, arg2)
}
// Overhead mocks base method // Overhead mocks base method
func (m *MockSealer) Overhead() int { func (m *MockSealer) Overhead() int {
ret := m.ctrl.Call(m, "Overhead") ret := m.ctrl.Call(m, "Overhead")