mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
correctly handle key updates within the 3 PTO period
We need to: * stop the timer to drop the previous generation * correctly log that the N-1 keys are dropped immediately when keys are updated to N+1
This commit is contained in:
parent
27448c1fa5
commit
e94db952b7
2 changed files with 86 additions and 5 deletions
|
@ -100,6 +100,14 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *updatableAEAD) rollKeys() {
|
func (a *updatableAEAD) rollKeys() {
|
||||||
|
if a.prevRcvAEAD != nil {
|
||||||
|
a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
|
||||||
|
if a.tracer != nil {
|
||||||
|
a.tracer.DroppedKey(a.keyPhase - 1)
|
||||||
|
}
|
||||||
|
a.prevRcvAEADExpiry = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
a.keyPhase++
|
a.keyPhase++
|
||||||
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
|
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
|
||||||
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
|
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
|
||||||
mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging"
|
mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
|
@ -40,13 +42,12 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
|
|
||||||
Context(fmt.Sprintf("using %s", qtls.CipherSuiteName(cs.ID)), func() {
|
Context(fmt.Sprintf("using %s", qtls.CipherSuiteName(cs.ID)), func() {
|
||||||
var (
|
var (
|
||||||
client, server *updatableAEAD
|
client, server *updatableAEAD
|
||||||
clientTracer, serverTracer *mocklogging.MockConnectionTracer
|
serverTracer *mocklogging.MockConnectionTracer
|
||||||
rttStats *utils.RTTStats
|
rttStats *utils.RTTStats
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
clientTracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
|
||||||
serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||||
trafficSecret1 := make([]byte, 16)
|
trafficSecret1 := make([]byte, 16)
|
||||||
trafficSecret2 := make([]byte, 16)
|
trafficSecret2 := make([]byte, 16)
|
||||||
|
@ -54,7 +55,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
rand.Read(trafficSecret2)
|
rand.Read(trafficSecret2)
|
||||||
|
|
||||||
rttStats = utils.NewRTTStats()
|
rttStats = utils.NewRTTStats()
|
||||||
client = newUpdatableAEAD(rttStats, clientTracer, utils.DefaultLogger)
|
client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger)
|
||||||
server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger)
|
server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger)
|
||||||
client.SetReadKey(cs, trafficSecret2)
|
client.SetReadKey(cs, trafficSecret2)
|
||||||
client.SetWriteKey(cs, trafficSecret1)
|
client.SetWriteKey(cs, trafficSecret1)
|
||||||
|
@ -378,6 +379,78 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
_, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad)
|
_, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() {
|
||||||
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
|
pn := protocol.PacketNumber(i)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
}
|
||||||
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
||||||
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
||||||
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
const nextPN = keyUpdateInterval + 1
|
||||||
|
// Send and receive an acknowledgement for a packet in key phase 1.
|
||||||
|
// We are now running a timer to drop the keys with 3 PTO.
|
||||||
|
server.Seal(nil, msg, nextPN, ad)
|
||||||
|
client.rollKeys()
|
||||||
|
dataKeyPhaseOne := client.Seal(nil, msg, 2, ad)
|
||||||
|
now := time.Now()
|
||||||
|
_, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(server.SetLargestAcked(nextPN))
|
||||||
|
// Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over.
|
||||||
|
// This mean that we need to drop the keys for key phase 0 immediately.
|
||||||
|
client.rollKeys()
|
||||||
|
dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad)
|
||||||
|
gomock.InOrder(
|
||||||
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)),
|
||||||
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true),
|
||||||
|
)
|
||||||
|
_, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("drops keys early when we initiate another key update within the 3 PTO period", func() {
|
||||||
|
// send so many packets that we initiate the first key update
|
||||||
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
|
pn := protocol.PacketNumber(i)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
}
|
||||||
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
||||||
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
||||||
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
// send so many packets that we initiate the next key update
|
||||||
|
for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ {
|
||||||
|
pn := protocol.PacketNumber(i)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
}
|
||||||
|
client.rollKeys()
|
||||||
|
b = client.Seal(nil, []byte("foobar"), 2, []byte("ad"))
|
||||||
|
now := time.Now()
|
||||||
|
_, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed())
|
||||||
|
gomock.InOrder(
|
||||||
|
serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)),
|
||||||
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false),
|
||||||
|
)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
// We haven't received an ACK for a packet sent in key phase 2 yet.
|
||||||
|
// Make sure we canceled the timer to drop the previous key phase.
|
||||||
|
b = client.Seal(nil, []byte("foobar"), 3, []byte("ad"))
|
||||||
|
_, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("reading the key update env", func() {
|
Context("reading the key update env", func() {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue