implement replacing with a closed session for all active connection IDs

This commit is contained in:
Marten Seemann 2019-11-03 10:05:23 +07:00
parent f948165824
commit 5d0d0d0679
2 changed files with 35 additions and 9 deletions

View file

@ -18,6 +18,7 @@ type connIDGenerator struct {
addConnectionID func(protocol.ConnectionID) [16]byte addConnectionID func(protocol.ConnectionID) [16]byte
removeConnectionID func(protocol.ConnectionID) removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func(protocol.ConnectionID, packetHandler)
queueControlFrame func(wire.Frame) queueControlFrame func(wire.Frame)
} }
@ -26,6 +27,7 @@ func newConnIDGenerator(
addConnectionID func(protocol.ConnectionID) [16]byte, addConnectionID func(protocol.ConnectionID) [16]byte,
removeConnectionID func(protocol.ConnectionID), removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func(protocol.ConnectionID, packetHandler),
queueControlFrame func(wire.Frame), queueControlFrame func(wire.Frame),
) *connIDGenerator { ) *connIDGenerator {
m := &connIDGenerator{ m := &connIDGenerator{
@ -34,6 +36,7 @@ func newConnIDGenerator(
addConnectionID: addConnectionID, addConnectionID: addConnectionID,
removeConnectionID: removeConnectionID, removeConnectionID: removeConnectionID,
retireConnectionID: retireConnectionID, retireConnectionID: retireConnectionID,
replaceWithClosed: replaceWithClosed,
queueControlFrame: queueControlFrame, queueControlFrame: queueControlFrame,
} }
m.activeSrcConnIDs[0] = initialConnectionID m.activeSrcConnIDs[0] = initialConnectionID
@ -93,3 +96,9 @@ func (m *connIDGenerator) RemoveAll() {
m.removeConnectionID(connID) m.removeConnectionID(connID)
} }
} }
func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
for _, connID := range m.activeSrcConnIDs {
m.replaceWithClosed(connID, handler)
}
}

View file

@ -13,17 +13,20 @@ var _ = Describe("Connection ID Generator", func() {
addedConnIDs []protocol.ConnectionID addedConnIDs []protocol.ConnectionID
retiredConnIDs []protocol.ConnectionID retiredConnIDs []protocol.ConnectionID
removedConnIDs []protocol.ConnectionID removedConnIDs []protocol.ConnectionID
replacedWithClosed map[string]packetHandler
queuedFrames []wire.Frame queuedFrames []wire.Frame
g *connIDGenerator g *connIDGenerator
) )
initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}
BeforeEach(func() { BeforeEach(func() {
addedConnIDs = nil addedConnIDs = nil
retiredConnIDs = nil retiredConnIDs = nil
removedConnIDs = nil removedConnIDs = nil
queuedFrames = nil queuedFrames = nil
replacedWithClosed = make(map[string]packetHandler)
g = newConnIDGenerator( g = newConnIDGenerator(
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}, initialConnID,
func(c protocol.ConnectionID) [16]byte { func(c protocol.ConnectionID) [16]byte {
addedConnIDs = append(addedConnIDs, c) addedConnIDs = append(addedConnIDs, c)
l := uint8(len(addedConnIDs)) l := uint8(len(addedConnIDs))
@ -31,6 +34,7 @@ var _ = Describe("Connection ID Generator", func() {
}, },
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h },
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
) )
}) })
@ -81,7 +85,7 @@ var _ = Describe("Connection ID Generator", func() {
Expect(g.Retire(0)).To(Succeed()) Expect(g.Retire(0)).To(Succeed())
Expect(removedConnIDs).To(BeEmpty()) Expect(removedConnIDs).To(BeEmpty())
Expect(retiredConnIDs).To(HaveLen(1)) Expect(retiredConnIDs).To(HaveLen(1))
Expect(retiredConnIDs[0]).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7})) Expect(retiredConnIDs[0]).To(Equal(initialConnID))
Expect(addedConnIDs).To(BeEmpty()) Expect(addedConnIDs).To(BeEmpty())
}) })
@ -102,10 +106,23 @@ var _ = Describe("Connection ID Generator", func() {
Expect(queuedFrames).To(HaveLen(5)) Expect(queuedFrames).To(HaveLen(5))
g.RemoveAll() g.RemoveAll()
Expect(removedConnIDs).To(HaveLen(6)) // initial connection ID and newly issued ones Expect(removedConnIDs).To(HaveLen(6)) // initial connection ID and newly issued ones
Expect(removedConnIDs).To(ContainElement(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7})) // initial connection ID Expect(removedConnIDs).To(ContainElement(initialConnID))
for _, f := range queuedFrames { for _, f := range queuedFrames {
nf := f.(*wire.NewConnectionIDFrame) nf := f.(*wire.NewConnectionIDFrame)
Expect(removedConnIDs).To(ContainElement(nf.ConnectionID)) Expect(removedConnIDs).To(ContainElement(nf.ConnectionID))
} }
}) })
It("replaces with a closed session for all connection IDs", func() {
Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
Expect(queuedFrames).To(HaveLen(5))
sess := NewMockPacketHandler(mockCtrl)
g.ReplaceWithClosed(sess)
Expect(replacedWithClosed).To(HaveLen(6)) // initial connection ID and newly issued ones
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess))
for _, f := range queuedFrames {
nf := f.(*wire.NewConnectionIDFrame)
Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess))
}
})
}) })