From 6617f3d39a63cc3c732618e5a2a459eaf99a3f2c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 26 Apr 2023 13:43:37 +0200 Subject: [PATCH] don't use timers to clean up 0-RTT queues --- server.go | 93 +++++++++++++++++++++++++++----------------------- server_test.go | 71 ++++++++++++++++++++++++++++++++------ 2 files changed, 111 insertions(+), 53 deletions(-) diff --git a/server.go b/server.go index 52bdba51..d984beea 100644 --- a/server.go +++ b/server.go @@ -56,8 +56,8 @@ type quicConn interface { } type zeroRTTQueue struct { - packets []*receivedPacket - retireTimer *time.Timer + packets []*receivedPacket + expiration time.Time } // A Listener of QUIC @@ -80,8 +80,8 @@ type baseServer struct { receivedPackets chan *receivedPacket - zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true - zeroRTTQueueDuration time.Duration + nextZeroRTTCleanup time.Time + zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true // set as a member, so they can be set in the tests newConn func( @@ -210,19 +210,18 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl return nil, err } s := &baseServer{ - conn: c, - tlsConf: tlsConf, - config: config, - zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, - tokenGenerator: tokenGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), - newConn: newConnection, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, + conn: c, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), + newConn: newConnection, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} @@ -323,6 +322,10 @@ func (s *baseServer) handlePacket(p *receivedPacket) { } func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { + if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) { + defer s.cleanupZeroRTTQueues(p.rcvTime) + } + if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") if s.config.Tracer != nil { @@ -422,9 +425,6 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { return true } - s.mutex.Lock() - defer s.mutex.Unlock() - if q, ok := s.zeroRTTQueues[connID]; ok { if len(q.packets) >= protocol.Max0RTTQueueLen { if s.config.Tracer != nil { @@ -444,26 +444,40 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { } queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)} queue.packets[0] = p - queue.retireTimer = time.AfterFunc(s.zeroRTTQueueDuration, func() { - s.mutex.Lock() - defer s.mutex.Unlock() - // The entry might have been replaced by an actual connection. - // Only delete it if it's still a 0-RTT queue. - if q, ok := s.zeroRTTQueues[connID]; ok { - for _, p := range q.packets { - // TODO: trace - p.buffer.Release() - } - delete(s.zeroRTTQueues, connID) - if s.logger.Debug() { - s.logger.Debugf("Removing 0-RTT queue for %s.", connID) - } - } - }) + expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration) + queue.expiration = expiration + if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) { + s.nextZeroRTTCleanup = expiration + } s.zeroRTTQueues[connID] = queue return true } +func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { + // Iterate over all queues to find those that are expired. + // This is ok since we're placing a pretty low limit on the number of queues. + var nextCleanup time.Time + for connID, q := range s.zeroRTTQueues { + if q.expiration.After(now) { + if nextCleanup.IsZero() || nextCleanup.After(q.expiration) { + nextCleanup = q.expiration + } + continue + } + for _, p := range q.packets { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + p.buffer.Release() + } + delete(s.zeroRTTQueues, connID) + if s.logger.Debug() { + s.logger.Debugf("Removing 0-RTT queue for %s.", connID) + } + } + s.nextZeroRTTCleanup = nextCleanup +} + // validateToken returns false if: // - address is invalid // - token is expired @@ -538,10 +552,8 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } } if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { - s.mutex.Lock() + // Retry invalidates all 0-RTT packets sent. delete(s.zeroRTTQueues, hdr.DestConnectionID) - s.mutex.Unlock() - go func() { defer p.buffer.Release() if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { @@ -603,15 +615,12 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro ) conn.handlePacket(p) - s.mutex.Lock() if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok { - q.retireTimer.Stop() for _, p := range q.packets { conn.handlePacket(p) } delete(s.zeroRTTQueues, hdr.DestConnectionID) } - s.mutex.Unlock() return conn }); !added { diff --git a/server_test.go b/server_test.go index e585cc44..a349efc1 100644 --- a/server_test.go +++ b/server_test.go @@ -1255,7 +1255,6 @@ var _ = Describe("Server", func() { serv = ln.(*earlyServer) phm = NewMockPacketHandlerManager(mockCtrl) serv.connHandler = phm - serv.zeroRTTQueueDuration = time.Hour }) AfterEach(func() { @@ -1383,23 +1382,73 @@ var _ = Describe("Server", func() { }) It("drops queues after a while", func() { - serv.zeroRTTQueueDuration = time.Millisecond + now := time.Now() + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) p := getPacket(&wire.Header{ Type: protocol.PacketType0RTT, DestConnectionID: connID, Version: serv.config.Versions[0], }, make([]byte, 200)) - got := make(chan struct{}) - phm.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(got) }) + p.rcvTime = now + + connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9}) + p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2) + p2 := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID2, + Version: serv.config.Versions[0], + }, make([]byte, 300)) + p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet + + dropped1 := make(chan struct{}) + dropped2 := make(chan struct{}) + // need to register the call before handling the packet to avoid race condition + gomock.InOrder( + tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(dropped1) + }), + tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(dropped2) + }), + ) + + phm.EXPECT().Get(connID) serv.handlePacket(p) - <-got - // TODO: use the tracer call - Eventually(func() int { - serv.mutex.Lock() - defer serv.mutex.Unlock() - return len(serv.zeroRTTQueues) - }).Should(BeZero()) + + // There's no cleanup Go routine. + // Cleanup is triggered when new packets are received. + + phm.EXPECT().Get(connID2) + serv.handlePacket(p2) + // make sure no cleanup is executed + Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed()) + + // There's no cleanup Go routine. + // Cleanup is triggered when new packets are received. + connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0}) + p3 := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID3, + Version: serv.config.Versions[0], + }, make([]byte, 200)) + p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup + phm.EXPECT().Get(connID3) + serv.handlePacket(p3) + Eventually(dropped1).Should(BeClosed()) + Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed()) + + // make sure the second packet is also cleaned up + connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1}) + p4 := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID4, + Version: serv.config.Versions[0], + }, make([]byte, 200)) + p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup + phm.EXPECT().Get(connID4) + serv.handlePacket(p4) + Eventually(dropped2).Should(BeClosed()) }) }) })