don't use timers to clean up 0-RTT queues

This commit is contained in:
Marten Seemann 2023-04-26 13:43:37 +02:00
parent 58487803d3
commit 6617f3d39a
2 changed files with 111 additions and 53 deletions

View file

@ -56,8 +56,8 @@ type quicConn interface {
} }
type zeroRTTQueue struct { type zeroRTTQueue struct {
packets []*receivedPacket packets []*receivedPacket
retireTimer *time.Timer expiration time.Time
} }
// A Listener of QUIC // A Listener of QUIC
@ -80,8 +80,8 @@ type baseServer struct {
receivedPackets chan *receivedPacket receivedPackets chan *receivedPacket
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true nextZeroRTTCleanup time.Time
zeroRTTQueueDuration time.Duration zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
// set as a member, so they can be set in the tests // set as a member, so they can be set in the tests
newConn func( newConn func(
@ -210,19 +210,18 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
return nil, err return nil, err
} }
s := &baseServer{ s := &baseServer{
conn: c, conn: c,
tlsConf: tlsConf, tlsConf: tlsConf,
config: config, config: config,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, tokenGenerator: tokenGenerator,
tokenGenerator: tokenGenerator, connHandler: connHandler,
connHandler: connHandler, connQueue: make(chan quicConn),
connQueue: make(chan quicConn), errorChan: make(chan struct{}),
errorChan: make(chan struct{}), running: make(chan struct{}),
running: make(chan struct{}), receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), newConn: newConnection,
newConn: newConnection, logger: utils.DefaultLogger.WithPrefix("server"),
logger: utils.DefaultLogger.WithPrefix("server"), acceptEarlyConns: acceptEarly,
acceptEarlyConns: acceptEarly,
} }
if acceptEarly { if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
@ -323,6 +322,10 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
} }
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ {
if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) {
defer s.cleanupZeroRTTQueues(p.rcvTime)
}
if wire.IsVersionNegotiationPacket(p.data) { if wire.IsVersionNegotiationPacket(p.data) {
s.logger.Debugf("Dropping Version Negotiation packet.") s.logger.Debugf("Dropping Version Negotiation packet.")
if s.config.Tracer != nil { if s.config.Tracer != nil {
@ -422,9 +425,6 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
return true return true
} }
s.mutex.Lock()
defer s.mutex.Unlock()
if q, ok := s.zeroRTTQueues[connID]; ok { if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen { if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.config.Tracer != nil { if s.config.Tracer != nil {
@ -444,26 +444,40 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
} }
queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)} queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)}
queue.packets[0] = p queue.packets[0] = p
queue.retireTimer = time.AfterFunc(s.zeroRTTQueueDuration, func() { expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration)
s.mutex.Lock() queue.expiration = expiration
defer s.mutex.Unlock() if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) {
// The entry might have been replaced by an actual connection. s.nextZeroRTTCleanup = expiration
// Only delete it if it's still a 0-RTT queue. }
if q, ok := s.zeroRTTQueues[connID]; ok {
for _, p := range q.packets {
// TODO: trace
p.buffer.Release()
}
delete(s.zeroRTTQueues, connID)
if s.logger.Debug() {
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
})
s.zeroRTTQueues[connID] = queue s.zeroRTTQueues[connID] = queue
return true return true
} }
func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
// Iterate over all queues to find those that are expired.
// This is ok since we're placing a pretty low limit on the number of queues.
var nextCleanup time.Time
for connID, q := range s.zeroRTTQueues {
if q.expiration.After(now) {
if nextCleanup.IsZero() || nextCleanup.After(q.expiration) {
nextCleanup = q.expiration
}
continue
}
for _, p := range q.packets {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
p.buffer.Release()
}
delete(s.zeroRTTQueues, connID)
if s.logger.Debug() {
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
s.nextZeroRTTCleanup = nextCleanup
}
// validateToken returns false if: // validateToken returns false if:
// - address is invalid // - address is invalid
// - token is expired // - token is expired
@ -538,10 +552,8 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
} }
} }
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
s.mutex.Lock() // Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID) delete(s.zeroRTTQueues, hdr.DestConnectionID)
s.mutex.Unlock()
go func() { go func() {
defer p.buffer.Release() defer p.buffer.Release()
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
@ -603,15 +615,12 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
) )
conn.handlePacket(p) conn.handlePacket(p)
s.mutex.Lock()
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok { if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
q.retireTimer.Stop()
for _, p := range q.packets { for _, p := range q.packets {
conn.handlePacket(p) conn.handlePacket(p)
} }
delete(s.zeroRTTQueues, hdr.DestConnectionID) delete(s.zeroRTTQueues, hdr.DestConnectionID)
} }
s.mutex.Unlock()
return conn return conn
}); !added { }); !added {

View file

@ -1255,7 +1255,6 @@ var _ = Describe("Server", func() {
serv = ln.(*earlyServer) serv = ln.(*earlyServer)
phm = NewMockPacketHandlerManager(mockCtrl) phm = NewMockPacketHandlerManager(mockCtrl)
serv.connHandler = phm serv.connHandler = phm
serv.zeroRTTQueueDuration = time.Hour
}) })
AfterEach(func() { AfterEach(func() {
@ -1383,23 +1382,73 @@ var _ = Describe("Server", func() {
}) })
It("drops queues after a while", func() { It("drops queues after a while", func() {
serv.zeroRTTQueueDuration = time.Millisecond now := time.Now()
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{ p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT, Type: protocol.PacketType0RTT,
DestConnectionID: connID, DestConnectionID: connID,
Version: serv.config.Versions[0], Version: serv.config.Versions[0],
}, make([]byte, 200)) }, make([]byte, 200))
got := make(chan struct{}) p.rcvTime = now
phm.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(got) })
connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
p2 := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID2,
Version: serv.config.Versions[0],
}, make([]byte, 300))
p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
dropped1 := make(chan struct{})
dropped2 := make(chan struct{})
// need to register the call before handling the packet to avoid race condition
gomock.InOrder(
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(dropped1)
}),
tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(dropped2)
}),
)
phm.EXPECT().Get(connID)
serv.handlePacket(p) serv.handlePacket(p)
<-got
// TODO: use the tracer call // There's no cleanup Go routine.
Eventually(func() int { // Cleanup is triggered when new packets are received.
serv.mutex.Lock()
defer serv.mutex.Unlock() phm.EXPECT().Get(connID2)
return len(serv.zeroRTTQueues) serv.handlePacket(p2)
}).Should(BeZero()) // make sure no cleanup is executed
Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
// There's no cleanup Go routine.
// Cleanup is triggered when new packets are received.
connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
p3 := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID3,
Version: serv.config.Versions[0],
}, make([]byte, 200))
p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
phm.EXPECT().Get(connID3)
serv.handlePacket(p3)
Eventually(dropped1).Should(BeClosed())
Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
// make sure the second packet is also cleaned up
connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
p4 := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID4,
Version: serv.config.Versions[0],
}, make([]byte, 200))
p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
phm.EXPECT().Get(connID4)
serv.handlePacket(p4)
Eventually(dropped2).Should(BeClosed())
}) })
}) })
}) })