diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 54c39f32..7e0bb99a 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -275,8 +275,8 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { } } -func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) { - h.aead.SetLargestAcked(pn) +func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error { + return h.aead.SetLargestAcked(pn) } func (h *cryptoSetup) RunHandshake() { diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 4c72281d..3bf9dc06 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -76,7 +76,7 @@ type CryptoSetup interface { GetSessionTicket() ([]byte, error) HandleMessage([]byte, protocol.EncryptionLevel) bool - SetLargest1RTTAcked(protocol.PacketNumber) + SetLargest1RTTAcked(protocol.PacketNumber) error DropHandshakeKeys() ConnectionState() ConnectionState diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 71d42fb4..d79a35c4 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -262,8 +262,13 @@ func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byt return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) } -func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) { +func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { + if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { + return qerr.NewError(qerr.KeyUpdateError, fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase)) + } a.largestAcked = pn + return nil } func (a *updatableAEAD) updateAllowed() bool { diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 46ab5bcb..a16c9c01 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -259,11 +259,57 @@ var _ = Describe("Updatable AEAD", func() { } // no update allowed before receiving an acknowledgement for the current key phase Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.SetLargestAcked(0) + // receive an ACK for a packet sent in key phase 0 + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) + It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { + // First make sure that we update our keys. + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + // Now that our keys are updated, send a packet using the new keys. + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + const nextPN = keyUpdateInterval + 1 + server.Seal(nil, msg, nextPN, ad) + // We haven't decrypted any packet in the new key phase yet. + // This means that the ACK must have been sent in the old key phase. + Expect(server.SetLargestAcked(nextPN)).To(MatchError("KEY_UPDATE_ERROR: received ACK for key phase 1, but peer didn't update keys")) + }) + + It("doesn't error before actually sending a packet in the new key phase", func() { + // First make sure that we update our keys. + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + // Now that our keys are updated, send a packet using the new keys. + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // We haven't decrypted any packet in the new key phase yet. + // This means that the ACK must have been sent in the old key phase. + Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) + }) + It("initiates a key update after opening the maximum number of packets", func() { for i := 0; i < keyUpdateInterval; i++ { pn := protocol.PacketNumber(i) @@ -275,7 +321,7 @@ var _ = Describe("Updatable AEAD", func() { // no update allowed before receiving an acknowledgement for the current key phase Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, 1, ad) - server.SetLargestAcked(1) + Expect(server.SetLargestAcked(1)).To(Succeed()) serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) @@ -286,13 +332,16 @@ var _ = Describe("Updatable AEAD", func() { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) - server.SetLargestAcked(pn) } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + Expect(server.SetLargestAcked(0)).To(Succeed()) // Now we've initiated the first key update. // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there threePTO := 3 * rttStats.PTO(false) dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) - _, err := server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) + _, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) Expect(err).ToNot(HaveOccurred()) // Now receive a packet with key phase 1. // This should start the timer to drop the keys after 3 PTOs. @@ -319,7 +368,7 @@ var _ = Describe("Updatable AEAD", func() { pn := protocol.PacketNumber(i) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, pn, ad) - server.SetLargestAcked(pn) + Expect(server.SetLargestAcked(pn)).To(Succeed()) } serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 193756e3..fb247276 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -250,9 +250,11 @@ func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call { } // SetLargest1RTTAcked mocks base method -func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) { +func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLargest1RTTAcked", arg0) + ret := m.ctrl.Call(m, "SetLargest1RTTAcked", arg0) + ret0, _ := ret[0].(error) + return ret0 } // SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked diff --git a/session.go b/session.go index d768d8e4..a1a23cde 100644 --- a/session.go +++ b/session.go @@ -51,7 +51,7 @@ type streamManager interface { type cryptoStreamHandler interface { RunHandshake() ChangeConnectionID(protocol.ConnectionID) - SetLargest1RTTAcked(protocol.PacketNumber) + SetLargest1RTTAcked(protocol.PacketNumber) error DropHandshakeKeys() GetSessionTicket() ([]byte, error) io.Closer @@ -1243,10 +1243,10 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt if err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime); err != nil { return err } - if encLevel == protocol.Encryption1RTT { - s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) + if encLevel != protocol.Encryption1RTT { + return nil } - return nil + return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } // closeLocal closes the session and send a CONNECTION_CLOSE containing the error