allow the first key update immediately after handshake confirmation

This commit is contained in:
Marten Seemann 2020-09-30 14:12:07 +07:00
parent b9090d71ae
commit 1c38acd8c9
3 changed files with 87 additions and 56 deletions

View file

@ -653,6 +653,7 @@ func (h *cryptoSetup) dropInitialKeys() {
}
func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed()
// drop Handshake keys
var dropped bool
h.mutex.Lock()

View file

@ -22,9 +22,10 @@ var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
type updatableAEAD struct {
suite *qtls.CipherSuiteTLS13
keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber
firstPacketNumber protocol.PacketNumber
keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber
firstPacketNumber protocol.PacketNumber
handshakeConfirmed bool
keyUpdateInterval uint64
invalidPacketLimit uint64
@ -172,35 +173,24 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
}
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
if kp != a.keyPhase.Bit() {
var receivedWrongInitialKeyPhase bool
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
if a.keyPhase == 0 {
// This can only occur when the first packet received has key phase 1.
// This is an error, since the key phase starts at 0,
// and peers are only allowed to update keys after the handshake is confirmed.
// Proceed from here, and only return an error if decryption of the packet succeeds.
receivedWrongInitialKeyPhase = true
} else {
if a.prevRcvAEAD == nil {
return nil, ErrKeysDropped
}
// we updated the key, but the peer hasn't updated yet
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
err = ErrDecryptionFailed
}
return dec, err
if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
if a.prevRcvAEAD == nil {
return nil, ErrKeysDropped
}
// we updated the key, but the peer hasn't updated yet
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
err = ErrDecryptionFailed
}
return dec, err
}
// try opening the packet with the next key phase
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err == nil && receivedWrongInitialKeyPhase {
return nil, qerr.NewError(qerr.KeyUpdateError, "wrong initial key phase")
} else if err != nil {
if err != nil {
return nil, ErrDecryptionFailed
}
// Opening succeeded. Check if the peer was allowed to update.
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
return nil, qerr.NewError(qerr.KeyUpdateError, "keys updated too quickly")
}
a.rollKeys()
@ -256,10 +246,20 @@ func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
return nil
}
func (a *updatableAEAD) SetHandshakeConfirmed() {
a.handshakeConfirmed = true
}
func (a *updatableAEAD) updateAllowed() bool {
return a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
a.largestAcked != protocol.InvalidPacketNumber &&
a.largestAcked >= a.firstSentWithCurrentKey
if !a.handshakeConfirmed {
return false
}
// the first key update is allowed as soon as the handshake is confirmed
return a.keyPhase == 0 ||
// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
a.largestAcked != protocol.InvalidPacketNumber &&
a.largestAcked >= a.firstSentWithCurrentKey)
}
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {

View file

@ -215,11 +215,13 @@ var _ = Describe("Updatable AEAD", func() {
Expect(err).To(MatchError(ErrKeysDropped))
})
It("errors when the peer starts with key phase 1", func() {
It("allows the first key update immediately", func() {
// receive a packet at key phase one, before having sent or received any packets at key phase 0
client.rollKeys()
encrypted := client.Seal(nil, msg, 0x1337, ad)
_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("KEY_UPDATE_ERROR: wrong initial key phase"))
encrypted1 := client.Seal(nil, msg, 0x1337, ad)
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
_, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
})
It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() {
@ -231,14 +233,16 @@ var _ = Describe("Updatable AEAD", func() {
})
It("errors when the peer updates keys too frequently", func() {
// receive the first packet at key phase zero
server.rollKeys()
client.rollKeys()
// receive the first packet at key phase one
encrypted0 := client.Seal(nil, msg, 0x42, ad)
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
// now receive a packet at key phase one, before having sent any packets
// now receive a packet at key phase two, before having sent any packets
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x42, ad)
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError("KEY_UPDATE_ERROR: keys updated too quickly"))
})
})
@ -249,25 +253,40 @@ var _ = Describe("Updatable AEAD", func() {
BeforeEach(func() {
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
server.keyUpdateInterval = keyUpdateInterval
server.SetHandshakeConfirmed()
})
It("initiates a key update after sealing the maximum number of packets", func() {
It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
}
// no update allowed before receiving an acknowledgement for the current key phase
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
// receive an ACK for a packet sent in key phase 0
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
// the first update is allowed without receiving an acknowledgement
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() {
server.rollKeys()
client.rollKeys()
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
server.Seal(nil, msg, pn, ad)
}
// no update allowed before receiving an acknowledgement for the current key phase
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// receive an ACK for a packet sent in key phase 0
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad"))
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
})
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
// First make sure that we update our keys.
for i := 0; i < keyUpdateInterval; i++ {
@ -275,14 +294,9 @@ var _ = Describe("Updatable AEAD", func() {
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
}
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
// Now that our keys are updated, send a packet using the new keys.
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// Now that our keys are updated, send a packet using the new keys.
const nextPN = keyUpdateInterval + 1
server.Seal(nil, msg, nextPN, ad)
// We haven't decrypted any packet in the new key phase yet.
@ -297,7 +311,6 @@ var _ = Describe("Updatable AEAD", func() {
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
}
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
Expect(err).ToNot(HaveOccurred())
@ -310,7 +323,7 @@ var _ = Describe("Updatable AEAD", func() {
Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred())
})
It("initiates a key update after opening the maximum number of packets", func() {
It("initiates a key update after opening the maximum number of packets, for the first update", func() {
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
@ -318,14 +331,30 @@ var _ = Describe("Updatable AEAD", func() {
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
}
// no update allowed before receiving an acknowledgement for the current key phase
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, 1, ad)
Expect(server.SetLargestAcked(1)).To(Succeed())
// the first update is allowed without receiving an acknowledgement
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() {
server.rollKeys()
client.rollKeys()
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
encrypted := client.Seal(nil, msg, pn, ad)
_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
}
// no update allowed before receiving an acknowledgement for the current key phase
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
server.Seal(nil, msg, 1, ad)
Expect(server.SetLargestAcked(1)).To(Succeed())
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
})
It("drops keys 3 PTOs after a key update", func() {
now := time.Now()
for i := 0; i < keyUpdateInterval; i++ {
@ -415,6 +444,7 @@ var _ = Describe("Updatable AEAD", func() {
})
It("drops keys early when we initiate another key update within the 3 PTO period", func() {
server.SetHandshakeConfirmed()
// send so many packets that we initiate the first key update
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)