diff --git a/connection.go b/connection.go index 0325a32b..ffb151ed 100644 --- a/connection.go +++ b/connection.go @@ -797,7 +797,7 @@ func (s *connection) handleHandshakeConfirmed(now time.Time) error { s.cryptoStreamHandler.SetHandshakeConfirmed() if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF { - s.mtuDiscoverer.Start() + s.mtuDiscoverer.Start(now) } return nil } @@ -1894,7 +1894,7 @@ func (s *connection) sendPackets(now time.Time) error { // Performance-wise, this doesn't matter, since we only send a very small (<10) number of // MTU probe packets per connection. if s.handshakeConfirmed && s.mtuDiscoverer != nil && s.mtuDiscoverer.ShouldSendProbe(now) { - ping, size := s.mtuDiscoverer.GetPing() + ping, size := s.mtuDiscoverer.GetPing(now) p, buf, err := s.packer.PackMTUProbePacket(ping, size, s.version) if err != nil { return err diff --git a/mock_mtu_discoverer_test.go b/mock_mtu_discoverer_test.go index 4951b6d5..ce58ffc2 100644 --- a/mock_mtu_discoverer_test.go +++ b/mock_mtu_discoverer_test.go @@ -81,18 +81,18 @@ func (c *MockMTUDiscovererCurrentSizeCall) DoAndReturn(f func() protocol.ByteCou } // GetPing mocks base method. -func (m *MockMTUDiscoverer) GetPing() (ackhandler.Frame, protocol.ByteCount) { +func (m *MockMTUDiscoverer) GetPing(now time.Time) (ackhandler.Frame, protocol.ByteCount) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPing") + ret := m.ctrl.Call(m, "GetPing", now) ret0, _ := ret[0].(ackhandler.Frame) ret1, _ := ret[1].(protocol.ByteCount) return ret0, ret1 } // GetPing indicates an expected call of GetPing. -func (mr *MockMTUDiscovererMockRecorder) GetPing() *MockMTUDiscovererGetPingCall { +func (mr *MockMTUDiscovererMockRecorder) GetPing(now any) *MockMTUDiscovererGetPingCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPing", reflect.TypeOf((*MockMTUDiscoverer)(nil).GetPing)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPing", reflect.TypeOf((*MockMTUDiscoverer)(nil).GetPing), now) return &MockMTUDiscovererGetPingCall{Call: call} } @@ -108,13 +108,13 @@ func (c *MockMTUDiscovererGetPingCall) Return(ping ackhandler.Frame, datagramSiz } // Do rewrite *gomock.Call.Do -func (c *MockMTUDiscovererGetPingCall) Do(f func() (ackhandler.Frame, protocol.ByteCount)) *MockMTUDiscovererGetPingCall { +func (c *MockMTUDiscovererGetPingCall) Do(f func(time.Time) (ackhandler.Frame, protocol.ByteCount)) *MockMTUDiscovererGetPingCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockMTUDiscovererGetPingCall) DoAndReturn(f func() (ackhandler.Frame, protocol.ByteCount)) *MockMTUDiscovererGetPingCall { +func (c *MockMTUDiscovererGetPingCall) DoAndReturn(f func(time.Time) (ackhandler.Frame, protocol.ByteCount)) *MockMTUDiscovererGetPingCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -158,15 +158,15 @@ func (c *MockMTUDiscovererShouldSendProbeCall) DoAndReturn(f func(time.Time) boo } // Start mocks base method. -func (m *MockMTUDiscoverer) Start() { +func (m *MockMTUDiscoverer) Start(now time.Time) { m.ctrl.T.Helper() - m.ctrl.Call(m, "Start") + m.ctrl.Call(m, "Start", now) } // Start indicates an expected call of Start. -func (mr *MockMTUDiscovererMockRecorder) Start() *MockMTUDiscovererStartCall { +func (mr *MockMTUDiscovererMockRecorder) Start(now any) *MockMTUDiscovererStartCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockMTUDiscoverer)(nil).Start)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockMTUDiscoverer)(nil).Start), now) return &MockMTUDiscovererStartCall{Call: call} } @@ -182,13 +182,13 @@ func (c *MockMTUDiscovererStartCall) Return() *MockMTUDiscovererStartCall { } // Do rewrite *gomock.Call.Do -func (c *MockMTUDiscovererStartCall) Do(f func()) *MockMTUDiscovererStartCall { +func (c *MockMTUDiscovererStartCall) Do(f func(time.Time)) *MockMTUDiscovererStartCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockMTUDiscovererStartCall) DoAndReturn(f func()) *MockMTUDiscovererStartCall { +func (c *MockMTUDiscovererStartCall) DoAndReturn(f func(time.Time)) *MockMTUDiscovererStartCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/mtu_discoverer.go b/mtu_discoverer.go index 6f906698..9b6fcfcf 100644 --- a/mtu_discoverer.go +++ b/mtu_discoverer.go @@ -13,10 +13,10 @@ import ( type mtuDiscoverer interface { // Start starts the MTU discovery process. // It's unnecessary to call ShouldSendProbe before that. - Start() + Start(now time.Time) ShouldSendProbe(now time.Time) bool CurrentSize() protocol.ByteCount - GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount) + GetPing(now time.Time) (ping ackhandler.Frame, datagramSize protocol.ByteCount) } const ( @@ -140,8 +140,8 @@ func (f *mtuFinder) max() protocol.ByteCount { return f.lost[len(f.lost)-1] } -func (f *mtuFinder) Start() { - f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately +func (f *mtuFinder) Start(now time.Time) { + f.lastProbeTime = now // makes sure the first probe packet is not sent immediately } func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { @@ -154,14 +154,14 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) } -func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { +func (f *mtuFinder) GetPing(now time.Time) (ackhandler.Frame, protocol.ByteCount) { var size protocol.ByteCount if f.lastProbeWasLost { size = (f.min + f.lost[0]) / 2 } else { size = (f.min + f.max()) / 2 } - f.lastProbeTime = time.Now() + f.lastProbeTime = now f.inFlight = size return ackhandler.Frame{ Frame: &wire.PingFrame{}, diff --git a/mtu_discoverer_test.go b/mtu_discoverer_test.go index ff764670..e71337a1 100644 --- a/mtu_discoverer_test.go +++ b/mtu_discoverer_test.go @@ -40,7 +40,7 @@ var _ = Describe("MTU Discoverer", func() { func(s protocol.ByteCount) { discoveredMTU = s }, nil, ) - d.Start() + d.Start(time.Now()) now = time.Now() }) @@ -51,26 +51,26 @@ var _ = Describe("MTU Discoverer", func() { }) It("doesn't allow a probe if another probe is still in flight", func() { - ping, _ := d.GetPing() + ping, _ := d.GetPing(time.Now()) Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeFalse()) ping.Handler.OnLost(ping.Frame) Expect(d.ShouldSendProbe(now.Add(10 * rtt))).To(BeTrue()) }) It("tries a lower size when a probe is lost", func() { - ping, size := d.GetPing() + ping, size := d.GetPing(time.Now()) Expect(size).To(Equal(protocol.ByteCount(1500))) ping.Handler.OnLost(ping.Frame) - _, size = d.GetPing() + _, size = d.GetPing(time.Now()) Expect(size).To(Equal(protocol.ByteCount(1250))) }) It("tries a higher size and calls the callback when a probe is acknowledged", func() { - ping, size := d.GetPing() + ping, size := d.GetPing(time.Now()) Expect(size).To(Equal(protocol.ByteCount(1500))) ping.Handler.OnAcked(ping.Frame) Expect(discoveredMTU).To(Equal(protocol.ByteCount(1500))) - _, size = d.GetPing() + _, size = d.GetPing(time.Now()) Expect(size).To(Equal(protocol.ByteCount(1750))) }) @@ -78,7 +78,7 @@ var _ = Describe("MTU Discoverer", func() { var sizes []protocol.ByteCount t := now.Add(5 * rtt) for d.ShouldSendProbe(t) { - ping, size := d.GetPing() + ping, size := d.GetPing(time.Now()) fmt.Println("sending", size) ping.Handler.OnAcked(ping.Frame) sizes = append(sizes, size) @@ -112,7 +112,7 @@ var _ = Describe("MTU Discoverer", func() { }, }, ) - d.Start() + d.Start(time.Now()) now := time.Now() realMTU := protocol.ByteCount(r.Intn(int(maxMTU-startMTU))) + startMTU fmt.Fprintf(GinkgoWriter, "MTU: %d, max: %d\n", realMTU, maxMTU) @@ -122,7 +122,7 @@ var _ = Describe("MTU Discoverer", func() { if len(probes) > 24 { Fail(fmt.Sprintf("too many iterations: %v", probes)) } - ping, size := d.GetPing() + ping, size := d.GetPing(time.Now()) probes = append(probes, size) if size <= realMTU { ping.Handler.OnAcked(ping.Frame) @@ -160,7 +160,7 @@ var _ = Describe("MTU Discoverer", func() { }, }, ) - d.Start() + d.Start(time.Now()) now := time.Now() realMTU := protocol.ByteCount(r.Intn(int(maxMTU-startMTU))) + startMTU fmt.Fprintf(GinkgoWriter, "MTU: %d, max: %d\n", realMTU, maxMTU) @@ -170,7 +170,7 @@ var _ = Describe("MTU Discoverer", func() { if len(probes) > 32 { Fail(fmt.Sprintf("too many iterations: %v", probes)) } - ping, size := d.GetPing() + ping, size := d.GetPing(time.Now()) probes = append(probes, size) packetFits := size <= realMTU var acked bool