diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index bbd1fb44..57aac995 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -439,20 +439,20 @@ func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.Encryptio } // same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime -func (h *sentPacketHandler) getPTOTimeAndSpace() (time.Time, protocol.EncryptionLevel) { - if !h.hasOutstandingPackets() { +func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { + // We only send application data probe packets once the handshake is confirmed, + // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. + if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() { + if h.peerCompletedAddressValidation { + return + } t := time.Now().Add(h.rttStats.PTO(false) << h.ptoCount) if h.initialPackets != nil { - return t, protocol.EncryptionInitial + return t, protocol.EncryptionInitial, true } - return t, protocol.EncryptionHandshake + return t, protocol.EncryptionHandshake, true } - var ( - encLevel protocol.EncryptionLevel - pto time.Time - ) - if h.initialPackets != nil { encLevel = protocol.EncryptionInitial if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() { @@ -473,7 +473,7 @@ func (h *sentPacketHandler) getPTOTimeAndSpace() (time.Time, protocol.Encryption encLevel = protocol.Encryption1RTT } } - return pto, encLevel + return pto, encLevel, true } func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { @@ -488,10 +488,7 @@ func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { } func (h *sentPacketHandler) hasOutstandingPackets() bool { - // We only send application data probe packets once the handshake completes, - // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. - return (h.handshakeConfirmed && h.appDataPackets.history.HasOutstandingPackets()) || - h.hasOutstandingCryptoPackets() + return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets() } func (h *sentPacketHandler) setLossDetectionTimer() { @@ -531,7 +528,10 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } // PTO alarm - ptoTime, encLevel := h.getPTOTimeAndSpace() + ptoTime, encLevel, ok := h.getPTOTimeAndSpace() + if !ok { + return + } h.alarm = ptoTime if h.tracer != nil && h.alarm != oldAlarm { h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) @@ -599,20 +599,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E } func (h *sentPacketHandler) OnLossDetectionTimeout() error { - // When all outstanding are acknowledged, the alarm is canceled in - // setLossDetectionTimer. This doesn't reset the timer in the session though. - // When OnAlarm is called, we therefore need to make sure that there are - // actually packets outstanding. - if h.hasOutstandingPackets() || !h.peerCompletedAddressValidation { - if err := h.onVerifiedLossDetectionTimeout(); err != nil { - return err - } - } - h.setLossDetectionTimer() - return nil -} - -func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { + defer h.setLossDetectionTimer() earliestLossTime, encLevel := h.getLossTimeAndSpace() if !earliestLossTime.IsZero() { if h.logger.Debug() { @@ -626,34 +613,12 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { } // PTO - h.ptoCount++ - if h.bytesInFlight > 0 { - _, encLevel = h.getPTOTimeAndSpace() - if h.logger.Debug() { - h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) - } - if h.tracer != nil { - h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) - h.tracer.UpdatedPTOCount(h.ptoCount) - } - h.numProbesToSend += 2 - //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. - switch encLevel { - case protocol.EncryptionInitial: - h.ptoMode = SendPTOInitial - case protocol.EncryptionHandshake: - h.ptoMode = SendPTOHandshake - case protocol.Encryption1RTT: - // skip a packet number in order to elicit an immediate ACK - _ = h.PopPacketNumber(protocol.Encryption1RTT) - h.ptoMode = SendPTOAppData - default: - return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) - } - } else { - if h.perspective == protocol.PerspectiveServer { - return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0") - } + // When all outstanding are acknowledged, the alarm is canceled in + // setLossDetectionTimer. This doesn't reset the timer in the session though. + // When OnAlarm is called, we therefore need to make sure that there are + // actually packets outstanding. + if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation { + h.ptoCount++ h.numProbesToSend++ if h.initialPackets != nil { h.ptoMode = SendPTOInitial @@ -662,6 +627,37 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { } else { return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped") } + return nil + } + + _, encLevel, ok := h.getPTOTimeAndSpace() + if !ok { + return nil + } + if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation { + return nil + } + h.ptoCount++ + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) + } + if h.tracer != nil { + h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) + h.tracer.UpdatedPTOCount(h.ptoCount) + } + h.numProbesToSend += 2 + //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. + switch encLevel { + case protocol.EncryptionInitial: + h.ptoMode = SendPTOInitial + case protocol.EncryptionHandshake: + h.ptoMode = SendPTOHandshake + case protocol.Encryption1RTT: + // skip a packet number in order to elicit an immediate ACK + _ = h.PopPacketNumber(protocol.Encryption1RTT) + h.ptoMode = SendPTOAppData + default: + return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) } return nil } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 20fa985f..26875f00 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -832,7 +832,7 @@ var _ = Describe("SentPacketHandler", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) updateRTT(time.Hour) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) Expect(handler.SendMode()).To(Equal(SendAny)) handler.SetHandshakeConfirmed() @@ -845,7 +845,7 @@ var _ = Describe("SentPacketHandler", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeConfirmed() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) - handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) + updateRTT(time.Second) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTOAppData)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} @@ -857,10 +857,18 @@ var _ = Describe("SentPacketHandler", func() { It("handles ACKs for the original packet", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)})) - handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) + updateRTT(time.Second) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) }) + + It("doesn't set the PTO timer for Path MTU probe packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.SetHandshakeConfirmed() + updateRTT(time.Second) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now(), IsPathMTUProbePacket: true})) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) }) Context("amplification limit, for the server", func() { @@ -1093,6 +1101,26 @@ var _ = Describe("SentPacketHandler", func() { expectInPacketHistory([]protocol.PacketNumber{3}, protocol.EncryptionInitial) Expect(handler.SendMode()).To(Equal(SendAny)) }) + + It("sets the early retransmit alarm for Path MTU probe packets", func() { + var mtuPacketDeclaredLost bool + now := time.Now() + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 1, + SendTime: now.Add(-3 * time.Second), + IsPathMTUProbePacket: true, + Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { mtuPacketDeclaredLost = true }}}, + })) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-3 * time.Second)})) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, now.Add(-time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(mtuPacketDeclaredLost).To(BeFalse()) + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(mtuPacketDeclaredLost).To(BeTrue()) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) }) Context("crypto packets", func() {