store 0-RTT queues in the packet handler map

This prevents a race condition between receiving of 0-RTT packets and
the creation of new session.
This commit is contained in:
Marten Seemann 2021-03-08 14:11:40 +08:00
parent ecc86aa1ab
commit 2bd316b89e
7 changed files with 242 additions and 302 deletions

View file

@ -26,6 +26,38 @@ func (e statelessResetErr) Error() string {
return fmt.Sprintf("received a stateless reset with token %x", e.token)
}
type zeroRTTQueue struct {
queue []*receivedPacket
retireTimer *time.Timer
}
var _ packetHandler = &zeroRTTQueue{}
func (h *zeroRTTQueue) handlePacket(p *receivedPacket) {
if len(h.queue) < protocol.Max0RTTQueueLen {
h.queue = append(h.queue, p)
}
}
func (h *zeroRTTQueue) shutdown() {}
func (h *zeroRTTQueue) destroy(error) {}
func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient }
func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) {
for _, p := range h.queue {
sess.handlePacket(p)
}
}
func (h *zeroRTTQueue) Clear() {
for _, p := range h.queue {
p.buffer.Release()
}
}
type packetHandlerMapEntry struct {
packetHandler packetHandler
is0RTTQueue bool
}
// The packetHandlerMap stores packetHandlers, identified by connection ID.
// It is used:
// * by the server to store sessions
@ -36,14 +68,16 @@ type packetHandlerMap struct {
conn connection
connIDLen int
handlers map[string] /* string(ConnectionID)*/ packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
server unknownPacketHandler
handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
server unknownPacketHandler
numZeroRTTEntries int
listening chan struct{} // is closed when listen returns
closed bool
deleteRetiredSessionsAfter time.Duration
zeroRTTQueueDuration time.Duration
statelessResetEnabled bool
statelessResetMutex sync.Mutex
@ -107,9 +141,10 @@ func newPacketHandlerMap(
conn: conn,
connIDLen: connIDLen,
listening: make(chan struct{}),
handlers: make(map[string]packetHandler),
handlers: make(map[string]packetHandlerMapEntry),
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
statelessResetEnabled: len(statelessResetKey) > 0,
statelessResetHasher: hmac.New(sha256.New, statelessResetKey),
tracer: tracer,
@ -157,7 +192,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
return false
}
h.handlers[string(id)] = handler
h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler}
h.logger.Debugf("Adding connection ID %s.", id)
return true
}
@ -166,14 +201,25 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
h.mutex.Lock()
defer h.mutex.Unlock()
if _, ok := h.handlers[string(clientDestConnID)]; ok {
h.logger.Debugf("Not adding connection ID %s for a new session, as it already exists.", clientDestConnID)
return false
var q *zeroRTTQueue
if entry, ok := h.handlers[string(clientDestConnID)]; ok {
if !entry.is0RTTQueue {
h.logger.Debugf("Not adding connection ID %s for a new session, as it already exists.", clientDestConnID)
return false
}
q = entry.packetHandler.(*zeroRTTQueue)
q.retireTimer.Stop()
h.numZeroRTTEntries--
if h.numZeroRTTEntries < 0 {
panic("number of 0-RTT queues < 0")
}
}
sess := fn()
h.handlers[string(clientDestConnID)] = sess
h.handlers[string(newConnID)] = sess
if q != nil {
q.EnqueueAll(sess)
}
h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess}
h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess}
h.logger.Debugf("Adding connection IDs %s and %s for a new session.", clientDestConnID, newConnID)
return true
}
@ -197,7 +243,7 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) {
h.mutex.Lock()
h.handlers[string(id)] = handler
h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler}
h.mutex.Unlock()
h.logger.Debugf("Replacing session for connection ID %s with a closed session.", id)
@ -236,14 +282,14 @@ func (h *packetHandlerMap) CloseServer() {
}
h.server = nil
var wg sync.WaitGroup
for _, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer {
for _, entry := range h.handlers {
if entry.packetHandler.getPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(handler packetHandler) {
// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
handler.shutdown()
wg.Done()
}(handler)
}(entry.packetHandler)
}
}
h.mutex.Unlock()
@ -268,12 +314,12 @@ func (h *packetHandlerMap) close(e error) error {
}
var wg sync.WaitGroup
for _, handler := range h.handlers {
for _, entry := range h.handlers {
wg.Add(1)
go func(handler packetHandler) {
handler.destroy(e)
wg.Done()
}(handler)
}(entry.packetHandler)
}
if h.server != nil {
@ -319,9 +365,16 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
return
}
if handler, ok := h.handlers[string(connID)]; ok { // existing session
handler.handlePacket(p)
return
if entry, ok := h.handlers[string(connID)]; ok {
if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue
if wire.Is0RTTPacket(p.data) {
entry.packetHandler.handlePacket(p)
return
}
} else { // existing session
entry.packetHandler.handlePacket(p)
return
}
}
if p.data[0]&0x80 == 0 {
go h.maybeSendStatelessReset(p, connID)
@ -331,6 +384,36 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
return
}
if wire.Is0RTTPacket(p.data) {
if h.numZeroRTTEntries >= protocol.Max0RTTQueues {
return
}
h.numZeroRTTEntries++
queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)}
h.handlers[string(connID)] = packetHandlerMapEntry{
packetHandler: queue,
is0RTTQueue: true,
}
queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() {
h.mutex.Lock()
defer h.mutex.Unlock()
// The entry might have been replaced by an actual session.
// Only delete it if it's still a 0-RTT queue.
if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue {
delete(h.handlers, string(connID))
h.numZeroRTTEntries--
if h.numZeroRTTEntries < 0 {
panic("number of 0-RTT queues < 0")
}
entry.packetHandler.(*zeroRTTQueue).Clear()
if h.logger.Debug() {
h.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
})
queue.handlePacket(p)
return
}
h.server.handlePacket(p)
}