diff --git a/conn_id_generator.go b/conn_id_generator.go index 9cf21d9a..d7be6540 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -19,7 +19,7 @@ type connIDGenerator struct { getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken removeConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte) + replaceWithClosed func([]protocol.ConnectionID, []byte) queueControlFrame func(wire.Frame) } @@ -30,7 +30,7 @@ func newConnIDGenerator( getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), - replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte), + replaceWithClosed func([]protocol.ConnectionID, []byte), queueControlFrame func(wire.Frame), generator ConnectionIDGenerator, ) *connIDGenerator { @@ -126,7 +126,7 @@ func (m *connIDGenerator) RemoveAll() { } } -func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) { +func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) if m.initialClientDestConnID != nil { connIDs = append(connIDs, *m.initialClientDestConnID) @@ -134,5 +134,5 @@ func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) } - m.replaceWithClosed(connIDs, pers, connClose) + m.replaceWithClosed(connIDs, connClose) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 2252de84..c438158e 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -41,9 +41,7 @@ var _ = Describe("Connection ID Generator", func() { connIDToToken, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, - func(cs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { - replacedWithClosed = append(replacedWithClosed, cs...) - }, + func(cs []protocol.ConnectionID, _ []byte) { replacedWithClosed = append(replacedWithClosed, cs...) }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, &protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()}, ) @@ -177,7 +175,7 @@ var _ = Describe("Connection ID Generator", func() { It("replaces with a closed connection for all connection IDs", func() { Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) Expect(queuedFrames).To(HaveLen(4)) - g.ReplaceWithClosed(protocol.PerspectiveClient, []byte("foobar")) + g.ReplaceWithClosed([]byte("foobar")) Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID)) Expect(replacedWithClosed).To(ContainElement(initialConnID)) diff --git a/connection.go b/connection.go index 994d2bf3..8c66cb0b 100644 --- a/connection.go +++ b/connection.go @@ -93,7 +93,7 @@ type connRunner interface { GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) - ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte) + ReplaceWithClosed([]protocol.ConnectionID, []byte) AddResetToken(protocol.StatelessResetToken, packetHandler) RemoveResetToken(protocol.StatelessResetToken) } @@ -1632,7 +1632,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { // If this is a remote close we're done here if closeErr.remote { - s.connIDGenerator.ReplaceWithClosed(s.perspective, nil) + s.connIDGenerator.ReplaceWithClosed(nil) return } if closeErr.immediate { @@ -1649,7 +1649,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { if err != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } - s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket) + s.connIDGenerator.ReplaceWithClosed(connClosePacket) } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { diff --git a/connection_test.go b/connection_test.go index eb8d2625..74e79618 100644 --- a/connection_test.go +++ b/connection_test.go @@ -76,7 +76,7 @@ var _ = Describe("Connection", func() { } expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ []byte) { Expect(connIDs).To(ContainElement(srcConnID)) if len(connIDs) > 1 { Expect(connIDs).To(ContainElement(clientDestConnID)) @@ -346,7 +346,7 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(expectedErr) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ []byte) { Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) }) cryptoSetup.EXPECT().Close() @@ -375,7 +375,7 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(testErr) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ []byte) { Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) }) cryptoSetup.EXPECT().Close() @@ -558,7 +558,7 @@ var _ = Describe("Connection", func() { runConn() cryptoSetup.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() b, err := wire.AppendShortHeader(nil, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne) Expect(err).ToNot(HaveOccurred()) @@ -2594,7 +2594,7 @@ var _ = Describe("Client Connection", func() { // make sure the go routine returns packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) + connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -2888,7 +2888,7 @@ var _ = Describe("Client Connection", func() { expectClose := func(applicationClose, errored bool) { if !closed && !errored { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()) if applicationClose { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } else { diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index b1fd19f4..cc119ddf 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -223,15 +223,15 @@ func (c *ConnRunnerRemoveResetTokenCall) DoAndReturn(f func(protocol.StatelessRe } // ReplaceWithClosed mocks base method. -func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) { +func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 []byte) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2) + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *ConnRunnerReplaceWithClosedCall { +func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 any) *ConnRunnerReplaceWithClosedCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1) return &ConnRunnerReplaceWithClosedCall{Call: call} } @@ -247,13 +247,13 @@ func (c *ConnRunnerReplaceWithClosedCall) Return() *ConnRunnerReplaceWithClosedC } // Do rewrite *gomock.Call.Do -func (c *ConnRunnerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, protocol.Perspective, []byte)) *ConnRunnerReplaceWithClosedCall { +func (c *ConnRunnerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, []byte)) *ConnRunnerReplaceWithClosedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *ConnRunnerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, protocol.Perspective, []byte)) *ConnRunnerReplaceWithClosedCall { +func (c *ConnRunnerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, []byte)) *ConnRunnerReplaceWithClosedCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index e154e00a..6170981f 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -375,15 +375,15 @@ func (c *PacketHandlerManagerRemoveResetTokenCall) DoAndReturn(f func(protocol.S } // ReplaceWithClosed mocks base method. -func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) { +func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 []byte) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2) + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *PacketHandlerManagerReplaceWithClosedCall { +func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1 any) *PacketHandlerManagerReplaceWithClosedCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1) return &PacketHandlerManagerReplaceWithClosedCall{Call: call} } @@ -399,13 +399,13 @@ func (c *PacketHandlerManagerReplaceWithClosedCall) Return() *PacketHandlerManag } // Do rewrite *gomock.Call.Do -func (c *PacketHandlerManagerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, protocol.Perspective, []byte)) *PacketHandlerManagerReplaceWithClosedCall { +func (c *PacketHandlerManagerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, []byte)) *PacketHandlerManagerReplaceWithClosedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *PacketHandlerManagerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, protocol.Perspective, []byte)) *PacketHandlerManagerReplaceWithClosedCall { +func (c *PacketHandlerManagerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, []byte)) *PacketHandlerManagerReplaceWithClosedCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/packet_handler_map.go b/packet_handler_map.go index d2b4ff4e..7840202c 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -164,7 +164,7 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { // Depending on which side closed the connection, we need to: // * remote close: absorb delayed packets // * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost -func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) { +func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte) { var handler packetHandler if connClosePacket != nil { handler = newClosedLocalConn( diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index ba55a614..d41108cd 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -118,7 +118,7 @@ var _ = Describe("Packet Handler Map", func() { handler := NewMockPacketHandler(mockCtrl) connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) Expect(m.Add(connID, handler)).To(BeTrue()) - m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar")) + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, []byte("foobar")) h, ok := m.Get(connID) Expect(ok).To(BeTrue()) Expect(h).ToNot(Equal(handler)) @@ -141,7 +141,7 @@ var _ = Describe("Packet Handler Map", func() { handler := NewMockPacketHandler(mockCtrl) connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) Expect(m.Add(connID, handler)).To(BeTrue()) - m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil) + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, nil) h, ok := m.Get(connID) Expect(ok).To(BeTrue()) Expect(h).ToNot(Equal(handler))