diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index cf19b115..10200f4c 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -3,7 +3,9 @@ package ackhandler import ( "time" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -15,6 +17,7 @@ type receivedPacketHandler struct { packetHistory *receivedPacketHistory ackSendDelay time.Duration + rttStats *congestion.RTTStats packetsReceivedSinceLastAck int retransmittablePacketsReceivedSinceLastAck int @@ -26,35 +29,53 @@ type receivedPacketHandler struct { } const ( - // ackSendDelay is the maximum delay that can be applied to an ACK for a retransmittable packet + // maximum delay that can be applied to an ACK for a retransmittable packet ackSendDelay = 25 * time.Millisecond - // retransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for + // initial maximum number of retransmittable packets received before sending an ack. + initialRetransmittablePacketsBeforeAck = 2 + // number of retransmittable that an ACK is sent for retransmittablePacketsBeforeAck = 10 + // 1/5 RTT delay when doing ack decimation + ackDecimationDelay = 1.0 / 4 + // 1/8 RTT delay when doing ack decimation + shortAckDecimationDelay = 1.0 / 8 + // Minimum number of packets received before ack decimation is enabled. + // This intends to avoid the beginning of slow start, when CWNDs may be + // rapidly increasing. + minReceivedBeforeAckDecimation = 100 + // Maximum number of packets to ack immediately after a missing packet for + // fast retransmission to kick in at the sender. This limit is created to + // reduce the number of acks sent that have no benefit for fast retransmission. + // Set to the number of nacks needed for fast retransmit plus one for protection + // against an ack loss + maxPacketsAfterNewMissing = 4 ) // NewReceivedPacketHandler creates a new receivedPacketHandler -func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler { +func NewReceivedPacketHandler(rttStats *congestion.RTTStats, version protocol.VersionNumber) ReceivedPacketHandler { return &receivedPacketHandler{ packetHistory: newReceivedPacketHistory(), ackSendDelay: ackSendDelay, + rttStats: rttStats, version: version, } } func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error { + if packetNumber < h.ignoreBelow { + return nil + } + + isMissing := h.isMissing(packetNumber) if packetNumber > h.largestObserved { h.largestObserved = packetNumber h.largestObservedReceivedTime = rcvTime } - if packetNumber < h.ignoreBelow { - return nil - } - if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil { return err } - h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck) + h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing) return nil } @@ -65,35 +86,68 @@ func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) { h.packetHistory.DeleteBelow(p) } -func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) { - h.packetsReceivedSinceLastAck++ - - if shouldInstigateAck { - h.retransmittablePacketsReceivedSinceLastAck++ +// isMissing says if a packet was reported missing in the last ACK. +func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool { + if h.lastAck == nil { + return false } + return p < h.lastAck.LargestAcked && !h.lastAck.AcksPacket(p) +} + +func (h *receivedPacketHandler) hasNewMissingPackets() bool { + if h.lastAck == nil { + return false + } + highestRange := h.packetHistory.GetHighestAckRange() + return highestRange.First >= h.lastAck.LargestAcked && highestRange.Len() <= maxPacketsAfterNewMissing +} + +// maybeQueueAck queues an ACK, if necessary. +// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck() +// in ACK_DECIMATION_WITH_REORDERING mode. +func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) { + h.packetsReceivedSinceLastAck++ // always ack the first packet if h.lastAck == nil { h.ackQueued = true + return } - // if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK - // note that it cannot be a duplicate because they're already filtered out by ReceivedPacket() - if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked { - h.ackQueued = true - } - - // check if a new missing range above the previously was created - if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked { + // Send an ACK if this packet was reported missing in an ACK sent before. + // Ack decimation with reordering relies on the timer to send an ACK, but if + // missing packets we reported in the previous ack, send an ACK immediately. + if wasMissing { h.ackQueued = true } if !h.ackQueued && shouldInstigateAck { - if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck { - h.ackQueued = true + h.retransmittablePacketsReceivedSinceLastAck++ + + if packetNumber > minReceivedBeforeAckDecimation { + // ack up to 10 packets at once + if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck { + h.ackQueued = true + } else if h.ackAlarm.IsZero() { + // wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack + ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay))) + h.ackAlarm = rcvTime.Add(ackDelay) + } } else { - if h.ackAlarm.IsZero() { - h.ackAlarm = rcvTime.Add(h.ackSendDelay) + // send an ACK every 2 retransmittable packets + if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck { + h.ackQueued = true + } else if h.ackAlarm.IsZero() { + h.ackAlarm = rcvTime.Add(ackSendDelay) + } + } + // If there are new missing packets to report, set a short timer to send an ACK. + if h.hasNewMissingPackets() { + // wait the minimum of 1/8 min RTT and the existing ack time + ackDelay := float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay) + ackTime := rcvTime.Add(time.Duration(ackDelay)) + if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) { + h.ackAlarm = ackTime } } } @@ -125,7 +179,6 @@ func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame { h.ackQueued = false h.packetsReceivedSinceLastAck = 0 h.retransmittablePacketsReceivedSinceLastAck = 0 - return ack } diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index 9cbca137..21a6a74b 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -3,6 +3,7 @@ package ackhandler import ( "time" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -12,11 +13,13 @@ import ( var _ = Describe("receivedPacketHandler", func() { var ( - handler *receivedPacketHandler + handler *receivedPacketHandler + rttStats *congestion.RTTStats ) BeforeEach(func() { - handler = NewReceivedPacketHandler(protocol.VersionWhatever).(*receivedPacketHandler) + rttStats = &congestion.RTTStats{} + handler = NewReceivedPacketHandler(rttStats, protocol.VersionWhatever).(*receivedPacketHandler) }) Context("accepting packets", func() { @@ -81,6 +84,15 @@ var _ = Describe("receivedPacketHandler", func() { Expect(handler.ackQueued).To(BeFalse()) } + receiveAndAckPacketsUntilAckDecimation := func() { + for i := 1; i <= minReceivedBeforeAckDecimation; i++ { + err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) + Expect(err).ToNot(HaveOccurred()) + } + Expect(handler.GetAckFrame()).ToNot(BeNil()) + Expect(handler.ackQueued).To(BeFalse()) + } + It("always queues an ACK for the first packet", func() { err := handler.ReceivedPacket(1, time.Time{}, false) Expect(err).ToNot(HaveOccurred()) @@ -95,17 +107,34 @@ var _ = Describe("receivedPacketHandler", func() { Expect(handler.GetAlarmTimeout()).To(BeZero()) }) - It("queues an ACK for every RetransmittablePacketsBeforeAck retransmittable packet, if they are arriving fast", func() { + It("queues an ACK for every second retransmittable packet at the beginning", func() { receiveAndAck10Packets() p := protocol.PacketNumber(11) - for i := 0; i < retransmittablePacketsBeforeAck-1; i++ { + for i := 0; i <= 20; i++ { err := handler.ReceivedPacket(p, time.Time{}, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeFalse()) p++ + err = handler.ReceivedPacket(p, time.Time{}, true) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeTrue()) + p++ + // dequeue the ACK frame + Expect(handler.GetAckFrame()).ToNot(BeNil()) + } + }) + + It("queues an ACK for every 10 retransmittable packet, if they are arriving fast", func() { + receiveAndAck10Packets() + p := protocol.PacketNumber(10000) + for i := 0; i < 9; i++ { + err := handler.ReceivedPacket(p, time.Now(), true) + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeFalse()) + p++ } Expect(handler.GetAlarmTimeout()).NotTo(BeZero()) - err := handler.ReceivedPacket(p, time.Time{}, true) + err := handler.ReceivedPacket(p, time.Now(), true) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeTrue()) Expect(handler.GetAlarmTimeout()).To(BeZero()) @@ -113,15 +142,15 @@ var _ = Describe("receivedPacketHandler", func() { It("only sets the timer when receiving a retransmittable packets", func() { receiveAndAck10Packets() - err := handler.ReceivedPacket(11, time.Time{}, false) + err := handler.ReceivedPacket(11, time.Now(), false) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeFalse()) - Expect(handler.ackAlarm).To(BeZero()) - err = handler.ReceivedPacket(12, time.Time{}, true) + Expect(handler.GetAlarmTimeout()).To(BeZero()) + rcvTime := time.Now().Add(10 * time.Millisecond) + err = handler.ReceivedPacket(12, rcvTime, true) Expect(err).ToNot(HaveOccurred()) Expect(handler.ackQueued).To(BeFalse()) - Expect(handler.ackAlarm).ToNot(BeZero()) - Expect(handler.GetAlarmTimeout()).NotTo(BeZero()) + Expect(handler.GetAlarmTimeout()).To(Equal(rcvTime.Add(ackSendDelay))) }) It("queues an ACK if it was reported missing before", func() { @@ -139,15 +168,32 @@ var _ = Describe("receivedPacketHandler", func() { Expect(handler.ackQueued).To(BeTrue()) }) - It("queues an ACK if it creates a new missing range", func() { - receiveAndAck10Packets() - for i := 11; i < 16; i++ { - err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) + It("doesn't queue an ACK if the packet closes a gap that was not yet reported", func() { + receiveAndAckPacketsUntilAckDecimation() + p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1) + err := handler.ReceivedPacket(p+1, time.Now(), true) // p is missing now + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeFalse()) + Expect(handler.GetAlarmTimeout()).ToNot(BeZero()) + err = handler.ReceivedPacket(p, time.Now(), true) // p is not missing any more + Expect(err).ToNot(HaveOccurred()) + Expect(handler.ackQueued).To(BeFalse()) + }) + + It("sets an ACK alarm after 1/4 RTT if it creates a new missing range", func() { + now := time.Now().Add(-time.Hour) + rtt := 80 * time.Millisecond + rttStats.UpdateRTT(rtt, 0, now) + receiveAndAckPacketsUntilAckDecimation() + p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1) + for i := p; i < p+6; i++ { + err := handler.ReceivedPacket(i, now, true) Expect(err).ToNot(HaveOccurred()) } - err := handler.ReceivedPacket(20, time.Time{}, true) // we now know that packets 16 to 19 are missing + err := handler.ReceivedPacket(p+10, now, true) // we now know that packets p+7, p+8 and p+9 Expect(err).ToNot(HaveOccurred()) - Expect(handler.ackQueued).To(BeTrue()) + Expect(rttStats.MinRTT()).To(Equal(rtt)) + Expect(handler.ackAlarm.Sub(now)).To(Equal(rtt / 8)) ack := handler.GetAckFrame() Expect(ack.HasMissingRanges()).To(BeTrue()) Expect(ack).ToNot(BeNil()) @@ -275,7 +321,7 @@ var _ = Describe("receivedPacketHandler", func() { handler.ackAlarm = time.Now().Add(-time.Minute) Expect(handler.GetAckFrame()).ToNot(BeNil()) Expect(handler.packetsReceivedSinceLastAck).To(BeZero()) - Expect(handler.ackAlarm).To(BeZero()) + Expect(handler.GetAlarmTimeout()).To(BeZero()) Expect(handler.retransmittablePacketsReceivedSinceLastAck).To(BeZero()) Expect(handler.ackQueued).To(BeFalse()) }) diff --git a/internal/wire/ack_range.go b/internal/wire/ack_range.go index c561762d..783528e6 100644 --- a/internal/wire/ack_range.go +++ b/internal/wire/ack_range.go @@ -7,3 +7,8 @@ type AckRange struct { First protocol.PacketNumber Last protocol.PacketNumber } + +// Len returns the number of packets contained in this ACK range +func (r AckRange) Len() protocol.PacketNumber { + return r.Last - r.First + 1 +} diff --git a/internal/wire/ack_range_test.go b/internal/wire/ack_range_test.go new file mode 100644 index 00000000..a9d268f1 --- /dev/null +++ b/internal/wire/ack_range_test.go @@ -0,0 +1,13 @@ +package wire + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ACK range", func() { + It("returns the length", func() { + Expect(AckRange{First: 10, Last: 10}.Len()).To(BeEquivalentTo(1)) + Expect(AckRange{First: 10, Last: 13}.Len()).To(BeEquivalentTo(4)) + }) +}) diff --git a/session.go b/session.go index 07b1415c..0cd76182 100644 --- a/session.go +++ b/session.go @@ -328,7 +328,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.sessionCreationTime = now s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.version) if s.version.UsesTLS() { s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)