diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index b735d76f..222388f4 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -267,28 +267,24 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } - ackedPackets, err := h.determineNewlyAckedPackets(ack, encLevel) + priorInFlight := h.bytesInFlight + ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel) if err != nil || len(ackedPackets) == 0 { return err } - - priorInFlight := h.bytesInFlight + lostPackets, err := h.detectAndRemoveLostPackets(rcvTime, encLevel) + if err != nil { + return err + } + for _, p := range lostPackets { + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + } for _, p := range ackedPackets { - if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { - h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1) - } - if err := h.onPacketAcked(p); err != nil { - return err - } if p.includedInBytesInFlight { h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) } } - if err := h.detectLostPackets(rcvTime, encLevel, priorInFlight); err != nil { - return err - } - if h.qlogger != nil && h.ptoCount != 0 { h.qlogger.UpdatedPTOCount(0) } @@ -303,15 +299,12 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu return h.lowestNotConfirmedAcked } -func (h *sentPacketHandler) determineNewlyAckedPackets( - ackFrame *wire.AckFrame, - encLevel protocol.EncryptionLevel, -) ([]*Packet, error) { +func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) { pnSpace := h.getPacketNumberSpace(encLevel) var ackedPackets []*Packet ackRangeIndex := 0 - lowestAcked := ackFrame.LowestAcked() - largestAcked := ackFrame.LargestAcked() + lowestAcked := ack.LowestAcked() + largestAcked := ack.LargestAcked() err := pnSpace.history.Iterate(func(p *Packet) (bool, error) { // Ignore packets below the lowest acked if p.PacketNumber < lowestAcked { @@ -322,12 +315,12 @@ func (h *sentPacketHandler) determineNewlyAckedPackets( return false, nil } - if ackFrame.HasMissingRanges() { - ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] + if ack.HasMissingRanges() { + ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] - for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 { + for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 { ackRangeIndex++ - ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] + ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] } if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range @@ -348,6 +341,28 @@ func (h *sentPacketHandler) determineNewlyAckedPackets( } h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns) } + + for _, p := range ackedPackets { + if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil { + continue + } + if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { + h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1) + } + + for _, f := range p.Frames { + if f.OnAcked != nil { + f.OnAcked(f.Frame) + } + } + if p.includedInBytesInFlight { + h.bytesInFlight -= p.Length + } + if err := pnSpace.history.Remove(p.PacketNumber); err != nil { + return nil, err + } + } + return ackedPackets, err } @@ -429,11 +444,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = sentTime.Add(h.rttStats.PTO(encLevel == protocol.Encryption1RTT) << h.ptoCount) } -func (h *sentPacketHandler) detectLostPackets( - now time.Time, - encLevel protocol.EncryptionLevel, - priorInFlight protocol.ByteCount, -) error { +func (h *sentPacketHandler) detectAndRemoveLostPackets(now time.Time, encLevel protocol.EncryptionLevel) ([]*Packet, error) { pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} @@ -486,7 +497,6 @@ func (h *sentPacketHandler) detectLostPackets( // the bytes in flight need to be reduced no matter if this packet will be retransmitted if p.includedInBytesInFlight { h.bytesInFlight -= p.Length - h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) } pnSpace.history.Remove(p.PacketNumber) if h.traceCallback != nil { @@ -505,7 +515,7 @@ func (h *sentPacketHandler) detectLostPackets( }) } } - return nil + return lostPackets, nil } func (h *sentPacketHandler) OnLossDetectionTimeout() error { @@ -529,7 +539,14 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) } // Early retransmit or time loss detection - return h.detectLostPackets(time.Now(), encLevel, h.bytesInFlight) + priorInFlight := h.bytesInFlight + lostPackets, err := h.detectAndRemoveLostPackets(time.Now(), encLevel) + if err != nil { + return err + } + for _, p := range lostPackets { + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + } } // PTO @@ -559,23 +576,6 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { return h.alarm } -func (h *sentPacketHandler) onPacketAcked(p *Packet) error { - pnSpace := h.getPacketNumberSpace(p.EncryptionLevel) - if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil { - return nil - } - - for _, f := range p.Frames { - if f.OnAcked != nil { - f.OnAcked(f.Frame) - } - } - if p.includedInBytesInFlight { - h.bytesInFlight -= p.Length - } - return pnSpace.history.Remove(p.PacketNumber) -} - func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { pnSpace := h.getPacketNumberSpace(encLevel) diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 9404e212..5f4b063d 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -460,8 +460,8 @@ var _ = Describe("SentPacketHandler", func() { // lose packet 1 gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), ) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) @@ -480,16 +480,16 @@ var _ = Describe("SentPacketHandler", func() { // receive the first ACK gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()), cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()), ) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now().Add(-30*time.Minute))).To(Succeed()) // receive the second ACK gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()), ) ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}} Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed())