diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index df2644b1..e5de2f8a 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -112,7 +112,9 @@ func (a *updatableAEAD) rollKeys() { } func (a *updatableAEAD) startKeyDropTimer(now time.Time) { - a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO(true)) + d := 3 * a.rttStats.PTO(true) + a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) + a.prevRcvAEADExpiry = now.Add(d) } func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { @@ -152,6 +154,7 @@ func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { a.prevRcvAEAD = nil + a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) a.prevRcvAEADExpiry = time.Time{} if a.tracer != nil { a.tracer.DroppedKey(a.keyPhase - 1) @@ -211,7 +214,10 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { // We initiated the key updated, and now we received the first packet protected with the new key phase. // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. - a.startKeyDropTimer(rcvTime) + if a.keyPhase > 0 { + a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) + a.startKeyDropTimer(rcvTime) + } a.firstRcvdWithCurrentKey = pn } return dec, err diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 9cfbd505..246fc0b7 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -298,6 +298,26 @@ var _ = Describe("Updatable AEAD", func() { _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) Expect(err).To(MatchError(ErrKeysDropped)) }) + + It("doesn't drop the first key generation too early", func() { + now := time.Now() + data1 := client.Seal(nil, msg, 1, ad) + _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + server.SetLargestAcked(pn) + } + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // The server never received a packet at key phase 1. + // Make sure the key phase 0 is still there at a much later point. + data2 := client.Seal(nil, msg, 1, ad) + _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + }) }) Context("reading the key update env", func() {