mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
don't use timers to clean up 0-RTT queues
This commit is contained in:
parent
58487803d3
commit
6617f3d39a
2 changed files with 111 additions and 53 deletions
93
server.go
93
server.go
|
@ -56,8 +56,8 @@ type quicConn interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type zeroRTTQueue struct {
|
type zeroRTTQueue struct {
|
||||||
packets []*receivedPacket
|
packets []*receivedPacket
|
||||||
retireTimer *time.Timer
|
expiration time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Listener of QUIC
|
// A Listener of QUIC
|
||||||
|
@ -80,8 +80,8 @@ type baseServer struct {
|
||||||
|
|
||||||
receivedPackets chan *receivedPacket
|
receivedPackets chan *receivedPacket
|
||||||
|
|
||||||
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
|
nextZeroRTTCleanup time.Time
|
||||||
zeroRTTQueueDuration time.Duration
|
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
|
||||||
|
|
||||||
// set as a member, so they can be set in the tests
|
// set as a member, so they can be set in the tests
|
||||||
newConn func(
|
newConn func(
|
||||||
|
@ -210,19 +210,18 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s := &baseServer{
|
s := &baseServer{
|
||||||
conn: c,
|
conn: c,
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
config: config,
|
config: config,
|
||||||
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
|
tokenGenerator: tokenGenerator,
|
||||||
tokenGenerator: tokenGenerator,
|
connHandler: connHandler,
|
||||||
connHandler: connHandler,
|
connQueue: make(chan quicConn),
|
||||||
connQueue: make(chan quicConn),
|
errorChan: make(chan struct{}),
|
||||||
errorChan: make(chan struct{}),
|
running: make(chan struct{}),
|
||||||
running: make(chan struct{}),
|
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||||
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
|
newConn: newConnection,
|
||||||
newConn: newConnection,
|
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
acceptEarlyConns: acceptEarly,
|
||||||
acceptEarlyConns: acceptEarly,
|
|
||||||
}
|
}
|
||||||
if acceptEarly {
|
if acceptEarly {
|
||||||
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
||||||
|
@ -323,6 +322,10 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ {
|
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ {
|
||||||
|
if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) {
|
||||||
|
defer s.cleanupZeroRTTQueues(p.rcvTime)
|
||||||
|
}
|
||||||
|
|
||||||
if wire.IsVersionNegotiationPacket(p.data) {
|
if wire.IsVersionNegotiationPacket(p.data) {
|
||||||
s.logger.Debugf("Dropping Version Negotiation packet.")
|
s.logger.Debugf("Dropping Version Negotiation packet.")
|
||||||
if s.config.Tracer != nil {
|
if s.config.Tracer != nil {
|
||||||
|
@ -422,9 +425,6 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if q, ok := s.zeroRTTQueues[connID]; ok {
|
if q, ok := s.zeroRTTQueues[connID]; ok {
|
||||||
if len(q.packets) >= protocol.Max0RTTQueueLen {
|
if len(q.packets) >= protocol.Max0RTTQueueLen {
|
||||||
if s.config.Tracer != nil {
|
if s.config.Tracer != nil {
|
||||||
|
@ -444,26 +444,40 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
|
||||||
}
|
}
|
||||||
queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)}
|
queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)}
|
||||||
queue.packets[0] = p
|
queue.packets[0] = p
|
||||||
queue.retireTimer = time.AfterFunc(s.zeroRTTQueueDuration, func() {
|
expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration)
|
||||||
s.mutex.Lock()
|
queue.expiration = expiration
|
||||||
defer s.mutex.Unlock()
|
if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) {
|
||||||
// The entry might have been replaced by an actual connection.
|
s.nextZeroRTTCleanup = expiration
|
||||||
// Only delete it if it's still a 0-RTT queue.
|
}
|
||||||
if q, ok := s.zeroRTTQueues[connID]; ok {
|
|
||||||
for _, p := range q.packets {
|
|
||||||
// TODO: trace
|
|
||||||
p.buffer.Release()
|
|
||||||
}
|
|
||||||
delete(s.zeroRTTQueues, connID)
|
|
||||||
if s.logger.Debug() {
|
|
||||||
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
s.zeroRTTQueues[connID] = queue
|
s.zeroRTTQueues[connID] = queue
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
|
||||||
|
// Iterate over all queues to find those that are expired.
|
||||||
|
// This is ok since we're placing a pretty low limit on the number of queues.
|
||||||
|
var nextCleanup time.Time
|
||||||
|
for connID, q := range s.zeroRTTQueues {
|
||||||
|
if q.expiration.After(now) {
|
||||||
|
if nextCleanup.IsZero() || nextCleanup.After(q.expiration) {
|
||||||
|
nextCleanup = q.expiration
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, p := range q.packets {
|
||||||
|
if s.config.Tracer != nil {
|
||||||
|
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||||
|
}
|
||||||
|
p.buffer.Release()
|
||||||
|
}
|
||||||
|
delete(s.zeroRTTQueues, connID)
|
||||||
|
if s.logger.Debug() {
|
||||||
|
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.nextZeroRTTCleanup = nextCleanup
|
||||||
|
}
|
||||||
|
|
||||||
// validateToken returns false if:
|
// validateToken returns false if:
|
||||||
// - address is invalid
|
// - address is invalid
|
||||||
// - token is expired
|
// - token is expired
|
||||||
|
@ -538,10 +552,8 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
|
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
|
||||||
s.mutex.Lock()
|
// Retry invalidates all 0-RTT packets sent.
|
||||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
s.mutex.Unlock()
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer p.buffer.Release()
|
defer p.buffer.Release()
|
||||||
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
|
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
|
||||||
|
@ -603,15 +615,12 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
||||||
)
|
)
|
||||||
conn.handlePacket(p)
|
conn.handlePacket(p)
|
||||||
|
|
||||||
s.mutex.Lock()
|
|
||||||
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
|
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
|
||||||
q.retireTimer.Stop()
|
|
||||||
for _, p := range q.packets {
|
for _, p := range q.packets {
|
||||||
conn.handlePacket(p)
|
conn.handlePacket(p)
|
||||||
}
|
}
|
||||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
}
|
}
|
||||||
s.mutex.Unlock()
|
|
||||||
|
|
||||||
return conn
|
return conn
|
||||||
}); !added {
|
}); !added {
|
||||||
|
|
|
@ -1255,7 +1255,6 @@ var _ = Describe("Server", func() {
|
||||||
serv = ln.(*earlyServer)
|
serv = ln.(*earlyServer)
|
||||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||||
serv.connHandler = phm
|
serv.connHandler = phm
|
||||||
serv.zeroRTTQueueDuration = time.Hour
|
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
|
@ -1383,23 +1382,73 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops queues after a while", func() {
|
It("drops queues after a while", func() {
|
||||||
serv.zeroRTTQueueDuration = time.Millisecond
|
now := time.Now()
|
||||||
|
|
||||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||||
p := getPacket(&wire.Header{
|
p := getPacket(&wire.Header{
|
||||||
Type: protocol.PacketType0RTT,
|
Type: protocol.PacketType0RTT,
|
||||||
DestConnectionID: connID,
|
DestConnectionID: connID,
|
||||||
Version: serv.config.Versions[0],
|
Version: serv.config.Versions[0],
|
||||||
}, make([]byte, 200))
|
}, make([]byte, 200))
|
||||||
got := make(chan struct{})
|
p.rcvTime = now
|
||||||
phm.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(got) })
|
|
||||||
|
connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
|
||||||
|
p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
|
||||||
|
p2 := getPacket(&wire.Header{
|
||||||
|
Type: protocol.PacketType0RTT,
|
||||||
|
DestConnectionID: connID2,
|
||||||
|
Version: serv.config.Versions[0],
|
||||||
|
}, make([]byte, 300))
|
||||||
|
p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
|
||||||
|
|
||||||
|
dropped1 := make(chan struct{})
|
||||||
|
dropped2 := make(chan struct{})
|
||||||
|
// need to register the call before handling the packet to avoid race condition
|
||||||
|
gomock.InOrder(
|
||||||
|
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
||||||
|
close(dropped1)
|
||||||
|
}),
|
||||||
|
tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
||||||
|
close(dropped2)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
phm.EXPECT().Get(connID)
|
||||||
serv.handlePacket(p)
|
serv.handlePacket(p)
|
||||||
<-got
|
|
||||||
// TODO: use the tracer call
|
// There's no cleanup Go routine.
|
||||||
Eventually(func() int {
|
// Cleanup is triggered when new packets are received.
|
||||||
serv.mutex.Lock()
|
|
||||||
defer serv.mutex.Unlock()
|
phm.EXPECT().Get(connID2)
|
||||||
return len(serv.zeroRTTQueues)
|
serv.handlePacket(p2)
|
||||||
}).Should(BeZero())
|
// make sure no cleanup is executed
|
||||||
|
Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
|
||||||
|
|
||||||
|
// There's no cleanup Go routine.
|
||||||
|
// Cleanup is triggered when new packets are received.
|
||||||
|
connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
|
||||||
|
p3 := getPacket(&wire.Header{
|
||||||
|
Type: protocol.PacketType0RTT,
|
||||||
|
DestConnectionID: connID3,
|
||||||
|
Version: serv.config.Versions[0],
|
||||||
|
}, make([]byte, 200))
|
||||||
|
p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
|
||||||
|
phm.EXPECT().Get(connID3)
|
||||||
|
serv.handlePacket(p3)
|
||||||
|
Eventually(dropped1).Should(BeClosed())
|
||||||
|
Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
|
||||||
|
|
||||||
|
// make sure the second packet is also cleaned up
|
||||||
|
connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
|
||||||
|
p4 := getPacket(&wire.Header{
|
||||||
|
Type: protocol.PacketType0RTT,
|
||||||
|
DestConnectionID: connID4,
|
||||||
|
Version: serv.config.Versions[0],
|
||||||
|
}, make([]byte, 200))
|
||||||
|
p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
|
||||||
|
phm.EXPECT().Get(connID4)
|
||||||
|
serv.handlePacket(p4)
|
||||||
|
Eventually(dropped2).Should(BeClosed())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue