Merge pull request #2455 from lucas-clemente/loss-before-ack

notify the congestion controller of losses first
This commit is contained in:
Marten Seemann 2020-04-02 14:38:39 +07:00 committed by GitHub
commit c10af76a4a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 51 deletions

View file

@ -267,28 +267,24 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
}
}
ackedPackets, err := h.determineNewlyAckedPackets(ack, encLevel)
priorInFlight := h.bytesInFlight
ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel)
if err != nil || len(ackedPackets) == 0 {
return err
}
priorInFlight := h.bytesInFlight
lostPackets, err := h.detectAndRemoveLostPackets(rcvTime, encLevel)
if err != nil {
return err
}
for _, p := range lostPackets {
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
for _, p := range ackedPackets {
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1)
}
if err := h.onPacketAcked(p); err != nil {
return err
}
if p.includedInBytesInFlight {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
}
if err := h.detectLostPackets(rcvTime, encLevel, priorInFlight); err != nil {
return err
}
if h.qlogger != nil && h.ptoCount != 0 {
h.qlogger.UpdatedPTOCount(0)
}
@ -303,15 +299,12 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
return h.lowestNotConfirmedAcked
}
func (h *sentPacketHandler) determineNewlyAckedPackets(
ackFrame *wire.AckFrame,
encLevel protocol.EncryptionLevel,
) ([]*Packet, error) {
func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) {
pnSpace := h.getPacketNumberSpace(encLevel)
var ackedPackets []*Packet
ackRangeIndex := 0
lowestAcked := ackFrame.LowestAcked()
largestAcked := ackFrame.LargestAcked()
lowestAcked := ack.LowestAcked()
largestAcked := ack.LargestAcked()
err := pnSpace.history.Iterate(func(p *Packet) (bool, error) {
// Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
@ -322,12 +315,12 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(
return false, nil
}
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
if ack.HasMissingRanges() {
ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
}
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
@ -348,6 +341,28 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(
}
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
}
for _, p := range ackedPackets {
if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil {
continue
}
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1)
}
for _, f := range p.Frames {
if f.OnAcked != nil {
f.OnAcked(f.Frame)
}
}
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
return nil, err
}
}
return ackedPackets, err
}
@ -429,11 +444,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() {
h.alarm = sentTime.Add(h.rttStats.PTO(encLevel == protocol.Encryption1RTT) << h.ptoCount)
}
func (h *sentPacketHandler) detectLostPackets(
now time.Time,
encLevel protocol.EncryptionLevel,
priorInFlight protocol.ByteCount,
) error {
func (h *sentPacketHandler) detectAndRemoveLostPackets(now time.Time, encLevel protocol.EncryptionLevel) ([]*Packet, error) {
pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.lossTime = time.Time{}
@ -486,7 +497,6 @@ func (h *sentPacketHandler) detectLostPackets(
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
pnSpace.history.Remove(p.PacketNumber)
if h.traceCallback != nil {
@ -505,7 +515,7 @@ func (h *sentPacketHandler) detectLostPackets(
})
}
}
return nil
return lostPackets, nil
}
func (h *sentPacketHandler) OnLossDetectionTimeout() error {
@ -529,7 +539,14 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime)
}
// Early retransmit or time loss detection
return h.detectLostPackets(time.Now(), encLevel, h.bytesInFlight)
priorInFlight := h.bytesInFlight
lostPackets, err := h.detectAndRemoveLostPackets(time.Now(), encLevel)
if err != nil {
return err
}
for _, p := range lostPackets {
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
}
// PTO
@ -559,23 +576,6 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(p *Packet) error {
pnSpace := h.getPacketNumberSpace(p.EncryptionLevel)
if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil {
return nil
}
for _, f := range p.Frames {
if f.OnAcked != nil {
f.OnAcked(f.Frame)
}
}
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
return pnSpace.history.Remove(p.PacketNumber)
}
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel)