mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
handshake: initiate the first key update after 100 packets
This commit is contained in:
parent
1ebd359b20
commit
b03585160f
4 changed files with 31 additions and 22 deletions
|
@ -51,12 +51,14 @@ func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader
|
|||
}
|
||||
|
||||
var _ = Describe("Key Update tests", func() {
|
||||
var server quic.Listener
|
||||
It("downloads a large file", func() {
|
||||
origKeyUpdateInterval := handshake.KeyUpdateInterval
|
||||
defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }()
|
||||
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
|
||||
|
||||
runServer := func() {
|
||||
var err error
|
||||
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), nil)
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -68,14 +70,7 @@ var _ = Describe("Key Update tests", func() {
|
|||
_, err = str.Write(PRDataLong)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
}
|
||||
|
||||
It("downloads a large file", func() {
|
||||
origKeyUpdateInterval := handshake.KeyUpdateInterval
|
||||
defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }()
|
||||
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
|
||||
|
||||
runServer()
|
||||
conn, err := quic.DialAddr(
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
|
|
|
@ -19,6 +19,10 @@ import (
|
|||
// It's a package-level variable to allow modifying it for testing purposes.
|
||||
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
|
||||
|
||||
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
|
||||
// It's a package-level variable to allow modifying it for testing purposes.
|
||||
var FirstKeyUpdateInterval uint64 = 100
|
||||
|
||||
type updatableAEAD struct {
|
||||
suite *qtls.CipherSuiteTLS13
|
||||
|
||||
|
@ -284,6 +288,12 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
|||
if !a.updateAllowed() {
|
||||
return false
|
||||
}
|
||||
// Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism.
|
||||
if a.keyPhase == 0 {
|
||||
if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if a.numRcvdWithCurrentKey >= KeyUpdateInterval {
|
||||
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
||||
return true
|
||||
|
|
|
@ -282,21 +282,25 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
})
|
||||
|
||||
Context("initiating key updates", func() {
|
||||
const firstKeyUpdateInterval = 5
|
||||
const keyUpdateInterval = 20
|
||||
var origKeyUpdateInterval uint64
|
||||
var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64
|
||||
|
||||
BeforeEach(func() {
|
||||
origKeyUpdateInterval = KeyUpdateInterval
|
||||
origFirstKeyUpdateInterval = FirstKeyUpdateInterval
|
||||
KeyUpdateInterval = keyUpdateInterval
|
||||
FirstKeyUpdateInterval = firstKeyUpdateInterval
|
||||
server.SetHandshakeConfirmed()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
KeyUpdateInterval = origKeyUpdateInterval
|
||||
FirstKeyUpdateInterval = origFirstKeyUpdateInterval
|
||||
})
|
||||
|
||||
It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, pn, ad)
|
||||
|
@ -328,7 +332,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
|
||||
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
|
||||
// First make sure that we update our keys.
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, pn, ad)
|
||||
|
@ -336,7 +340,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||
// Now that our keys are updated, send a packet using the new keys.
|
||||
const nextPN = keyUpdateInterval + 1
|
||||
const nextPN = firstKeyUpdateInterval + 1
|
||||
server.Seal(nil, msg, nextPN, ad)
|
||||
// We haven't decrypted any packet in the new key phase yet.
|
||||
// This means that the ACK must have been sent in the old key phase.
|
||||
|
@ -348,7 +352,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
|
||||
It("doesn't error before actually sending a packet in the new key phase", func() {
|
||||
// First make sure that we update our keys.
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, pn, ad)
|
||||
|
@ -366,7 +370,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
})
|
||||
|
||||
It("initiates a key update after opening the maximum number of packets, for the first update", func() {
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
encrypted := client.Seal(nil, msg, pn, ad)
|
||||
|
@ -399,7 +403,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
|
||||
It("drops keys 3 PTOs after a key update", func() {
|
||||
now := time.Now()
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, pn, ad)
|
||||
|
@ -435,7 +439,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
data1 := client.Seal(nil, msg, 1, ad)
|
||||
_, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, pn, ad)
|
||||
|
@ -451,7 +455,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
})
|
||||
|
||||
It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() {
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, pn, ad)
|
||||
|
@ -488,7 +492,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
It("drops keys early when we initiate another key update within the 3 PTO period", func() {
|
||||
server.SetHandshakeConfirmed()
|
||||
// send so many packets that we initiate the first key update
|
||||
for i := 0; i < keyUpdateInterval; i++ {
|
||||
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||
pn := protocol.PacketNumber(i)
|
||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||
server.Seal(nil, msg, pn, ad)
|
||||
|
|
|
@ -88,7 +88,7 @@ func runTestcase(testcase string) error {
|
|||
switch testcase {
|
||||
case "handshake", "transfer", "retry":
|
||||
case "keyupdate":
|
||||
handshake.KeyUpdateInterval = 100
|
||||
handshake.FirstKeyUpdateInterval = 100
|
||||
case "chacha20":
|
||||
tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256}
|
||||
case "multiconnect":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue