mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
replace all connection IDs at the same time when connection is closed
This commit is contained in:
parent
635dc90475
commit
c3ab9c4ea9
7 changed files with 34 additions and 25 deletions
|
@ -20,7 +20,7 @@ type connIDGenerator struct {
|
||||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
||||||
removeConnectionID func(protocol.ConnectionID)
|
removeConnectionID func(protocol.ConnectionID)
|
||||||
retireConnectionID func(protocol.ConnectionID)
|
retireConnectionID func(protocol.ConnectionID)
|
||||||
replaceWithClosed func(protocol.ConnectionID, packetHandler)
|
replaceWithClosed func([]protocol.ConnectionID, packetHandler)
|
||||||
queueControlFrame func(wire.Frame)
|
queueControlFrame func(wire.Frame)
|
||||||
|
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
|
@ -33,7 +33,7 @@ func newConnIDGenerator(
|
||||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
||||||
removeConnectionID func(protocol.ConnectionID),
|
removeConnectionID func(protocol.ConnectionID),
|
||||||
retireConnectionID func(protocol.ConnectionID),
|
retireConnectionID func(protocol.ConnectionID),
|
||||||
replaceWithClosed func(protocol.ConnectionID, packetHandler),
|
replaceWithClosed func([]protocol.ConnectionID, packetHandler),
|
||||||
queueControlFrame func(wire.Frame),
|
queueControlFrame func(wire.Frame),
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) *connIDGenerator {
|
) *connIDGenerator {
|
||||||
|
@ -131,10 +131,12 @@ func (m *connIDGenerator) RemoveAll() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
|
func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
|
||||||
|
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
||||||
if m.initialClientDestConnID != nil {
|
if m.initialClientDestConnID != nil {
|
||||||
m.replaceWithClosed(m.initialClientDestConnID, handler)
|
connIDs = append(connIDs, m.initialClientDestConnID)
|
||||||
}
|
}
|
||||||
for _, connID := range m.activeSrcConnIDs {
|
for _, connID := range m.activeSrcConnIDs {
|
||||||
m.replaceWithClosed(connID, handler)
|
connIDs = append(connIDs, connID)
|
||||||
}
|
}
|
||||||
|
m.replaceWithClosed(connIDs, handler)
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,11 @@ var _ = Describe("Connection ID Generator", func() {
|
||||||
connIDToToken,
|
connIDToToken,
|
||||||
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
|
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
|
||||||
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
|
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
|
||||||
func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h },
|
func(cs []protocol.ConnectionID, h packetHandler) {
|
||||||
|
for _, c := range cs {
|
||||||
|
replacedWithClosed[string(c)] = h
|
||||||
|
}
|
||||||
|
},
|
||||||
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
|
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
|
||||||
protocol.VersionDraft29,
|
protocol.VersionDraft29,
|
||||||
)
|
)
|
||||||
|
|
|
@ -95,7 +95,7 @@ type connRunner interface {
|
||||||
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
|
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
|
||||||
Retire(protocol.ConnectionID)
|
Retire(protocol.ConnectionID)
|
||||||
Remove(protocol.ConnectionID)
|
Remove(protocol.ConnectionID)
|
||||||
ReplaceWithClosed(protocol.ConnectionID, packetHandler)
|
ReplaceWithClosed([]protocol.ConnectionID, packetHandler)
|
||||||
AddResetToken(protocol.StatelessResetToken, packetHandler)
|
AddResetToken(protocol.StatelessResetToken, packetHandler)
|
||||||
RemoveResetToken(protocol.StatelessResetToken)
|
RemoveResetToken(protocol.StatelessResetToken)
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,8 +72,11 @@ var _ = Describe("Connection", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
expectReplaceWithClosed := func() {
|
expectReplaceWithClosed := func() {
|
||||||
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1)
|
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
|
||||||
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
Expect(connIDs).To(ContainElement(srcConnID))
|
||||||
|
if len(connIDs) > 1 {
|
||||||
|
Expect(connIDs).To(ContainElement(clientDestConnID))
|
||||||
|
}
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
|
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
|
||||||
s.shutdown()
|
s.shutdown()
|
||||||
Eventually(areClosedConnsRunning).Should(BeFalse())
|
Eventually(areClosedConnsRunning).Should(BeFalse())
|
||||||
|
@ -330,10 +333,8 @@ var _ = Describe("Connection", func() {
|
||||||
ErrorMessage: "foobar",
|
ErrorMessage: "foobar",
|
||||||
}
|
}
|
||||||
streamManager.EXPECT().CloseWithError(expectedErr)
|
streamManager.EXPECT().CloseWithError(expectedErr)
|
||||||
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
|
||||||
})
|
|
||||||
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
||||||
})
|
})
|
||||||
cryptoSetup.EXPECT().Close()
|
cryptoSetup.EXPECT().Close()
|
||||||
|
@ -361,10 +362,8 @@ var _ = Describe("Connection", func() {
|
||||||
ErrorMessage: "foobar",
|
ErrorMessage: "foobar",
|
||||||
}
|
}
|
||||||
streamManager.EXPECT().CloseWithError(testErr)
|
streamManager.EXPECT().CloseWithError(testErr)
|
||||||
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
|
||||||
})
|
|
||||||
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
|
||||||
})
|
})
|
||||||
cryptoSetup.EXPECT().Close()
|
cryptoSetup.EXPECT().Close()
|
||||||
|
@ -2433,7 +2432,7 @@ var _ = Describe("Client Connection", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
expectReplaceWithClosed := func() {
|
expectReplaceWithClosed := func() {
|
||||||
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
|
||||||
s.shutdown()
|
s.shutdown()
|
||||||
Eventually(areClosedConnsRunning).Should(BeFalse())
|
Eventually(areClosedConnsRunning).Should(BeFalse())
|
||||||
})
|
})
|
||||||
|
@ -2767,7 +2766,7 @@ var _ = Describe("Client Connection", func() {
|
||||||
|
|
||||||
expectClose := func(applicationClose bool) {
|
expectClose := func(applicationClose bool) {
|
||||||
if !closed {
|
if !closed {
|
||||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
|
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
|
||||||
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
|
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
|
||||||
s.shutdown()
|
s.shutdown()
|
||||||
})
|
})
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceWithClosed mocks base method.
|
// ReplaceWithClosed mocks base method.
|
||||||
func (m *MockConnRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
|
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,7 +139,7 @@ func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceWithClosed mocks base method.
|
// ReplaceWithClosed mocks base method.
|
||||||
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
|
m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -219,18 +219,22 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) {
|
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handler packetHandler) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
h.handlers[string(id)] = handler
|
for _, id := range ids {
|
||||||
|
h.handlers[string(id)] = handler
|
||||||
|
}
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id)
|
h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)
|
||||||
|
|
||||||
time.AfterFunc(h.deleteRetiredConnsAfter, func() {
|
time.AfterFunc(h.deleteRetiredConnsAfter, func() {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
handler.shutdown()
|
handler.shutdown()
|
||||||
delete(h.handlers, string(id))
|
for _, id := range ids {
|
||||||
|
delete(h.handlers, string(id))
|
||||||
|
}
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id)
|
h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue