don't use timers to clean up 0-RTT queues

This commit is contained in:
Marten Seemann 2023-04-26 13:43:37 +02:00
parent 58487803d3
commit 6617f3d39a
2 changed files with 111 additions and 53 deletions

View file

@ -56,8 +56,8 @@ type quicConn interface {
}
type zeroRTTQueue struct {
packets []*receivedPacket
retireTimer *time.Timer
packets []*receivedPacket
expiration time.Time
}
// A Listener of QUIC
@ -80,8 +80,8 @@ type baseServer struct {
receivedPackets chan *receivedPacket
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
zeroRTTQueueDuration time.Duration
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
// set as a member, so they can be set in the tests
newConn func(
@ -210,19 +210,18 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
return nil, err
}
s := &baseServer{
conn: c,
tlsConf: tlsConf,
config: config,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
tokenGenerator: tokenGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
newConn: newConnection,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
conn: c,
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
newConn: newConnection,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
@ -323,6 +322,10 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
}
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ {
if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) {
defer s.cleanupZeroRTTQueues(p.rcvTime)
}
if wire.IsVersionNegotiationPacket(p.data) {
s.logger.Debugf("Dropping Version Negotiation packet.")
if s.config.Tracer != nil {
@ -422,9 +425,6 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
return true
}
s.mutex.Lock()
defer s.mutex.Unlock()
if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.config.Tracer != nil {
@ -444,26 +444,40 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
}
queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)}
queue.packets[0] = p
queue.retireTimer = time.AfterFunc(s.zeroRTTQueueDuration, func() {
s.mutex.Lock()
defer s.mutex.Unlock()
// The entry might have been replaced by an actual connection.
// Only delete it if it's still a 0-RTT queue.
if q, ok := s.zeroRTTQueues[connID]; ok {
for _, p := range q.packets {
// TODO: trace
p.buffer.Release()
}
delete(s.zeroRTTQueues, connID)
if s.logger.Debug() {
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
})
expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration)
queue.expiration = expiration
if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) {
s.nextZeroRTTCleanup = expiration
}
s.zeroRTTQueues[connID] = queue
return true
}
func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
// Iterate over all queues to find those that are expired.
// This is ok since we're placing a pretty low limit on the number of queues.
var nextCleanup time.Time
for connID, q := range s.zeroRTTQueues {
if q.expiration.After(now) {
if nextCleanup.IsZero() || nextCleanup.After(q.expiration) {
nextCleanup = q.expiration
}
continue
}
for _, p := range q.packets {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
p.buffer.Release()
}
delete(s.zeroRTTQueues, connID)
if s.logger.Debug() {
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
s.nextZeroRTTCleanup = nextCleanup
}
// validateToken returns false if:
// - address is invalid
// - token is expired
@ -538,10 +552,8 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
}
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
s.mutex.Lock()
// Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID)
s.mutex.Unlock()
go func() {
defer p.buffer.Release()
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
@ -603,15 +615,12 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
)
conn.handlePacket(p)
s.mutex.Lock()
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
q.retireTimer.Stop()
for _, p := range q.packets {
conn.handlePacket(p)
}
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
s.mutex.Unlock()
return conn
}); !added {