mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
replace all connection IDs at the same time when connection is closed
This commit is contained in:
parent
635dc90475
commit
c3ab9c4ea9
7 changed files with 34 additions and 25 deletions
|
@ -20,7 +20,7 @@ type connIDGenerator struct {
|
|||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
||||
removeConnectionID func(protocol.ConnectionID)
|
||||
retireConnectionID func(protocol.ConnectionID)
|
||||
replaceWithClosed func(protocol.ConnectionID, packetHandler)
|
||||
replaceWithClosed func([]protocol.ConnectionID, packetHandler)
|
||||
queueControlFrame func(wire.Frame)
|
||||
|
||||
version protocol.VersionNumber
|
||||
|
@ -33,7 +33,7 @@ func newConnIDGenerator(
|
|||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
||||
removeConnectionID func(protocol.ConnectionID),
|
||||
retireConnectionID func(protocol.ConnectionID),
|
||||
replaceWithClosed func(protocol.ConnectionID, packetHandler),
|
||||
replaceWithClosed func([]protocol.ConnectionID, packetHandler),
|
||||
queueControlFrame func(wire.Frame),
|
||||
version protocol.VersionNumber,
|
||||
) *connIDGenerator {
|
||||
|
@ -131,10 +131,12 @@ func (m *connIDGenerator) RemoveAll() {
|
|||
}
|
||||
|
||||
func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
|
||||
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.replaceWithClosed(m.initialClientDestConnID, handler)
|
||||
connIDs = append(connIDs, m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
m.replaceWithClosed(connID, handler)
|
||||
connIDs = append(connIDs, connID)
|
||||
}
|
||||
m.replaceWithClosed(connIDs, handler)
|
||||
}
|
||||
|
|
|
@ -40,7 +40,11 @@ var _ = Describe("Connection ID Generator", func() {
|
|||
connIDToToken,
|
||||
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
|
||||
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
|
||||
func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h },
|
||||
func(cs []protocol.ConnectionID, h packetHandler) {
|
||||
for _, c := range cs {
|
||||
replacedWithClosed[string(c)] = h
|
||||
}
|
||||
},
|
||||
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
|
||||
protocol.VersionDraft29,
|
||||
)
|
||||
|
|
|
@ -95,7 +95,7 @@ type connRunner interface {
|
|||
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
|
||||
Retire(protocol.ConnectionID)
|
||||
Remove(protocol.ConnectionID)
|
||||
ReplaceWithClosed(protocol.ConnectionID, packetHandler)
|
||||
ReplaceWithClosed([]protocol.ConnectionID, packetHandler)
|
||||
AddResetToken(protocol.StatelessResetToken, packetHandler)
|
||||
RemoveResetToken(protocol.StatelessResetToken)
|
||||
}
|
||||
|
|
|
@ -72,8 +72,11 @@ var _ = Describe("Connection", func() {
|
|||
}
|
||||
|
||||
expectReplaceWithClosed := func() {
|
||||
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1)
|
||||
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
|
||||
Expect(connIDs).To(ContainElement(srcConnID))
|
||||
if len(connIDs) > 1 {
|
||||
Expect(connIDs).To(ContainElement(clientDestConnID))
|
||||
}
|
||||
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
|
||||
s.shutdown()
|
||||
Eventually(areClosedConnsRunning).Should(BeFalse())
|
||||
|
@ -330,10 +333,8 @@ var _ = Describe("Connection", func() {
|
|||
ErrorMessage: "foobar",
|
||||
}
|
||||
streamManager.EXPECT().CloseWithError(expectedErr)
|
||||
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
||||
})
|
||||
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
|
||||
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
|
||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
||||
})
|
||||
cryptoSetup.EXPECT().Close()
|
||||
|
@ -361,10 +362,8 @@ var _ = Describe("Connection", func() {
|
|||
ErrorMessage: "foobar",
|
||||
}
|
||||
streamManager.EXPECT().CloseWithError(testErr)
|
||||
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
||||
})
|
||||
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
|
||||
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
|
||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
||||
})
|
||||
cryptoSetup.EXPECT().Close()
|
||||
|
@ -2433,7 +2432,7 @@ var _ = Describe("Client Connection", func() {
|
|||
}
|
||||
|
||||
expectReplaceWithClosed := func() {
|
||||
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
||||
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
|
||||
s.shutdown()
|
||||
Eventually(areClosedConnsRunning).Should(BeFalse())
|
||||
})
|
||||
|
@ -2767,7 +2766,7 @@ var _ = Describe("Client Connection", func() {
|
|||
|
||||
expectClose := func(applicationClose bool) {
|
||||
if !closed {
|
||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
|
||||
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
|
||||
s.shutdown()
|
||||
})
|
||||
|
|
|
@ -99,7 +99,7 @@ func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock
|
|||
}
|
||||
|
||||
// ReplaceWithClosed mocks base method.
|
||||
func (m *MockConnRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
||||
func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
|
||||
}
|
||||
|
|
|
@ -139,7 +139,7 @@ func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{
|
|||
}
|
||||
|
||||
// ReplaceWithClosed mocks base method.
|
||||
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
||||
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
|
||||
}
|
||||
|
|
|
@ -219,18 +219,22 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
|
|||
})
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) {
|
||||
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handler packetHandler) {
|
||||
h.mutex.Lock()
|
||||
for _, id := range ids {
|
||||
h.handlers[string(id)] = handler
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id)
|
||||
h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)
|
||||
|
||||
time.AfterFunc(h.deleteRetiredConnsAfter, func() {
|
||||
h.mutex.Lock()
|
||||
handler.shutdown()
|
||||
for _, id := range ids {
|
||||
delete(h.handlers, string(id))
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id)
|
||||
h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue