use type assertions to identify 0-RTT queues in packet handler map

This commit is contained in:
Marten Seemann 2022-08-21 12:44:40 +03:00
parent 57650fca7a
commit 19c6a1b252

View file

@ -57,11 +57,6 @@ type rawConn interface {
io.Closer
}
type packetHandlerMapEntry struct {
packetHandler packetHandler
is0RTTQueue bool
}
// The packetHandlerMap stores packetHandlers, identified by connection ID.
// It is used:
// * by the server to store connections
@ -72,7 +67,7 @@ type packetHandlerMap struct {
conn rawConn
connIDLen int
handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry
handlers map[string] /* string(ConnectionID)*/ packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
server unknownPacketHandler
numZeroRTTEntries int
@ -151,7 +146,7 @@ func newPacketHandlerMap(
conn: conn,
connIDLen: connIDLen,
listening: make(chan struct{}),
handlers: make(map[string]packetHandlerMapEntry),
handlers: make(map[string]packetHandler),
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
@ -202,7 +197,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
return false
}
h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler}
h.handlers[string(id)] = handler
h.logger.Debugf("Adding connection ID %s.", id)
return true
}
@ -212,24 +207,24 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
defer h.mutex.Unlock()
var q *zeroRTTQueue
if entry, ok := h.handlers[string(clientDestConnID)]; ok {
if !entry.is0RTTQueue {
if handler, ok := h.handlers[string(clientDestConnID)]; ok {
q, ok = handler.(*zeroRTTQueue)
if !ok {
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
q = entry.packetHandler.(*zeroRTTQueue)
q.retireTimer.Stop()
h.numZeroRTTEntries--
if h.numZeroRTTEntries < 0 {
panic("number of 0-RTT queues < 0")
}
}
sess := fn()
conn := fn()
if q != nil {
q.EnqueueAll(sess)
q.EnqueueAll(conn)
}
h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess}
h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess}
h.handlers[string(clientDestConnID)] = conn
h.handlers[string(newConnID)] = conn
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
return true
}
@ -253,7 +248,7 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) {
h.mutex.Lock()
h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler}
h.handlers[string(id)] = handler
h.mutex.Unlock()
h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id)
@ -292,14 +287,14 @@ func (h *packetHandlerMap) CloseServer() {
}
h.server = nil
var wg sync.WaitGroup
for _, entry := range h.handlers {
if entry.packetHandler.getPerspective() == protocol.PerspectiveServer {
for _, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(handler packetHandler) {
// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
handler.shutdown()
wg.Done()
}(entry.packetHandler)
}(handler)
}
}
h.mutex.Unlock()
@ -324,12 +319,12 @@ func (h *packetHandlerMap) close(e error) error {
}
var wg sync.WaitGroup
for _, entry := range h.handlers {
for _, handler := range h.handlers {
wg.Add(1)
go func(handler packetHandler) {
handler.destroy(e)
wg.Done()
}(entry.packetHandler)
}(handler)
}
if h.server != nil {
@ -379,14 +374,14 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
return
}
if entry, ok := h.handlers[string(connID)]; ok {
if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue
if handler, ok := h.handlers[string(connID)]; ok {
if ha, ok := handler.(*zeroRTTQueue); ok { // only enqueue 0-RTT packets in the 0-RTT queue
if wire.Is0RTTPacket(p.data) {
entry.packetHandler.handlePacket(p)
ha.handlePacket(p)
return
}
} else { // existing connection
entry.packetHandler.handlePacket(p)
handler.handlePacket(p)
return
}
}
@ -404,24 +399,23 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
}
h.numZeroRTTEntries++
queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)}
h.handlers[string(connID)] = packetHandlerMapEntry{
packetHandler: queue,
is0RTTQueue: true,
}
h.handlers[string(connID)] = queue
queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() {
h.mutex.Lock()
defer h.mutex.Unlock()
// The entry might have been replaced by an actual connection.
// Only delete it if it's still a 0-RTT queue.
if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue {
delete(h.handlers, string(connID))
h.numZeroRTTEntries--
if h.numZeroRTTEntries < 0 {
panic("number of 0-RTT queues < 0")
}
entry.packetHandler.(*zeroRTTQueue).Clear()
if h.logger.Debug() {
h.logger.Debugf("Removing 0-RTT queue for %s.", connID)
if handler, ok := h.handlers[string(connID)]; ok {
if q, ok := handler.(*zeroRTTQueue); ok {
delete(h.handlers, string(connID))
h.numZeroRTTEntries--
if h.numZeroRTTEntries < 0 {
panic("number of 0-RTT queues < 0")
}
q.Clear()
if h.logger.Debug() {
h.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
}
})