mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
check that the peer updated its keys when acknowledging a key update
This commit is contained in:
parent
272a2c88e6
commit
9d4b4f6bf0
6 changed files with 71 additions and 15 deletions
|
@ -275,8 +275,8 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) {
|
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
|
||||||
h.aead.SetLargestAcked(pn)
|
return h.aead.SetLargestAcked(pn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) RunHandshake() {
|
func (h *cryptoSetup) RunHandshake() {
|
||||||
|
|
|
@ -76,7 +76,7 @@ type CryptoSetup interface {
|
||||||
GetSessionTicket() ([]byte, error)
|
GetSessionTicket() ([]byte, error)
|
||||||
|
|
||||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||||
SetLargest1RTTAcked(protocol.PacketNumber)
|
SetLargest1RTTAcked(protocol.PacketNumber) error
|
||||||
DropHandshakeKeys()
|
DropHandshakeKeys()
|
||||||
ConnectionState() ConnectionState
|
ConnectionState() ConnectionState
|
||||||
|
|
||||||
|
|
|
@ -262,8 +262,13 @@ func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byt
|
||||||
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
|
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) {
|
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
|
||||||
|
if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
||||||
|
pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
|
||||||
|
return qerr.NewError(qerr.KeyUpdateError, fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase))
|
||||||
|
}
|
||||||
a.largestAcked = pn
|
a.largestAcked = pn
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *updatableAEAD) updateAllowed() bool {
|
func (a *updatableAEAD) updateAllowed() bool {
|
||||||
|
|
|
@ -259,11 +259,57 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
}
|
}
|
||||||
// no update allowed before receiving an acknowledgement for the current key phase
|
// no update allowed before receiving an acknowledgement for the current key phase
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.SetLargestAcked(0)
|
// receive an ACK for a packet sent in key phase 0
|
||||||
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
||||||
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
||||||
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
|
||||||
|
// First make sure that we update our keys.
|
||||||
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
|
pn := protocol.PacketNumber(i)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
}
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
||||||
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
||||||
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
|
// Now that our keys are updated, send a packet using the new keys.
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
const nextPN = keyUpdateInterval + 1
|
||||||
|
server.Seal(nil, msg, nextPN, ad)
|
||||||
|
// We haven't decrypted any packet in the new key phase yet.
|
||||||
|
// This means that the ACK must have been sent in the old key phase.
|
||||||
|
Expect(server.SetLargestAcked(nextPN)).To(MatchError("KEY_UPDATE_ERROR: received ACK for key phase 1, but peer didn't update keys"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't error before actually sending a packet in the new key phase", func() {
|
||||||
|
// First make sure that we update our keys.
|
||||||
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
|
pn := protocol.PacketNumber(i)
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
server.Seal(nil, msg, pn, ad)
|
||||||
|
}
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
||||||
|
_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
|
||||||
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
|
// Now that our keys are updated, send a packet using the new keys.
|
||||||
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
// We haven't decrypted any packet in the new key phase yet.
|
||||||
|
// This means that the ACK must have been sent in the old key phase.
|
||||||
|
Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
It("initiates a key update after opening the maximum number of packets", func() {
|
It("initiates a key update after opening the maximum number of packets", func() {
|
||||||
for i := 0; i < keyUpdateInterval; i++ {
|
for i := 0; i < keyUpdateInterval; i++ {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
|
@ -275,7 +321,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
// no update allowed before receiving an acknowledgement for the current key phase
|
// no update allowed before receiving an acknowledgement for the current key phase
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, 1, ad)
|
server.Seal(nil, msg, 1, ad)
|
||||||
server.SetLargestAcked(1)
|
Expect(server.SetLargestAcked(1)).To(Succeed())
|
||||||
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
})
|
})
|
||||||
|
@ -286,13 +332,16 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
server.SetLargestAcked(pn)
|
|
||||||
}
|
}
|
||||||
|
b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
|
||||||
|
_, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(server.SetLargestAcked(0)).To(Succeed())
|
||||||
// Now we've initiated the first key update.
|
// Now we've initiated the first key update.
|
||||||
// Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there
|
// Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there
|
||||||
threePTO := 3 * rttStats.PTO(false)
|
threePTO := 3 * rttStats.PTO(false)
|
||||||
dataKeyPhaseZero := client.Seal(nil, msg, 1, ad)
|
dataKeyPhaseZero := client.Seal(nil, msg, 1, ad)
|
||||||
_, err := server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad)
|
_, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
// Now receive a packet with key phase 1.
|
// Now receive a packet with key phase 1.
|
||||||
// This should start the timer to drop the keys after 3 PTOs.
|
// This should start the timer to drop the keys after 3 PTOs.
|
||||||
|
@ -319,7 +368,7 @@ var _ = Describe("Updatable AEAD", func() {
|
||||||
pn := protocol.PacketNumber(i)
|
pn := protocol.PacketNumber(i)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
|
||||||
server.Seal(nil, msg, pn, ad)
|
server.Seal(nil, msg, pn, ad)
|
||||||
server.SetLargestAcked(pn)
|
Expect(server.SetLargestAcked(pn)).To(Succeed())
|
||||||
}
|
}
|
||||||
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
|
||||||
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
|
||||||
|
|
|
@ -250,9 +250,11 @@ func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLargest1RTTAcked mocks base method
|
// SetLargest1RTTAcked mocks base method
|
||||||
func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) {
|
func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "SetLargest1RTTAcked", arg0)
|
ret := m.ctrl.Call(m, "SetLargest1RTTAcked", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked
|
// SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked
|
||||||
|
|
|
@ -51,7 +51,7 @@ type streamManager interface {
|
||||||
type cryptoStreamHandler interface {
|
type cryptoStreamHandler interface {
|
||||||
RunHandshake()
|
RunHandshake()
|
||||||
ChangeConnectionID(protocol.ConnectionID)
|
ChangeConnectionID(protocol.ConnectionID)
|
||||||
SetLargest1RTTAcked(protocol.PacketNumber)
|
SetLargest1RTTAcked(protocol.PacketNumber) error
|
||||||
DropHandshakeKeys()
|
DropHandshakeKeys()
|
||||||
GetSessionTicket() ([]byte, error)
|
GetSessionTicket() ([]byte, error)
|
||||||
io.Closer
|
io.Closer
|
||||||
|
@ -1243,10 +1243,10 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt
|
||||||
if err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime); err != nil {
|
if err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if encLevel == protocol.Encryption1RTT {
|
if encLevel != protocol.Encryption1RTT {
|
||||||
s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked())
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked())
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeLocal closes the session and send a CONNECTION_CLOSE containing the error
|
// closeLocal closes the session and send a CONNECTION_CLOSE containing the error
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue