introduce an absolute key phase, use it for key updates

This commit is contained in:
Marten Seemann 2019-06-29 15:38:21 +07:00
parent 5a9c593463
commit a2a4a216de
3 changed files with 20 additions and 13 deletions

View file

@ -43,7 +43,7 @@ func setKeyUpdateInterval() {
type updatableAEAD struct { type updatableAEAD struct {
suite cipherSuite suite cipherSuite
keyPhase protocol.KeyPhaseBit keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber largestAcked protocol.PacketNumber
keyUpdateInterval uint64 keyUpdateInterval uint64
@ -85,7 +85,7 @@ func newUpdatableAEAD(logger utils.Logger) *updatableAEAD {
} }
func (a *updatableAEAD) rollKeys() { func (a *updatableAEAD) rollKeys() {
a.keyPhase = a.keyPhase.Next() a.keyPhase++
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
a.numRcvdWithCurrentKey = 0 a.numRcvdWithCurrentKey = 0
@ -136,7 +136,7 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
if kp != a.keyPhase { if kp != a.keyPhase.Bit() {
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
if a.prevRcvAEAD == nil { if a.prevRcvAEAD == nil {
// This can only occur when the first packet received has key phase 1. // This can only occur when the first packet received has key phase 1.
@ -205,11 +205,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
return false return false
} }
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %s", a.numRcvdWithCurrentKey, a.keyPhase.Next()) a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %s", a.numRcvdWithCurrentKey, a.keyPhase+1)
return true return true
} }
if a.numSentWithCurrentKey >= a.keyUpdateInterval { if a.numSentWithCurrentKey >= a.keyUpdateInterval {
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %s", a.numSentWithCurrentKey, a.keyPhase.Next()) a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %s", a.numSentWithCurrentKey, a.keyPhase+1)
return true return true
} }
return false return false
@ -219,7 +219,7 @@ func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
if a.shouldInitiateKeyUpdate() { if a.shouldInitiateKeyUpdate() {
a.rollKeys() a.rollKeys()
} }
return a.keyPhase return a.keyPhase.Bit()
} }
func (a *updatableAEAD) Overhead() int { func (a *updatableAEAD) Overhead() int {

View file

@ -1,5 +1,12 @@
package protocol package protocol
// KeyPhase is the key phase
type KeyPhase uint64
func (p KeyPhase) Bit() KeyPhaseBit {
return p%2 == 1
}
// KeyPhaseBit is the key phase bit // KeyPhaseBit is the key phase bit
type KeyPhaseBit bool type KeyPhaseBit bool
@ -16,7 +23,3 @@ func (p KeyPhaseBit) String() string {
} }
return "1" return "1"
} }
func (p KeyPhaseBit) Next() KeyPhaseBit {
return !p
}

View file

@ -11,8 +11,12 @@ var _ = Describe("Key Phases", func() {
Expect(KeyPhaseOne.String()).To(Equal("1")) Expect(KeyPhaseOne.String()).To(Equal("1"))
}) })
It("returns the next key phase", func() { It("converts the key phase to the key phase bit", func() {
Expect(KeyPhaseZero.Next()).To(Equal(KeyPhaseOne)) Expect(KeyPhase(0).Bit()).To(Equal(KeyPhaseZero))
Expect(KeyPhaseOne.Next()).To(Equal(KeyPhaseZero)) Expect(KeyPhase(2).Bit()).To(Equal(KeyPhaseZero))
Expect(KeyPhase(4).Bit()).To(Equal(KeyPhaseZero))
Expect(KeyPhase(1).Bit()).To(Equal(KeyPhaseOne))
Expect(KeyPhase(3).Bit()).To(Equal(KeyPhaseOne))
Expect(KeyPhase(5).Bit()).To(Equal(KeyPhaseOne))
}) })
}) })