diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index e086846c..a08824de 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -43,7 +43,7 @@ func setKeyUpdateInterval() { type updatableAEAD struct { suite cipherSuite - keyPhase protocol.KeyPhaseBit + keyPhase protocol.KeyPhase largestAcked protocol.PacketNumber keyUpdateInterval uint64 @@ -85,7 +85,7 @@ func newUpdatableAEAD(logger utils.Logger) *updatableAEAD { } func (a *updatableAEAD) rollKeys() { - a.keyPhase = a.keyPhase.Next() + a.keyPhase++ a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber a.firstSentWithCurrentKey = protocol.InvalidPacketNumber a.numRcvdWithCurrentKey = 0 @@ -136,7 +136,7 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) - if kp != a.keyPhase { + if kp != a.keyPhase.Bit() { if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { if a.prevRcvAEAD == nil { // This can only occur when the first packet received has key phase 1. @@ -205,11 +205,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { return false } if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { - a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %s", a.numRcvdWithCurrentKey, a.keyPhase.Next()) + a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %s", a.numRcvdWithCurrentKey, a.keyPhase+1) return true } if a.numSentWithCurrentKey >= a.keyUpdateInterval { - a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %s", a.numSentWithCurrentKey, a.keyPhase.Next()) + a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %s", a.numSentWithCurrentKey, a.keyPhase+1) return true } return false @@ -219,7 +219,7 @@ func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { a.rollKeys() } - return a.keyPhase + return a.keyPhase.Bit() } func (a *updatableAEAD) Overhead() int { diff --git a/internal/protocol/key_phase.go b/internal/protocol/key_phase.go index 693aeb15..2ebc3f92 100644 --- a/internal/protocol/key_phase.go +++ b/internal/protocol/key_phase.go @@ -1,5 +1,12 @@ package protocol +// KeyPhase is the key phase +type KeyPhase uint64 + +func (p KeyPhase) Bit() KeyPhaseBit { + return p%2 == 1 +} + // KeyPhaseBit is the key phase bit type KeyPhaseBit bool @@ -16,7 +23,3 @@ func (p KeyPhaseBit) String() string { } return "1" } - -func (p KeyPhaseBit) Next() KeyPhaseBit { - return !p -} diff --git a/internal/protocol/key_phase_test.go b/internal/protocol/key_phase_test.go index 67e3aabf..da06513e 100644 --- a/internal/protocol/key_phase_test.go +++ b/internal/protocol/key_phase_test.go @@ -11,8 +11,12 @@ var _ = Describe("Key Phases", func() { Expect(KeyPhaseOne.String()).To(Equal("1")) }) - It("returns the next key phase", func() { - Expect(KeyPhaseZero.Next()).To(Equal(KeyPhaseOne)) - Expect(KeyPhaseOne.Next()).To(Equal(KeyPhaseZero)) + It("converts the key phase to the key phase bit", func() { + Expect(KeyPhase(0).Bit()).To(Equal(KeyPhaseZero)) + Expect(KeyPhase(2).Bit()).To(Equal(KeyPhaseZero)) + Expect(KeyPhase(4).Bit()).To(Equal(KeyPhaseZero)) + Expect(KeyPhase(1).Bit()).To(Equal(KeyPhaseOne)) + Expect(KeyPhase(3).Bit()).To(Equal(KeyPhaseOne)) + Expect(KeyPhase(5).Bit()).To(Equal(KeyPhaseOne)) }) })