mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
Merge pull request #3745 from quic-go/aggressive-key-updates
initiate the first key update after sending / receiving 100 packets
This commit is contained in:
commit
379e7ec848
4 changed files with 100 additions and 29 deletions
|
@ -51,12 +51,14 @@ func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ = Describe("Key Update tests", func() {
|
var _ = Describe("Key Update tests", func() {
|
||||||
var server quic.Listener
|
It("downloads a large file", func() {
|
||||||
|
origKeyUpdateInterval := handshake.KeyUpdateInterval
|
||||||
|
defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }()
|
||||||
|
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
|
||||||
|
|
||||||
runServer := func() {
|
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
|
||||||
var err error
|
|
||||||
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), nil)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -68,14 +70,7 @@ var _ = Describe("Key Update tests", func() {
|
||||||
_, err = str.Write(PRDataLong)
|
_, err = str.Write(PRDataLong)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}()
|
}()
|
||||||
}
|
|
||||||
|
|
||||||
It("downloads a large file", func() {
|
|
||||||
origKeyUpdateInterval := handshake.KeyUpdateInterval
|
|
||||||
defer func() { handshake.KeyUpdateInterval = origKeyUpdateInterval }()
|
|
||||||
handshake.KeyUpdateInterval = 1 // update keys as frequently as possible
|
|
||||||
|
|
||||||
runServer()
|
|
||||||
conn, err := quic.DialAddr(
|
conn, err := quic.DialAddr(
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
getTLSClientConfig(),
|
getTLSClientConfig(),
|
||||||
|
|
|
@ -19,6 +19,10 @@ import (
|
||||||
// It's a package-level variable to allow modifying it for testing purposes.
|
// It's a package-level variable to allow modifying it for testing purposes.
|
||||||
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
|
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
|
||||||
|
|
||||||
|
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
|
||||||
|
// It's a package-level variable to allow modifying it for testing purposes.
|
||||||
|
var FirstKeyUpdateInterval uint64 = 100
|
||||||
|
|
||||||
type updatableAEAD struct {
|
type updatableAEAD struct {
|
||||||
suite *qtls.CipherSuiteTLS13
|
suite *qtls.CipherSuiteTLS13
|
||||||
|
|
||||||
|
@ -27,7 +31,6 @@ type updatableAEAD struct {
|
||||||
firstPacketNumber protocol.PacketNumber
|
firstPacketNumber protocol.PacketNumber
|
||||||
handshakeConfirmed bool
|
handshakeConfirmed bool
|
||||||
|
|
||||||
keyUpdateInterval uint64
|
|
||||||
invalidPacketLimit uint64
|
invalidPacketLimit uint64
|
||||||
invalidPacketCount uint64
|
invalidPacketCount uint64
|
||||||
|
|
||||||
|
@ -74,7 +77,6 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer,
|
||||||
largestAcked: protocol.InvalidPacketNumber,
|
largestAcked: protocol.InvalidPacketNumber,
|
||||||
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
|
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
|
||||||
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
|
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
|
||||||
keyUpdateInterval: KeyUpdateInterval,
|
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
@ -116,6 +118,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte
|
||||||
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
|
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetReadKey sets the read key.
|
||||||
// For the client, this function is called before SetWriteKey.
|
// For the client, this function is called before SetWriteKey.
|
||||||
// For the server, this function is called after SetWriteKey.
|
// For the server, this function is called after SetWriteKey.
|
||||||
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||||
|
@ -129,6 +132,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [
|
||||||
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
|
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWriteKey sets the write key.
|
||||||
// For the client, this function is called after SetReadKey.
|
// For the client, this function is called after SetReadKey.
|
||||||
// For the server, this function is called before SetWriteKey.
|
// For the server, this function is called before SetWriteKey.
|
||||||
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||||
|
@ -284,11 +288,17 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
||||||
if !a.updateAllowed() {
|
if !a.updateAllowed() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
|
// Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism.
|
||||||
|
if a.keyPhase == 0 {
|
||||||
|
if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if a.numRcvdWithCurrentKey >= KeyUpdateInterval {
|
||||||
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if a.numSentWithCurrentKey >= a.keyUpdateInterval {
|
if a.numSentWithCurrentKey >= KeyUpdateInterval {
|
||||||
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
|
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,15 +4,15 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
|
|
||||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/qerr"
|
"github.com/quic-go/quic-go/internal/qerr"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
"github.com/quic-go/quic-go/internal/utils"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
@ -282,16 +282,25 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("initiating key updates", func() {
|
Context("initiating key updates", func() {
|
||||||
|
const firstKeyUpdateInterval = 5
|
||||||
const keyUpdateInterval = 20
|
const keyUpdateInterval = 20
|
||||||
|
var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval))
|
origKeyUpdateInterval = KeyUpdateInterval
|
||||||
server.keyUpdateInterval = keyUpdateInterval
|
origFirstKeyUpdateInterval = FirstKeyUpdateInterval
|
||||||
|
KeyUpdateInterval = keyUpdateInterval
|
||||||
|
FirstKeyUpdateInterval = firstKeyUpdateInterval
|
||||||
server.SetHandshakeConfirmed()
|
server.SetHandshakeConfirmed()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
KeyUpdateInterval = origKeyUpdateInterval
|
||||||
|
FirstKeyUpdateInterval = origFirstKeyUpdateInterval
|
||||||
|
})
|
||||||
|
|
||||||
It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
|
It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
@ -323,7 +332,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
|
|
||||||
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
|
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
|
||||||
// First make sure that we update our keys.
|
// First make sure that we update our keys.
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
@ -331,7 +340,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
// Now that our keys are updated, send a packet using the new keys.
|
// Now that our keys are updated, send a packet using the new keys.
|
||||||
const nextPN = keyUpdateInterval + 1
|
const nextPN = firstKeyUpdateInterval + 1
|
||||||
server.Seal(nil, msg, nextPN, ad)
|
server.Seal(nil, msg, nextPN, ad)
|
||||||
// We haven't decrypted any packet in the new key phase yet.
|
// We haven't decrypted any packet in the new key phase yet.
|
||||||
// This means that the ACK must have been sent in the old key phase.
|
// This means that the ACK must have been sent in the old key phase.
|
||||||
|
@ -343,7 +352,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
|
|
||||||
It("doesn't error before actually sending a packet in the new key phase", func() {
|
It("doesn't error before actually sending a packet in the new key phase", func() {
|
||||||
// First make sure that we update our keys.
|
// First make sure that we update our keys.
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
@ -361,7 +370,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("initiates a key update after opening the maximum number of packets, for the first update", func() {
|
It("initiates a key update after opening the maximum number of packets, for the first update", func() {
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
encrypted := client.Seal(nil, msg, pn, ad)
|
encrypted := client.Seal(nil, msg, pn, ad)
|
||||||
|
@ -394,7 +403,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
|
|
||||||
It("drops keys 3 PTOs after a key update", func() {
|
It("drops keys 3 PTOs after a key update", func() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
@ -430,7 +439,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
data1 := client.Seal(nil, msg, 1, ad)
|
data1 := client.Seal(nil, msg, 1, ad)
|
||||||
_, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad)
|
_, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
@ -446,7 +455,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() {
|
It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() {
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
@ -483,7 +492,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
It("drops keys early when we initiate another key update within the 3 PTO period", func() {
|
It("drops keys early when we initiate another key update within the 3 PTO period", func() {
|
||||||
server.SetHandshakeConfirmed()
|
server.SetHandshakeConfirmed()
|
||||||
// send so many packets that we initiate the first key update
|
// send so many packets that we initiate the first key update
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < firstKeyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
@ -525,3 +534,60 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
func getClientAndServer() (client, server *updatableAEAD) {
|
||||||
|
trafficSecret1 := make([]byte, 16)
|
||||||
|
trafficSecret2 := make([]byte, 16)
|
||||||
|
rand.Read(trafficSecret1)
|
||||||
|
rand.Read(trafficSecret2)
|
||||||
|
|
||||||
|
cs := cipherSuites[0]
|
||||||
|
rttStats := utils.NewRTTStats()
|
||||||
|
client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
|
||||||
|
server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
|
||||||
|
client.SetReadKey(cs, trafficSecret2)
|
||||||
|
client.SetWriteKey(cs, trafficSecret1)
|
||||||
|
server.SetReadKey(cs, trafficSecret1)
|
||||||
|
server.SetWriteKey(cs, trafficSecret2)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPacketEncryption(b *testing.B) {
|
||||||
|
client, _ := getClientAndServer()
|
||||||
|
const l = 1200
|
||||||
|
src := make([]byte, l)
|
||||||
|
rand.Read(src)
|
||||||
|
ad := make([]byte, 32)
|
||||||
|
rand.Read(ad)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
src = client.Seal(src[:0], src[:l], protocol.PacketNumber(i), ad)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPacketDecryption(b *testing.B) {
|
||||||
|
client, server := getClientAndServer()
|
||||||
|
const l = 1200
|
||||||
|
src := make([]byte, l)
|
||||||
|
dst := make([]byte, l)
|
||||||
|
rand.Read(src)
|
||||||
|
ad := make([]byte, 32)
|
||||||
|
rand.Read(ad)
|
||||||
|
src = client.Seal(src[:0], src[:l], 1337, ad)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := server.Open(dst[:0], src, time.Time{}, 1337, protocol.KeyPhaseZero, ad); err != nil {
|
||||||
|
b.Fatalf("opening failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRollKeys(b *testing.B) {
|
||||||
|
client, _ := getClientAndServer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
client.rollKeys()
|
||||||
|
}
|
||||||
|
if int(client.keyPhase) != b.N {
|
||||||
|
b.Fatal("didn't roll keys often enough")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -89,7 +89,7 @@ func runTestcase(testcase string) error {
|
||||||
switch testcase {
|
switch testcase {
|
||||||
case "handshake", "transfer", "retry":
|
case "handshake", "transfer", "retry":
|
||||||
case "keyupdate":
|
case "keyupdate":
|
||||||
handshake.KeyUpdateInterval = 100
|
handshake.FirstKeyUpdateInterval = 100
|
||||||
case "chacha20":
|
case "chacha20":
|
||||||
reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
|
reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
|
||||||
defer reset()
|
defer reset()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue