From b4636469fa71fd1519ffe70e13606a0ff87f70f2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 30 Sep 2020 12:05:33 +0700 Subject: [PATCH 1/3] refactor confirmation of the handshake --- session.go | 8 ++++---- session_test.go | 15 +++++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/session.go b/session.go index 16e189b4..c776f9fb 100644 --- a/session.go +++ b/session.go @@ -688,6 +688,8 @@ func (s *session) handleHandshakeComplete() { s.connIDGenerator.SetHandshakeComplete() if s.perspective == protocol.PerspectiveServer { + s.handshakeConfirmed = true + s.sentPacketHandler.SetHandshakeConfirmed() ticket, err := s.cryptoStreamHandler.GetSessionTicket() if err != nil { s.closeLocal(err) @@ -1238,6 +1240,8 @@ func (s *session) handleHandshakeDoneFrame() error { if s.perspective == protocol.PerspectiveServer { return qerr.NewError(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame") } + s.handshakeConfirmed = true + s.sentPacketHandler.SetHandshakeConfirmed() s.cryptoStreamHandler.DropHandshakeKeys() return nil } @@ -1347,10 +1351,6 @@ func (s *session) handleCloseError(closeErr closeError) { } func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { - if encLevel == protocol.EncryptionHandshake { - s.handshakeConfirmed = true - s.sentPacketHandler.SetHandshakeConfirmed() - } s.sentPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel) if s.tracer != nil { diff --git a/session_test.go b/session_test.go index 2e382991..bf119243 100644 --- a/session_test.go +++ b/session_test.go @@ -1634,6 +1634,7 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode().AnyTimes() + sph.EXPECT().SetHandshakeConfirmed() sessionRunner.EXPECT().Retire(clientDestConnID) go func() { defer GinkgoRecover() @@ -1729,14 +1730,17 @@ var _ = Describe("Session", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().HasPacingBudget().Return(true) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SetHandshakeConfirmed() + sph.EXPECT().SentPacket(gomock.Any()) + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sess.sentPacketHandler = sph done := make(chan struct{}) sessionRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackCoalescedPacket(gomock.Any()).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) { + packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) @@ -1748,7 +1752,7 @@ var _ = Describe("Session", func() { buffer: getPacketBuffer(), }, nil }) - packer.EXPECT().PackCoalescedPacket(gomock.Any()).AnyTimes() + packer.EXPECT().PackPacket().AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -2270,6 +2274,9 @@ var _ = Describe("Client Session", func() { }) It("handles HANDSHAKE_DONE frames", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sess.sentPacketHandler = sph + sph.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().DropHandshakeKeys() Expect(sess.handleHandshakeDoneFrame()).To(Succeed()) }) From b9090d71ae112fced687c79c65985f27516d3e7a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 30 Sep 2020 12:14:16 +0700 Subject: [PATCH 2/3] rename cryptoSetup.DropHandshakeKeys() to SetHandshakeConfirmed() --- internal/handshake/crypto_setup.go | 3 ++- internal/handshake/interface.go | 2 +- internal/mocks/crypto_setup.go | 24 ++++++++++++------------ session.go | 6 +++--- session_test.go | 10 +++++----- 5 files changed, 23 insertions(+), 22 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 9ce12088..16798b7a 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -652,7 +652,8 @@ func (h *cryptoSetup) dropInitialKeys() { h.logger.Debugf("Dropping Initial keys.") } -func (h *cryptoSetup) DropHandshakeKeys() { +func (h *cryptoSetup) SetHandshakeConfirmed() { + // drop Handshake keys var dropped bool h.mutex.Lock() if h.handshakeOpener != nil { diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 3bf9dc06..b64cd015 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -77,7 +77,7 @@ type CryptoSetup interface { HandleMessage([]byte, protocol.EncryptionLevel) bool SetLargest1RTTAcked(protocol.PacketNumber) error - DropHandshakeKeys() + SetHandshakeConfirmed() ConnectionState() ConnectionState GetInitialOpener() (LongHeaderOpener, error) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index fb247276..ba1d2965 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -76,18 +76,6 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) } -// DropHandshakeKeys mocks base method -func (m *MockCryptoSetup) DropHandshakeKeys() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DropHandshakeKeys") -} - -// DropHandshakeKeys indicates an expected call of DropHandshakeKeys -func (mr *MockCryptoSetupMockRecorder) DropHandshakeKeys() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropHandshakeKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DropHandshakeKeys)) -} - // Get0RTTOpener mocks base method func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { m.ctrl.T.Helper() @@ -249,6 +237,18 @@ func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake)) } +// SetHandshakeConfirmed mocks base method +func (m *MockCryptoSetup) SetHandshakeConfirmed() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetHandshakeConfirmed") +} + +// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed +func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed)) +} + // SetLargest1RTTAcked mocks base method func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error { m.ctrl.T.Helper() diff --git a/session.go b/session.go index c776f9fb..ace8a8c9 100644 --- a/session.go +++ b/session.go @@ -52,7 +52,7 @@ type cryptoStreamHandler interface { RunHandshake() ChangeConnectionID(protocol.ConnectionID) SetLargest1RTTAcked(protocol.PacketNumber) error - DropHandshakeKeys() + SetHandshakeConfirmed() GetSessionTicket() ([]byte, error) io.Closer ConnectionState() handshake.ConnectionState @@ -705,7 +705,7 @@ func (s *session) handleHandshakeComplete() { s.closeLocal(err) } s.queueControlFrame(&wire.NewTokenFrame{Token: token}) - s.cryptoStreamHandler.DropHandshakeKeys() + s.cryptoStreamHandler.SetHandshakeConfirmed() s.queueControlFrame(&wire.HandshakeDoneFrame{}) } } @@ -1242,7 +1242,7 @@ func (s *session) handleHandshakeDoneFrame() error { } s.handshakeConfirmed = true s.sentPacketHandler.SetHandshakeConfirmed() - s.cryptoStreamHandler.DropHandshakeKeys() + s.cryptoStreamHandler.SetHandshakeConfirmed() return nil } diff --git a/session_test.go b/session_test.go index bf119243..cb01fbcf 100644 --- a/session_test.go +++ b/session_test.go @@ -1640,7 +1640,7 @@ var _ = Describe("Session", func() { defer GinkgoRecover() <-finishHandshake cryptoSetup.EXPECT().RunHandshake() - cryptoSetup.EXPECT().DropHandshakeKeys() + cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() close(sess.handshakeCompleteChan) sess.run() @@ -1670,7 +1670,7 @@ var _ = Describe("Session", func() { defer GinkgoRecover() <-finishHandshake cryptoSetup.EXPECT().RunHandshake() - cryptoSetup.EXPECT().DropHandshakeKeys() + cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) close(sess.handshakeCompleteChan) sess.run() @@ -1756,7 +1756,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() - cryptoSetup.EXPECT().DropHandshakeKeys() + cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() mconn.EXPECT().Write(gomock.Any()) close(sess.handshakeCompleteChan) @@ -2030,7 +2030,7 @@ var _ = Describe("Session", func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) - cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1) + cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) close(sess.handshakeCompleteChan) err := sess.run() nerr, ok := err.(net.Error) @@ -2277,7 +2277,7 @@ var _ = Describe("Client Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph sph.EXPECT().SetHandshakeConfirmed() - cryptoSetup.EXPECT().DropHandshakeKeys() + cryptoSetup.EXPECT().SetHandshakeConfirmed() Expect(sess.handleHandshakeDoneFrame()).To(Succeed()) }) From 1c38acd8c97c1e9289ffbdae857ecc92ffd46a35 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 30 Sep 2020 14:12:07 +0700 Subject: [PATCH 3/3] allow the first key update immediately after handshake confirmation --- internal/handshake/crypto_setup.go | 1 + internal/handshake/updatable_aead.go | 56 +++++++-------- internal/handshake/updatable_aead_test.go | 86 +++++++++++++++-------- 3 files changed, 87 insertions(+), 56 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 16798b7a..bad9d164 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -653,6 +653,7 @@ func (h *cryptoSetup) dropInitialKeys() { } func (h *cryptoSetup) SetHandshakeConfirmed() { + h.aead.SetHandshakeConfirmed() // drop Handshake keys var dropped bool h.mutex.Lock() diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index aadf0309..67247ae0 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -22,9 +22,10 @@ var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval type updatableAEAD struct { suite *qtls.CipherSuiteTLS13 - keyPhase protocol.KeyPhase - largestAcked protocol.PacketNumber - firstPacketNumber protocol.PacketNumber + keyPhase protocol.KeyPhase + largestAcked protocol.PacketNumber + firstPacketNumber protocol.PacketNumber + handshakeConfirmed bool keyUpdateInterval uint64 invalidPacketLimit uint64 @@ -172,35 +173,24 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac } binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) if kp != a.keyPhase.Bit() { - var receivedWrongInitialKeyPhase bool - if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { - if a.keyPhase == 0 { - // This can only occur when the first packet received has key phase 1. - // This is an error, since the key phase starts at 0, - // and peers are only allowed to update keys after the handshake is confirmed. - // Proceed from here, and only return an error if decryption of the packet succeeds. - receivedWrongInitialKeyPhase = true - } else { - if a.prevRcvAEAD == nil { - return nil, ErrKeysDropped - } - // we updated the key, but the peer hasn't updated yet - dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err != nil { - err = ErrDecryptionFailed - } - return dec, err + if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { + if a.prevRcvAEAD == nil { + return nil, ErrKeysDropped } + // we updated the key, but the peer hasn't updated yet + dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + err = ErrDecryptionFailed + } + return dec, err } // try opening the packet with the next key phase dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err == nil && receivedWrongInitialKeyPhase { - return nil, qerr.NewError(qerr.KeyUpdateError, "wrong initial key phase") - } else if err != nil { + if err != nil { return nil, ErrDecryptionFailed } // Opening succeeded. Check if the peer was allowed to update. - if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { + if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { return nil, qerr.NewError(qerr.KeyUpdateError, "keys updated too quickly") } a.rollKeys() @@ -256,10 +246,20 @@ func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { return nil } +func (a *updatableAEAD) SetHandshakeConfirmed() { + a.handshakeConfirmed = true +} + func (a *updatableAEAD) updateAllowed() bool { - return a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && - a.largestAcked != protocol.InvalidPacketNumber && - a.largestAcked >= a.firstSentWithCurrentKey + if !a.handshakeConfirmed { + return false + } + // the first key update is allowed as soon as the handshake is confirmed + return a.keyPhase == 0 || + // subsequent key updates as soon as a packet sent with that key phase has been acknowledged + (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + a.largestAcked != protocol.InvalidPacketNumber && + a.largestAcked >= a.firstSentWithCurrentKey) } func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index ccb785cc..0c115cdc 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -215,11 +215,13 @@ var _ = Describe("Updatable AEAD", func() { Expect(err).To(MatchError(ErrKeysDropped)) }) - It("errors when the peer starts with key phase 1", func() { + It("allows the first key update immediately", func() { + // receive a packet at key phase one, before having sent or received any packets at key phase 0 client.rollKeys() - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError("KEY_UPDATE_ERROR: wrong initial key phase")) + encrypted1 := client.Seal(nil, msg, 0x1337, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) }) It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { @@ -231,14 +233,16 @@ var _ = Describe("Updatable AEAD", func() { }) It("errors when the peer updates keys too frequently", func() { - // receive the first packet at key phase zero + server.rollKeys() + client.rollKeys() + // receive the first packet at key phase one encrypted0 := client.Seal(nil, msg, 0x42, ad) - _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) Expect(err).ToNot(HaveOccurred()) - // now receive a packet at key phase one, before having sent any packets + // now receive a packet at key phase two, before having sent any packets client.rollKeys() encrypted1 := client.Seal(nil, msg, 0x42, ad) - _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad) + _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) Expect(err).To(MatchError("KEY_UPDATE_ERROR: keys updated too quickly")) }) }) @@ -249,25 +253,40 @@ var _ = Describe("Updatable AEAD", func() { BeforeEach(func() { Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) server.keyUpdateInterval = keyUpdateInterval + server.SetHandshakeConfirmed() }) - It("initiates a key update after sealing the maximum number of packets", func() { + It("initiates a key update after sealing the maximum number of packets, for the first update", func() { for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - // receive an ACK for a packet sent in key phase 0 - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + // the first update is allowed without receiving an acknowledgement serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) + It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, pn, ad) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // receive an ACK for a packet sent in key phase 0 + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { // First make sure that we update our keys. for i := 0; i < keyUpdateInterval; i++ { @@ -275,14 +294,9 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) } - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - // Now that our keys are updated, send a packet using the new keys. Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // Now that our keys are updated, send a packet using the new keys. const nextPN = keyUpdateInterval + 1 server.Seal(nil, msg, nextPN, ad) // We haven't decrypted any packet in the new key phase yet. @@ -297,7 +311,6 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) } - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) Expect(err).ToNot(HaveOccurred()) @@ -310,7 +323,7 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) }) - It("initiates a key update after opening the maximum number of packets", func() { + It("initiates a key update after opening the maximum number of packets, for the first update", func() { for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) @@ -318,14 +331,30 @@ var _ = Describe("Updatable AEAD", func() { _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) Expect(err).ToNot(HaveOccurred()) } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, 1, ad) - Expect(server.SetLargestAcked(1)).To(Succeed()) + // the first update is allowed without receiving an acknowledgement serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) + It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, 1, ad) + Expect(server.SetLargestAcked(1)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + It("drops keys 3 PTOs after a key update", func() { now := time.Now() for i := 0; i < keyUpdateInterval; i++ { @@ -415,6 +444,7 @@ var _ = Describe("Updatable AEAD", func() { }) It("drops keys early when we initiate another key update within the 3 PTO period", func() { + server.SetHandshakeConfirmed() // send so many packets that we initiate the first key update for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i)