retire the stateless reset token when the session is closed

This commit is contained in:
Marten Seemann 2019-10-26 17:56:37 +07:00
parent 9a81603b4a
commit ccb777b834
6 changed files with 62 additions and 0 deletions

View file

@ -18,6 +18,7 @@ type connIDManager struct {
packetsSinceLastChange uint64 packetsSinceLastChange uint64
addStatelessResetToken func([16]byte) addStatelessResetToken func([16]byte)
removeStatelessResetToken func([16]byte)
retireStatelessResetToken func([16]byte) retireStatelessResetToken func([16]byte)
queueControlFrame func(wire.Frame) queueControlFrame func(wire.Frame)
} }
@ -25,12 +26,14 @@ type connIDManager struct {
func newConnIDManager( func newConnIDManager(
initialDestConnID protocol.ConnectionID, initialDestConnID protocol.ConnectionID,
addStatelessResetToken func([16]byte), addStatelessResetToken func([16]byte),
removeStatelessResetToken func([16]byte),
retireStatelessResetToken func([16]byte), retireStatelessResetToken func([16]byte),
queueControlFrame func(wire.Frame), queueControlFrame func(wire.Frame),
) *connIDManager { ) *connIDManager {
return &connIDManager{ return &connIDManager{
activeConnectionID: initialDestConnID, activeConnectionID: initialDestConnID,
addStatelessResetToken: addStatelessResetToken, addStatelessResetToken: addStatelessResetToken,
removeStatelessResetToken: removeStatelessResetToken,
retireStatelessResetToken: retireStatelessResetToken, retireStatelessResetToken: retireStatelessResetToken,
queueControlFrame: queueControlFrame, queueControlFrame: queueControlFrame,
} }
@ -114,6 +117,12 @@ func (h *connIDManager) updateConnectionID() {
h.addStatelessResetToken(*h.activeStatelessResetToken) h.addStatelessResetToken(*h.activeStatelessResetToken)
} }
func (h *connIDManager) Close() {
if h.activeStatelessResetToken != nil {
h.removeStatelessResetToken(*h.activeStatelessResetToken)
}
}
// is called when the server performs a Retry // is called when the server performs a Retry
// and when the server changes the connection ID in the first Initial sent // and when the server changes the connection ID in the first Initial sent
func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) { func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {

View file

@ -13,6 +13,7 @@ var _ = Describe("Connection ID Manager", func() {
frameQueue []wire.Frame frameQueue []wire.Frame
tokenAdded *[16]byte tokenAdded *[16]byte
retiredTokens [][16]byte retiredTokens [][16]byte
removedTokens [][16]byte
) )
initialConnID := protocol.ConnectionID{1, 1, 1, 1} initialConnID := protocol.ConnectionID{1, 1, 1, 1}
@ -20,9 +21,11 @@ var _ = Describe("Connection ID Manager", func() {
frameQueue = nil frameQueue = nil
tokenAdded = nil tokenAdded = nil
retiredTokens = nil retiredTokens = nil
removedTokens = nil
m = newConnIDManager( m = newConnIDManager(
initialConnID, initialConnID,
func(token [16]byte) { tokenAdded = &token }, func(token [16]byte) { tokenAdded = &token },
func(token [16]byte) { removedTokens = append(removedTokens, token) },
func(token [16]byte) { retiredTokens = append(retiredTokens, token) }, func(token [16]byte) { retiredTokens = append(retiredTokens, token) },
func(f wire.Frame, func(f wire.Frame,
) { ) {
@ -209,4 +212,20 @@ var _ = Describe("Connection ID Manager", func() {
Expect(retiredTokens).To(HaveLen(1)) Expect(retiredTokens).To(HaveLen(1))
Expect(retiredTokens[0]).To(Equal([16]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})) Expect(retiredTokens[0]).To(Equal([16]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}))
}) })
It("removes the currently active stateless reset token when it is closed", func() {
m.Close()
Expect(retiredTokens).To(BeEmpty())
Expect(removedTokens).To(BeEmpty())
Expect(m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ConnectionID{1, 2, 3, 4},
StatelessResetToken: [16]byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
})).To(Succeed())
Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4}))
m.Close()
Expect(retiredTokens).To(BeEmpty())
Expect(removedTokens).To(HaveLen(1))
Expect(removedTokens[0]).To(Equal([16]byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}))
})
}) })

View file

@ -110,6 +110,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0)
} }
// RemoveResetToken mocks base method
func (m *MockPacketHandlerManager) RemoveResetToken(arg0 [16]byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RemoveResetToken", arg0)
}
// RemoveResetToken indicates an expected call of RemoveResetToken
func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0)
}
// ReplaceWithClosed mocks base method // ReplaceWithClosed mocks base method
func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -58,6 +58,18 @@ func (mr *MockSessionRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionRunner)(nil).Remove), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionRunner)(nil).Remove), arg0)
} }
// RemoveResetToken mocks base method
func (m *MockSessionRunner) RemoveResetToken(arg0 [16]byte) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RemoveResetToken", arg0)
}
// RemoveResetToken indicates an expected call of RemoveResetToken
func (mr *MockSessionRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockSessionRunner)(nil).RemoveResetToken), arg0)
}
// ReplaceWithClosed mocks base method // ReplaceWithClosed mocks base method
func (m *MockSessionRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { func (m *MockSessionRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -134,6 +134,12 @@ func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler)
h.mutex.Unlock() h.mutex.Unlock()
} }
func (h *packetHandlerMap) RemoveResetToken(token [16]byte) {
h.mutex.Lock()
delete(h.resetTokens, token)
h.mutex.Unlock()
}
func (h *packetHandlerMap) RetireResetToken(token [16]byte) { func (h *packetHandlerMap) RetireResetToken(token [16]byte) {
time.AfterFunc(h.deleteRetiredSessionsAfter, func() { time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
h.mutex.Lock() h.mutex.Lock()

View file

@ -77,6 +77,7 @@ type sessionRunner interface {
Remove(protocol.ConnectionID) Remove(protocol.ConnectionID)
ReplaceWithClosed(protocol.ConnectionID, packetHandler) ReplaceWithClosed(protocol.ConnectionID, packetHandler)
AddResetToken([16]byte, packetHandler) AddResetToken([16]byte, packetHandler)
RemoveResetToken([16]byte)
RetireResetToken([16]byte) RetireResetToken([16]byte)
} }
@ -209,6 +210,7 @@ var newSession = func(
s.connIDManager = newConnIDManager( s.connIDManager = newConnIDManager(
destConnID, destConnID,
func(token [16]byte) { runner.AddResetToken(token, s) }, func(token [16]byte) { runner.AddResetToken(token, s) },
runner.RemoveResetToken,
runner.RetireResetToken, runner.RetireResetToken,
s.queueControlFrame, s.queueControlFrame,
) )
@ -282,6 +284,7 @@ var newClientSession = func(
s.connIDManager = newConnIDManager( s.connIDManager = newConnIDManager(
destConnID, destConnID,
func(token [16]byte) { runner.AddResetToken(token, s) }, func(token [16]byte) { runner.AddResetToken(token, s) },
runner.RemoveResetToken,
runner.RetireResetToken, runner.RetireResetToken,
s.queueControlFrame, s.queueControlFrame,
) )
@ -985,6 +988,7 @@ func (s *session) handleCloseError(closeErr closeError) {
} }
s.streamsMap.CloseWithError(quicErr) s.streamsMap.CloseWithError(quicErr)
s.connIDManager.Close()
// If this is a remote close we're done here // If this is a remote close we're done here
if closeErr.remote { if closeErr.remote {