Merge pull request #3745 from quic-go/aggressive-key-updates

initiate the first key update after sending / receiving 100 packets
This commit is contained in:
Marten Seemann 2023-04-19 16:28:33 +02:00 committed by GitHub
commit 379e7ec848
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 29 deletions

View file

@ -51,12 +51,14 @@ func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader
} }
var _ = Describe("Key Update tests", func() { var _ = Describe("Key Update tests", func() {
var server quic.Listener It("downloads a large file", func() {
origKeyUpdateInterval := handshake.KeyUpdateInterval
defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }()
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
runServer := func() { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
var err error
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer server.Close()
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -68,14 +70,7 @@ var _ = Describe("Key Update tests", func() {
_, err = str.Write(PRDataLong) _, err = str.Write(PRDataLong)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}() }()
}
It("downloads a large file", func() {
origKeyUpdateInterval := handshake.KeyUpdateInterval
defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }()
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
runServer()
conn, err := quic.DialAddr( conn, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),

View file

@ -19,6 +19,10 @@ import (
// It's a package-level variable to allow modifying it for testing purposes. // It's a package-level variable to allow modifying it for testing purposes.
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
// It's a package-level variable to allow modifying it for testing purposes.
var FirstKeyUpdateInterval uint64 = 100
type updatableAEAD struct { type updatableAEAD struct {
suite *qtls.CipherSuiteTLS13 suite *qtls.CipherSuiteTLS13
@ -27,7 +31,6 @@ type updatableAEAD struct {
firstPacketNumber protocol.PacketNumber firstPacketNumber protocol.PacketNumber
handshakeConfirmed bool handshakeConfirmed bool
keyUpdateInterval uint64
invalidPacketLimit uint64 invalidPacketLimit uint64
invalidPacketCount uint64 invalidPacketCount uint64
@ -74,7 +77,6 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer,
largestAcked: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber,
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
firstSentWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber,
keyUpdateInterval: KeyUpdateInterval,
rttStats: rttStats, rttStats: rttStats,
tracer: tracer, tracer: tracer,
logger: logger, logger: logger,
@ -116,6 +118,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
} }
// SetReadKey sets the read key.
// For the client, this function is called before SetWriteKey. // For the client, this function is called before SetWriteKey.
// For the server, this function is called after SetWriteKey. // For the server, this function is called after SetWriteKey.
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
@ -129,6 +132,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
} }
// SetWriteKey sets the write key.
// For the client, this function is called after SetReadKey. // For the client, this function is called after SetReadKey.
// For the server, this function is called before SetWriteKey. // For the server, this function is called before SetWriteKey.
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
@ -284,11 +288,17 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
if !a.updateAllowed() { if !a.updateAllowed() {
return false return false
} }
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { // Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism.
if a.keyPhase == 0 {
if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval {
return true
}
}
if a.numRcvdWithCurrentKey >= KeyUpdateInterval {
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
return true return true
} }
if a.numSentWithCurrentKey >= a.keyUpdateInterval { if a.numSentWithCurrentKey >= KeyUpdateInterval {
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
return true return true
} }

View file

@ -4,15 +4,15 @@ import (
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"testing"
"time" "time"
"github.com/golang/mock/gomock"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -282,16 +282,25 @@ var _ = Describe("Updatable AEAD", func() {
}) })
Context("initiating key updates", func() { Context("initiating key updates", func() {
const firstKeyUpdateInterval = 5
const keyUpdateInterval = 20 const keyUpdateInterval = 20
var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64
BeforeEach(func() { BeforeEach(func() {
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) origKeyUpdateInterval = KeyUpdateInterval
server.keyUpdateInterval = keyUpdateInterval origFirstKeyUpdateInterval = FirstKeyUpdateInterval
KeyUpdateInterval = keyUpdateInterval
FirstKeyUpdateInterval = firstKeyUpdateInterval
server.SetHandshakeConfirmed() server.SetHandshakeConfirmed()
}) })
AfterEach(func() {
KeyUpdateInterval = origKeyUpdateInterval
FirstKeyUpdateInterval = origFirstKeyUpdateInterval
})
It("initiates a key update after sealing the maximum number of packets, for the first update", func() { It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
@ -323,7 +332,7 @@ var _ = Describe("Updatable AEAD", func() {
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
// First make sure that we update our keys. // First make sure that we update our keys.
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
@ -331,7 +340,7 @@ var _ = Describe("Updatable AEAD", func() {
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
// Now that our keys are updated, send a packet using the new keys. // Now that our keys are updated, send a packet using the new keys.
const nextPN = keyUpdateInterval + 1 const nextPN = firstKeyUpdateInterval + 1
server.Seal(nil, msg, nextPN, ad) server.Seal(nil, msg, nextPN, ad)
// We haven't decrypted any packet in the new key phase yet. // We haven't decrypted any packet in the new key phase yet.
// This means that the ACK must have been sent in the old key phase. // This means that the ACK must have been sent in the old key phase.
@ -343,7 +352,7 @@ var _ = Describe("Updatable AEAD", func() {
It("doesn't error before actually sending a packet in the new key phase", func() { It("doesn't error before actually sending a packet in the new key phase", func() {
// First make sure that we update our keys. // First make sure that we update our keys.
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
@ -361,7 +370,7 @@ var _ = Describe("Updatable AEAD", func() {
}) })
It("initiates a key update after opening the maximum number of packets, for the first update", func() { It("initiates a key update after opening the maximum number of packets, for the first update", func() {
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
encrypted := client.Seal(nil, msg, pn, ad) encrypted := client.Seal(nil, msg, pn, ad)
@ -394,7 +403,7 @@ var _ = Describe("Updatable AEAD", func() {
It("drops keys 3 PTOs after a key update", func() { It("drops keys 3 PTOs after a key update", func() {
now := time.Now() now := time.Now()
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
@ -430,7 +439,7 @@ var _ = Describe("Updatable AEAD", func() {
data1 := client.Seal(nil, msg, 1, ad) data1 := client.Seal(nil, msg, 1, ad)
_, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
@ -446,7 +455,7 @@ var _ = Describe("Updatable AEAD", func() {
}) })
It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() {
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
@ -483,7 +492,7 @@ var _ = Describe("Updatable AEAD", func() {
It("drops keys early when we initiate another key update within the 3 PTO period", func() { It("drops keys early when we initiate another key update within the 3 PTO period", func() {
server.SetHandshakeConfirmed() server.SetHandshakeConfirmed()
// send so many packets that we initiate the first key update // send so many packets that we initiate the first key update
for i := 0; i < keyUpdateInterval; i++ { for i := 0; i < firstKeyUpdateInterval; i++ {
pn := protocol.PacketNumber(i) pn := protocol.PacketNumber(i)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, pn, ad) server.Seal(nil, msg, pn, ad)
@ -525,3 +534,60 @@ var _ = Describe("Updatable AEAD", func() {
}) })
} }
}) })
func getClientAndServer() (client, server *updatableAEAD) {
trafficSecret1 := make([]byte, 16)
trafficSecret2 := make([]byte, 16)
rand.Read(trafficSecret1)
rand.Read(trafficSecret2)
cs := cipherSuites[0]
rttStats := utils.NewRTTStats()
client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
client.SetReadKey(cs, trafficSecret2)
client.SetWriteKey(cs, trafficSecret1)
server.SetReadKey(cs, trafficSecret1)
server.SetWriteKey(cs, trafficSecret2)
return
}
func BenchmarkPacketEncryption(b *testing.B) {
client, _ := getClientAndServer()
const l = 1200
src := make([]byte, l)
rand.Read(src)
ad := make([]byte, 32)
rand.Read(ad)
for i := 0; i < b.N; i++ {
src = client.Seal(src[:0], src[:l], protocol.PacketNumber(i), ad)
}
}
func BenchmarkPacketDecryption(b *testing.B) {
client, server := getClientAndServer()
const l = 1200
src := make([]byte, l)
dst := make([]byte, l)
rand.Read(src)
ad := make([]byte, 32)
rand.Read(ad)
src = client.Seal(src[:0], src[:l], 1337, ad)
for i := 0; i < b.N; i++ {
if _, err := server.Open(dst[:0], src, time.Time{}, 1337, protocol.KeyPhaseZero, ad); err != nil {
b.Fatalf("opening failed: %v", err)
}
}
}
func BenchmarkRollKeys(b *testing.B) {
client, _ := getClientAndServer()
for i := 0; i < b.N; i++ {
client.rollKeys()
}
if int(client.keyPhase) != b.N {
b.Fatal("didn't roll keys often enough")
}
}

View file

@ -89,7 +89,7 @@ func runTestcase(testcase string) error {
switch testcase { switch testcase {
case "handshake", "transfer", "retry": case "handshake", "transfer", "retry":
case "keyupdate": case "keyupdate":
handshake.KeyUpdateInterval = 100 handshake.FirstKeyUpdateInterval = 100
case "chacha20": case "chacha20":
reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256) reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
defer reset() defer reset()