mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
add Transport config options to limit the number of handshakes (#4248)
* add Transport config options to limit the number of handshakes * fix accounting for failed handshakes * increase handshake limits, improve documentation
This commit is contained in:
parent
bda5b7e6dc
commit
892851eb8c
4 changed files with 466 additions and 142 deletions
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
|
@ -14,6 +15,7 @@ import (
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/qerr"
|
"github.com/quic-go/quic-go/internal/qerr"
|
||||||
"github.com/quic-go/quic-go/internal/qtls"
|
"github.com/quic-go/quic-go/internal/qtls"
|
||||||
|
"github.com/quic-go/quic-go/logging"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
@ -301,7 +303,7 @@ var _ = Describe("Handshake tests", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("rate limiting", func() {
|
Context("queuening and accepting connections", func() {
|
||||||
var (
|
var (
|
||||||
server *quic.Listener
|
server *quic.Listener
|
||||||
pconn net.PacketConn
|
pconn net.PacketConn
|
||||||
|
@ -343,8 +345,11 @@ var _ = Describe("Handshake tests", func() {
|
||||||
}
|
}
|
||||||
time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
|
time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
|
||||||
|
|
||||||
_, err := dial()
|
conn, err := dial()
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_, err = conn.AcceptStream(ctx)
|
||||||
var transportErr *quic.TransportError
|
var transportErr *quic.TransportError
|
||||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||||
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
||||||
|
@ -353,18 +358,21 @@ var _ = Describe("Handshake tests", func() {
|
||||||
_, err = server.Accept(context.Background())
|
_, err = server.Accept(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
// dial again, and expect that this dial succeeds
|
// dial again, and expect that this dial succeeds
|
||||||
conn, err := dial()
|
conn2, err := dial()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer conn.CloseWithError(0, "")
|
defer conn2.CloseWithError(0, "")
|
||||||
time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
|
time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
|
||||||
|
|
||||||
_, err = dial()
|
conn3, err := dial()
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_, err = conn3.AcceptStream(ctx)
|
||||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||||
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("removes closed connections from the accept queue", func() {
|
It("also returns closed connections from the accept queue", func() {
|
||||||
firstConn, err := dial()
|
firstConn, err := dial()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
@ -375,8 +383,11 @@ var _ = Describe("Handshake tests", func() {
|
||||||
}
|
}
|
||||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
|
||||||
|
|
||||||
_, err = dial()
|
conn, err := dial()
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_, err = conn.AcceptStream(ctx)
|
||||||
var transportErr *quic.TransportError
|
var transportErr *quic.TransportError
|
||||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||||
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
||||||
|
@ -388,8 +399,11 @@ var _ = Describe("Handshake tests", func() {
|
||||||
time.Sleep(scaleDuration(200 * time.Millisecond))
|
time.Sleep(scaleDuration(200 * time.Millisecond))
|
||||||
|
|
||||||
// dial again, and expect that this fails again
|
// dial again, and expect that this fails again
|
||||||
_, err = dial()
|
conn2, err := dial()
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_, err = conn2.AcceptStream(ctx)
|
||||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||||
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
||||||
|
|
||||||
|
@ -448,6 +462,145 @@ var _ = Describe("Handshake tests", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("limiting handshakes", func() {
|
||||||
|
var conn *net.UDPConn
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
conn, err = net.ListenUDP("udp", addr)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() { conn.Close() })
|
||||||
|
|
||||||
|
It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
|
||||||
|
const limit = 3
|
||||||
|
tr := quic.Transport{
|
||||||
|
Conn: conn,
|
||||||
|
MaxUnvalidatedHandshakes: limit,
|
||||||
|
}
|
||||||
|
defer tr.Close()
|
||||||
|
|
||||||
|
// Block all handshakes.
|
||||||
|
handshakes := make(chan struct{})
|
||||||
|
var tlsConf tls.Config
|
||||||
|
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
handshakes <- struct{}{}
|
||||||
|
return getTLSConfig(), nil
|
||||||
|
}
|
||||||
|
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
const additional = 2
|
||||||
|
results := make([]struct{ retry, closed atomic.Bool }, limit+additional)
|
||||||
|
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
|
||||||
|
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
|
||||||
|
// exactly 2 to experience a Retry.
|
||||||
|
for i := 0; i < limit+additional; i++ {
|
||||||
|
go func(index int) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
quicConf := getQuicConfig(&quic.Config{
|
||||||
|
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||||
|
return &logging.ConnectionTracer{
|
||||||
|
ReceivedRetry: func(*logging.Header) { results[index].retry.Store(true) },
|
||||||
|
ClosedConnection: func(error) { results[index].closed.Store(true) },
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
conn.CloseWithError(0, "")
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
numRetries := func() (n int) {
|
||||||
|
for i := 0; i < limit+additional; i++ {
|
||||||
|
if results[i].retry.Load() {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
numClosed := func() (n int) {
|
||||||
|
for i := 0; i < limit+2; i++ {
|
||||||
|
if results[i].closed.Load() {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Eventually(numRetries).Should(Equal(additional))
|
||||||
|
// allow the handshakes to complete
|
||||||
|
for i := 0; i < limit+additional; i++ {
|
||||||
|
Eventually(handshakes).Should(Receive())
|
||||||
|
}
|
||||||
|
Eventually(numClosed).Should(Equal(limit + additional))
|
||||||
|
Expect(numRetries()).To(Equal(additional)) // just to be on the safe side
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
|
||||||
|
const limit = 3
|
||||||
|
tr := quic.Transport{
|
||||||
|
Conn: conn,
|
||||||
|
MaxHandshakes: limit,
|
||||||
|
}
|
||||||
|
defer tr.Close()
|
||||||
|
|
||||||
|
// Block all handshakes.
|
||||||
|
handshakes := make(chan struct{})
|
||||||
|
var tlsConf tls.Config
|
||||||
|
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
handshakes <- struct{}{}
|
||||||
|
return getTLSConfig(), nil
|
||||||
|
}
|
||||||
|
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
const additional = 2
|
||||||
|
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
|
||||||
|
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
|
||||||
|
// exactly 2 to experience a Retry.
|
||||||
|
var numSuccessful, numFailed atomic.Int32
|
||||||
|
for i := 0; i < limit+additional; i++ {
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
quicConf := getQuicConfig(&quic.Config{
|
||||||
|
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||||
|
return &logging.ConnectionTracer{
|
||||||
|
ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") },
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
|
||||||
|
if err != nil {
|
||||||
|
var transportErr *quic.TransportError
|
||||||
|
if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused {
|
||||||
|
Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err))
|
||||||
|
}
|
||||||
|
numFailed.Add(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
numSuccessful.Add(1)
|
||||||
|
conn.CloseWithError(0, "")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional))
|
||||||
|
// allow the handshakes to complete
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
Eventually(handshakes).Should(Receive())
|
||||||
|
}
|
||||||
|
Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit))
|
||||||
|
|
||||||
|
// make sure that the server is reachable again after these handshakes have completed
|
||||||
|
go func() { <-handshakes }() // allow this handshake to complete immediately
|
||||||
|
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
conn.CloseWithError(0, "")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Context("ALPN", func() {
|
Context("ALPN", func() {
|
||||||
It("negotiates an application protocol", func() {
|
It("negotiates an application protocol", func() {
|
||||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||||
|
|
145
server.go
145
server.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/handshake"
|
"github.com/quic-go/quic-go/internal/handshake"
|
||||||
|
@ -110,6 +111,11 @@ type baseServer struct {
|
||||||
connectionRefusedQueue chan rejectedPacket
|
connectionRefusedQueue chan rejectedPacket
|
||||||
retryQueue chan rejectedPacket
|
retryQueue chan rejectedPacket
|
||||||
|
|
||||||
|
maxNumHandshakesUnvalidated int
|
||||||
|
maxNumHandshakesTotal int
|
||||||
|
numHandshakesUnvalidated atomic.Int64
|
||||||
|
numHandshakesValidated atomic.Int64
|
||||||
|
|
||||||
connQueue chan quicConn
|
connQueue chan quicConn
|
||||||
|
|
||||||
tracer *logging.Tracer
|
tracer *logging.Tracer
|
||||||
|
@ -238,31 +244,34 @@ func newServer(
|
||||||
onClose func(),
|
onClose func(),
|
||||||
tokenGeneratorKey TokenGeneratorKey,
|
tokenGeneratorKey TokenGeneratorKey,
|
||||||
maxTokenAge time.Duration,
|
maxTokenAge time.Duration,
|
||||||
|
maxNumHandshakesUnvalidated, maxNumHandshakesTotal int,
|
||||||
disableVersionNegotiation bool,
|
disableVersionNegotiation bool,
|
||||||
acceptEarly bool,
|
acceptEarly bool,
|
||||||
) *baseServer {
|
) *baseServer {
|
||||||
s := &baseServer{
|
s := &baseServer{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
config: config,
|
config: config,
|
||||||
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
|
||||||
maxTokenAge: maxTokenAge,
|
maxTokenAge: maxTokenAge,
|
||||||
connIDGenerator: connIDGenerator,
|
maxNumHandshakesUnvalidated: maxNumHandshakesUnvalidated,
|
||||||
connHandler: connHandler,
|
maxNumHandshakesTotal: maxNumHandshakesTotal,
|
||||||
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
connIDGenerator: connIDGenerator,
|
||||||
errorChan: make(chan struct{}),
|
connHandler: connHandler,
|
||||||
running: make(chan struct{}),
|
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
||||||
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
errorChan: make(chan struct{}),
|
||||||
versionNegotiationQueue: make(chan receivedPacket, 4),
|
running: make(chan struct{}),
|
||||||
invalidTokenQueue: make(chan rejectedPacket, 4),
|
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||||
connectionRefusedQueue: make(chan rejectedPacket, 4),
|
versionNegotiationQueue: make(chan receivedPacket, 4),
|
||||||
retryQueue: make(chan rejectedPacket, 8),
|
invalidTokenQueue: make(chan rejectedPacket, 4),
|
||||||
newConn: newConnection,
|
connectionRefusedQueue: make(chan rejectedPacket, 4),
|
||||||
tracer: tracer,
|
retryQueue: make(chan rejectedPacket, 8),
|
||||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
newConn: newConnection,
|
||||||
acceptEarlyConns: acceptEarly,
|
tracer: tracer,
|
||||||
disableVersionNegotiation: disableVersionNegotiation,
|
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||||
onClose: onClose,
|
acceptEarlyConns: acceptEarly,
|
||||||
|
disableVersionNegotiation: disableVersionNegotiation,
|
||||||
|
onClose: onClose,
|
||||||
}
|
}
|
||||||
if acceptEarly {
|
if acceptEarly {
|
||||||
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
||||||
|
@ -570,8 +579,8 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAddrIsValid := s.validateToken(token, p.remoteAddr)
|
clientAddrValidated := s.validateToken(token, p.remoteAddr)
|
||||||
if token != nil && !clientAddrIsValid {
|
if token != nil && !clientAddrValidated {
|
||||||
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
|
||||||
// We just ignore them, and act as if there was no token on this packet at all.
|
// We just ignore them, and act as if there was no token on this packet at all.
|
||||||
// This also means we might send a Retry later.
|
// This also means we might send a Retry later.
|
||||||
|
@ -590,7 +599,25 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
|
|
||||||
|
// Until the next call to handleInitialImpl, these numbers are guaranteed to not increase.
|
||||||
|
// They might decrease if another connection completes the handshake.
|
||||||
|
numHandshakesUnvalidated := s.numHandshakesUnvalidated.Load()
|
||||||
|
numHandshakesValidated := s.numHandshakesValidated.Load()
|
||||||
|
|
||||||
|
// Check the total handshake limit first. It's better to reject than to initiate a retry.
|
||||||
|
if total := numHandshakesUnvalidated + numHandshakesValidated; total >= int64(s.maxNumHandshakesTotal) {
|
||||||
|
s.logger.Debugf("Rejecting new connection. Server currently busy. Currently handshaking: %d (max %d)", total, s.maxNumHandshakesTotal)
|
||||||
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
|
select {
|
||||||
|
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
||||||
|
default:
|
||||||
|
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
||||||
|
p.buffer.Release()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if token == nil && (s.config.RequireAddressValidation(p.remoteAddr) || numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated)) {
|
||||||
// Retry invalidates all 0-RTT packets sent.
|
// Retry invalidates all 0-RTT packets sent.
|
||||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||||
select {
|
select {
|
||||||
|
@ -602,17 +629,6 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if queueLen := len(s.connQueue); queueLen >= protocol.MaxAcceptQueueSize {
|
|
||||||
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
|
|
||||||
select {
|
|
||||||
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
|
|
||||||
default:
|
|
||||||
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
|
|
||||||
p.buffer.Release()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
connID, err := s.connIDGenerator.GenerateConnectionID()
|
connID, err := s.connIDGenerator.GenerateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -652,7 +668,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
config,
|
config,
|
||||||
s.tlsConf,
|
s.tlsConf,
|
||||||
s.tokenGenerator,
|
s.tokenGenerator,
|
||||||
clientAddrIsValid,
|
clientAddrValidated,
|
||||||
tracer,
|
tracer,
|
||||||
tracingID,
|
tracingID,
|
||||||
s.logger,
|
s.logger,
|
||||||
|
@ -677,8 +693,31 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if clientAddrValidated {
|
||||||
|
s.numHandshakesValidated.Add(1)
|
||||||
|
} else {
|
||||||
|
s.numHandshakesUnvalidated.Add(1)
|
||||||
|
}
|
||||||
go conn.run()
|
go conn.run()
|
||||||
go s.handleNewConn(conn)
|
go func() {
|
||||||
|
completed := s.handleNewConn(conn)
|
||||||
|
if clientAddrValidated {
|
||||||
|
if s.numHandshakesValidated.Add(-1) < 0 {
|
||||||
|
panic("server BUG: number of validated handshakes negative")
|
||||||
|
}
|
||||||
|
} else if s.numHandshakesUnvalidated.Add(-1) < 0 {
|
||||||
|
panic("server BUG: number of unvalidated handshakes negative")
|
||||||
|
}
|
||||||
|
if !completed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case s.connQueue <- conn:
|
||||||
|
default:
|
||||||
|
conn.closeWithTransportError(ConnectionRefused)
|
||||||
|
}
|
||||||
|
}()
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
p.buffer.Release()
|
p.buffer.Release()
|
||||||
return nil
|
return nil
|
||||||
|
@ -686,34 +725,28 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *baseServer) handleNewConn(conn quicConn) {
|
func (s *baseServer) handleNewConn(conn quicConn) bool {
|
||||||
connCtx := conn.Context()
|
|
||||||
if s.acceptEarlyConns {
|
if s.acceptEarlyConns {
|
||||||
// wait until the early connection is ready, the handshake fails, or the server is closed
|
// wait until the early connection is ready, the handshake fails, or the server is closed
|
||||||
select {
|
select {
|
||||||
case <-s.errorChan:
|
case <-s.errorChan:
|
||||||
conn.closeWithTransportError(ConnectionRefused)
|
conn.closeWithTransportError(ConnectionRefused)
|
||||||
return
|
return false
|
||||||
|
case <-conn.Context().Done():
|
||||||
|
return false
|
||||||
case <-conn.earlyConnReady():
|
case <-conn.earlyConnReady():
|
||||||
case <-connCtx.Done():
|
return true
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// wait until the handshake is complete (or fails)
|
|
||||||
select {
|
|
||||||
case <-s.errorChan:
|
|
||||||
conn.closeWithTransportError(ConnectionRefused)
|
|
||||||
return
|
|
||||||
case <-conn.HandshakeComplete():
|
|
||||||
case <-connCtx.Done():
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// wait until the handshake completes, fails, or the server is closed
|
||||||
select {
|
select {
|
||||||
case s.connQueue <- conn:
|
case <-s.errorChan:
|
||||||
default:
|
conn.closeWithTransportError(ConnectionRefused)
|
||||||
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
|
return false
|
||||||
|
case <-conn.Context().Done():
|
||||||
|
return false
|
||||||
|
case <-conn.HandshakeComplete():
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
244
server_test.go
244
server_test.go
|
@ -83,6 +83,25 @@ var _ = Describe("Server", func() {
|
||||||
return hdr
|
return hdr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) {
|
||||||
|
replyHdr := parseHeader(b)
|
||||||
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||||
|
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
|
||||||
|
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
|
||||||
|
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
|
||||||
|
extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||||
|
ccf := f.(*wire.ConnectionCloseFrame)
|
||||||
|
Expect(ccf.IsApplicationError).To(BeFalse())
|
||||||
|
Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode))
|
||||||
|
Expect(ccf.ReasonPhrase).To(BeEmpty())
|
||||||
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
conn = NewMockPacketConn(mockCtrl)
|
conn = NewMockPacketConn(mockCtrl)
|
||||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||||
|
@ -534,6 +553,9 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops packets if the receive queue is full", func() {
|
It("drops packets if the receive queue is full", func() {
|
||||||
|
serv.maxNumHandshakesTotal = 10000
|
||||||
|
serv.maxNumHandshakesUnvalidated = 10000
|
||||||
|
|
||||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
@ -542,7 +564,7 @@ var _ = Describe("Server", func() {
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
acceptConn := make(chan struct{})
|
acceptConn := make(chan struct{})
|
||||||
var counter uint32 // to be used as an atomic, so we query it in Eventually
|
var counter atomic.Uint32
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
runner connRunner,
|
runner connRunner,
|
||||||
|
@ -563,7 +585,7 @@ var _ = Describe("Server", func() {
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
) quicConn {
|
) quicConn {
|
||||||
<-acceptConn
|
<-acceptConn
|
||||||
atomic.AddUint32(&counter, 1)
|
counter.Add(1)
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
|
conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
|
||||||
conn.EXPECT().run().MaxTimes(1)
|
conn.EXPECT().run().MaxTimes(1)
|
||||||
|
@ -590,10 +612,10 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
close(acceptConn)
|
close(acceptConn)
|
||||||
Eventually(
|
Eventually(
|
||||||
func() uint32 { return atomic.LoadUint32(&counter) },
|
func() uint32 { return counter.Load() },
|
||||||
scaleDuration(100*time.Millisecond),
|
scaleDuration(100*time.Millisecond),
|
||||||
).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
|
).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
|
||||||
Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
|
Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("only creates a single connection for a duplicate Initial", func() {
|
It("only creates a single connection for a duplicate Initial", func() {
|
||||||
|
@ -633,7 +655,20 @@ var _ = Describe("Server", func() {
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects new connection attempts if the accept queue is full", func() {
|
It("limits the number of unvalidated handshakes", func() {
|
||||||
|
const limit = 3
|
||||||
|
serv.maxNumHandshakesTotal = 10000
|
||||||
|
serv.maxNumHandshakesUnvalidated = limit
|
||||||
|
|
||||||
|
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||||
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||||
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
_, ok := fn()
|
||||||
|
return ok
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
|
handshakeChan := make(chan struct{})
|
||||||
|
connChan := make(chan *MockQUICConn, 1)
|
||||||
serv.newConn = func(
|
serv.newConn = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
runner connRunner,
|
runner connRunner,
|
||||||
|
@ -653,73 +688,140 @@ var _ = Describe("Server", func() {
|
||||||
_ utils.Logger,
|
_ utils.Logger,
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
) quicConn {
|
) quicConn {
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
conn := <-connChan
|
||||||
conn.EXPECT().handlePacket(gomock.Any())
|
conn.EXPECT().handlePacket(gomock.Any())
|
||||||
conn.EXPECT().run()
|
conn.EXPECT().run()
|
||||||
conn.EXPECT().Context().Return(context.Background())
|
conn.EXPECT().Context().Return(context.Background())
|
||||||
c := make(chan struct{})
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||||
close(c)
|
|
||||||
conn.EXPECT().HandshakeComplete().Return(c)
|
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
|
// Initiate the maximum number of allowed connection attempts.
|
||||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
for i := 0; i < limit; i++ {
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
_, ok := fn()
|
connChan <- conn
|
||||||
return ok
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||||
}).Times(protocol.MaxAcceptQueueSize)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(protocol.MaxAcceptQueueSize)
|
|
||||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
|
||||||
go func() {
|
|
||||||
defer GinkgoRecover()
|
|
||||||
defer wg.Done()
|
|
||||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
|
||||||
// make sure there are no Write calls on the packet conn
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
wg.Wait()
|
|
||||||
|
// Now initiate another connection attempt.
|
||||||
p := getInitialWithRandomDestConnID()
|
p := getInitialWithRandomDestConnID()
|
||||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||||
|
})
|
||||||
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||||
|
defer GinkgoRecover()
|
||||||
defer close(done)
|
defer close(done)
|
||||||
rejectHdr := parseHeader(b)
|
hdr, _, _, err := wire.ParsePacket(b)
|
||||||
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(rejectHdr.Version).To(Equal(hdr.Version))
|
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||||
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
||||||
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
})
|
})
|
||||||
serv.handlePacket(p)
|
serv.handlePacket(p)
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
|
|
||||||
|
close(handshakeChan)
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
_, err := serv.Accept(context.Background())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
|
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
|
||||||
|
connChan <- conn
|
||||||
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("limits the number of total handshakes", func() {
|
||||||
|
const limit = 3
|
||||||
|
serv.maxNumHandshakesTotal = limit
|
||||||
|
serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry
|
||||||
|
|
||||||
|
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||||
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||||
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
_, ok := fn()
|
||||||
|
return ok
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
|
handshakeChan := make(chan struct{})
|
||||||
|
connChan := make(chan *MockQUICConn, 1)
|
||||||
|
serv.newConn = func(
|
||||||
|
_ sendConn,
|
||||||
|
runner connRunner,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ *protocol.ConnectionID,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ ConnectionIDGenerator,
|
||||||
|
_ protocol.StatelessResetToken,
|
||||||
|
_ *Config,
|
||||||
|
_ *tls.Config,
|
||||||
|
_ *handshake.TokenGenerator,
|
||||||
|
_ bool,
|
||||||
|
_ *logging.ConnectionTracer,
|
||||||
|
_ uint64,
|
||||||
|
_ utils.Logger,
|
||||||
|
_ protocol.VersionNumber,
|
||||||
|
) quicConn {
|
||||||
|
conn := <-connChan
|
||||||
|
conn.EXPECT().handlePacket(gomock.Any())
|
||||||
|
conn.EXPECT().run()
|
||||||
|
conn.EXPECT().Context().Return(context.Background())
|
||||||
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
|
connChan <- conn
|
||||||
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||||
|
}
|
||||||
|
|
||||||
|
p := getInitialWithRandomDestConnID()
|
||||||
|
done := make(chan struct{})
|
||||||
|
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
hdr, _, _, err := wire.ParsePacket(p.data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||||
|
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||||
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||||
|
Expect(frames).To(HaveLen(1))
|
||||||
|
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
||||||
|
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
||||||
|
Expect(ccf.IsApplicationError).To(BeFalse())
|
||||||
|
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused))
|
||||||
|
})
|
||||||
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
defer close(done)
|
||||||
|
hdr, _, _, err := wire.ParsePacket(p.data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
checkConnectionCloseError(b, hdr, qerr.ConnectionRefused)
|
||||||
|
return len(b), nil
|
||||||
|
})
|
||||||
|
serv.handlePacket(p)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
|
||||||
|
close(handshakeChan)
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
_, err := serv.Accept(context.Background())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
|
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
|
||||||
|
connChan <- conn
|
||||||
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("token validation", func() {
|
Context("token validation", func() {
|
||||||
checkInvalidToken := func(b []byte, origHdr *wire.Header) {
|
|
||||||
replyHdr := parseHeader(b)
|
|
||||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
||||||
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
|
|
||||||
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
|
|
||||||
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
|
|
||||||
extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
|
||||||
ccf := f.(*wire.ConnectionCloseFrame)
|
|
||||||
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
|
||||||
Expect(ccf.ReasonPhrase).To(BeEmpty())
|
|
||||||
}
|
|
||||||
|
|
||||||
It("decodes the token from the token field", func() {
|
It("decodes the token from the token field", func() {
|
||||||
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
||||||
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
|
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
|
||||||
|
@ -771,7 +873,7 @@ var _ = Describe("Server", func() {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
checkInvalidToken(b, hdr)
|
checkConnectionCloseError(b, hdr, qerr.InvalidToken)
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
})
|
})
|
||||||
phm.EXPECT().Get(gomock.Any())
|
phm.EXPECT().Get(gomock.Any())
|
||||||
|
@ -809,7 +911,7 @@ var _ = Describe("Server", func() {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
checkInvalidToken(b, hdr)
|
checkConnectionCloseError(b, hdr, qerr.InvalidToken)
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
})
|
})
|
||||||
phm.EXPECT().Get(gomock.Any())
|
phm.EXPECT().Get(gomock.Any())
|
||||||
|
@ -1186,8 +1288,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects new connection attempts if the accept queue is full", func() {
|
It("rejects new connection attempts if the accept queue is full", func() {
|
||||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
connChan := make(chan *MockQUICConn, 1)
|
||||||
|
|
||||||
serv.baseServer.newConn = func(
|
serv.baseServer.newConn = func(
|
||||||
_ sendConn,
|
_ sendConn,
|
||||||
runner connRunner,
|
runner connRunner,
|
||||||
|
@ -1209,7 +1310,7 @@ var _ = Describe("Server", func() {
|
||||||
) quicConn {
|
) quicConn {
|
||||||
ready := make(chan struct{})
|
ready := make(chan struct{})
|
||||||
close(ready)
|
close(ready)
|
||||||
conn := NewMockQUICConn(mockCtrl)
|
conn := <-connChan
|
||||||
conn.EXPECT().handlePacket(gomock.Any())
|
conn.EXPECT().handlePacket(gomock.Any())
|
||||||
conn.EXPECT().run()
|
conn.EXPECT().run()
|
||||||
conn.EXPECT().earlyConnReady().Return(ready)
|
conn.EXPECT().earlyConnReady().Return(ready)
|
||||||
|
@ -1224,27 +1325,22 @@ var _ = Describe("Server", func() {
|
||||||
return ok
|
return ok
|
||||||
}).Times(protocol.MaxAcceptQueueSize)
|
}).Times(protocol.MaxAcceptQueueSize)
|
||||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||||
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
|
connChan <- conn
|
||||||
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
|
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
|
||||||
}
|
}
|
||||||
|
|
||||||
Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
|
Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
|
||||||
// make sure there are no Write calls on the packet conn
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
p := getInitialWithRandomDestConnID()
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||||
hdr := parseHeader(p.data)
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
done := make(chan struct{})
|
_, ok := fn()
|
||||||
conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
return ok
|
||||||
defer close(done)
|
|
||||||
rejectHdr := parseHeader(b)
|
|
||||||
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
||||||
Expect(rejectHdr.Version).To(Equal(hdr.Version))
|
|
||||||
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
||||||
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
|
||||||
return len(b), nil
|
|
||||||
})
|
})
|
||||||
serv.baseServer.handlePacket(p)
|
conn := NewMockQUICConn(mockCtrl)
|
||||||
Eventually(done).Should(BeClosed())
|
conn.EXPECT().closeWithTransportError(ConnectionRefused)
|
||||||
|
connChan <- conn
|
||||||
|
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't accept new connections if they were closed in the mean time", func() {
|
It("doesn't accept new connections if they were closed in the mean time", func() {
|
||||||
|
|
42
transport.go
42
transport.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -18,6 +19,18 @@ import (
|
||||||
|
|
||||||
var errListenerAlreadySet = errors.New("listener already set")
|
var errListenerAlreadySet = errors.New("listener already set")
|
||||||
|
|
||||||
|
const (
|
||||||
|
// defaultMaxNumUnvalidatedHandshakes is the default value for Transport.MaxUnvalidatedHandshakes.
|
||||||
|
defaultMaxNumUnvalidatedHandshakes = 128
|
||||||
|
// defaultMaxNumHandshakes is the default value for Transport.MaxHandshakes.
|
||||||
|
// It's not clear how to choose a reasonable value that works for all use cases.
|
||||||
|
// In production, implementations should:
|
||||||
|
// 1. Choose a lower value.
|
||||||
|
// 2. Implement some kind of IP-address based filtering using the Config.GetConfigForClient
|
||||||
|
// callback in order to prevent flooding attacks from a single / small number of IP addresses.
|
||||||
|
defaultMaxNumHandshakes = math.MaxInt32
|
||||||
|
)
|
||||||
|
|
||||||
// The Transport is the central point to manage incoming and outgoing QUIC connections.
|
// The Transport is the central point to manage incoming and outgoing QUIC connections.
|
||||||
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
|
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
|
||||||
// This means that a single UDP socket can be used for listening for incoming connections, as well as
|
// This means that a single UDP socket can be used for listening for incoming connections, as well as
|
||||||
|
@ -77,6 +90,25 @@ type Transport struct {
|
||||||
// It has no effect for clients.
|
// It has no effect for clients.
|
||||||
DisableVersionNegotiationPackets bool
|
DisableVersionNegotiationPackets bool
|
||||||
|
|
||||||
|
// MaxUnvalidatedHandshakes is the maximum number of concurrent incoming QUIC handshakes
|
||||||
|
// originating from unvalidated source addresses.
|
||||||
|
// If the number of handshakes from unvalidated addresses reaches this number, new incoming
|
||||||
|
// connection attempts will need to proof reachability at the respective source address using the
|
||||||
|
// Retry mechanism, as described in RFC 9000 section 8.1.2.
|
||||||
|
// Validating the source address adds one additional network roundtrip to the handshake.
|
||||||
|
// If unset, a default value of 128 will be used.
|
||||||
|
// When set to a negative value, every connection attempt will need to validate the source address.
|
||||||
|
// It does not make sense to set this value higher than MaxHandshakes.
|
||||||
|
MaxUnvalidatedHandshakes int
|
||||||
|
// MaxHandshakes is the maximum number of concurrent incoming handshakes, both from validated
|
||||||
|
// and unvalidated source addresses.
|
||||||
|
// If unset, the number of concurrent handshakes will not be limited.
|
||||||
|
// Applications should choose a reasonable value based on their thread model, and consider
|
||||||
|
// implementing IP-based rate limiting using Config.GetConfigForClient.
|
||||||
|
// If the number of handshakes reaches this number, new connection attempts will be rejected by
|
||||||
|
// terminating the connection attempt using a CONNECTION_REFUSED error.
|
||||||
|
MaxHandshakes int
|
||||||
|
|
||||||
// A Tracer traces events that don't belong to a single QUIC connection.
|
// A Tracer traces events that don't belong to a single QUIC connection.
|
||||||
Tracer *logging.Tracer
|
Tracer *logging.Tracer
|
||||||
|
|
||||||
|
@ -151,6 +183,14 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
||||||
if err := t.init(false); err != nil {
|
if err := t.init(false); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
maxUnvalidatedHandshakes := t.MaxUnvalidatedHandshakes
|
||||||
|
if maxUnvalidatedHandshakes == 0 {
|
||||||
|
maxUnvalidatedHandshakes = defaultMaxNumUnvalidatedHandshakes
|
||||||
|
}
|
||||||
|
maxHandshakes := t.MaxHandshakes
|
||||||
|
if maxHandshakes == 0 {
|
||||||
|
maxHandshakes = defaultMaxNumHandshakes
|
||||||
|
}
|
||||||
s := newServer(
|
s := newServer(
|
||||||
t.conn,
|
t.conn,
|
||||||
t.handlerMap,
|
t.handlerMap,
|
||||||
|
@ -161,6 +201,8 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
|
||||||
t.closeServer,
|
t.closeServer,
|
||||||
*t.TokenGeneratorKey,
|
*t.TokenGeneratorKey,
|
||||||
t.MaxTokenAge,
|
t.MaxTokenAge,
|
||||||
|
maxUnvalidatedHandshakes,
|
||||||
|
maxHandshakes,
|
||||||
t.DisableVersionNegotiationPackets,
|
t.DisableVersionNegotiationPackets,
|
||||||
allow0RTT,
|
allow0RTT,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue