From fa89ec345ab9bca8ea6e06c2150c7fc7c17b9562 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 7 Sep 2019 11:06:30 +0700 Subject: [PATCH] add support for ChaCha20 header protection --- go.mod | 2 + go.sum | 4 + internal/handshake/aead_test.go | 143 +++---- internal/handshake/handshake_suite_test.go | 51 ++- internal/handshake/header_protector.go | 52 ++- internal/handshake/updatable_aead_test.go | 431 +++++++++++---------- 6 files changed, 398 insertions(+), 285 deletions(-) diff --git a/go.mod b/go.mod index 32a1df7f..41eabadf 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,11 @@ module github.com/lucas-clemente/quic-go go 1.13 require ( + github.com/alangpierce/go-forceexport v0.0.0-20160317203124-8f1d6941cd75 github.com/cheekybits/genny v1.0.0 github.com/golang/mock v1.2.0 github.com/golang/protobuf v1.3.0 + github.com/marten-seemann/chacha20 v0.2.0 github.com/marten-seemann/qpack v0.1.0 github.com/marten-seemann/qtls v0.4.0 github.com/onsi/ginkgo v1.7.0 diff --git a/go.sum b/go.sum index c534a87a..c343aaa8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/alangpierce/go-forceexport v0.0.0-20160317203124-8f1d6941cd75 h1:3ILjVyslFbc4jl1w5TWuvvslFD/nDfR2H8tVaMVLrEY= +github.com/alangpierce/go-forceexport v0.0.0-20160317203124-8f1d6941cd75/go.mod h1:uAXEEpARkRhCZfEvy/y0Jcc888f9tHCc1W7/UeEtreE= github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= @@ -9,6 +11,8 @@ github.com/golang/protobuf v1.3.0 h1:kbxbvI4Un1LUWKxufD+BiE6AEExYYgkQLQmLFqA1LFk github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/marten-seemann/chacha20 v0.2.0 h1:f40vqzzx+3GdOmzQoItkLX5WLvHgPgyYqFFIO5Gh4hQ= +github.com/marten-seemann/chacha20 v0.2.0/go.mod h1:HSdjFau7GzYRj+ahFNwsO3ouVJr1HFkWoEwNDb4TMtE= github.com/marten-seemann/qpack v0.1.0 h1:/0M7lkda/6mus9B8u34Asqm8ZhHAAt9Ho0vniNuVSVg= github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI= github.com/marten-seemann/qtls v0.4.0 h1:HM9ftULNeuhGiCliIfPKvp5VDJw6pvi/Ghq6PYf7B0E= diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 407da319..c0e5f2a2 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -4,83 +4,92 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "fmt" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("AEAD", func() { - getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err := cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) + for i := range cipherSuites { + cs := cipherSuites[i] - return newLongHeaderSealer(aead, newHeaderProtector(aesSuite, key, true)), - newLongHeaderOpener(aead, newAESHeaderProtector(aesSuite, key, true)) - } + Context(fmt.Sprintf("using %s", cs.name), func() { + suite := cs.suite - Context("message encryption", func() { - msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad := []byte("Donec in velit neque.") + getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { + key := make([]byte, 16) + hpKey := make([]byte, 16) + rand.Read(key) + rand.Read(hpKey) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err := cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) - It("encrypts and decrypts a message", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - opened, err := opener.Open(nil, encrypted, 0x1337, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("fails to open a message if the associated data is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x42, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - }) - - Context("header encryption", func() { - It("encrypts and encrypts the header", func() { - sealer, opener := getSealerAndOpener() - var lastFourBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sealer.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0xf != 0xb5&0xf { - lastFourBitsDifferent++ - } - Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + return newLongHeaderSealer(aead, newHeaderProtector(suite, key, true)), + newLongHeaderOpener(aead, newHeaderProtector(suite, key, true)) } - Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) - }) - It("fails to decrypt the header when using a different sample", func() { - sealer, opener := getSealerAndOpener() - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sample := make([]byte, 16) - rand.Read(sample) - sealer.EncryptHeader(sample, &header[0], header[9:13]) - rand.Read(sample) // use a different sample - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + Context("message encryption", func() { + msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad := []byte("Donec in velit neque.") + + It("encrypts and decrypts a message", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + opened, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("fails to open a message if the associated data is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("fails to open a message if the packet number is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x42, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + }) + + Context("header encryption", func() { + It("encrypts and encrypts the header", func() { + sealer, opener := getSealerAndOpener() + var lastFourBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0xf != 0xb5&0xf { + lastFourBitsDifferent++ + } + Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) + }) + + It("fails to decrypt the header when using a different sample", func() { + sealer, opener := getSealerAndOpener() + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sample := make([]byte, 16) + rand.Read(sample) + sealer.EncryptHeader(sample, &header[0], header[9:13]) + rand.Read(sample) // use a different sample + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + }) + }) }) - }) + } }) diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go index 10645e50..158460eb 100644 --- a/internal/handshake/handshake_suite_test.go +++ b/internal/handshake/handshake_suite_test.go @@ -2,7 +2,9 @@ package handshake import ( "crypto" + "crypto/cipher" + "github.com/alangpierce/go-forceexport" "github.com/golang/mock/gomock" "github.com/marten-seemann/qtls" @@ -27,9 +29,48 @@ var _ = AfterEach(func() { mockCtrl.Finish() }) -var aesSuite = &qtls.CipherSuiteTLS13{ - ID: qtls.TLS_AES_128_GCM_SHA256, - KeyLen: 16, - AEAD: qtls.AEADAESGCMTLS13, - Hash: crypto.SHA256, +var aeadChaCha20Poly1305 func(key, nonceMask []byte) cipher.AEAD + +var cipherSuites = []struct { + name string + suite *qtls.CipherSuiteTLS13 +}{ + { + name: "TLS_AES_128_GCM_SHA256", + suite: &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_AES_128_GCM_SHA256, + KeyLen: 16, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA256, + }, + }, + { + name: "TLS_AES_256_GCM_SHA384", + suite: &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_AES_256_GCM_SHA384, + KeyLen: 32, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA384, + }, + }, + { + name: "TLS_CHACHA20_POLY1305_SHA256", + suite: &qtls.CipherSuiteTLS13{ + ID: qtls.TLS_CHACHA20_POLY1305_SHA256, + KeyLen: 32, + AEAD: nil, // will be set by init + Hash: crypto.SHA256, + }, + }, +} + +func init() { + if err := forceexport.GetFunc(&aeadChaCha20Poly1305, "github.com/marten-seemann/qtls.aeadChaCha20Poly1305"); err != nil { + panic(err) + } + for _, s := range cipherSuites { + if s.suite.ID == qtls.TLS_CHACHA20_POLY1305_SHA256 { + s.suite.AEAD = aeadChaCha20Poly1305 + } + } } diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go index 17e8fcd4..019d5703 100644 --- a/internal/handshake/header_protector.go +++ b/internal/handshake/header_protector.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "fmt" + "github.com/marten-seemann/chacha20" "github.com/marten-seemann/qtls" ) @@ -18,8 +19,7 @@ func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLo case qtls.TLS_AES_128_GCM_SHA256, qtls.TLS_AES_256_GCM_SHA384: return newAESHeaderProtector(suite, trafficSecret, isLongHeader) case qtls.TLS_CHACHA20_POLY1305_SHA256: - // TODO: implement ChaCha header protection - fallthrough + return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader) default: panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) } @@ -69,3 +69,51 @@ func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []by hdrBytes[i] ^= p.mask[i+1] } } + +type chachaHeaderProtector struct { + mask [5]byte + + key [32]byte + sampleBuf [16]byte + isLongHeader bool +} + +var _ headerProtector = &chachaHeaderProtector{} + +func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector { + hpKey := qtls.HkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen) + + p := &chachaHeaderProtector{ + isLongHeader: isLongHeader, + } + copy(p.key[:], hpKey) + return p +} + +func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) < len(p.mask) { + panic("invalid sample size") + } + for i := 0; i < 5; i++ { + p.mask[i] = 0 + } + copy(p.sampleBuf[:], sample) + chacha20.XORKeyStream(p.mask[:], p.mask[:], &p.sampleBuf, &p.key) + + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 80c1721f..ab28d28b 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -2,6 +2,7 @@ package handshake import ( "crypto/rand" + "fmt" "os" "time" @@ -13,236 +14,244 @@ import ( ) var _ = Describe("Updatable AEAD", func() { - getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) { - trafficSecret1 := make([]byte, 16) - trafficSecret2 := make([]byte, 16) - rand.Read(trafficSecret1) - rand.Read(trafficSecret2) + for i := range cipherSuites { + cs := cipherSuites[i] - client = newUpdatableAEAD(rttStats, utils.DefaultLogger) - server = newUpdatableAEAD(rttStats, utils.DefaultLogger) - client.SetReadKey(aesSuite, trafficSecret2) - client.SetWriteKey(aesSuite, trafficSecret1) - server.SetReadKey(aesSuite, trafficSecret1) - server.SetWriteKey(aesSuite, trafficSecret2) - return - } + Context(fmt.Sprintf("using %s", cs.name), func() { + suite := cs.suite - Context("header protection", func() { - It("encrypts and decrypts the header", func() { - server, client := getPeers(&congestion.RTTStats{}) - var lastFiveBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - client.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0x1f != 0xb5&0x1f { - lastFiveBitsDifferent++ - } - Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - server.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) { + trafficSecret1 := make([]byte, 16) + trafficSecret2 := make([]byte, 16) + rand.Read(trafficSecret1) + rand.Read(trafficSecret2) + + client = newUpdatableAEAD(rttStats, utils.DefaultLogger) + server = newUpdatableAEAD(rttStats, utils.DefaultLogger) + client.SetReadKey(suite, trafficSecret2) + client.SetWriteKey(suite, trafficSecret1) + server.SetReadKey(suite, trafficSecret1) + server.SetWriteKey(suite, trafficSecret2) + return } - Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) - }) - }) - Context("message encryption", func() { - var msg, ad []byte - var server, client *updatableAEAD - var rttStats *congestion.RTTStats - - BeforeEach(func() { - rttStats = &congestion.RTTStats{} - server, client = getPeers(rttStats) - msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad = []byte("Donec in velit neque.") - }) - - It("encrypts and decrypts a message", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("fails to open a message if the associated data is not the same", func() { - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - Context("key updates", func() { - Context("receiving key updates", func() { - It("updates keys", func() { - now := time.Now() - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted0 := server.Seal(nil, msg, 0x1337, ad) - server.rollKeys(now) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - encrypted1 := server.Seal(nil, msg, 0x1337, ad) - Expect(encrypted0).ToNot(Equal(encrypted1)) - // expect opening to fail. The client didn't roll keys yet - _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - client.rollKeys(now) - decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - }) - - It("updates the keys when receiving a packet with the next key phase", func() { - now := time.Now() - // receive the first packet at key phase zero - encrypted0 := client.Seal(nil, msg, 0x42, ad) - decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - // send one packet at key phase zero - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _ = server.Seal(nil, msg, 0x1, ad) - // now received a message at key phase one - client.rollKeys(now) - encrypted1 := client.Seal(nil, msg, 0x43, ad) - decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("opens a reordered packet with the old keys after an update", func() { - now := time.Now() - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys(now) - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("drops keys 3 PTOs after a key update", func() { - now := time.Now() - rttStats.UpdateRTT(10*time.Millisecond, 0, now) - pto := rttStats.PTO() - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys(now) - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrKeysDropped)) - }) - - It("errors when the peer starts with key phase 1", func() { - client.rollKeys(time.Now()) - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase")) - }) - - It("errors when the peer updates keys too frequently", func() { - // receive the first packet at key phase zero - encrypted0 := client.Seal(nil, msg, 0x42, ad) - _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // now receive a packet at key phase one, before having sent any packets - client.rollKeys(time.Now()) - encrypted1 := client.Seal(nil, msg, 0x42, ad) - _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly")) + Context("header protection", func() { + It("encrypts and decrypts the header", func() { + server, client := getPeers(&congestion.RTTStats{}) + var lastFiveBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + client.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0x1f != 0xb5&0x1f { + lastFiveBitsDifferent++ + } + Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + server.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) }) }) - Context("initiating key updates", func() { - const keyUpdateInterval = 20 + Context("message encryption", func() { + var msg, ad []byte + var server, client *updatableAEAD + var rttStats *congestion.RTTStats BeforeEach(func() { - Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) - server.keyUpdateInterval = keyUpdateInterval + rttStats = &congestion.RTTStats{} + server, client = getPeers(rttStats) + msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad = []byte("Donec in velit neque.") }) - It("initiates a key update after sealing the maximum number of packets", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.SetLargestAcked(0) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + It("encrypts and decrypts a message", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) }) - It("initiates a key update after opening the maximum number of packets", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted := client.Seal(nil, msg, pn, ad) - _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, 1, ad) - server.SetLargestAcked(1) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - }) - - Context("reading the key update env", func() { - AfterEach(func() { - os.Setenv(keyUpdateEnv, "") - setKeyUpdateInterval() + It("fails to open a message if the associated data is not the same", func() { + encrypted := client.Seal(nil, msg, 0x1337, ad) + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) }) - It("uses the default value if the env is not set", func() { - setKeyUpdateInterval() - Expect(keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + It("fails to open a message if the packet number is not the same", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) }) - It("uses the env", func() { - os.Setenv(keyUpdateEnv, "1337") - setKeyUpdateInterval() - Expect(keyUpdateInterval).To(BeEquivalentTo(1337)) - }) + Context("key updates", func() { + Context("receiving key updates", func() { + It("updates keys", func() { + now := time.Now() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted0 := server.Seal(nil, msg, 0x1337, ad) + server.rollKeys(now) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted1 := server.Seal(nil, msg, 0x1337, ad) + Expect(encrypted0).ToNot(Equal(encrypted1)) + // expect opening to fail. The client didn't roll keys yet + _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + client.rollKeys(now) + decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + }) - It("panics when it can't parse the env", func() { - os.Setenv(keyUpdateEnv, "foobar") - Expect(setKeyUpdateInterval).To(Panic()) + It("updates the keys when receiving a packet with the next key phase", func() { + now := time.Now() + // receive the first packet at key phase zero + encrypted0 := client.Seal(nil, msg, 0x42, ad) + decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + // send one packet at key phase zero + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _ = server.Seal(nil, msg, 0x1, ad) + // now received a message at key phase one + client.rollKeys(now) + encrypted1 := client.Seal(nil, msg, 0x43, ad) + decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("opens a reordered packet with the old keys after an update", func() { + now := time.Now() + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys(now) + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("drops keys 3 PTOs after a key update", func() { + now := time.Now() + rttStats.UpdateRTT(10*time.Millisecond, 0, now) + pto := rttStats.PTO() + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys(now) + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) + + It("errors when the peer starts with key phase 1", func() { + client.rollKeys(time.Now()) + encrypted := client.Seal(nil, msg, 0x1337, ad) + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase")) + }) + + It("errors when the peer updates keys too frequently", func() { + // receive the first packet at key phase zero + encrypted0 := client.Seal(nil, msg, 0x42, ad) + _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // now receive a packet at key phase one, before having sent any packets + client.rollKeys(time.Now()) + encrypted1 := client.Seal(nil, msg, 0x42, ad) + _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad) + Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly")) + }) + }) + + Context("initiating key updates", func() { + const keyUpdateInterval = 20 + + BeforeEach(func() { + Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + server.keyUpdateInterval = keyUpdateInterval + }) + + It("initiates a key update after sealing the maximum number of packets", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.SetLargestAcked(0) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("initiates a key update after opening the maximum number of packets", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, 1, ad) + server.SetLargestAcked(1) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + }) + + Context("reading the key update env", func() { + AfterEach(func() { + os.Setenv(keyUpdateEnv, "") + setKeyUpdateInterval() + }) + + It("uses the default value if the env is not set", func() { + setKeyUpdateInterval() + Expect(keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + }) + + It("uses the env", func() { + os.Setenv(keyUpdateEnv, "1337") + setKeyUpdateInterval() + Expect(keyUpdateInterval).To(BeEquivalentTo(1337)) + }) + + It("panics when it can't parse the env", func() { + os.Setenv(keyUpdateEnv, "foobar") + Expect(setKeyUpdateInterval).To(Panic()) + }) + }) }) }) }) - }) + } })