diff --git a/mtu_discoverer.go b/mtu_discoverer.go index 9b6fcfcf..ee636a6d 100644 --- a/mtu_discoverer.go +++ b/mtu_discoverer.go @@ -22,7 +22,7 @@ type mtuDiscoverer interface { const ( // At some point, we have to stop searching for a higher MTU. // We're happy to send a packet that's 10 bytes smaller than the actual MTU. - maxMTUDiff = 20 + maxMTUDiff protocol.ByteCount = 20 // send a probe packet every mtuProbeDelay RTTs mtuProbeDelay = 5 // Once maxLostMTUProbes MTU probe packets larger than a certain size are lost, diff --git a/mtu_discoverer_test.go b/mtu_discoverer_test.go index e71337a1..6d87b9ec 100644 --- a/mtu_discoverer_test.go +++ b/mtu_discoverer_test.go @@ -2,200 +2,206 @@ package quic import ( "fmt" + "math/rand/v2" + "testing" "time" - "golang.org/x/exp/rand" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/logging" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) -var _ = Describe("MTU Discoverer", func() { - const ( - rtt = 100 * time.Millisecond - startMTU protocol.ByteCount = 1000 - maxMTU protocol.ByteCount = 2000 - ) +func TestMTUDiscovererTiming(t *testing.T) { + const rtt = 100 * time.Millisecond + var rttStats utils.RTTStats + rttStats.UpdateRTT(rtt, 0) + d := newMTUDiscoverer(&rttStats, 1000, 2000, func(s protocol.ByteCount) {}, nil) - var ( - d *mtuFinder - rttStats *utils.RTTStats - now time.Time - discoveredMTU protocol.ByteCount - ) - r := rand.New(rand.NewSource(uint64(GinkgoRandomSeed()))) + now := time.Now() + require.False(t, d.ShouldSendProbe(now)) + d.Start(now) + require.False(t, d.ShouldSendProbe(now)) + require.False(t, d.ShouldSendProbe(now.Add(rtt*9/2))) + now = now.Add(5 * rtt) + require.True(t, d.ShouldSendProbe(now)) - BeforeEach(func() { - rttStats = &utils.RTTStats{} - rttStats.SetInitialRTT(rtt) - Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) - d = newMTUDiscoverer( - rttStats, - startMTU, - maxMTU, - func(s protocol.ByteCount) { discoveredMTU = s }, - nil, - ) - d.Start(time.Now()) - now = time.Now() - }) + // only a single outstanding probe packet is permitted + ping, _ := d.GetPing(now) + require.False(t, d.ShouldSendProbe(now)) + now = now.Add(5 * rtt) + require.False(t, d.ShouldSendProbe(now)) + ping.Handler.OnLost(ping.Frame) + require.True(t, d.ShouldSendProbe(now)) +} - It("only allows a probe 5 RTTs after the handshake completes", func() { - Expect(d.ShouldSendProbe(now)).To(BeFalse()) - Expect(d.ShouldSendProbe(now.Add(rtt * 9 / 2))).To(BeFalse()) - Expect(d.ShouldSendProbe(now.Add(rtt * 5))).To(BeTrue()) - }) +func TestMTUDiscovererAckAndLoss(t *testing.T) { + var mtu protocol.ByteCount + d := newMTUDiscoverer(&utils.RTTStats{}, 1000, 2000, func(s protocol.ByteCount) { mtu = s }, nil) + // we use an RTT of 0 here, so we don't have to advance the timer on every step + now := time.Now() + ping, size := d.GetPing(now) + require.Equal(t, protocol.ByteCount(1500), size) + // the MTU is reduced if the frame is lost + ping.Handler.OnLost(ping.Frame) + require.Zero(t, mtu) // no change to the MTU yet - It("doesn't allow a probe if another probe is still in flight", func() { - ping, _ := d.GetPing(time.Now()) - Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeFalse()) - ping.Handler.OnLost(ping.Frame) - Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeTrue()) - }) + require.True(t, d.ShouldSendProbe(now)) + ping, size = d.GetPing(now) + require.Equal(t, protocol.ByteCount(1250), size) + ping.Handler.OnAcked(ping.Frame) + require.Equal(t, protocol.ByteCount(1250), mtu) // the MTU is increased - It("tries a lower size when a probe is lost", func() { - ping, size := d.GetPing(time.Now()) - Expect(size).To(Equal(protocol.ByteCount(1500))) - ping.Handler.OnLost(ping.Frame) - _, size = d.GetPing(time.Now()) - Expect(size).To(Equal(protocol.ByteCount(1250))) - }) + // Even though the 1500 byte MTU probe packet was lost, we try again with a higher MTU. + // This protects against regular (non-MTU-related) packet loss. + require.True(t, d.ShouldSendProbe(now)) + ping, size = d.GetPing(now) + require.Greater(t, size, protocol.ByteCount(1500)) + ping.Handler.OnAcked(ping.Frame) + require.Equal(t, size, mtu) - It("tries a higher size and calls the callback when a probe is acknowledged", func() { - ping, size := d.GetPing(time.Now()) - Expect(size).To(Equal(protocol.ByteCount(1500))) + // We continue probing until the MTU is close to the maximum. + var steps int + oldSize := size + for d.ShouldSendProbe(now) { + ping, size = d.GetPing(now) + require.Greater(t, size, oldSize) + oldSize = size ping.Handler.OnAcked(ping.Frame) - Expect(discoveredMTU).To(Equal(protocol.ByteCount(1500))) - _, size = d.GetPing(time.Now()) - Expect(size).To(Equal(protocol.ByteCount(1750))) - }) + steps++ + require.Less(t, steps, 10) + } + require.Less(t, 2000-maxMTUDiff, size) +} - It("stops discovery after getting close enough to the MTU", func() { - var sizes []protocol.ByteCount - t := now.Add(5 * rtt) - for d.ShouldSendProbe(t) { - ping, size := d.GetPing(time.Now()) - fmt.Println("sending", size) +func TestMTUDiscovererMTUDiscovery(t *testing.T) { + for i := range 5 { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + testMTUDiscovererMTUDiscovery(t) + }) + } +} + +func testMTUDiscovererMTUDiscovery(t *testing.T) { + const rtt = 100 * time.Millisecond + const startMTU protocol.ByteCount = 1000 + + var rttStats utils.RTTStats + rttStats.UpdateRTT(rtt, 0) + + maxMTU := protocol.ByteCount(rand.IntN(int(3000-startMTU))) + startMTU + 1 + currentMTU := startMTU + var tracedMTU protocol.ByteCount + var tracerDone bool + d := newMTUDiscoverer( + &rttStats, + startMTU, maxMTU, + func(s protocol.ByteCount) { currentMTU = s }, + &logging.ConnectionTracer{ + UpdatedMTU: func(mtu logging.ByteCount, done bool) { + tracedMTU = mtu + tracerDone = done + }, + }, + ) + now := time.Now() + d.Start(now) + realMTU := protocol.ByteCount(rand.IntN(int(maxMTU-startMTU))) + startMTU + t.Logf("MTU: %d, max: %d", realMTU, maxMTU) + now = now.Add(mtuProbeDelay * rtt) + var probes []protocol.ByteCount + for d.ShouldSendProbe(now) { + require.Less(t, len(probes), 25, fmt.Sprintf("too many iterations: %v", probes)) + ping, size := d.GetPing(now) + probes = append(probes, size) + if size <= realMTU { ping.Handler.OnAcked(ping.Frame) - sizes = append(sizes, size) - t = t.Add(5 * rtt) + } else { + ping.Handler.OnLost(ping.Frame) } - Expect(sizes).To(Equal([]protocol.ByteCount{1500, 1750, 1875, 1937, 1968, 1984})) - Expect(d.ShouldSendProbe(t.Add(10 * rtt))).To(BeFalse()) - }) + now = now.Add(mtuProbeDelay * rtt) + } + diff := realMTU - currentMTU + require.GreaterOrEqual(t, diff, protocol.ByteCount(0)) + if maxMTU > currentMTU+maxMTU { + require.Equal(t, currentMTU, tracedMTU) + require.True(t, tracerDone) + } + t.Logf("MTU discovered: %d (diff: %d)", currentMTU, diff) + t.Logf("probes sent (%d): %v", len(probes), probes) + require.LessOrEqual(t, diff, maxMTUDiff) +} - It("doesn't do discovery before being started", func() { - d := newMTUDiscoverer(rttStats, startMTU, protocol.MaxByteCount, func(s protocol.ByteCount) {}, nil) - for i := 0; i < 5; i++ { - Expect(d.ShouldSendProbe(time.Now())).To(BeFalse()) - } - }) - - It("finds the MTU", MustPassRepeatedly(300), func() { - maxMTU := protocol.ByteCount(r.Intn(int(3000-startMTU))) + startMTU + 1 - currentMTU := startMTU - var tracedMTU protocol.ByteCount - var tracerDone bool - d := newMTUDiscoverer( - rttStats, - startMTU, - maxMTU, - func(s protocol.ByteCount) { currentMTU = s }, - &logging.ConnectionTracer{ - UpdatedMTU: func(mtu logging.ByteCount, done bool) { - tracedMTU = mtu - tracerDone = done - }, - }, - ) - d.Start(time.Now()) - now := time.Now() - realMTU := protocol.ByteCount(r.Intn(int(maxMTU-startMTU))) + startMTU - fmt.Fprintf(GinkgoWriter, "MTU: %d, max: %d\n", realMTU, maxMTU) - t := now.Add(mtuProbeDelay * rtt) - var probes []protocol.ByteCount - for d.ShouldSendProbe(t) { - if len(probes) > 24 { - Fail(fmt.Sprintf("too many iterations: %v", probes)) - } - ping, size := d.GetPing(time.Now()) - probes = append(probes, size) - if size <= realMTU { - ping.Handler.OnAcked(ping.Frame) - } else { - ping.Handler.OnLost(ping.Frame) - } - t = t.Add(mtuProbeDelay * rtt) - } - diff := realMTU - currentMTU - Expect(diff).To(BeNumerically(">=", 0)) - if maxMTU > currentMTU+maxMTU { - Expect(tracedMTU).To(Equal(currentMTU)) - Expect(tracerDone).To(BeTrue()) - } - fmt.Fprintf(GinkgoWriter, "MTU discovered: %d (diff: %d)\n", currentMTU, diff) - fmt.Fprintf(GinkgoWriter, "probes sent (%d): %v\n", len(probes), probes) - Expect(diff).To(BeNumerically("<=", maxMTUDiff)) - }) +func TestMTUDiscovererWithRandomLoss(t *testing.T) { + for i := range 5 { + t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { + testMTUDiscovererWithRandomLoss(t) + }) + } +} +func testMTUDiscovererWithRandomLoss(t *testing.T) { + const rtt = 100 * time.Millisecond + const startMTU protocol.ByteCount = 1000 const maxRandomLoss = maxLostMTUProbes - 1 - It(fmt.Sprintf("finds the MTU, with up to %d packets lost", maxRandomLoss), MustPassRepeatedly(500), func() { - maxMTU := protocol.ByteCount(r.Intn(int(3000-startMTU))) + startMTU + 1 - currentMTU := startMTU - var tracedMTU protocol.ByteCount - var tracerDone bool - d := newMTUDiscoverer( - rttStats, - startMTU, - maxMTU, - func(s protocol.ByteCount) { currentMTU = s }, - &logging.ConnectionTracer{ - UpdatedMTU: func(mtu logging.ByteCount, done bool) { - tracedMTU = mtu - tracerDone = done - }, + + rttStats := &utils.RTTStats{} + rttStats.SetInitialRTT(rtt) + require.Equal(t, rtt, rttStats.SmoothedRTT()) + + maxMTU := protocol.ByteCount(rand.IntN(int(3000-startMTU))) + startMTU + 1 + currentMTU := startMTU + var tracedMTU protocol.ByteCount + var tracerDone bool + + d := newMTUDiscoverer( + rttStats, + startMTU, + maxMTU, + func(s protocol.ByteCount) { currentMTU = s }, + &logging.ConnectionTracer{ + UpdatedMTU: func(mtu logging.ByteCount, done bool) { + tracedMTU = mtu + tracerDone = done }, - ) - d.Start(time.Now()) - now := time.Now() - realMTU := protocol.ByteCount(r.Intn(int(maxMTU-startMTU))) + startMTU - fmt.Fprintf(GinkgoWriter, "MTU: %d, max: %d\n", realMTU, maxMTU) - t := now.Add(mtuProbeDelay * rtt) - var probes, randomLosses []protocol.ByteCount - for d.ShouldSendProbe(t) { - if len(probes) > 32 { - Fail(fmt.Sprintf("too many iterations: %v", probes)) + }, + ) + d.Start(time.Now()) + now := time.Now() + realMTU := protocol.ByteCount(rand.IntN(int(maxMTU-startMTU))) + startMTU + t.Logf("MTU: %d, max: %d", realMTU, maxMTU) + now = now.Add(mtuProbeDelay * rtt) + var probes, randomLosses []protocol.ByteCount + + for d.ShouldSendProbe(now) { + require.Less(t, len(probes), 32, fmt.Sprintf("too many iterations: %v", probes)) + ping, size := d.GetPing(now) + probes = append(probes, size) + packetFits := size <= realMTU + var acked bool + if packetFits { + randomLoss := rand.IntN(maxLostMTUProbes) == 0 && len(randomLosses) < maxRandomLoss + if randomLoss { + randomLosses = append(randomLosses, size) + } else { + ping.Handler.OnAcked(ping.Frame) + acked = true } - ping, size := d.GetPing(time.Now()) - probes = append(probes, size) - packetFits := size <= realMTU - var acked bool - if packetFits { - randomLoss := r.Intn(maxLostMTUProbes) == 0 && len(randomLosses) < maxRandomLoss - if randomLoss { - randomLosses = append(randomLosses, size) - } else { - ping.Handler.OnAcked(ping.Frame) - acked = true - } - } - if !acked { - ping.Handler.OnLost(ping.Frame) - } - t = t.Add(mtuProbeDelay * rtt) } - diff := realMTU - currentMTU - Expect(diff).To(BeNumerically(">=", 0)) - if maxMTU > currentMTU+maxMTU { - Expect(tracedMTU).To(Equal(currentMTU)) - Expect(tracerDone).To(BeTrue()) + if !acked { + ping.Handler.OnLost(ping.Frame) } - fmt.Fprintf(GinkgoWriter, "MTU discovered with random losses %v: %d (diff: %d)\n", randomLosses, currentMTU, diff) - fmt.Fprintf(GinkgoWriter, "probes sent (%d): %v\n", len(probes), probes) - Expect(diff).To(BeNumerically("<=", maxMTUDiff)) - }) -}) + now = now.Add(mtuProbeDelay * rtt) + } + + diff := realMTU - currentMTU + require.GreaterOrEqual(t, diff, protocol.ByteCount(0)) + if maxMTU > currentMTU+maxMTU { + require.Equal(t, currentMTU, tracedMTU) + require.True(t, tracerDone) + } + t.Logf("MTU discovered with random losses %v: %d (diff: %d)", randomLosses, currentMTU, diff) + t.Logf("probes sent (%d): %v", len(probes), probes) + require.LessOrEqual(t, diff, maxMTUDiff) +}