mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
introduce an absolute key phase, use it for key updates
This commit is contained in:
parent
5a9c593463
commit
a2a4a216de
3 changed files with 20 additions and 13 deletions
|
@ -43,7 +43,7 @@ func setKeyUpdateInterval() {
|
||||||
type updatableAEAD struct {
|
type updatableAEAD struct {
|
||||||
suite cipherSuite
|
suite cipherSuite
|
||||||
|
|
||||||
keyPhase protocol.KeyPhaseBit
|
keyPhase protocol.KeyPhase
|
||||||
largestAcked protocol.PacketNumber
|
largestAcked protocol.PacketNumber
|
||||||
keyUpdateInterval uint64
|
keyUpdateInterval uint64
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ func newUpdatableAEAD(logger utils.Logger) *updatableAEAD {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *updatableAEAD) rollKeys() {
|
func (a *updatableAEAD) rollKeys() {
|
||||||
a.keyPhase = a.keyPhase.Next()
|
a.keyPhase++
|
||||||
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
|
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
|
||||||
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
|
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
|
||||||
a.numRcvdWithCurrentKey = 0
|
a.numRcvdWithCurrentKey = 0
|
||||||
|
@ -136,7 +136,7 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) {
|
||||||
|
|
||||||
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
||||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||||
if kp != a.keyPhase {
|
if kp != a.keyPhase.Bit() {
|
||||||
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
||||||
if a.prevRcvAEAD == nil {
|
if a.prevRcvAEAD == nil {
|
||||||
// This can only occur when the first packet received has key phase 1.
|
// This can only occur when the first packet received has key phase 1.
|
||||||
|
@ -205,11 +205,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
|
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
|
||||||
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %s", a.numRcvdWithCurrentKey, a.keyPhase.Next())
|
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %s", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if a.numSentWithCurrentKey >= a.keyUpdateInterval {
|
if a.numSentWithCurrentKey >= a.keyUpdateInterval {
|
||||||
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %s", a.numSentWithCurrentKey, a.keyPhase.Next())
|
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %s", a.numSentWithCurrentKey, a.keyPhase+1)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
@ -219,7 +219,7 @@ func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
|
||||||
if a.shouldInitiateKeyUpdate() {
|
if a.shouldInitiateKeyUpdate() {
|
||||||
a.rollKeys()
|
a.rollKeys()
|
||||||
}
|
}
|
||||||
return a.keyPhase
|
return a.keyPhase.Bit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *updatableAEAD) Overhead() int {
|
func (a *updatableAEAD) Overhead() int {
|
||||||
|
|
|
@ -1,5 +1,12 @@
|
||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
|
// KeyPhase is the key phase
|
||||||
|
type KeyPhase uint64
|
||||||
|
|
||||||
|
func (p KeyPhase) Bit() KeyPhaseBit {
|
||||||
|
return p%2 == 1
|
||||||
|
}
|
||||||
|
|
||||||
// KeyPhaseBit is the key phase bit
|
// KeyPhaseBit is the key phase bit
|
||||||
type KeyPhaseBit bool
|
type KeyPhaseBit bool
|
||||||
|
|
||||||
|
@ -16,7 +23,3 @@ func (p KeyPhaseBit) String() string {
|
||||||
}
|
}
|
||||||
return "1"
|
return "1"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p KeyPhaseBit) Next() KeyPhaseBit {
|
|
||||||
return !p
|
|
||||||
}
|
|
||||||
|
|
|
@ -11,8 +11,12 @@ var _ = Describe("Key Phases", func() {
|
||||||
Expect(KeyPhaseOne.String()).To(Equal("1"))
|
Expect(KeyPhaseOne.String()).To(Equal("1"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns the next key phase", func() {
|
It("converts the key phase to the key phase bit", func() {
|
||||||
Expect(KeyPhaseZero.Next()).To(Equal(KeyPhaseOne))
|
Expect(KeyPhase(0).Bit()).To(Equal(KeyPhaseZero))
|
||||||
Expect(KeyPhaseOne.Next()).To(Equal(KeyPhaseZero))
|
Expect(KeyPhase(2).Bit()).To(Equal(KeyPhaseZero))
|
||||||
|
Expect(KeyPhase(4).Bit()).To(Equal(KeyPhaseZero))
|
||||||
|
Expect(KeyPhase(1).Bit()).To(Equal(KeyPhaseOne))
|
||||||
|
Expect(KeyPhase(3).Bit()).To(Equal(KeyPhaseOne))
|
||||||
|
Expect(KeyPhase(5).Bit()).To(Equal(KeyPhaseOne))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue