diff --git a/conn_id_generator.go b/conn_id_generator.go index caec6b01..ab748f1a 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -16,6 +16,7 @@ type connIDGenerator struct { activeSrcConnIDs map[uint64]protocol.ConnectionID addConnectionID func(protocol.ConnectionID) [16]byte + removeConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID) queueControlFrame func(wire.Frame) } @@ -23,6 +24,7 @@ type connIDGenerator struct { func newConnIDGenerator( initialConnectionID protocol.ConnectionID, addConnectionID func(protocol.ConnectionID) [16]byte, + removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), queueControlFrame func(wire.Frame), ) *connIDGenerator { @@ -30,6 +32,7 @@ func newConnIDGenerator( connIDLen: initialConnectionID.Len(), activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), addConnectionID: addConnectionID, + removeConnectionID: removeConnectionID, retireConnectionID: retireConnectionID, queueControlFrame: queueControlFrame, } @@ -84,3 +87,9 @@ func (m *connIDGenerator) issueNewConnID() error { m.highestSeq++ return nil } + +func (m *connIDGenerator) RemoveAll() { + for _, connID := range m.activeSrcConnIDs { + m.removeConnectionID(connID) + } +} diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 5d3cec91..c775d694 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -12,6 +12,7 @@ var _ = Describe("Connection ID Generator", func() { var ( addedConnIDs []protocol.ConnectionID retiredConnIDs []protocol.ConnectionID + removedConnIDs []protocol.ConnectionID queuedFrames []wire.Frame g *connIDGenerator ) @@ -19,6 +20,7 @@ var _ = Describe("Connection ID Generator", func() { BeforeEach(func() { addedConnIDs = nil retiredConnIDs = nil + removedConnIDs = nil queuedFrames = nil g = newConnIDGenerator( protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}, @@ -27,6 +29,7 @@ var _ = Describe("Connection ID Generator", func() { l := uint8(len(addedConnIDs)) return [16]byte{l, l, l, l, l, l, l, l, l, l, l, l, l, l, l, l} }, + func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, ) @@ -76,6 +79,7 @@ var _ = Describe("Connection ID Generator", func() { It("retires the initial connection ID", func() { Expect(g.Retire(0)).To(Succeed()) + Expect(removedConnIDs).To(BeEmpty()) Expect(retiredConnIDs).To(HaveLen(1)) Expect(retiredConnIDs[0]).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7})) Expect(addedConnIDs).To(BeEmpty()) @@ -92,4 +96,16 @@ var _ = Describe("Connection ID Generator", func() { Expect(retiredConnIDs).To(HaveLen(1)) Expect(queuedFrames).To(HaveLen(1)) }) + + It("removes all connection IDs", func() { + Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) + Expect(queuedFrames).To(HaveLen(5)) + g.RemoveAll() + Expect(removedConnIDs).To(HaveLen(6)) // initial connection ID and newly issued ones + Expect(removedConnIDs).To(ContainElement(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7})) // initial connection ID + for _, f := range queuedFrames { + nf := f.(*wire.NewConnectionIDFrame) + Expect(removedConnIDs).To(ContainElement(nf.ConnectionID)) + } + }) })