ackhandler: use Go iterators to iterate over sent packets (#4952)

This commit is contained in:
Marten Seemann 2025-02-13 13:53:25 +01:00 committed by GitHub
parent 12f2be058b
commit b32f1fa0e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 91 deletions

View file

@ -177,10 +177,9 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t
if pnSpace == nil {
return
}
pnSpace.history.Iterate(func(p *packet) bool {
for p := range pnSpace.history.Packets() {
h.removeFromBytesInFlight(p)
return true
})
}
}
// drop the packet history
//nolint:exhaustive // Not every packet number space can be dropped.
@ -197,14 +196,13 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t
// and not when the client drops 0-RTT keys when the handshake completes.
// When 0-RTT is rejected, all application data sent so far becomes invalid.
// Delete the packets from the history and remove them from bytes_in_flight.
h.appDataPackets.history.Iterate(func(p *packet) bool {
for p := range h.appDataPackets.history.Packets() {
if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
return false
break
}
h.removeFromBytesInFlight(p)
h.appDataPackets.history.Remove(p.PacketNumber)
return true
})
}
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
@ -430,15 +428,13 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
ackRangeIndex := 0
lowestAcked := ack.LowestAcked()
largestAcked := ack.LargestAcked()
var processErr error
pnSpace.history.Iterate(func(p *packet) bool {
for p := range pnSpace.history.Packets() {
// ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true
continue
}
// break after largest acked is reached
if p.PacketNumber > largestAcked {
return false
break
}
if ack.HasMissingRanges() {
@ -450,19 +446,17 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
}
if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
return true
continue
}
if p.PacketNumber > ackRange.Largest {
processErr = fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
return false
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
}
if p.skippedPacket {
processErr = &qerr.TransportError{
return nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
}
return false
}
if p.isPathProbePacket {
probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber)
@ -470,11 +464,10 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
panic(fmt.Sprintf("path probe doesn't exist: %d", p.PacketNumber))
}
h.ackedPackets = append(h.ackedPackets, probePacket)
} else {
h.ackedPackets = append(h.ackedPackets, p)
continue
}
return true
})
h.ackedPackets = append(h.ackedPackets, p)
}
if h.logger.Debug() && len(h.ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(h.ackedPackets))
for i, p := range h.ackedPackets {
@ -482,9 +475,6 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
}
h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns)
}
if processErr != nil {
return nil, processErr
}
for _, p := range h.ackedPackets {
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
@ -659,12 +649,11 @@ func (h *sentPacketHandler) detectLostPathProbes(now time.Time) {
lossTime := now.Add(-pathProbePacketLossTimeout)
// RemovePathProbe cannot be called while iterating.
var lostPathProbes []*packet
h.appDataPackets.history.IteratePathProbes(func(p *packet) bool {
for p := range h.appDataPackets.history.PathProbes() {
if !p.SendTime.After(lossTime) {
lostPathProbes = append(lostPathProbes, p)
}
return true
})
}
for _, p := range lostPathProbes {
for _, f := range p.Frames {
f.Handler.OnLost(f.Frame)
@ -687,9 +676,9 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
lostSendTime := now.Add(-lossDelay)
priorInFlight := h.bytesInFlight
pnSpace.history.Iterate(func(p *packet) bool {
for p := range pnSpace.history.Packets() {
if p.PacketNumber > pnSpace.largestAcked {
return false
break
}
isRegularPacket := !p.skippedPacket && !p.isPathProbePacket
@ -736,8 +725,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
}
}
}
return true
})
}
}
func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error {
@ -945,24 +933,21 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
func (h *sentPacketHandler) ResetForRetry(now time.Time) {
h.bytesInFlight = 0
var firstPacketSendTime time.Time
h.initialPackets.history.Iterate(func(p *packet) bool {
for p := range h.initialPackets.history.Packets() {
if firstPacketSendTime.IsZero() {
firstPacketSendTime = p.SendTime
}
if p.declaredLost || p.skippedPacket {
return true
}
h.queueFramesForRetransmission(p)
return true
})
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
h.appDataPackets.history.Iterate(func(p *packet) bool {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
return true
})
}
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
for p := range h.appDataPackets.history.Packets() {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
}
// Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial.
// Otherwise, we don't know which Initial the Retry was sent in response to.
@ -993,18 +978,16 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) {
func (h *sentPacketHandler) MigratedPath(now time.Time, initialMaxDatagramSize protocol.ByteCount) {
h.rttStats.ResetForPathMigration()
h.appDataPackets.history.Iterate(func(p *packet) bool {
for p := range h.appDataPackets.history.Packets() {
h.appDataPackets.history.DeclareLost(p.PacketNumber)
if !p.skippedPacket && !p.isPathProbePacket {
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
}
return true
})
h.appDataPackets.history.IteratePathProbes(func(p *packet) bool {
}
for p := range h.appDataPackets.history.PathProbes() {
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
return true
})
}
h.congestion = congestion.NewCubicSender(
congestion.DefaultClock{},
h.rttStats,

View file

@ -2,6 +2,7 @@ package ackhandler
import (
"fmt"
"iter"
"github.com/quic-go/quic-go/internal/protocol"
)
@ -68,23 +69,25 @@ func (h *sentPacketHistory) SentPathProbePacket(p *packet) {
h.pathProbePackets = append(h.pathProbePackets, p)
}
// Iterate iterates through all packets.
func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool)) {
for _, p := range h.packets {
if p == nil {
continue
}
if cont := cb(p); !cont {
return
func (h *sentPacketHistory) Packets() iter.Seq[*packet] {
return func(yield func(*packet) bool) {
for _, p := range h.packets {
if p == nil {
continue
}
if !yield(p) {
return
}
}
}
}
// IteratePathProbes iterates through all packets.
func (h *sentPacketHistory) IteratePathProbes(cb func(*packet) (cont bool)) {
for _, p := range h.pathProbePackets {
if cont := cb(p); !cont {
return
func (h *sentPacketHistory) PathProbes() iter.Seq[*packet] {
return func(yield func(*packet) bool) {
for _, p := range h.pathProbePackets {
if !yield(p) {
return
}
}
}
}

View file

@ -149,36 +149,18 @@ func TestSentPacketHistoryIterating(t *testing.T) {
require.NoError(t, hist.Remove(4))
var packets, skippedPackets []protocol.PacketNumber
hist.Iterate(func(p *packet) bool {
for p := range hist.Packets() {
if p.skippedPacket {
skippedPackets = append(skippedPackets, p.PacketNumber)
} else {
packets = append(packets, p.PacketNumber)
}
return true
})
}
require.Equal(t, []protocol.PacketNumber{1, 2, 6}, packets)
require.Equal(t, []protocol.PacketNumber{0, 5}, skippedPackets)
}
func TestSentPacketHistoryStopIterating(t *testing.T) {
hist := newSentPacketHistory(true)
hist.SkippedPacket(0)
hist.SentAckElicitingPacket(&packet{PacketNumber: 1})
hist.SentAckElicitingPacket(&packet{PacketNumber: 2})
var iterations []protocol.PacketNumber
hist.Iterate(func(p *packet) bool {
if p.skippedPacket {
return true
}
iterations = append(iterations, p.PacketNumber)
return p.PacketNumber < 1
})
require.Equal(t, []protocol.PacketNumber{1}, iterations)
}
func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
hist := newSentPacketHistory(true)
hist.SentAckElicitingPacket(&packet{PacketNumber: 0})
@ -189,7 +171,7 @@ func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
hist.SentAckElicitingPacket(&packet{PacketNumber: 5})
var iterations []protocol.PacketNumber
hist.Iterate(func(p *packet) bool {
for p := range hist.Packets() {
iterations = append(iterations, p.PacketNumber)
switch p.PacketNumber {
case 0:
@ -197,8 +179,7 @@ func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
case 4:
require.NoError(t, hist.Remove(4))
}
return true
})
}
require.Equal(t, []protocol.PacketNumber{0, 1, 2, 3, 4, 5}, iterations)
require.Equal(t, []protocol.PacketNumber{1, 3, 5}, hist.getPacketNumbers())
@ -217,7 +198,7 @@ func TestSentPacketHistoryPathProbes(t *testing.T) {
getPacketsInHistory := func(t *testing.T) []protocol.PacketNumber {
t.Helper()
var pns []protocol.PacketNumber
hist.Iterate(func(p *packet) bool {
for p := range hist.Packets() {
pns = append(pns, p.PacketNumber)
switch p.PacketNumber {
case 2, 5:
@ -225,18 +206,16 @@ func TestSentPacketHistoryPathProbes(t *testing.T) {
default:
require.False(t, p.isPathProbePacket)
}
return true
})
}
return pns
}
getPacketsInPathProbeHistory := func(t *testing.T) []protocol.PacketNumber {
t.Helper()
var pns []protocol.PacketNumber
hist.IteratePathProbes(func(p *packet) bool {
for p := range hist.PathProbes() {
pns = append(pns, p.PacketNumber)
return true
})
}
return pns
}