add Transport config options to limit the number of handshakes (#4248)

* add Transport config options to limit the number of handshakes

* fix accounting for failed handshakes

* increase handshake limits, improve documentation
This commit is contained in:
Marten Seemann 2024-01-22 21:04:25 -08:00 committed by GitHub
parent bda5b7e6dc
commit 892851eb8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 466 additions and 142 deletions

145
server.go
View file

@ -7,6 +7,7 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/handshake"
@ -110,6 +111,11 @@ type baseServer struct {
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket
maxNumHandshakesUnvalidated int
maxNumHandshakesTotal int
numHandshakesUnvalidated atomic.Int64
numHandshakesValidated atomic.Int64
connQueue chan quicConn
tracer *logging.Tracer
@ -238,31 +244,34 @@ func newServer(
onClose func(),
tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration,
maxNumHandshakesUnvalidated, maxNumHandshakesTotal int,
disableVersionNegotiation bool,
acceptEarly bool,
) *baseServer {
s := &baseServer{
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan rejectedPacket, 4),
connectionRefusedQueue: make(chan rejectedPacket, 4),
retryQueue: make(chan rejectedPacket, 8),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
disableVersionNegotiation: disableVersionNegotiation,
onClose: onClose,
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
maxNumHandshakesUnvalidated: maxNumHandshakesUnvalidated,
maxNumHandshakesTotal: maxNumHandshakesTotal,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan rejectedPacket, 4),
connectionRefusedQueue: make(chan rejectedPacket, 4),
retryQueue: make(chan rejectedPacket, 8),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
disableVersionNegotiation: disableVersionNegotiation,
onClose: onClose,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
@ -570,8 +579,8 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
}
clientAddrIsValid := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrIsValid {
clientAddrValidated := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrValidated {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all.
// This also means we might send a Retry later.
@ -590,7 +599,25 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
}
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
// Until the next call to handleInitialImpl, these numbers are guaranteed to not increase.
// They might decrease if another connection completes the handshake.
numHandshakesUnvalidated := s.numHandshakesUnvalidated.Load()
numHandshakesValidated := s.numHandshakesValidated.Load()
// Check the total handshake limit first. It's better to reject than to initiate a retry.
if total := numHandshakesUnvalidated + numHandshakesValidated; total >= int64(s.maxNumHandshakesTotal) {
s.logger.Debugf("Rejecting new connection. Server currently busy. Currently handshaking: %d (max %d)", total, s.maxNumHandshakesTotal)
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
return nil
}
if token == nil && (s.config.RequireAddressValidation(p.remoteAddr) || numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated)) {
// Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
@ -602,17 +629,6 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
if queueLen := len(s.connQueue); queueLen >= protocol.MaxAcceptQueueSize {
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
return nil
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
@ -652,7 +668,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
config,
s.tlsConf,
s.tokenGenerator,
clientAddrIsValid,
clientAddrValidated,
tracer,
tracingID,
s.logger,
@ -677,8 +693,31 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
return nil
}
if clientAddrValidated {
s.numHandshakesValidated.Add(1)
} else {
s.numHandshakesUnvalidated.Add(1)
}
go conn.run()
go s.handleNewConn(conn)
go func() {
completed := s.handleNewConn(conn)
if clientAddrValidated {
if s.numHandshakesValidated.Add(-1) < 0 {
panic("server BUG: number of validated handshakes negative")
}
} else if s.numHandshakesUnvalidated.Add(-1) < 0 {
panic("server BUG: number of unvalidated handshakes negative")
}
if !completed {
return
}
select {
case s.connQueue <- conn:
default:
conn.closeWithTransportError(ConnectionRefused)
}
}()
if conn == nil {
p.buffer.Release()
return nil
@ -686,34 +725,28 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
func (s *baseServer) handleNewConn(conn quicConn) {
connCtx := conn.Context()
func (s *baseServer) handleNewConn(conn quicConn) bool {
if s.acceptEarlyConns {
// wait until the early connection is ready, the handshake fails, or the server is closed
select {
case <-s.errorChan:
conn.closeWithTransportError(ConnectionRefused)
return
return false
case <-conn.Context().Done():
return false
case <-conn.earlyConnReady():
case <-connCtx.Done():
return
}
} else {
// wait until the handshake is complete (or fails)
select {
case <-s.errorChan:
conn.closeWithTransportError(ConnectionRefused)
return
case <-conn.HandshakeComplete():
case <-connCtx.Done():
return
return true
}
}
// wait until the handshake completes, fails, or the server is closed
select {
case s.connQueue <- conn:
default:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
case <-s.errorChan:
conn.closeWithTransportError(ConnectionRefused)
return false
case <-conn.Context().Done():
return false
case <-conn.HandshakeComplete():
return true
}
}