diff --git a/mtu_discoverer.go b/mtu_discoverer.go index d2d1ac77..fd1f6679 100644 --- a/mtu_discoverer.go +++ b/mtu_discoverer.go @@ -98,6 +98,11 @@ type mtuFinder struct { lost [maxLostMTUProbes]protocol.ByteCount lastProbeWasLost bool + // The generation is used to ignore ACKs / losses for probe packets sent before a reset. + // Resets happen when the connection is migrated to a new path. + // We're therefore not concerned about overflows of this counter. + generation uint8 + tracer *logging.ConnectionTracer } @@ -110,10 +115,15 @@ func newMTUDiscoverer( ) *mtuFinder { f := &mtuFinder{ inFlight: protocol.InvalidByteCount, - min: start, rttStats: rttStats, tracer: tracer, } + f.init(start, max) + return f +} + +func (f *mtuFinder) init(start, max protocol.ByteCount) { + f.min = start for i := range f.lost { if i == 0 { f.lost[i] = max @@ -121,7 +131,6 @@ func newMTUDiscoverer( } f.lost[i] = protocol.InvalidByteCount } - return f } func (f *mtuFinder) done() bool { @@ -162,7 +171,7 @@ func (f *mtuFinder) GetPing(now time.Time) (ackhandler.Frame, protocol.ByteCount f.inFlight = size return ackhandler.Frame{ Frame: &wire.PingFrame{}, - Handler: &mtuFinderAckHandler{f}, + Handler: &mtuFinderAckHandler{mtuFinder: f, generation: f.generation}, }, size } @@ -170,13 +179,26 @@ func (f *mtuFinder) CurrentSize() protocol.ByteCount { return f.min } +func (f *mtuFinder) Reset(now time.Time, start, max protocol.ByteCount) { + f.generation++ + f.lastProbeTime = now + f.lastProbeWasLost = false + f.inFlight = protocol.InvalidByteCount + f.init(start, max) +} + type mtuFinderAckHandler struct { *mtuFinder + generation uint8 } var _ ackhandler.FrameHandler = &mtuFinderAckHandler{} func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { + if h.generation != h.mtuFinder.generation { + // ACK for probe sent before reset + return + } size := h.inFlight if size == protocol.InvalidByteCount { panic("OnAcked callback called although there's no MTU probe packet in flight") @@ -207,6 +229,10 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { } func (h *mtuFinderAckHandler) OnLost(wire.Frame) { + if h.generation != h.mtuFinder.generation { + // probe sent before reset received + return + } size := h.inFlight if size == protocol.InvalidByteCount { panic("OnLost callback called although there's no MTU probe packet in flight") diff --git a/mtu_discoverer_test.go b/mtu_discoverer_test.go index 415d0638..733d0517 100644 --- a/mtu_discoverer_test.go +++ b/mtu_discoverer_test.go @@ -197,3 +197,53 @@ func testMTUDiscovererWithRandomLoss(t *testing.T) { t.Logf("probes sent (%d): %v", len(probes), probes) require.LessOrEqual(t, diff, maxMTUDiff) } + +func TestMTUDiscovererReset(t *testing.T) { + t.Run("probe on old path acknowledged", func(t *testing.T) { + testMTUDiscovererReset(t, true) + }) + t.Run("probe on old path lost", func(t *testing.T) { + testMTUDiscovererReset(t, false) + }) +} + +func testMTUDiscovererReset(t *testing.T, ackLastProbe bool) { + const startMTU protocol.ByteCount = 1000 + const maxMTU = 1400 + const rtt = 100 * time.Millisecond + + rttStats := &utils.RTTStats{} + rttStats.SetInitialRTT(rtt) + + now := time.Now() + d := newMTUDiscoverer(rttStats, startMTU, maxMTU, nil) + d.Start(now) + + ping, _ := d.GetPing(now.Add(5 * rtt)) + ping.Handler.OnAcked(ping.Frame) + require.Greater(t, d.CurrentSize(), startMTU) + now = now.Add(5 * rtt) + + // send another probe packet, but neither acknowledge nor lose it before resetting + ping, _ = d.GetPing(now.Add(5 * rtt)) + now = now.Add(2 * rtt) // advance the timer by an arbitrary amount + + const newStartMTU protocol.ByteCount = 900 + const newMaxMTU = 1500 + d.Reset(now, newStartMTU, newMaxMTU) + require.Equal(t, d.CurrentSize(), newStartMTU) + + // Now acknowledge / lose the probe packet. + // This should be ignored, since it's on the old path. + if ackLastProbe { + ping.Handler.OnAcked(ping.Frame) + } else { + ping.Handler.OnLost(ping.Frame) + } + + // the MTU should not have changed + require.Equal(t, d.CurrentSize(), newStartMTU) + // the next probe should be sent after 5 RTTs + require.False(t, d.ShouldSendProbe(now.Add(5*rtt).Add(-time.Microsecond))) + require.True(t, d.ShouldSendProbe(now.Add(5*rtt))) +}