move 0-RTT queue handling from the packet handler map to the server

This commit is contained in:
Marten Seemann 2023-04-26 10:07:28 +02:00
parent c0b94ee4b0
commit 58487803d3
7 changed files with 367 additions and 217 deletions

131
server.go
View file

@ -36,6 +36,7 @@ type unknownPacketHandler interface {
}
type packetHandlerManager interface {
Get(protocol.ConnectionID) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
Destroy() error
connRunner
@ -54,6 +55,11 @@ type quicConn interface {
shutdown()
}
type zeroRTTQueue struct {
packets []*receivedPacket
retireTimer *time.Timer
}
// A Listener of QUIC
type baseServer struct {
mutex sync.Mutex
@ -74,6 +80,9 @@ type baseServer struct {
receivedPackets chan *receivedPacket
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
zeroRTTQueueDuration time.Duration
// set as a member, so they can be set in the tests
newConn func(
sendConn,
@ -201,18 +210,22 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
return nil, err
}
s := &baseServer{
conn: c,
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
newConn: newConnection,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
conn: c,
tlsConf: tlsConf,
config: config,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration,
tokenGenerator: tokenGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
newConn: newConnection,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
}
go s.run()
connHandler.SetServer(s)
@ -344,6 +357,17 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
}
return false
}
if wire.Is0RTTPacket(p.data) {
if !s.acceptEarlyConns {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
return s.handle0RTTPacket(p)
}
// If we're creating a new connection, the packet will be passed to the connection.
// The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data)
@ -383,6 +407,63 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
return true
}
func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
}
return false
}
// check again if we might have a connection now
if handler, ok := s.connHandler.Get(connID); ok {
handler.handlePacket(p)
return true
}
s.mutex.Lock()
defer s.mutex.Unlock()
if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
q.packets = append(q.packets, p)
return true
}
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)}
queue.packets[0] = p
queue.retireTimer = time.AfterFunc(s.zeroRTTQueueDuration, func() {
s.mutex.Lock()
defer s.mutex.Unlock()
// The entry might have been replaced by an actual connection.
// Only delete it if it's still a 0-RTT queue.
if q, ok := s.zeroRTTQueues[connID]; ok {
for _, p := range q.packets {
// TODO: trace
p.buffer.Release()
}
delete(s.zeroRTTQueues, connID)
if s.logger.Debug() {
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
})
s.zeroRTTQueues[connID] = queue
return true
}
// validateToken returns false if:
// - address is invalid
// - token is expired
@ -412,6 +493,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return errors.New("too short connection ID")
}
// The server queues packets for a while, and we might already have established a connection by now.
// This results in a second check in the connection map.
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok {
handler.handlePacket(p)
return nil
}
var (
token *handshake.Token
retrySrcConnID *protocol.ConnectionID
@ -429,7 +518,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
clientAddrIsValid := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrIsValid {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all.
@ -450,6 +538,10 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
}
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
s.mutex.Lock()
delete(s.zeroRTTQueues, hdr.DestConnectionID)
s.mutex.Unlock()
go func() {
defer p.buffer.Release()
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
@ -510,8 +602,21 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
hdr.Version,
)
conn.handlePacket(p)
s.mutex.Lock()
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
q.retireTimer.Stop()
for _, p := range q.packets {
conn.handlePacket(p)
}
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
s.mutex.Unlock()
return conn
}); !added {
// TODO: don't just drop the packet
// Properly reject the connection attempt.
return nil
}
go conn.run()