diff --git a/conn_id_generator.go b/conn_id_generator.go index 07df3d60..b07c7e48 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -20,7 +20,7 @@ type connIDGenerator struct { getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken removeConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func(protocol.ConnectionID, packetHandler) + replaceWithClosed func([]protocol.ConnectionID, packetHandler) queueControlFrame func(wire.Frame) version protocol.VersionNumber @@ -33,7 +33,7 @@ func newConnIDGenerator( getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), - replaceWithClosed func(protocol.ConnectionID, packetHandler), + replaceWithClosed func([]protocol.ConnectionID, packetHandler), queueControlFrame func(wire.Frame), version protocol.VersionNumber, ) *connIDGenerator { @@ -131,10 +131,12 @@ func (m *connIDGenerator) RemoveAll() { } func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { + connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) if m.initialClientDestConnID != nil { - m.replaceWithClosed(m.initialClientDestConnID, handler) + connIDs = append(connIDs, m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { - m.replaceWithClosed(connID, handler) + connIDs = append(connIDs, connID) } + m.replaceWithClosed(connIDs, handler) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 8162b2ff..9c832fd4 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -40,7 +40,11 @@ var _ = Describe("Connection ID Generator", func() { connIDToToken, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, - func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h }, + func(cs []protocol.ConnectionID, h packetHandler) { + for _, c := range cs { + replacedWithClosed[string(c)] = h + } + }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, protocol.VersionDraft29, ) diff --git a/connection.go b/connection.go index 3316069a..8c6f50e7 100644 --- a/connection.go +++ b/connection.go @@ -95,7 +95,7 @@ type connRunner interface { GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) - ReplaceWithClosed(protocol.ConnectionID, packetHandler) + ReplaceWithClosed([]protocol.ConnectionID, packetHandler) AddResetToken(protocol.StatelessResetToken, packetHandler) RemoveResetToken(protocol.StatelessResetToken) } diff --git a/connection_test.go b/connection_test.go index 2d792dd2..d7535d5f 100644 --- a/connection_test.go +++ b/connection_test.go @@ -72,8 +72,11 @@ var _ = Describe("Connection", func() { } expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1) - connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) { + Expect(connIDs).To(ContainElement(srcConnID)) + if len(connIDs) > 1 { + Expect(connIDs).To(ContainElement(clientDestConnID)) + } Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) s.shutdown() Eventually(areClosedConnsRunning).Should(BeFalse()) @@ -330,10 +333,8 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(expectedErr) - connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) - }) - connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) { + Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) }) cryptoSetup.EXPECT().Close() @@ -361,10 +362,8 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(testErr) - connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) - }) - connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) { + Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) }) cryptoSetup.EXPECT().Close() @@ -2433,7 +2432,7 @@ var _ = Describe("Client Connection", func() { } expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) { s.shutdown() Eventually(areClosedConnsRunning).Should(BeFalse()) }) @@ -2767,7 +2766,7 @@ var _ = Describe("Client Connection", func() { expectClose := func(applicationClose bool) { if !closed { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) { Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) s.shutdown() }) diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index 607bd027..02080834 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -99,7 +99,7 @@ func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock } // ReplaceWithClosed mocks base method. -func (m *MockConnRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { +func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) } diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 01553655..331695c4 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -139,7 +139,7 @@ func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{ } // ReplaceWithClosed mocks base method. -func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { +func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) } diff --git a/packet_handler_map.go b/packet_handler_map.go index 06692e93..3b37bf27 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -219,18 +219,22 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { }) } -func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { +func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handler packetHandler) { h.mutex.Lock() - h.handlers[string(id)] = handler + for _, id := range ids { + h.handlers[string(id)] = handler + } h.mutex.Unlock() - h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id) + h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids) time.AfterFunc(h.deleteRetiredConnsAfter, func() { h.mutex.Lock() handler.shutdown() - delete(h.handlers, string(id)) + for _, id := range ids { + delete(h.handlers, string(id)) + } h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id) + h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids) }) }