diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index f698db66..52446af9 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -563,6 +563,15 @@ func (h *cryptoSetup) dropInitialKeys() { h.logger.Debugf("Dropping Initial keys.") } +func (h *cryptoSetup) DropHandshakeKeys() { + h.mutex.Lock() + h.handshakeOpener = nil + h.handshakeSealer = nil + h.mutex.Unlock() + h.runner.DropKeys(protocol.EncryptionHandshake) + h.logger.Debugf("Dropping Handshake keys.") +} + func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { h.mutex.Lock() defer h.mutex.Unlock() diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 1baee25a..1159266a 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -73,6 +73,7 @@ type CryptoSetup interface { HandleMessage([]byte, protocol.EncryptionLevel) bool SetLargest1RTTAcked(protocol.PacketNumber) + DropHandshakeKeys() ConnectionState() tls.ConnectionState GetInitialOpener() (LongHeaderOpener, error) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index af24ed46..411bbd04 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -76,6 +76,18 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) } +// DropHandshakeKeys mocks base method +func (m *MockCryptoSetup) DropHandshakeKeys() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropHandshakeKeys") +} + +// DropHandshakeKeys indicates an expected call of DropHandshakeKeys +func (mr *MockCryptoSetupMockRecorder) DropHandshakeKeys() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropHandshakeKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DropHandshakeKeys)) +} + // Get1RTTOpener mocks base method func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { m.ctrl.T.Helper() diff --git a/session.go b/session.go index a3efa271..4665b641 100644 --- a/session.go +++ b/session.go @@ -51,6 +51,7 @@ type cryptoStreamHandler interface { RunHandshake() ChangeConnectionID(protocol.ConnectionID) SetLargest1RTTAcked(protocol.PacketNumber) + DropHandshakeKeys() io.Closer ConnectionState() tls.ConnectionState } @@ -610,6 +611,7 @@ func (s *session) handleHandshakeComplete() { s.closeLocal(err) } s.queueControlFrame(&wire.NewTokenFrame{Token: token}) + s.cryptoStreamHandler.DropHandshakeKeys() s.queueControlFrame(&wire.HandshakeDoneFrame{}) } } diff --git a/session_test.go b/session_test.go index ede06093..c48cd4a8 100644 --- a/session_test.go +++ b/session_test.go @@ -1204,6 +1204,7 @@ var _ = Describe("Session", func() { defer GinkgoRecover() <-finishHandshake cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().DropHandshakeKeys() close(sess.handshakeCompleteChan) sess.run() }() @@ -1256,6 +1257,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().DropHandshakeKeys() close(sess.handshakeCompleteChan) sess.run() }() @@ -1506,6 +1508,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1) close(sess.handshakeCompleteChan) err := sess.run() nerr, ok := err.(net.Error)