mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
ackhandler: use Go iterators to iterate over sent packets (#4952)
This commit is contained in:
parent
12f2be058b
commit
b32f1fa0e4
3 changed files with 56 additions and 91 deletions
|
@ -177,10 +177,9 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t
|
|||
if pnSpace == nil {
|
||||
return
|
||||
}
|
||||
pnSpace.history.Iterate(func(p *packet) bool {
|
||||
for p := range pnSpace.history.Packets() {
|
||||
h.removeFromBytesInFlight(p)
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
// drop the packet history
|
||||
//nolint:exhaustive // Not every packet number space can be dropped.
|
||||
|
@ -197,14 +196,13 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t
|
|||
// and not when the client drops 0-RTT keys when the handshake completes.
|
||||
// When 0-RTT is rejected, all application data sent so far becomes invalid.
|
||||
// Delete the packets from the history and remove them from bytes_in_flight.
|
||||
h.appDataPackets.history.Iterate(func(p *packet) bool {
|
||||
for p := range h.appDataPackets.history.Packets() {
|
||||
if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
|
||||
return false
|
||||
break
|
||||
}
|
||||
h.removeFromBytesInFlight(p)
|
||||
h.appDataPackets.history.Remove(p.PacketNumber)
|
||||
return true
|
||||
})
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
|
||||
}
|
||||
|
@ -430,15 +428,13 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
|
|||
ackRangeIndex := 0
|
||||
lowestAcked := ack.LowestAcked()
|
||||
largestAcked := ack.LargestAcked()
|
||||
var processErr error
|
||||
pnSpace.history.Iterate(func(p *packet) bool {
|
||||
for p := range pnSpace.history.Packets() {
|
||||
// ignore packets below the lowest acked
|
||||
if p.PacketNumber < lowestAcked {
|
||||
return true
|
||||
continue
|
||||
}
|
||||
// break after largest acked is reached
|
||||
if p.PacketNumber > largestAcked {
|
||||
return false
|
||||
break
|
||||
}
|
||||
|
||||
if ack.HasMissingRanges() {
|
||||
|
@ -450,19 +446,17 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
|
|||
}
|
||||
|
||||
if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
|
||||
return true
|
||||
continue
|
||||
}
|
||||
if p.PacketNumber > ackRange.Largest {
|
||||
processErr = fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
|
||||
return false
|
||||
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
|
||||
}
|
||||
}
|
||||
if p.skippedPacket {
|
||||
processErr = &qerr.TransportError{
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
|
||||
}
|
||||
return false
|
||||
}
|
||||
if p.isPathProbePacket {
|
||||
probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber)
|
||||
|
@ -470,11 +464,10 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
|
|||
panic(fmt.Sprintf("path probe doesn't exist: %d", p.PacketNumber))
|
||||
}
|
||||
h.ackedPackets = append(h.ackedPackets, probePacket)
|
||||
} else {
|
||||
h.ackedPackets = append(h.ackedPackets, p)
|
||||
continue
|
||||
}
|
||||
return true
|
||||
})
|
||||
h.ackedPackets = append(h.ackedPackets, p)
|
||||
}
|
||||
if h.logger.Debug() && len(h.ackedPackets) > 0 {
|
||||
pns := make([]protocol.PacketNumber, len(h.ackedPackets))
|
||||
for i, p := range h.ackedPackets {
|
||||
|
@ -482,9 +475,6 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
|
|||
}
|
||||
h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns)
|
||||
}
|
||||
if processErr != nil {
|
||||
return nil, processErr
|
||||
}
|
||||
|
||||
for _, p := range h.ackedPackets {
|
||||
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
|
||||
|
@ -659,12 +649,11 @@ func (h *sentPacketHandler) detectLostPathProbes(now time.Time) {
|
|||
lossTime := now.Add(-pathProbePacketLossTimeout)
|
||||
// RemovePathProbe cannot be called while iterating.
|
||||
var lostPathProbes []*packet
|
||||
h.appDataPackets.history.IteratePathProbes(func(p *packet) bool {
|
||||
for p := range h.appDataPackets.history.PathProbes() {
|
||||
if !p.SendTime.After(lossTime) {
|
||||
lostPathProbes = append(lostPathProbes, p)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
for _, p := range lostPathProbes {
|
||||
for _, f := range p.Frames {
|
||||
f.Handler.OnLost(f.Frame)
|
||||
|
@ -687,9 +676,9 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
|
|||
lostSendTime := now.Add(-lossDelay)
|
||||
|
||||
priorInFlight := h.bytesInFlight
|
||||
pnSpace.history.Iterate(func(p *packet) bool {
|
||||
for p := range pnSpace.history.Packets() {
|
||||
if p.PacketNumber > pnSpace.largestAcked {
|
||||
return false
|
||||
break
|
||||
}
|
||||
|
||||
isRegularPacket := !p.skippedPacket && !p.isPathProbePacket
|
||||
|
@ -736,8 +725,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
|
|||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error {
|
||||
|
@ -945,24 +933,21 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
|
|||
func (h *sentPacketHandler) ResetForRetry(now time.Time) {
|
||||
h.bytesInFlight = 0
|
||||
var firstPacketSendTime time.Time
|
||||
h.initialPackets.history.Iterate(func(p *packet) bool {
|
||||
for p := range h.initialPackets.history.Packets() {
|
||||
if firstPacketSendTime.IsZero() {
|
||||
firstPacketSendTime = p.SendTime
|
||||
}
|
||||
if p.declaredLost || p.skippedPacket {
|
||||
return true
|
||||
}
|
||||
h.queueFramesForRetransmission(p)
|
||||
return true
|
||||
})
|
||||
// All application data packets sent at this point are 0-RTT packets.
|
||||
// In the case of a Retry, we can assume that the server dropped all of them.
|
||||
h.appDataPackets.history.Iterate(func(p *packet) bool {
|
||||
if !p.declaredLost && !p.skippedPacket {
|
||||
h.queueFramesForRetransmission(p)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
// All application data packets sent at this point are 0-RTT packets.
|
||||
// In the case of a Retry, we can assume that the server dropped all of them.
|
||||
for p := range h.appDataPackets.history.Packets() {
|
||||
if !p.declaredLost && !p.skippedPacket {
|
||||
h.queueFramesForRetransmission(p)
|
||||
}
|
||||
}
|
||||
|
||||
// Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial.
|
||||
// Otherwise, we don't know which Initial the Retry was sent in response to.
|
||||
|
@ -993,18 +978,16 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) {
|
|||
|
||||
func (h *sentPacketHandler) MigratedPath(now time.Time, initialMaxDatagramSize protocol.ByteCount) {
|
||||
h.rttStats.ResetForPathMigration()
|
||||
h.appDataPackets.history.Iterate(func(p *packet) bool {
|
||||
for p := range h.appDataPackets.history.Packets() {
|
||||
h.appDataPackets.history.DeclareLost(p.PacketNumber)
|
||||
if !p.skippedPacket && !p.isPathProbePacket {
|
||||
h.removeFromBytesInFlight(p)
|
||||
h.queueFramesForRetransmission(p)
|
||||
}
|
||||
return true
|
||||
})
|
||||
h.appDataPackets.history.IteratePathProbes(func(p *packet) bool {
|
||||
}
|
||||
for p := range h.appDataPackets.history.PathProbes() {
|
||||
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
|
||||
return true
|
||||
})
|
||||
}
|
||||
h.congestion = congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
h.rttStats,
|
||||
|
|
|
@ -2,6 +2,7 @@ package ackhandler
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
@ -68,23 +69,25 @@ func (h *sentPacketHistory) SentPathProbePacket(p *packet) {
|
|||
h.pathProbePackets = append(h.pathProbePackets, p)
|
||||
}
|
||||
|
||||
// Iterate iterates through all packets.
|
||||
func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool)) {
|
||||
for _, p := range h.packets {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if cont := cb(p); !cont {
|
||||
return
|
||||
func (h *sentPacketHistory) Packets() iter.Seq[*packet] {
|
||||
return func(yield func(*packet) bool) {
|
||||
for _, p := range h.packets {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if !yield(p) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IteratePathProbes iterates through all packets.
|
||||
func (h *sentPacketHistory) IteratePathProbes(cb func(*packet) (cont bool)) {
|
||||
for _, p := range h.pathProbePackets {
|
||||
if cont := cb(p); !cont {
|
||||
return
|
||||
func (h *sentPacketHistory) PathProbes() iter.Seq[*packet] {
|
||||
return func(yield func(*packet) bool) {
|
||||
for _, p := range h.pathProbePackets {
|
||||
if !yield(p) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -149,36 +149,18 @@ func TestSentPacketHistoryIterating(t *testing.T) {
|
|||
require.NoError(t, hist.Remove(4))
|
||||
|
||||
var packets, skippedPackets []protocol.PacketNumber
|
||||
hist.Iterate(func(p *packet) bool {
|
||||
for p := range hist.Packets() {
|
||||
if p.skippedPacket {
|
||||
skippedPackets = append(skippedPackets, p.PacketNumber)
|
||||
} else {
|
||||
packets = append(packets, p.PacketNumber)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
require.Equal(t, []protocol.PacketNumber{1, 2, 6}, packets)
|
||||
require.Equal(t, []protocol.PacketNumber{0, 5}, skippedPackets)
|
||||
}
|
||||
|
||||
func TestSentPacketHistoryStopIterating(t *testing.T) {
|
||||
hist := newSentPacketHistory(true)
|
||||
hist.SkippedPacket(0)
|
||||
hist.SentAckElicitingPacket(&packet{PacketNumber: 1})
|
||||
hist.SentAckElicitingPacket(&packet{PacketNumber: 2})
|
||||
|
||||
var iterations []protocol.PacketNumber
|
||||
hist.Iterate(func(p *packet) bool {
|
||||
if p.skippedPacket {
|
||||
return true
|
||||
}
|
||||
iterations = append(iterations, p.PacketNumber)
|
||||
return p.PacketNumber < 1
|
||||
})
|
||||
require.Equal(t, []protocol.PacketNumber{1}, iterations)
|
||||
}
|
||||
|
||||
func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
|
||||
hist := newSentPacketHistory(true)
|
||||
hist.SentAckElicitingPacket(&packet{PacketNumber: 0})
|
||||
|
@ -189,7 +171,7 @@ func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
|
|||
hist.SentAckElicitingPacket(&packet{PacketNumber: 5})
|
||||
|
||||
var iterations []protocol.PacketNumber
|
||||
hist.Iterate(func(p *packet) bool {
|
||||
for p := range hist.Packets() {
|
||||
iterations = append(iterations, p.PacketNumber)
|
||||
switch p.PacketNumber {
|
||||
case 0:
|
||||
|
@ -197,8 +179,7 @@ func TestSentPacketHistoryDeleteWhileIterating(t *testing.T) {
|
|||
case 4:
|
||||
require.NoError(t, hist.Remove(4))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
require.Equal(t, []protocol.PacketNumber{0, 1, 2, 3, 4, 5}, iterations)
|
||||
require.Equal(t, []protocol.PacketNumber{1, 3, 5}, hist.getPacketNumbers())
|
||||
|
@ -217,7 +198,7 @@ func TestSentPacketHistoryPathProbes(t *testing.T) {
|
|||
getPacketsInHistory := func(t *testing.T) []protocol.PacketNumber {
|
||||
t.Helper()
|
||||
var pns []protocol.PacketNumber
|
||||
hist.Iterate(func(p *packet) bool {
|
||||
for p := range hist.Packets() {
|
||||
pns = append(pns, p.PacketNumber)
|
||||
switch p.PacketNumber {
|
||||
case 2, 5:
|
||||
|
@ -225,18 +206,16 @@ func TestSentPacketHistoryPathProbes(t *testing.T) {
|
|||
default:
|
||||
require.False(t, p.isPathProbePacket)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return pns
|
||||
}
|
||||
|
||||
getPacketsInPathProbeHistory := func(t *testing.T) []protocol.PacketNumber {
|
||||
t.Helper()
|
||||
var pns []protocol.PacketNumber
|
||||
hist.IteratePathProbes(func(p *packet) bool {
|
||||
for p := range hist.PathProbes() {
|
||||
pns = append(pns, p.PacketNumber)
|
||||
return true
|
||||
})
|
||||
}
|
||||
return pns
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue