diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index ff547a8c..1175c790 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -14,7 +14,7 @@ type receivedPacketHandler struct { initialPackets *receivedPacketTracker handshakePackets *receivedPacketTracker - appDataPackets *receivedPacketTracker + appDataPackets appDataReceivedPacketTracker lowest1RTTPacket protocol.PacketNumber } @@ -24,9 +24,9 @@ var _ ReceivedPacketHandler = &receivedPacketHandler{} func newReceivedPacketHandler(sentPackets sentPacketTracker, logger utils.Logger) ReceivedPacketHandler { return &receivedPacketHandler{ sentPackets: sentPackets, - initialPackets: newReceivedPacketTracker(logger), - handshakePackets: newReceivedPacketTracker(logger), - appDataPackets: newReceivedPacketTracker(logger), + initialPackets: newReceivedPacketTracker(), + handshakePackets: newReceivedPacketTracker(), + appDataPackets: *newAppDataReceivedPacketTracker(logger), lowest1RTTPacket: protocol.InvalidPacketNumber, } } @@ -84,41 +84,28 @@ func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { } func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { - var initialAlarm, handshakeAlarm time.Time - if h.initialPackets != nil { - initialAlarm = h.initialPackets.GetAlarmTimeout() - } - if h.handshakePackets != nil { - handshakeAlarm = h.handshakePackets.GetAlarmTimeout() - } - oneRTTAlarm := h.appDataPackets.GetAlarmTimeout() - return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) + return h.appDataPackets.GetAlarmTimeout() } func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { - var ack *wire.AckFrame //nolint:exhaustive // 0-RTT packets can't contain ACK frames. switch encLevel { case protocol.EncryptionInitial: if h.initialPackets != nil { - ack = h.initialPackets.GetAckFrame(onlyIfQueued) + return h.initialPackets.GetAckFrame() } + return nil case protocol.EncryptionHandshake: if h.handshakePackets != nil { - ack = h.handshakePackets.GetAckFrame(onlyIfQueued) + return h.handshakePackets.GetAckFrame() } + return nil case protocol.Encryption1RTT: - // 0-RTT packets can't contain ACK frames return h.appDataPackets.GetAckFrame(onlyIfQueued) default: + // 0-RTT packets can't contain ACK frames return nil } - // For Initial and Handshake ACKs, the delay time is ignored by the receiver. - // Set it to 0 in order to save bytes. - if ack != nil { - ack.DelayTime = 0 - } - return ack } func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool { diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index 26662f11..08af6f1e 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -9,35 +9,19 @@ import ( "github.com/quic-go/quic-go/internal/wire" ) -// number of ack-eliciting packets received before sending an ack. -const packetsBeforeAck = 2 - +// The receivedPacketTracker tracks packets for the Initial and Handshake packet number space. +// Every received packet is acknowledged immediately. type receivedPacketTracker struct { - largestObserved protocol.PacketNumber - ignoreBelow protocol.PacketNumber - largestObservedRcvdTime time.Time - ect0, ect1, ecnce uint64 + ect0, ect1, ecnce uint64 - packetHistory *receivedPacketHistory - - maxAckDelay time.Duration + packetHistory receivedPacketHistory + lastAck *wire.AckFrame hasNewAck bool // true as soon as we received an ack-eliciting new packet - ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets - - ackElicitingPacketsReceivedSinceLastAck int - ackAlarm time.Time - lastAck *wire.AckFrame - - logger utils.Logger } -func newReceivedPacketTracker(logger utils.Logger) *receivedPacketTracker { - return &receivedPacketTracker{ - packetHistory: newReceivedPacketHistory(), - maxAckDelay: protocol.MaxAckDelay, - logger: logger, - } +func newReceivedPacketTracker() *receivedPacketTracker { + return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()} } func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { @@ -45,12 +29,6 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn) } - isMissing := h.isMissing(pn) - if pn >= h.largestObserved { - h.largestObserved = pn - h.largestObservedRcvdTime = rcvTime - } - //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE. switch ecn { case protocol.ECT0: @@ -60,13 +38,82 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro case protocol.ECNCE: h.ecnce++ } - if !ackEliciting { return nil } - h.hasNewAck = true + return nil +} + +func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame { + if !h.hasNewAck { + return nil + } + + // This function always returns the same ACK frame struct, filled with the most recent values. + ack := h.lastAck + if ack == nil { + ack = &wire.AckFrame{} + } + ack.Reset() + ack.ECT0 = h.ect0 + ack.ECT1 = h.ect1 + ack.ECNCE = h.ecnce + ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) + + h.lastAck = ack + h.hasNewAck = false + return ack +} + +func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { + return h.packetHistory.IsPotentiallyDuplicate(pn) +} + +// number of ack-eliciting packets received before sending an ACK +const packetsBeforeAck = 2 + +// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space. +// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached. +type appDataReceivedPacketTracker struct { + receivedPacketTracker + + largestObservedRcvdTime time.Time + + largestObserved protocol.PacketNumber + ignoreBelow protocol.PacketNumber + + maxAckDelay time.Duration + ackQueued bool // true if we need send a new ACK + + ackElicitingPacketsReceivedSinceLastAck int + ackAlarm time.Time + + logger utils.Logger +} + +func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker { + h := &appDataReceivedPacketTracker{ + receivedPacketTracker: *newReceivedPacketTracker(), + maxAckDelay: protocol.MaxAckDelay, + logger: logger, + } + return h +} + +func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { + if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil { + return err + } + if pn >= h.largestObserved { + h.largestObserved = pn + h.largestObservedRcvdTime = rcvTime + } + if !ackEliciting { + return nil + } h.ackElicitingPacketsReceivedSinceLastAck++ + isMissing := h.isMissing(pn) if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) { h.ackQueued = true h.ackAlarm = time.Time{} // cancel the ack alarm @@ -83,7 +130,7 @@ func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn pro // IgnoreBelow sets a lower limit for acknowledging packets. // Packets with packet numbers smaller than p will not be acked. -func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { +func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { if pn <= h.ignoreBelow { return } @@ -95,14 +142,14 @@ func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { } // isMissing says if a packet was reported missing in the last ACK. -func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool { +func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool { if h.lastAck == nil || p < h.ignoreBelow { return false } return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) } -func (h *receivedPacketTracker) hasNewMissingPackets() bool { +func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool { if h.lastAck == nil { return false } @@ -110,7 +157,7 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool { return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 } -func (h *receivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool { +func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool { // always acknowledge the first packet if h.lastAck == nil { h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") @@ -119,7 +166,7 @@ func (h *receivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn pro // Send an ACK if this packet was reported missing in an ACK sent before. // Ack decimation with reordering relies on the timer to send an ACK, but if - // missing packets we reported in the previous ack, send an ACK immediately. + // missing packets we reported in the previous ACK, send an ACK immediately. if wasMissing { if h.logger.Debug() { h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn) @@ -149,42 +196,25 @@ func (h *receivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn pro return false } -func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { - if !h.hasNewAck { - return nil - } +func (h *appDataReceivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { now := time.Now() - if onlyIfQueued { - if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { + if onlyIfQueued && !h.ackQueued { + if h.ackAlarm.IsZero() || h.ackAlarm.After(now) { return nil } - if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { + if h.logger.Debug() && !h.ackAlarm.IsZero() { h.logger.Debugf("Sending ACK because the ACK timer expired.") } } - - // This function always returns the same ACK frame struct, filled with the most recent values. - ack := h.lastAck + ack := h.receivedPacketTracker.GetAckFrame() if ack == nil { - ack = &wire.AckFrame{} + return nil } - ack.Reset() ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime)) - ack.ECT0 = h.ect0 - ack.ECT1 = h.ect1 - ack.ECNCE = h.ecnce - ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) - - h.lastAck = ack - h.ackAlarm = time.Time{} h.ackQueued = false - h.hasNewAck = false + h.ackAlarm = time.Time{} h.ackElicitingPacketsReceivedSinceLastAck = 0 return ack } -func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } - -func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { - return h.packetHistory.IsPotentiallyDuplicate(pn) -} +func (h *appDataReceivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } diff --git a/internal/ackhandler/received_packet_tracker_test.go b/internal/ackhandler/received_packet_tracker_test.go index 1c29f980..c970fbf1 100644 --- a/internal/ackhandler/received_packet_tracker_test.go +++ b/internal/ackhandler/received_packet_tracker_test.go @@ -15,13 +15,67 @@ var _ = Describe("Received Packet Tracker", func() { var tracker *receivedPacketTracker BeforeEach(func() { - tracker = newReceivedPacketTracker(utils.DefaultLogger) + tracker = newReceivedPacketTracker() + }) + + It("acknowledges packets", func() { + t := time.Now().Add(-10 * time.Second) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, t, true)).To(Succeed()) + ack := tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{{Smallest: 3, Largest: 3}})) + Expect(ack.DelayTime).To(BeZero()) + // now receive another packet + Expect(tracker.ReceivedPacket(protocol.PacketNumber(4), protocol.ECNNon, t.Add(time.Second), true)).To(Succeed()) + ack = tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{{Smallest: 3, Largest: 4}})) + Expect(ack.DelayTime).To(BeZero()) + }) + + It("also acknowledges delayed packets", func() { + t := time.Now().Add(-10 * time.Second) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, t, true)).To(Succeed()) + ack := tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(3))) + Expect(ack.DelayTime).To(BeZero()) + // now receive another packet + Expect(tracker.ReceivedPacket(protocol.PacketNumber(1), protocol.ECNNon, t.Add(time.Second), true)).To(Succeed()) + ack = tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(HaveLen(2)) + Expect(ack.AckRanges).To(ContainElement(wire.AckRange{Smallest: 1, Largest: 1})) + Expect(ack.AckRanges).To(ContainElement(wire.AckRange{Smallest: 3, Largest: 3})) + Expect(ack.DelayTime).To(BeZero()) + }) + + It("doesn't trigger ACKs for non-ack-eliciting packets", func() { + t := time.Now().Add(-10 * time.Second) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, t, false)).To(Succeed()) + Expect(tracker.GetAckFrame()).To(BeNil()) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(4), protocol.ECNNon, t.Add(5*time.Second), false)).To(Succeed()) + Expect(tracker.GetAckFrame()).To(BeNil()) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(5), protocol.ECNNon, t.Add(10*time.Second), true)).To(Succeed()) + ack := tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{{Smallest: 3, Largest: 5}})) + }) +}) + +var _ = Describe("Application Data Received Packet Tracker", func() { + var tracker *appDataReceivedPacketTracker + + BeforeEach(func() { + tracker = newAppDataReceivedPacketTracker(utils.DefaultLogger) }) Context("accepting packets", func() { It("saves the time when each packet arrived", func() { - Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, time.Now(), true)).To(Succeed()) - Expect(tracker.largestObservedRcvdTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) + t := time.Now() + Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, t, true)).To(Succeed()) + Expect(tracker.largestObservedRcvdTime).To(Equal(t)) }) It("updates the largestObserved and the largestObservedRcvdTime", func() { diff --git a/internal/utils/minmax.go b/internal/utils/minmax.go index 6884ef40..03a9c9a8 100644 --- a/internal/utils/minmax.go +++ b/internal/utils/minmax.go @@ -27,18 +27,6 @@ func MinTime(a, b time.Time) time.Time { return a } -// MinNonZeroTime returns the earlist time that is not time.Time{} -// If both a and b are time.Time{}, it returns time.Time{} -func MinNonZeroTime(a, b time.Time) time.Time { - if a.IsZero() { - return b - } - if b.IsZero() { - return a - } - return MinTime(a, b) -} - // MaxTime returns the later time func MaxTime(a, b time.Time) time.Time { if a.After(b) { diff --git a/internal/utils/minmax_test.go b/internal/utils/minmax_test.go index c480219c..3d648502 100644 --- a/internal/utils/minmax_test.go +++ b/internal/utils/minmax_test.go @@ -30,14 +30,4 @@ var _ = Describe("Min / Max", func() { Expect(MinNonZeroDuration(b, a)).To(Equal(b)) Expect(MinNonZeroDuration(time.Minute, time.Hour)).To(Equal(time.Minute)) }) - - It("returns the minium non-zero time", func() { - a := time.Time{} - b := time.Now() - Expect(MinNonZeroTime(time.Time{}, time.Time{})).To(Equal(time.Time{})) - Expect(MinNonZeroTime(a, b)).To(Equal(b)) - Expect(MinNonZeroTime(b, a)).To(Equal(b)) - Expect(MinNonZeroTime(b, b.Add(time.Second))).To(Equal(b)) - Expect(MinNonZeroTime(b.Add(time.Second), b)).To(Equal(b)) - }) })