diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 9f302228..2ffa77a0 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -342,7 +342,7 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() { // Early retransmit timer or time loss detection. h.alarm = h.lossTime } else { // PTO alarm - h.alarm = h.lastSentAckElicitingPacketTime.Add(h.computePTOTimeout() << h.ptoCount) + h.alarm = h.lastSentAckElicitingPacketTime.Add(h.rttStats.PTO() << h.ptoCount) } } @@ -657,10 +657,6 @@ func (h *sentPacketHandler) computeCryptoTimeout() time.Duration { return duration << h.cryptoCount } -func (h *sentPacketHandler) computePTOTimeout() time.Duration { - return h.rttStats.SmoothedOrInitialRTT() + utils.MaxDuration(4*h.rttStats.MeanDeviation(), protocol.TimerGranularity) + h.rttStats.MaxAckDelay() -} - func (h *sentPacketHandler) ResetForRetry() error { h.cryptoCount = 0 h.bytesInFlight = 0 diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 3942a3fe..db6ffc3d 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -611,20 +611,6 @@ var _ = Describe("SentPacketHandler", func() { }) Context("probe packets", func() { - It("uses the RTT from RTT stats", func() { - rtt := 2 * time.Second - updateRTT(rtt) - Expect(handler.rttStats.SmoothedOrInitialRTT()).To(Equal(2 * time.Second)) - Expect(handler.rttStats.MeanDeviation()).To(Equal(time.Second)) - Expect(handler.computePTOTimeout()).To(Equal(time.Duration(2+4) * time.Second)) - }) - - It("uses the granularity for short RTTs", func() { - rtt := time.Microsecond - updateRTT(rtt) - Expect(handler.computePTOTimeout()).To(Equal(rtt + protocol.TimerGranularity)) - }) - It("implements exponential backoff", func() { sendTime := time.Now().Add(-time.Hour) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) diff --git a/internal/congestion/rtt_stats.go b/internal/congestion/rtt_stats.go index 7c9b79ed..a44d7991 100644 --- a/internal/congestion/rtt_stats.go +++ b/internal/congestion/rtt_stats.go @@ -3,6 +3,7 @@ package congestion import ( "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) @@ -56,6 +57,10 @@ func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay } +func (r *RTTStats) PTO() time.Duration { + return r.SmoothedOrInitialRTT() + utils.MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) + r.MaxAckDelay() +} + // UpdateRTT updates the RTT based on a new sample. func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { if sendDelta == utils.InfDuration || sendDelta <= 0 { diff --git a/internal/congestion/rtt_stats_test.go b/internal/congestion/rtt_stats_test.go index 2d3114aa..d7ddc8bd 100644 --- a/internal/congestion/rtt_stats_test.go +++ b/internal/congestion/rtt_stats_test.go @@ -3,6 +3,7 @@ package congestion import ( "time" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -64,6 +65,22 @@ var _ = Describe("RTT stats", func() { Expect(rttStats.MaxAckDelay()).To(Equal(42 * time.Minute)) }) + It("computes the PTO", func() { + maxAckDelay := 42 * time.Minute + rttStats.SetMaxAckDelay(maxAckDelay) + rtt := time.Second + rttStats.UpdateRTT(rtt, 0, time.Time{}) + Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) + Expect(rttStats.MeanDeviation()).To(Equal(rtt / 2)) + Expect(rttStats.PTO()).To(Equal(rtt + 4*(rtt/2) + maxAckDelay)) + }) + + It("uses the granularity for computing the PTO for short RTTs", func() { + rtt := time.Microsecond + rttStats.UpdateRTT(rtt, 0, time.Time{}) + Expect(rttStats.PTO()).To(Equal(rtt + protocol.TimerGranularity)) + }) + It("ExpireSmoothedMetrics", func() { initialRtt := (10 * time.Millisecond) rttStats.UpdateRTT(initialRtt, 0, time.Time{})