diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 89a9dcd6..6fa4f76f 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -27,7 +27,6 @@ type updatableAEAD struct { firstPacketNumber protocol.PacketNumber handshakeConfirmed bool - keyUpdateInterval uint64 invalidPacketLimit uint64 invalidPacketCount uint64 @@ -74,7 +73,6 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, largestAcked: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber, - keyUpdateInterval: KeyUpdateInterval, rttStats: rttStats, tracer: tracer, logger: logger, @@ -116,6 +114,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) } +// SetReadKey sets the read key. // For the client, this function is called before SetWriteKey. // For the server, this function is called after SetWriteKey. func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { @@ -129,6 +128,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [ a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) } +// SetWriteKey sets the write key. // For the client, this function is called after SetReadKey. // For the server, this function is called before SetWriteKey. func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { @@ -284,11 +284,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { if !a.updateAllowed() { return false } - if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { + if a.numRcvdWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) return true } - if a.numSentWithCurrentKey >= a.keyUpdateInterval { + if a.numSentWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) return true } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index b42f3dce..f9ac53a9 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -283,13 +283,18 @@ var _ = Describe("Updatable AEAD", func() { Context("initiating key updates", func() { const keyUpdateInterval = 20 + var origKeyUpdateInterval uint64 BeforeEach(func() { - Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) - server.keyUpdateInterval = keyUpdateInterval + origKeyUpdateInterval = KeyUpdateInterval + KeyUpdateInterval = keyUpdateInterval server.SetHandshakeConfirmed() }) + AfterEach(func() { + KeyUpdateInterval = origKeyUpdateInterval + }) + It("initiates a key update after sealing the maximum number of packets, for the first update", func() { for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i)