Merge pull request #2455 from lucas-clemente/loss-before-ack

notify the congestion controller of losses first
This commit is contained in:
Marten Seemann 2020-04-02 14:38:39 +07:00 committed by GitHub
commit c10af76a4a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 51 deletions

View file

@ -267,28 +267,24 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
} }
} }
ackedPackets, err := h.determineNewlyAckedPackets(ack, encLevel) priorInFlight := h.bytesInFlight
ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel)
if err != nil || len(ackedPackets) == 0 { if err != nil || len(ackedPackets) == 0 {
return err return err
} }
lostPackets, err := h.detectAndRemoveLostPackets(rcvTime, encLevel)
priorInFlight := h.bytesInFlight if err != nil {
return err
}
for _, p := range lostPackets {
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
for _, p := range ackedPackets { for _, p := range ackedPackets {
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1)
}
if err := h.onPacketAcked(p); err != nil {
return err
}
if p.includedInBytesInFlight { if p.includedInBytesInFlight {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
} }
} }
if err := h.detectLostPackets(rcvTime, encLevel, priorInFlight); err != nil {
return err
}
if h.qlogger != nil && h.ptoCount != 0 { if h.qlogger != nil && h.ptoCount != 0 {
h.qlogger.UpdatedPTOCount(0) h.qlogger.UpdatedPTOCount(0)
} }
@ -303,15 +299,12 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
return h.lowestNotConfirmedAcked return h.lowestNotConfirmedAcked
} }
func (h *sentPacketHandler) determineNewlyAckedPackets( func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) {
ackFrame *wire.AckFrame,
encLevel protocol.EncryptionLevel,
) ([]*Packet, error) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
var ackedPackets []*Packet var ackedPackets []*Packet
ackRangeIndex := 0 ackRangeIndex := 0
lowestAcked := ackFrame.LowestAcked() lowestAcked := ack.LowestAcked()
largestAcked := ackFrame.LargestAcked() largestAcked := ack.LargestAcked()
err := pnSpace.history.Iterate(func(p *Packet) (bool, error) { err := pnSpace.history.Iterate(func(p *Packet) (bool, error) {
// Ignore packets below the lowest acked // Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked { if p.PacketNumber < lowestAcked {
@ -322,12 +315,12 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(
return false, nil return false, nil
} }
if ackFrame.HasMissingRanges() { if ack.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 { for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 {
ackRangeIndex++ ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex] ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
} }
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
@ -348,6 +341,28 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(
} }
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns) h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
} }
for _, p := range ackedPackets {
if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil {
continue
}
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1)
}
for _, f := range p.Frames {
if f.OnAcked != nil {
f.OnAcked(f.Frame)
}
}
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
return nil, err
}
}
return ackedPackets, err return ackedPackets, err
} }
@ -429,11 +444,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() {
h.alarm = sentTime.Add(h.rttStats.PTO(encLevel == protocol.Encryption1RTT) << h.ptoCount) h.alarm = sentTime.Add(h.rttStats.PTO(encLevel == protocol.Encryption1RTT) << h.ptoCount)
} }
func (h *sentPacketHandler) detectLostPackets( func (h *sentPacketHandler) detectAndRemoveLostPackets(now time.Time, encLevel protocol.EncryptionLevel) ([]*Packet, error) {
now time.Time,
encLevel protocol.EncryptionLevel,
priorInFlight protocol.ByteCount,
) error {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.lossTime = time.Time{} pnSpace.lossTime = time.Time{}
@ -486,7 +497,6 @@ func (h *sentPacketHandler) detectLostPackets(
// the bytes in flight need to be reduced no matter if this packet will be retransmitted // the bytes in flight need to be reduced no matter if this packet will be retransmitted
if p.includedInBytesInFlight { if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length h.bytesInFlight -= p.Length
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
} }
pnSpace.history.Remove(p.PacketNumber) pnSpace.history.Remove(p.PacketNumber)
if h.traceCallback != nil { if h.traceCallback != nil {
@ -505,7 +515,7 @@ func (h *sentPacketHandler) detectLostPackets(
}) })
} }
} }
return nil return lostPackets, nil
} }
func (h *sentPacketHandler) OnLossDetectionTimeout() error { func (h *sentPacketHandler) OnLossDetectionTimeout() error {
@ -529,7 +539,14 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime)
} }
// Early retransmit or time loss detection // Early retransmit or time loss detection
return h.detectLostPackets(time.Now(), encLevel, h.bytesInFlight) priorInFlight := h.bytesInFlight
lostPackets, err := h.detectAndRemoveLostPackets(time.Now(), encLevel)
if err != nil {
return err
}
for _, p := range lostPackets {
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
} }
// PTO // PTO
@ -559,23 +576,6 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
return h.alarm return h.alarm
} }
func (h *sentPacketHandler) onPacketAcked(p *Packet) error {
pnSpace := h.getPacketNumberSpace(p.EncryptionLevel)
if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil {
return nil
}
for _, f := range p.Frames {
if f.OnAcked != nil {
f.OnAcked(f.Frame)
}
}
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
return pnSpace.history.Remove(p.PacketNumber)
}
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)

View file

@ -460,8 +460,8 @@ var _ = Describe("SentPacketHandler", func() {
// lose packet 1 // lose packet 1
gomock.InOrder( gomock.InOrder(
cong.EXPECT().MaybeExitSlowStart(), cong.EXPECT().MaybeExitSlowStart(),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()),
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()),
) )
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed())
@ -480,16 +480,16 @@ var _ = Describe("SentPacketHandler", func() {
// receive the first ACK // receive the first ACK
gomock.InOrder( gomock.InOrder(
cong.EXPECT().MaybeExitSlowStart(), cong.EXPECT().MaybeExitSlowStart(),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()),
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)), cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4), gomock.Any()),
) )
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now().Add(-30*time.Minute))).To(Succeed()) Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now().Add(-30*time.Minute))).To(Succeed())
// receive the second ACK // receive the second ACK
gomock.InOrder( gomock.InOrder(
cong.EXPECT().MaybeExitSlowStart(), cong.EXPECT().MaybeExitSlowStart(),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()),
cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)), cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)),
cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2), gomock.Any()),
) )
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}} ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}}
Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed())