diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index f8b8cd92..6c138e45 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -177,10 +177,9 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t if pnSpace == nil { return } - pnSpace.history.Iterate(func(p *packet) bool { + for p := range pnSpace.history.Packets() { h.removeFromBytesInFlight(p) - return true - }) + } } // drop the packet history //nolint:exhaustive // Not every packet number space can be dropped. @@ -197,14 +196,13 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t // and not when the client drops 0-RTT keys when the handshake completes. // When 0-RTT is rejected, all application data sent so far becomes invalid. // Delete the packets from the history and remove them from bytes_in_flight. - h.appDataPackets.history.Iterate(func(p *packet) bool { + for p := range h.appDataPackets.history.Packets() { if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket { - return false + break } h.removeFromBytesInFlight(p) h.appDataPackets.history.Remove(p.PacketNumber) - return true - }) + } default: panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) } @@ -430,15 +428,13 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL ackRangeIndex := 0 lowestAcked := ack.LowestAcked() largestAcked := ack.LargestAcked() - var processErr error - pnSpace.history.Iterate(func(p *packet) bool { + for p := range pnSpace.history.Packets() { // ignore packets below the lowest acked if p.PacketNumber < lowestAcked { - return true + continue } - // break after largest acked is reached if p.PacketNumber > largestAcked { - return false + break } if ack.HasMissingRanges() { @@ -450,19 +446,17 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL } if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range - return true + continue } if p.PacketNumber > ackRange.Largest { - processErr = fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest) - return false + return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest) } } if p.skippedPacket { - processErr = &qerr.TransportError{ + return nil, &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel), } - return false } if p.isPathProbePacket { probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber) @@ -470,11 +464,10 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL panic(fmt.Sprintf("path probe doesn't exist: %d", p.PacketNumber)) } h.ackedPackets = append(h.ackedPackets, probePacket) - } else { - h.ackedPackets = append(h.ackedPackets, p) + continue } - return true - }) + h.ackedPackets = append(h.ackedPackets, p) + } if h.logger.Debug() && len(h.ackedPackets) > 0 { pns := make([]protocol.PacketNumber, len(h.ackedPackets)) for i, p := range h.ackedPackets { @@ -482,9 +475,6 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL } h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns) } - if processErr != nil { - return nil, processErr - } for _, p := range h.ackedPackets { if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { @@ -659,12 +649,11 @@ func (h *sentPacketHandler) detectLostPathProbes(now time.Time) { lossTime := now.Add(-pathProbePacketLossTimeout) // RemovePathProbe cannot be called while iterating. var lostPathProbes []*packet - h.appDataPackets.history.IteratePathProbes(func(p *packet) bool { + for p := range h.appDataPackets.history.PathProbes() { if !p.SendTime.After(lossTime) { lostPathProbes = append(lostPathProbes, p) } - return true - }) + } for _, p := range lostPathProbes { for _, f := range p.Frames { f.Handler.OnLost(f.Frame) @@ -687,9 +676,9 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E lostSendTime := now.Add(-lossDelay) priorInFlight := h.bytesInFlight - pnSpace.history.Iterate(func(p *packet) bool { + for p := range pnSpace.history.Packets() { if p.PacketNumber > pnSpace.largestAcked { - return false + break } isRegularPacket := !p.skippedPacket && !p.isPathProbePacket @@ -736,8 +725,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E } } } - return true - }) + } } func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error { @@ -945,24 +933,21 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) { func (h *sentPacketHandler) ResetForRetry(now time.Time) { h.bytesInFlight = 0 var firstPacketSendTime time.Time - h.initialPackets.history.Iterate(func(p *packet) bool { + for p := range h.initialPackets.history.Packets() { if firstPacketSendTime.IsZero() { firstPacketSendTime = p.SendTime } - if p.declaredLost || p.skippedPacket { - return true - } - h.queueFramesForRetransmission(p) - return true - }) - // All application data packets sent at this point are 0-RTT packets. - // In the case of a Retry, we can assume that the server dropped all of them. - h.appDataPackets.history.Iterate(func(p *packet) bool { if !p.declaredLost && !p.skippedPacket { h.queueFramesForRetransmission(p) } - return true - }) + } + // All application data packets sent at this point are 0-RTT packets. + // In the case of a Retry, we can assume that the server dropped all of them. + for p := range h.appDataPackets.history.Packets() { + if !p.declaredLost && !p.skippedPacket { + h.queueFramesForRetransmission(p) + } + } // Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial. // Otherwise, we don't know which Initial the Retry was sent in response to. @@ -993,18 +978,16 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) { func (h *sentPacketHandler) MigratedPath(now time.Time, initialMaxDatagramSize protocol.ByteCount) { h.rttStats.ResetForPathMigration() - h.appDataPackets.history.Iterate(func(p *packet) bool { + for p := range h.appDataPackets.history.Packets() { h.appDataPackets.history.DeclareLost(p.PacketNumber) if !p.skippedPacket && !p.isPathProbePacket { h.removeFromBytesInFlight(p) h.queueFramesForRetransmission(p) } - return true - }) - h.appDataPackets.history.IteratePathProbes(func(p *packet) bool { + } + for p := range h.appDataPackets.history.PathProbes() { h.appDataPackets.history.RemovePathProbe(p.PacketNumber) - return true - }) + } h.congestion = congestion.NewCubicSender( congestion.DefaultClock{}, h.rttStats, diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go index a608f504..0aabc6d9 100644 --- a/internal/ackhandler/sent_packet_history.go +++ b/internal/ackhandler/sent_packet_history.go @@ -2,6 +2,7 @@ package ackhandler import ( "fmt" + "iter" "github.com/quic-go/quic-go/internal/protocol" ) @@ -68,23 +69,25 @@ func (h *sentPacketHistory) SentPathProbePacket(p *packet) { h.pathProbePackets = append(h.pathProbePackets, p) } -// Iterate iterates through all packets. -func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool)) { - for _, p := range h.packets { - if p == nil { - continue - } - if cont := cb(p); !cont { - return +func (h *sentPacketHistory) Packets() iter.Seq[*packet] { + return func(yield func(*packet) bool) { + for _, p := range h.packets { + if p == nil { + continue + } + if !yield(p) { + return + } } } } -// IteratePathProbes iterates through all packets. -func (h *sentPacketHistory) IteratePathProbes(cb func(*packet) (cont bool)) { - for _, p := range h.pathProbePackets { - if cont := cb(p); !cont { - return +func (h *sentPacketHistory) PathProbes() iter.Seq[*packet] { + return func(yield func(*packet) bool) { + for _, p := range h.pathProbePackets { + if !yield(p) { + return + } } } } diff --git a/internal/ackhandler/sent_packet_history_test.go b/internal/ackhandler/sent_packet_history_test.go index 06e2d97a..a74155d7 100644 --- a/internal/ackhandler/sent_packet_history_test.go +++ b/internal/ackhandler/sent_packet_history_test.go @@ -149,36 +149,18 @@ func TestSentPacketHistoryIterating(t *testing.T) { require.NoError(t, hist.Remove(4)) var packets, skippedPackets []protocol.PacketNumber - hist.Iterate(func(p *packet) bool { + for p := range hist.Packets() { if p.skippedPacket { skippedPackets = append(skippedPackets, p.PacketNumber) } else { packets = append(packets, p.PacketNumber) } - return true - }) + } require.Equal(t, []protocol.PacketNumber{1, 2, 6}, packets) require.Equal(t, []protocol.PacketNumber{0, 5}, skippedPackets) } -func TestSentPacketHistoryStopIterating(t *testing.T) { - hist := newSentPacketHistory(true) - hist.SkippedPacket(0) - hist.SentAckElicitingPacket(&packet{PacketNumber: 1}) - hist.SentAckElicitingPacket(&packet{PacketNumber: 2}) - - var iterations []protocol.PacketNumber - hist.Iterate(func(p *packet) bool { - if p.skippedPacket { - return true - } - iterations = append(iterations, p.PacketNumber) - return p.PacketNumber < 1 - }) - require.Equal(t, []protocol.PacketNumber{1}, iterations) -} - func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) { hist := newSentPacketHistory(true) hist.SentAckElicitingPacket(&packet{PacketNumber: 0}) @@ -189,7 +171,7 @@ func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) { hist.SentAckElicitingPacket(&packet{PacketNumber: 5}) var iterations []protocol.PacketNumber - hist.Iterate(func(p *packet) bool { + for p := range hist.Packets() { iterations = append(iterations, p.PacketNumber) switch p.PacketNumber { case 0: @@ -197,8 +179,7 @@ func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) { case 4: require.NoError(t, hist.Remove(4)) } - return true - }) + } require.Equal(t, []protocol.PacketNumber{0, 1, 2, 3, 4, 5}, iterations) require.Equal(t, []protocol.PacketNumber{1, 3, 5}, hist.getPacketNumbers()) @@ -217,7 +198,7 @@ func TestSentPacketHistoryPathProbes(t *testing.T) { getPacketsInHistory := func(t *testing.T) []protocol.PacketNumber { t.Helper() var pns []protocol.PacketNumber - hist.Iterate(func(p *packet) bool { + for p := range hist.Packets() { pns = append(pns, p.PacketNumber) switch p.PacketNumber { case 2, 5: @@ -225,18 +206,16 @@ func TestSentPacketHistoryPathProbes(t *testing.T) { default: require.False(t, p.isPathProbePacket) } - return true - }) + } return pns } getPacketsInPathProbeHistory := func(t *testing.T) []protocol.PacketNumber { t.Helper() var pns []protocol.PacketNumber - hist.IteratePathProbes(func(p *packet) bool { + for p := range hist.PathProbes() { pns = append(pns, p.PacketNumber) - return true - }) + } return pns }