fix logging of locally initiated key updates

This commit is contained in:
Marten Seemann 2020-09-09 16:16:00 +07:00
parent e0824e3163
commit ade2882ef4
2 changed files with 21 additions and 11 deletions

View file

@ -242,10 +242,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
if a.shouldInitiateKeyUpdate() { if a.shouldInitiateKeyUpdate() {
a.rollKeys(time.Now())
a.logger.Debugf("Initiating key update to key phase %s", a.keyPhase)
if a.tracer != nil { if a.tracer != nil {
a.tracer.UpdatedKey(a.keyPhase, false) a.tracer.UpdatedKey(a.keyPhase, false)
} }
a.rollKeys(time.Now())
} }
return a.keyPhase.Bit() return a.keyPhase.Bit()
} }

View file

@ -7,6 +7,7 @@ import (
"os" "os"
"time" "time"
mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
@ -37,24 +38,31 @@ var _ = Describe("Updatable AEAD", func() {
cs := cipherSuites[i] cs := cipherSuites[i]
Context(fmt.Sprintf("using %s", qtls.CipherSuiteName(cs.ID)), func() { Context(fmt.Sprintf("using %s", qtls.CipherSuiteName(cs.ID)), func() {
getPeers := func(rttStats *utils.RTTStats) (client, server *updatableAEAD) { var (
client, server *updatableAEAD
clientTracer, serverTracer *mocklogging.MockConnectionTracer
rttStats *utils.RTTStats
)
BeforeEach(func() {
clientTracer = mocklogging.NewMockConnectionTracer(mockCtrl)
serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl)
trafficSecret1 := make([]byte, 16) trafficSecret1 := make([]byte, 16)
trafficSecret2 := make([]byte, 16) trafficSecret2 := make([]byte, 16)
rand.Read(trafficSecret1) rand.Read(trafficSecret1)
rand.Read(trafficSecret2) rand.Read(trafficSecret2)
client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger) rttStats = utils.NewRTTStats()
server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger) client = newUpdatableAEAD(rttStats, clientTracer, utils.DefaultLogger)
server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger)
client.SetReadKey(cs, trafficSecret2) client.SetReadKey(cs, trafficSecret2)
client.SetWriteKey(cs, trafficSecret1) client.SetWriteKey(cs, trafficSecret1)
server.SetReadKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1)
server.SetWriteKey(cs, trafficSecret2) server.SetWriteKey(cs, trafficSecret2)
return })
}
Context("header protection", func() { Context("header protection", func() {
It("encrypts and decrypts the header", func() { It("encrypts and decrypts the header", func() {
server, client := getPeers(&utils.RTTStats{})
var lastFiveBitsDifferent int var lastFiveBitsDifferent int
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
sample := make([]byte, 16) sample := make([]byte, 16)
@ -76,12 +84,8 @@ var _ = Describe("Updatable AEAD", func() {
Context("message encryption", func() { Context("message encryption", func() {
var msg, ad []byte var msg, ad []byte
var server, client *updatableAEAD
var rttStats *utils.RTTStats
BeforeEach(func() { BeforeEach(func() {
rttStats = &utils.RTTStats{}
server, client = getPeers(rttStats)
msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
ad = []byte("Donec in velit neque.") ad = []byte("Donec in velit neque.")
}) })
@ -144,6 +148,7 @@ var _ = Describe("Updatable AEAD", func() {
// now received a message at key phase one // now received a message at key phase one
client.rollKeys(now) client.rollKeys(now)
encrypted1 := client.Seal(nil, msg, 0x43, ad) encrypted1 := client.Seal(nil, msg, 0x43, ad)
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(msg)) Expect(decrypted).To(Equal(msg))
@ -163,6 +168,7 @@ var _ = Describe("Updatable AEAD", func() {
client.rollKeys(now) client.rollKeys(now)
encrypted1 := client.Seal(nil, msg, 0x44, ad) encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
@ -188,6 +194,7 @@ var _ = Describe("Updatable AEAD", func() {
client.rollKeys(now) client.rollKeys(now)
encrypted1 := client.Seal(nil, msg, 0x44, ad) encrypted1 := client.Seal(nil, msg, 0x44, ad)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
@ -233,6 +240,7 @@ var _ = Describe("Updatable AEAD", func() {
// no update allowed before receiving an acknowledgement for the current key phase // no update allowed before receiving an acknowledgement for the current key phase
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.SetLargestAcked(0) server.SetLargestAcked(0)
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
}) })
@ -248,6 +256,7 @@ var _ = Describe("Updatable AEAD", func() {
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
server.Seal(nil, msg, 1, ad) server.Seal(nil, msg, 1, ad)
server.SetLargestAcked(1) server.SetLargestAcked(1)
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
}) })
}) })