diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index ac541a5a..54518542 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -196,6 +196,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe return err } + priorInFlight := h.bytesInFlight for _, p := range ackedPackets { if encLevel < p.EncryptionLevel { return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel) @@ -209,9 +210,12 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe if err := h.onPacketAcked(p); err != nil { return err } + if p.includedInBytesInFlight { + h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight) + } } - if err := h.detectLostPackets(rcvTime); err != nil { + if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil { return err } h.updateLossDetectionAlarm() @@ -288,7 +292,7 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() { } } -func (h *sentPacketHandler) detectLostPackets(now time.Time) error { +func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error { h.lossTime = time.Time{} maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) @@ -314,6 +318,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time) error { // the bytes in flight need to be reduced no matter if this packet will be retransmitted if p.includedInBytesInFlight { h.bytesInFlight -= p.Length + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) } if p.canBeRetransmitted { // queue the packet for retransmission, and report the loss to the congestion controller @@ -322,9 +327,6 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time) error { return err } } - if p.includedInBytesInFlight { - h.congestion.OnPacketLost(p.PacketNumber, p.Length, h.bytesInFlight) - } h.packetHistory.Remove(p.PacketNumber) } return nil @@ -340,7 +342,7 @@ func (h *sentPacketHandler) OnAlarm() error { err = h.queueHandshakePacketsForRetransmission() } else if !h.lossTime.IsZero() { // Early retransmit or time loss detection - err = h.detectLostPackets(now) + err = h.detectLostPackets(now, h.bytesInFlight) } else { // RTO h.rtoCount++ @@ -388,7 +390,6 @@ func (h *sentPacketHandler) onPacketAcked(p *Packet) error { // this also applies to packets that have been retransmitted as probe packets if p.includedInBytesInFlight { h.bytesInFlight -= p.Length - h.congestion.OnPacketAcked(p.PacketNumber, p.Length, h.bytesInFlight) } if h.rtoCount > 0 { h.verifyRTO(p.PacketNumber) diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 26369a46..c6f8bbbd 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -537,8 +537,8 @@ var _ = Describe("SentPacketHandler", func() { cong.EXPECT().TimeUntilSend(gomock.Any()).Times(3) gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), // must be called before packets are acked - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(1)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(3)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(3)), ) handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1})) handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2})) @@ -571,12 +571,12 @@ var _ = Describe("SentPacketHandler", func() { // send one probe packet and receive an ACK for it gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(5), protocol.ByteCount(1), protocol.ByteCount(4)), cong.EXPECT().OnRetransmissionTimeout(true), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(3)), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2)), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(1)), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(0)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(5), protocol.ByteCount(1), protocol.ByteCount(5)), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(5)), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(5)), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(5)), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(5)), ) handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 5})) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 5, LowestAcked: 5}, 1, protocol.EncryptionForwardSecure, time.Now()) @@ -598,8 +598,8 @@ var _ = Describe("SentPacketHandler", func() { // don't EXPECT any call to OnRetransmissionTimeout gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2)), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(1)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(3)), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(3)), ) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now()) Expect(err).ToNot(HaveOccurred()) @@ -613,8 +613,8 @@ var _ = Describe("SentPacketHandler", func() { // lose packet 1 gomock.InOrder( cong.EXPECT().MaybeExitSlowStart(), - cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(1)), - cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(0)), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)), ) err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now()) Expect(err).ToNot(HaveOccurred()) @@ -623,6 +623,31 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) }) + It("calls OnPacketAcked and OnPacketLost with the right bytes_in_flight value", func() { + cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(4) + cong.EXPECT().TimeUntilSend(gomock.Any()).Times(4) + handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) + handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2, SendTime: time.Now().Add(-30 * time.Minute)})) + handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3, SendTime: time.Now().Add(-30 * time.Minute)})) + handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 4, SendTime: time.Now()})) + // receive the first ACK + gomock.InOrder( + cong.EXPECT().MaybeExitSlowStart(), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(2), protocol.ByteCount(1), protocol.ByteCount(4)), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)), + ) + err := handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now().Add(-30*time.Minute)) + Expect(err).ToNot(HaveOccurred()) + // receive the second ACK + gomock.InOrder( + cong.EXPECT().MaybeExitSlowStart(), + cong.EXPECT().OnPacketAcked(protocol.PacketNumber(4), protocol.ByteCount(1), protocol.ByteCount(2)), + cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)), + ) + err = handler.ReceivedAck(&wire.AckFrame{LargestAcked: 4, LowestAcked: 4}, 2, protocol.EncryptionForwardSecure, time.Now()) + Expect(err).ToNot(HaveOccurred()) + }) + It("only allows sending of ACKs when congestion limited", func() { handler.bytesInFlight = 100 cong.EXPECT().GetCongestionWindow().Return(protocol.ByteCount(200))