handshake: remove unnecessary member variable from updatableAEAD

This commit is contained in:
Marten Seemann 2023-03-28 22:50:21 +09:00
parent af6c19617c
commit 1ebd359b20
2 changed files with 11 additions and 6 deletions

View file

@ -27,7 +27,6 @@ type updatableAEAD struct {
firstPacketNumber protocol.PacketNumber firstPacketNumber protocol.PacketNumber
handshakeConfirmed bool handshakeConfirmed bool
keyUpdateInterval uint64
invalidPacketLimit uint64 invalidPacketLimit uint64
invalidPacketCount uint64 invalidPacketCount uint64
@ -74,7 +73,6 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer,
largestAcked: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber,
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
firstSentWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber,
keyUpdateInterval: KeyUpdateInterval,
rttStats: rttStats, rttStats: rttStats,
tracer: tracer, tracer: tracer,
logger: logger, logger: logger,
@ -116,6 +114,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
} }
// SetReadKey sets the read key.
// For the client, this function is called before SetWriteKey. // For the client, this function is called before SetWriteKey.
// For the server, this function is called after SetWriteKey. // For the server, this function is called after SetWriteKey.
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
@ -129,6 +128,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
} }
// SetWriteKey sets the write key.
// For the client, this function is called after SetReadKey. // For the client, this function is called after SetReadKey.
// For the server, this function is called before SetWriteKey. // For the server, this function is called before SetWriteKey.
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
@ -284,11 +284,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
if !a.updateAllowed() { if !a.updateAllowed() {
return false return false
} }
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { if a.numRcvdWithCurrentKey >= KeyUpdateInterval {
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
return true return true
} }
if a.numSentWithCurrentKey >= a.keyUpdateInterval { if a.numSentWithCurrentKey >= KeyUpdateInterval {
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
return true return true
} }

View file

@ -283,13 +283,18 @@ var _ = Describe("Updatable AEAD", func() {
Context("initiating key updates", func() { Context("initiating key updates", func() {
const keyUpdateInterval = 20 const keyUpdateInterval = 20
var origKeyUpdateInterval uint64
BeforeEach(func() { BeforeEach(func() {
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) origKeyUpdateInterval = KeyUpdateInterval
server.keyUpdateInterval = keyUpdateInterval KeyUpdateInterval = keyUpdateInterval
server.SetHandshakeConfirmed() server.SetHandshakeConfirmed()
}) })
AfterEach(func() {
KeyUpdateInterval = origKeyUpdateInterval
})
It("initiates a key update after sealing the maximum number of packets, for the first update", func() { It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < keyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)