diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 73bbff6a..ccd7c6c3 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -16,8 +16,6 @@ const ( // Maximum reordering in time space before time based loss detection considers a packet lost. // In fraction of an RTT. timeReorderingFraction = 1.0 / 8 - // The default RTT used before an RTT sample is taken. - defaultInitialRTT = 100 * time.Millisecond // defaultRTOTimeout is the RTO time on new connections defaultRTOTimeout = 500 * time.Millisecond // Minimum time in the future a tail loss probe alarm may be set for. @@ -567,11 +565,7 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error { } func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { - duration := 2 * h.rttStats.SmoothedRTT() - if duration == 0 { - duration = 2 * defaultInitialRTT - } - duration = utils.MaxDuration(duration, minTPLTimeout) + duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout) // exponential backoff // There's an implicit limit to this set by the handshake timeout. return duration << h.handshakeCount @@ -579,11 +573,7 @@ func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration { func (h *sentPacketHandler) computeTLPTimeout() time.Duration { // TODO(#1236): include the max_ack_delay - srtt := h.rttStats.SmoothedRTT() - if srtt == 0 { - srtt = defaultInitialRTT - } - return utils.MaxDuration(srtt*3/2, minTPLTimeout) + return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout) } func (h *sentPacketHandler) computeRTOTimeout() time.Duration { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 37df5965..04c32a49 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -781,10 +781,6 @@ var _ = Describe("SentPacketHandler", func() { }) Context("TLPs", func() { - It("uses the default RTT", func() { - Expect(handler.computeTLPTimeout()).To(Equal(defaultInitialRTT * 3 / 2)) - }) - It("uses the RTT from RTT stats", func() { rtt := 2 * time.Second updateRTT(rtt) diff --git a/internal/congestion/rtt_stats.go b/internal/congestion/rtt_stats.go index 539410dd..f0ebbb23 100644 --- a/internal/congestion/rtt_stats.go +++ b/internal/congestion/rtt_stats.go @@ -11,6 +11,8 @@ const ( oneMinusAlpha float32 = (1 - rttAlpha) rttBeta float32 = 0.25 oneMinusBeta float32 = (1 - rttBeta) + // The default RTT used before an RTT sample is taken. + defaultInitialRTT = 100 * time.Millisecond ) // RTTStats provides round-trip statistics @@ -38,6 +40,15 @@ func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT } // May return Zero if no valid updates have occurred. func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } +// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection. +// If no valid updates have occurred, it returns the initial RTT. +func (r *RTTStats) SmoothedOrInitialRTT() time.Duration { + if r.smoothedRTT != 0 { + return r.smoothedRTT + } + return defaultInitialRTT +} + // MeanDeviation gets the mean deviation func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } diff --git a/internal/congestion/rtt_stats_test.go b/internal/congestion/rtt_stats_test.go index 55c8ed18..e8f8c69c 100644 --- a/internal/congestion/rtt_stats_test.go +++ b/internal/congestion/rtt_stats_test.go @@ -37,6 +37,12 @@ var _ = Describe("RTT stats", func() { Expect(rttStats.SmoothedRTT()).To(Equal((287500 * time.Microsecond))) }) + It("SmoothedOrInitialRTT", func() { + Expect(rttStats.SmoothedOrInitialRTT()).To(Equal(defaultInitialRTT)) + rttStats.UpdateRTT((300 * time.Millisecond), (100 * time.Millisecond), time.Time{}) + Expect(rttStats.SmoothedOrInitialRTT()).To(Equal((300 * time.Millisecond))) + }) + It("MinRTT", func() { rttStats.UpdateRTT((200 * time.Millisecond), 0, time.Time{}) Expect(rttStats.MinRTT()).To(Equal((200 * time.Millisecond)))