query MTU discoverer for increases after processing ACK frame (#4941)

This commit is contained in:
Marten Seemann 2025-01-27 13:50:14 +01:00 committed by GitHub
parent 7f5ea8a54d
commit 8f27760e60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 24 additions and 35 deletions

View file

@ -147,7 +147,7 @@ type connection struct {
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received
maxPayloadSizeEstimate atomic.Uint32
currentMTUEstimate atomic.Uint32
initialStream *cryptoStream
handshakeStream *cryptoStream
@ -279,7 +279,7 @@ var newConnection = func(
s.tracer,
s.logger,
)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID)
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -391,7 +391,7 @@ var newClientConnection = func(
s.tracer,
s.logger,
)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -1586,6 +1586,13 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr
return err
}
}
// If one of the acknowledged packets was a Path MTU probe packet, this might have increased the Path MTU estimate.
if s.mtuDiscoverer != nil {
if mtu := s.mtuDiscoverer.CurrentSize(); mtu > protocol.ByteCount(s.currentMTUEstimate.Load()) {
s.currentMTUEstimate.Store(uint32(mtu))
s.sentPacketHandler.SetMaxDatagramSize(mtu)
}
}
return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked())
}
@ -1851,7 +1858,6 @@ func (s *connection) applyTransportParameters() {
s.rttStats,
protocol.ByteCount(s.config.InitialPacketSize),
maxPacketSize,
s.onMTUIncreased,
s.tracer,
)
}
@ -2319,11 +2325,6 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
s.framer.RemoveActiveStream(id)
}
func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu)))
s.sentPacketHandler.SetMaxDatagramSize(mtu)
}
func (s *connection) SendDatagram(p []byte) error {
if !s.supportsDatagrams() {
return errors.New("datagram support disabled")
@ -2334,7 +2335,7 @@ func (s *connection) SendDatagram(p []byte) error {
// Under many circumstances we could send a few more bytes.
maxDataLen := min(
f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version),
protocol.ByteCount(s.maxPayloadSizeEstimate.Load()),
protocol.ByteCount(s.currentMTUEstimate.Load()),
)
if protocol.ByteCount(len(p)) > maxDataLen {
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}

View file

@ -88,7 +88,6 @@ const (
type mtuFinder struct {
lastProbeTime time.Time
mtuIncreased func(protocol.ByteCount)
rttStats *utils.RTTStats
@ -107,15 +106,13 @@ var _ mtuDiscoverer = &mtuFinder{}
func newMTUDiscoverer(
rttStats *utils.RTTStats,
start, max protocol.ByteCount,
mtuIncreased func(protocol.ByteCount),
tracer *logging.ConnectionTracer,
) *mtuFinder {
f := &mtuFinder{
inFlight: protocol.InvalidByteCount,
min: start,
rttStats: rttStats,
mtuIncreased: mtuIncreased,
tracer: tracer,
inFlight: protocol.InvalidByteCount,
min: start,
rttStats: rttStats,
tracer: tracer,
}
for i := range f.lost {
if i == 0 {
@ -207,7 +204,6 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
if h.tracer != nil && h.tracer.UpdatedMTU != nil {
h.tracer.UpdatedMTU(size, h.done())
}
h.mtuIncreased(size)
}
func (h *mtuFinderAckHandler) OnLost(wire.Frame) {

View file

@ -17,7 +17,7 @@ func TestMTUDiscovererTiming(t *testing.T) {
const rtt = 100 * time.Millisecond
var rttStats utils.RTTStats
rttStats.UpdateRTT(rtt, 0)
d := newMTUDiscoverer(&rttStats, 1000, 2000, func(s protocol.ByteCount) {}, nil)
d := newMTUDiscoverer(&rttStats, 1000, 2000, nil)
now := time.Now()
require.False(t, d.ShouldSendProbe(now))
@ -37,21 +37,20 @@ func TestMTUDiscovererTiming(t *testing.T) {
}
func TestMTUDiscovererAckAndLoss(t *testing.T) {
var mtu protocol.ByteCount
d := newMTUDiscoverer(&utils.RTTStats{}, 1000, 2000, func(s protocol.ByteCount) { mtu = s }, nil)
d := newMTUDiscoverer(&utils.RTTStats{}, 1000, 2000, nil)
// we use an RTT of 0 here, so we don't have to advance the timer on every step
now := time.Now()
ping, size := d.GetPing(now)
require.Equal(t, protocol.ByteCount(1500), size)
// the MTU is reduced if the frame is lost
ping.Handler.OnLost(ping.Frame)
require.Zero(t, mtu) // no change to the MTU yet
require.Equal(t, protocol.ByteCount(1000), d.CurrentSize()) // no change to the MTU yet
require.True(t, d.ShouldSendProbe(now))
ping, size = d.GetPing(now)
require.Equal(t, protocol.ByteCount(1250), size)
ping.Handler.OnAcked(ping.Frame)
require.Equal(t, protocol.ByteCount(1250), mtu) // the MTU is increased
require.Equal(t, protocol.ByteCount(1250), d.CurrentSize()) // the MTU is increased
// Even though the 1500 byte MTU probe packet was lost, we try again with a higher MTU.
// This protects against regular (non-MTU-related) packet loss.
@ -59,7 +58,7 @@ func TestMTUDiscovererAckAndLoss(t *testing.T) {
ping, size = d.GetPing(now)
require.Greater(t, size, protocol.ByteCount(1500))
ping.Handler.OnAcked(ping.Frame)
require.Equal(t, size, mtu)
require.Equal(t, size, d.CurrentSize())
// We continue probing until the MTU is close to the maximum.
var steps int
@ -91,13 +90,9 @@ func testMTUDiscovererMTUDiscovery(t *testing.T) {
rttStats.UpdateRTT(rtt, 0)
maxMTU := protocol.ByteCount(rand.IntN(int(3000-startMTU))) + startMTU + 1
currentMTU := startMTU
var tracedMTU protocol.ByteCount
var tracerDone bool
d := newMTUDiscoverer(
&rttStats,
startMTU, maxMTU,
func(s protocol.ByteCount) { currentMTU = s },
d := newMTUDiscoverer(&rttStats, startMTU, maxMTU,
&logging.ConnectionTracer{
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
tracedMTU = mtu
@ -122,6 +117,7 @@ func testMTUDiscovererMTUDiscovery(t *testing.T) {
}
now = now.Add(mtuProbeDelay * rtt)
}
currentMTU := d.CurrentSize()
diff := realMTU - currentMTU
require.GreaterOrEqual(t, diff, protocol.ByteCount(0))
if maxMTU > currentMTU+maxMTU {
@ -151,15 +147,10 @@ func testMTUDiscovererWithRandomLoss(t *testing.T) {
require.Equal(t, rtt, rttStats.SmoothedRTT())
maxMTU := protocol.ByteCount(rand.IntN(int(3000-startMTU))) + startMTU + 1
currentMTU := startMTU
var tracedMTU protocol.ByteCount
var tracerDone bool
d := newMTUDiscoverer(
rttStats,
startMTU,
maxMTU,
func(s protocol.ByteCount) { currentMTU = s },
d := newMTUDiscoverer(rttStats, startMTU, maxMTU,
&logging.ConnectionTracer{
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
tracedMTU = mtu
@ -195,6 +186,7 @@ func testMTUDiscovererWithRandomLoss(t *testing.T) {
now = now.Add(mtuProbeDelay * rtt)
}
currentMTU := d.CurrentSize()
diff := realMTU - currentMTU
require.GreaterOrEqual(t, diff, protocol.ByteCount(0))
if maxMTU > currentMTU+maxMTU {