initiate a key update after sending / receiving 100000 packets

This commit is contained in:
Marten Seemann 2019-06-13 16:24:06 +08:00
parent 73d5d83a1d
commit a09c045324
5 changed files with 156 additions and 74 deletions

View file

@ -9,7 +9,7 @@
- Use a varint for error codes.
- Add support for [quic-trace](https://github.com/google/quic-trace).
- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`.
- Implement receiving of TLS key updates.
- Implement TLS key updates.
## v0.11.0 (2019-04-05)

View file

@ -223,7 +223,8 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) error {
return nil
}
func (h *cryptoSetup) SetLargest1RTTAcked(_ protocol.PacketNumber) {
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) {
h.aead.SetLargestAcked(pn)
// drop initial keys
// TODO: do this earlier
if h.initialOpener != nil {

View file

@ -15,12 +15,16 @@ import (
type updatableAEAD struct {
suite cipherSuite
keyPhase protocol.KeyPhase
keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber
keyUpdateInterval uint64
prevRcvAEAD cipher.AEAD
firstRcvdWithCurrentKey protocol.PacketNumber
firstSentWithCurrentKey protocol.PacketNumber
numRcvdWithCurrentKey uint64
numSentWithCurrentKey uint64
rcvAEAD cipher.AEAD
sendAEAD cipher.AEAD
@ -44,17 +48,20 @@ var _ ShortHeaderSealer = &updatableAEAD{}
func newUpdatableAEAD(logger utils.Logger) *updatableAEAD {
return &updatableAEAD{
largestAcked: protocol.InvalidPacketNumber,
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
keyUpdateInterval: protocol.KeyUpdateInterval,
logger: logger,
}
}
func (a *updatableAEAD) rollKeys() {
a.keyPhase = a.keyPhase.Next()
a.logger.Debugf("Updating keys to the next key phase: %s", a.keyPhase)
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
a.numRcvdWithCurrentKey = 0
a.numSentWithCurrentKey = 0
a.prevRcvAEAD = a.rcvAEAD
a.rcvAEAD = a.nextRcvAEAD
a.sendAEAD = a.nextSendAEAD
@ -126,6 +133,7 @@ func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp proto
return nil, qerr.Error(qerr.ProtocolViolation, "keys updated too quickly")
}
a.rollKeys()
a.logger.Debugf("Peer updated keys to %s", a.keyPhase)
a.firstRcvdWithCurrentKey = pn
return dec, err
}
@ -134,8 +142,11 @@ func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp proto
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
err = ErrDecryptionFailed
} else if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
a.firstRcvdWithCurrentKey = pn
} else {
a.numRcvdWithCurrentKey++
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
a.firstRcvdWithCurrentKey = pn
}
}
return dec, err
}
@ -144,13 +155,42 @@ func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byt
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
a.firstSentWithCurrentKey = pn
}
a.numSentWithCurrentKey++
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
}
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) {
a.largestAcked = pn
}
func (a *updatableAEAD) updateAllowed() bool {
return a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
a.largestAcked != protocol.InvalidPacketNumber &&
a.largestAcked >= a.firstSentWithCurrentKey
}
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
if !a.updateAllowed() {
return false
}
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %s", a.numRcvdWithCurrentKey, a.keyPhase.Next())
return true
}
if a.numSentWithCurrentKey >= a.keyUpdateInterval {
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %s", a.numSentWithCurrentKey, a.keyPhase.Next())
return true
}
return false
}
func (a *updatableAEAD) KeyPhase() protocol.KeyPhase {
if a.shouldInitiateKeyUpdate() {
a.rollKeys()
}
return a.keyPhase
}

View file

@ -95,79 +95,117 @@ var _ = Describe("Updatable AEAD", func() {
})
Context("key updates", func() {
It("updates keys", func() {
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted0 := server.Seal(nil, msg, 0x1337, ad)
server.rollKeys()
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
encrypted1 := server.Seal(nil, msg, 0x1337, ad)
Expect(encrypted0).ToNot(Equal(encrypted1))
// expect opening to fail. The client didn't roll keys yet
_, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrDecryptionFailed))
client.rollKeys()
decrypted, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
Context("receiving key updates", func() {
It("updates keys", func() {
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted0 := server.Seal(nil, msg, 0x1337, ad)
server.rollKeys()
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
encrypted1 := server.Seal(nil, msg, 0x1337, ad)
Expect(encrypted0).ToNot(Equal(encrypted1))
// expect opening to fail. The client didn't roll keys yet
_, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).To(MatchError(ErrDecryptionFailed))
client.rollKeys()
decrypted, err := client.Open(nil, encrypted1, 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
})
It("updates the keys when receiving a packet with the next key phase", func() {
// receive the first packet at key phase zero
encrypted0 := client.Seal(nil, msg, 0x42, ad)
decrypted, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
// send one packet at key phase zero
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_ = server.Seal(nil, msg, 0x1, ad)
// now received a message at key phase one
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x43, ad)
decrypted, err = server.Open(nil, encrypted1, 0x43, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
It("opens a reordered packet with the old keys after an update", func() {
encrypted01 := client.Seal(nil, msg, 0x42, ad)
encrypted02 := client.Seal(nil, msg, 0x43, ad)
// receive the first packet with key phase 0
_, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// send one packet at key phase zero
_ = server.Seal(nil, msg, 0x1, ad)
// now receive a packet with key phase 1
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// now receive a reordered packet with key phase 0
decrypted, err := server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
It("errors when the peer starts with key phase 1", func() {
client.rollKeys()
encrypted := client.Seal(nil, msg, 0x1337, ad)
_, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase"))
})
It("errors when the peer updates keys too frequently", func() {
// receive the first packet at key phase zero
encrypted0 := client.Seal(nil, msg, 0x42, ad)
_, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// now receive a packet at key phase one, before having sent any packets
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x42, ad)
_, err = server.Open(nil, encrypted1, 0x42, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly"))
})
})
It("updates the keys when receiving a packet with the next key phase", func() {
// receive the first packet at key phase zero
encrypted0 := client.Seal(nil, msg, 0x42, ad)
decrypted, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
// send one packet at key phase zero
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_ = server.Seal(nil, msg, 0x1, ad)
// now received a message at key phase one
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x43, ad)
decrypted, err = server.Open(nil, encrypted1, 0x43, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
Context("initiating key updates", func() {
const keyUpdateInterval = 20
It("opens a reordered packet with the old keys after an update", func() {
encrypted01 := client.Seal(nil, msg, 0x42, ad)
encrypted02 := client.Seal(nil, msg, 0x43, ad)
// receive the first packet with key phase 0
_, err := server.Open(nil, encrypted01, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// send one packet at key phase zero
_ = server.Seal(nil, msg, 0x1, ad)
// now receive a packet with key phase 1
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
_, err = server.Open(nil, encrypted1, 0x44, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred())
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// now receive a reordered packet with key phase 0
decrypted, err := server.Open(nil, encrypted02, 0x43, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg))
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
BeforeEach(func() {
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
server.keyUpdateInterval = keyUpdateInterval
})
It("errors when the peer starts with key phase 1", func() {
client.rollKeys()
encrypted := client.Seal(nil, msg, 0x1337, ad)
_, err := server.Open(nil, encrypted, 0x1337, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: wrong initial keyphase"))
})
It("initiates a key update after sealing the maximum number of packets", func() {
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad)
}
// no update allowed before receiving an acknowledgement for the current key phase
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.SetLargestAcked(0)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
It("errors when the peer updates keys too frequently", func() {
// receive the first packet at key phase zero
encrypted0 := client.Seal(nil, msg, 0x42, ad)
_, err := server.Open(nil, encrypted0, 0x42, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
// now receive a packet at key phase one, before having sent any packets
client.rollKeys()
encrypted1 := client.Seal(nil, msg, 0x42, ad)
_, err = server.Open(nil, encrypted1, 0x42, protocol.KeyPhaseOne, ad)
Expect(err).To(MatchError("PROTOCOL_VIOLATION: keys updated too quickly"))
It("initiates a key update after opening the maximum number of packets", func() {
for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted := client.Seal(nil, msg, pn, ad)
_, err := server.Open(nil, encrypted, pn, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
}
// no update allowed before receiving an acknowledgement for the current key phase
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, 1, ad)
server.SetLargestAcked(1)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
})
})
})
})

View file

@ -139,3 +139,6 @@ const MaxAckDelay = 25 * time.Millisecond
// MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity.
// This is the value that should be advertised to the peer.
const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key udpate.
const KeyUpdateInterval = 100 * 1000