diff --git a/conn_id_manager.go b/conn_id_manager.go index d2e1ca2d..8fc537c3 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -18,6 +18,7 @@ type connIDManager struct { packetsSinceLastChange uint64 addStatelessResetToken func([16]byte) + removeStatelessResetToken func([16]byte) retireStatelessResetToken func([16]byte) queueControlFrame func(wire.Frame) } @@ -25,12 +26,14 @@ type connIDManager struct { func newConnIDManager( initialDestConnID protocol.ConnectionID, addStatelessResetToken func([16]byte), + removeStatelessResetToken func([16]byte), retireStatelessResetToken func([16]byte), queueControlFrame func(wire.Frame), ) *connIDManager { return &connIDManager{ activeConnectionID: initialDestConnID, addStatelessResetToken: addStatelessResetToken, + removeStatelessResetToken: removeStatelessResetToken, retireStatelessResetToken: retireStatelessResetToken, queueControlFrame: queueControlFrame, } @@ -114,6 +117,12 @@ func (h *connIDManager) updateConnectionID() { h.addStatelessResetToken(*h.activeStatelessResetToken) } +func (h *connIDManager) Close() { + if h.activeStatelessResetToken != nil { + h.removeStatelessResetToken(*h.activeStatelessResetToken) + } +} + // is called when the server performs a Retry // and when the server changes the connection ID in the first Initial sent func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) { diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 2f293ba6..453daa8f 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -13,6 +13,7 @@ var _ = Describe("Connection ID Manager", func() { frameQueue []wire.Frame tokenAdded *[16]byte retiredTokens [][16]byte + removedTokens [][16]byte ) initialConnID := protocol.ConnectionID{1, 1, 1, 1} @@ -20,9 +21,11 @@ var _ = Describe("Connection ID Manager", func() { frameQueue = nil tokenAdded = nil retiredTokens = nil + removedTokens = nil m = newConnIDManager( initialConnID, func(token [16]byte) { tokenAdded = &token }, + func(token [16]byte) { removedTokens = append(removedTokens, token) }, func(token [16]byte) { retiredTokens = append(retiredTokens, token) }, func(f wire.Frame, ) { @@ -209,4 +212,20 @@ var _ = Describe("Connection ID Manager", func() { Expect(retiredTokens).To(HaveLen(1)) Expect(retiredTokens[0]).To(Equal([16]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})) }) + + It("removes the currently active stateless reset token when it is closed", func() { + m.Close() + Expect(retiredTokens).To(BeEmpty()) + Expect(removedTokens).To(BeEmpty()) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: [16]byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + })).To(Succeed()) + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + m.Close() + Expect(retiredTokens).To(BeEmpty()) + Expect(removedTokens).To(HaveLen(1)) + Expect(removedTokens[0]).To(Equal([16]byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})) + }) }) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index e129febf..9daf8f0c 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -110,6 +110,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) } +// RemoveResetToken mocks base method +func (m *MockPacketHandlerManager) RemoveResetToken(arg0 [16]byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveResetToken", arg0) +} + +// RemoveResetToken indicates an expected call of RemoveResetToken +func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) +} + // ReplaceWithClosed mocks base method func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { m.ctrl.T.Helper() diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index e8308b20..053cff0c 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -58,6 +58,18 @@ func (mr *MockSessionRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionRunner)(nil).Remove), arg0) } +// RemoveResetToken mocks base method +func (m *MockSessionRunner) RemoveResetToken(arg0 [16]byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveResetToken", arg0) +} + +// RemoveResetToken indicates an expected call of RemoveResetToken +func (mr *MockSessionRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockSessionRunner)(nil).RemoveResetToken), arg0) +} + // ReplaceWithClosed mocks base method func (m *MockSessionRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { m.ctrl.T.Helper() diff --git a/packet_handler_map.go b/packet_handler_map.go index f9b22ba6..864d72b6 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -134,6 +134,12 @@ func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler) h.mutex.Unlock() } +func (h *packetHandlerMap) RemoveResetToken(token [16]byte) { + h.mutex.Lock() + delete(h.resetTokens, token) + h.mutex.Unlock() +} + func (h *packetHandlerMap) RetireResetToken(token [16]byte) { time.AfterFunc(h.deleteRetiredSessionsAfter, func() { h.mutex.Lock() diff --git a/session.go b/session.go index a0db8705..15e81bb5 100644 --- a/session.go +++ b/session.go @@ -77,6 +77,7 @@ type sessionRunner interface { Remove(protocol.ConnectionID) ReplaceWithClosed(protocol.ConnectionID, packetHandler) AddResetToken([16]byte, packetHandler) + RemoveResetToken([16]byte) RetireResetToken([16]byte) } @@ -209,6 +210,7 @@ var newSession = func( s.connIDManager = newConnIDManager( destConnID, func(token [16]byte) { runner.AddResetToken(token, s) }, + runner.RemoveResetToken, runner.RetireResetToken, s.queueControlFrame, ) @@ -282,6 +284,7 @@ var newClientSession = func( s.connIDManager = newConnIDManager( destConnID, func(token [16]byte) { runner.AddResetToken(token, s) }, + runner.RemoveResetToken, runner.RetireResetToken, s.queueControlFrame, ) @@ -985,6 +988,7 @@ func (s *session) handleCloseError(closeErr closeError) { } s.streamsMap.CloseWithError(quicErr) + s.connIDManager.Close() // If this is a remote close we're done here if closeErr.remote {