diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 16798b7a..bad9d164 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -653,6 +653,7 @@ func (h *cryptoSetup) dropInitialKeys() { } func (h *cryptoSetup) SetHandshakeConfirmed() { + h.aead.SetHandshakeConfirmed() // drop Handshake keys var dropped bool h.mutex.Lock() diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index aadf0309..67247ae0 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -22,9 +22,10 @@ var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval type updatableAEAD struct { suite *qtls.CipherSuiteTLS13 - keyPhase protocol.KeyPhase - largestAcked protocol.PacketNumber - firstPacketNumber protocol.PacketNumber + keyPhase protocol.KeyPhase + largestAcked protocol.PacketNumber + firstPacketNumber protocol.PacketNumber + handshakeConfirmed bool keyUpdateInterval uint64 invalidPacketLimit uint64 @@ -172,35 +173,24 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac } binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) if kp != a.keyPhase.Bit() { - var receivedWrongInitialKeyPhase bool - if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { - if a.keyPhase == 0 { - // This can only occur when the first packet received has key phase 1. - // This is an error, since the key phase starts at 0, - // and peers are only allowed to update keys after the handshake is confirmed. - // Proceed from here, and only return an error if decryption of the packet succeeds. - receivedWrongInitialKeyPhase = true - } else { - if a.prevRcvAEAD == nil { - return nil, ErrKeysDropped - } - // we updated the key, but the peer hasn't updated yet - dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err != nil { - err = ErrDecryptionFailed - } - return dec, err + if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { + if a.prevRcvAEAD == nil { + return nil, ErrKeysDropped } + // we updated the key, but the peer hasn't updated yet + dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + err = ErrDecryptionFailed + } + return dec, err } // try opening the packet with the next key phase dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) - if err == nil && receivedWrongInitialKeyPhase { - return nil, qerr.NewError(qerr.KeyUpdateError, "wrong initial key phase") - } else if err != nil { + if err != nil { return nil, ErrDecryptionFailed } // Opening succeeded. Check if the peer was allowed to update. - if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { + if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { return nil, qerr.NewError(qerr.KeyUpdateError, "keys updated too quickly") } a.rollKeys() @@ -256,10 +246,20 @@ func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { return nil } +func (a *updatableAEAD) SetHandshakeConfirmed() { + a.handshakeConfirmed = true +} + func (a *updatableAEAD) updateAllowed() bool { - return a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && - a.largestAcked != protocol.InvalidPacketNumber && - a.largestAcked >= a.firstSentWithCurrentKey + if !a.handshakeConfirmed { + return false + } + // the first key update is allowed as soon as the handshake is confirmed + return a.keyPhase == 0 || + // subsequent key updates as soon as a packet sent with that key phase has been acknowledged + (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + a.largestAcked != protocol.InvalidPacketNumber && + a.largestAcked >= a.firstSentWithCurrentKey) } func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index ccb785cc..0c115cdc 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -215,11 +215,13 @@ var _ = Describe("Updatable AEAD", func() { Expect(err).To(MatchError(ErrKeysDropped)) }) - It("errors when the peer starts with key phase 1", func() { + It("allows the first key update immediately", func() { + // receive a packet at key phase one, before having sent or received any packets at key phase 0 client.rollKeys() - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError("KEY_UPDATE_ERROR: wrong initial key phase")) + encrypted1 := client.Seal(nil, msg, 0x1337, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) }) It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { @@ -231,14 +233,16 @@ var _ = Describe("Updatable AEAD", func() { }) It("errors when the peer updates keys too frequently", func() { - // receive the first packet at key phase zero + server.rollKeys() + client.rollKeys() + // receive the first packet at key phase one encrypted0 := client.Seal(nil, msg, 0x42, ad) - _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) Expect(err).ToNot(HaveOccurred()) - // now receive a packet at key phase one, before having sent any packets + // now receive a packet at key phase two, before having sent any packets client.rollKeys() encrypted1 := client.Seal(nil, msg, 0x42, ad) - _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad) + _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) Expect(err).To(MatchError("KEY_UPDATE_ERROR: keys updated too quickly")) }) }) @@ -249,25 +253,40 @@ var _ = Describe("Updatable AEAD", func() { BeforeEach(func() { Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) server.keyUpdateInterval = keyUpdateInterval + server.SetHandshakeConfirmed() }) - It("initiates a key update after sealing the maximum number of packets", func() { + It("initiates a key update after sealing the maximum number of packets, for the first update", func() { for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - // receive an ACK for a packet sent in key phase 0 - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + // the first update is allowed without receiving an acknowledgement serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) + It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, pn, ad) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // receive an ACK for a packet sent in key phase 0 + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { // First make sure that we update our keys. for i := 0; i < keyUpdateInterval; i++ { @@ -275,14 +294,9 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) } - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - // Now that our keys are updated, send a packet using the new keys. Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // Now that our keys are updated, send a packet using the new keys. const nextPN = keyUpdateInterval + 1 server.Seal(nil, msg, nextPN, ad) // We haven't decrypted any packet in the new key phase yet. @@ -297,7 +311,6 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) } - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) Expect(err).ToNot(HaveOccurred()) @@ -310,7 +323,7 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) }) - It("initiates a key update after opening the maximum number of packets", func() { + It("initiates a key update after opening the maximum number of packets, for the first update", func() { for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) @@ -318,14 +331,30 @@ var _ = Describe("Updatable AEAD", func() { _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) Expect(err).ToNot(HaveOccurred()) } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, 1, ad) - Expect(server.SetLargestAcked(1)).To(Succeed()) + // the first update is allowed without receiving an acknowledgement serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) + It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, 1, ad) + Expect(server.SetLargestAcked(1)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + It("drops keys 3 PTOs after a key update", func() { now := time.Now() for i := 0; i < keyUpdateInterval; i++ { @@ -415,6 +444,7 @@ var _ = Describe("Updatable AEAD", func() { }) It("drops keys early when we initiate another key update within the 3 PTO period", func() { + server.SetHandshakeConfirmed() // send so many packets that we initiate the first key update for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i)