diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 1465a654..460a2571 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -62,7 +62,7 @@ type sentPacketTracker interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool - ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error + ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error DropPackets(protocol.EncryptionLevel) GetAlarmTimeout() time.Time diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index f2124fe6..f0bd10df 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -38,6 +38,7 @@ func newReceivedPacketHandler( func (h *receivedPacketHandler) ReceivedPacket( pn protocol.PacketNumber, + ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool, @@ -45,20 +46,20 @@ func (h *receivedPacketHandler) ReceivedPacket( h.sentPackets.ReceivedPacket(encLevel) switch encLevel { case protocol.EncryptionInitial: - h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) + h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) case protocol.EncryptionHandshake: - h.handshakePackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) + h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) case protocol.Encryption0RTT: if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) } - h.appDataPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) + h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) case protocol.Encryption1RTT: if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { h.lowest1RTTPacket = pn } h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) - h.appDataPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) + h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) default: panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel)) } diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index e3f4bd64..ebd989fa 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -33,27 +33,36 @@ var _ = Describe("Received Packet Handler", func() { sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionInitial).Times(2) sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionHandshake).Times(2) sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT).Times(2) - Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(5, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(3, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(2, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(4, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.ECT0, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.ECT1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(5, protocol.ECNCE, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.ECT0, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.ECT1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(4, protocol.ECNCE, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) initialAck := handler.GetAckFrame(protocol.EncryptionInitial, true) Expect(initialAck).ToNot(BeNil()) Expect(initialAck.AckRanges).To(HaveLen(1)) Expect(initialAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 2, Largest: 3})) Expect(initialAck.DelayTime).To(BeZero()) + Expect(initialAck.ECT0).To(BeEquivalentTo(2)) + Expect(initialAck.ECT1).To(BeZero()) + Expect(initialAck.ECNCE).To(BeZero()) handshakeAck := handler.GetAckFrame(protocol.EncryptionHandshake, true) Expect(handshakeAck).ToNot(BeNil()) Expect(handshakeAck.AckRanges).To(HaveLen(1)) Expect(handshakeAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2})) Expect(handshakeAck.DelayTime).To(BeZero()) + Expect(handshakeAck.ECT0).To(BeZero()) + Expect(handshakeAck.ECT1).To(BeEquivalentTo(2)) + Expect(handshakeAck.ECNCE).To(BeZero()) oneRTTAck := handler.GetAckFrame(protocol.Encryption1RTT, true) Expect(oneRTTAck).ToNot(BeNil()) Expect(oneRTTAck.AckRanges).To(HaveLen(1)) Expect(oneRTTAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5})) Expect(oneRTTAck.DelayTime).To(BeNumerically("~", time.Second, 50*time.Millisecond)) + Expect(oneRTTAck.ECT0).To(BeZero()) + Expect(oneRTTAck.ECT1).To(BeZero()) + Expect(oneRTTAck.ECNCE).To(BeEquivalentTo(2)) }) It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { @@ -61,8 +70,8 @@ var _ = Describe("Received Packet Handler", func() { sentPackets.EXPECT().ReceivedPacket(protocol.Encryption0RTT) sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT) sendTime := time.Now().Add(-time.Second) - Expect(handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) ack := handler.GetAckFrame(protocol.Encryption1RTT, true) Expect(ack).ToNot(BeNil()) Expect(ack.AckRanges).To(HaveLen(1)) @@ -73,25 +82,25 @@ var _ = Describe("Received Packet Handler", func() { sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now() - Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(11, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(12, protocol.Encryption0RTT, sendTime, true)).To(MatchError("received packet number 12 on a 0-RTT packet after receiving 11 on a 1-RTT packet")) + Expect(handler.ReceivedPacket(10, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(11, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(12, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(MatchError("received packet number 12 on a 0-RTT packet after receiving 11 on a 1-RTT packet")) }) It("allows reordered 0-RTT packets", func() { sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now() - Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(12, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(11, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(10, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(12, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(11, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) }) It("drops Initial packets", func() { sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) sendTime := time.Now().Add(-time.Second) - Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.GetAckFrame(protocol.EncryptionInitial, true)).ToNot(BeNil()) handler.DropPackets(protocol.EncryptionInitial) Expect(handler.GetAckFrame(protocol.EncryptionInitial, true)).To(BeNil()) @@ -102,8 +111,8 @@ var _ = Describe("Received Packet Handler", func() { sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now().Add(-time.Second) - Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).ToNot(BeNil()) handler.DropPackets(protocol.EncryptionInitial) Expect(handler.GetAckFrame(protocol.EncryptionHandshake, true)).To(BeNil()) @@ -118,16 +127,16 @@ var _ = Describe("Received Packet Handler", func() { sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() sendTime := time.Now() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2) - Expect(handler.ReceivedPacket(1, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) - Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(1, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) ack := handler.GetAckFrame(protocol.Encryption1RTT, true) Expect(ack).ToNot(BeNil()) Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked() - Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Return(protocol.PacketNumber(2)) - Expect(handler.ReceivedPacket(4, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(4, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) ack = handler.GetAckFrame(protocol.Encryption1RTT, true) Expect(ack).ToNot(BeNil()) Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(2))) @@ -140,20 +149,20 @@ var _ = Describe("Received Packet Handler", func() { sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() // Initial Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionInitial)).To(BeFalse()) - Expect(handler.ReceivedPacket(3, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionInitial)).To(BeTrue()) // Handshake Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionHandshake)).To(BeFalse()) - Expect(handler.ReceivedPacket(3, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.IsPotentiallyDuplicate(3, protocol.EncryptionHandshake)).To(BeTrue()) // 0-RTT Expect(handler.IsPotentiallyDuplicate(3, protocol.Encryption0RTT)).To(BeFalse()) - Expect(handler.ReceivedPacket(3, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.IsPotentiallyDuplicate(3, protocol.Encryption0RTT)).To(BeTrue()) // 1-RTT Expect(handler.IsPotentiallyDuplicate(3, protocol.Encryption1RTT)).To(BeTrue()) Expect(handler.IsPotentiallyDuplicate(4, protocol.Encryption1RTT)).To(BeFalse()) - Expect(handler.ReceivedPacket(4, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(4, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) Expect(handler.IsPotentiallyDuplicate(4, protocol.Encryption1RTT)).To(BeTrue()) }) }) diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index 085293b3..56e79269 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -15,6 +15,7 @@ type receivedPacketTracker struct { largestObserved protocol.PacketNumber ignoreBelow protocol.PacketNumber largestObservedReceivedTime time.Time + ect0, ect1, ecnce uint64 packetHistory *receivedPacketHistory @@ -47,7 +48,7 @@ func newReceivedPacketTracker( } } -func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) { +func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) { if packetNumber < h.ignoreBelow { return } @@ -64,6 +65,15 @@ func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumbe if shouldInstigateAck { h.maybeQueueAck(packetNumber, rcvTime, isMissing) } + switch ecn { + case protocol.ECNNon: + case protocol.ECT0: + h.ect0++ + case protocol.ECT1: + h.ect1++ + case protocol.ECNCE: + h.ecnce++ + } } // IgnoreBelow sets a lower limit for acknowledging packets. @@ -166,6 +176,9 @@ func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { // Make sure that the DelayTime is always positive. // This is not guaranteed on systems that don't have a monotonic clock. DelayTime: utils.MaxDuration(0, now.Sub(h.largestObservedReceivedTime)), + ECT0: h.ect0, + ECT1: h.ect1, + ECNCE: h.ecnce, } h.lastAck = ack diff --git a/internal/ackhandler/received_packet_tracker_test.go b/internal/ackhandler/received_packet_tracker_test.go index 731e28b2..9980634c 100644 --- a/internal/ackhandler/received_packet_tracker_test.go +++ b/internal/ackhandler/received_packet_tracker_test.go @@ -24,7 +24,7 @@ var _ = Describe("Received Packet Tracker", func() { Context("accepting packets", func() { It("saves the time when each packet arrived", func() { - tracker.ReceivedPacket(protocol.PacketNumber(3), time.Now(), true) + tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, time.Now(), true) Expect(tracker.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) }) @@ -32,7 +32,7 @@ var _ = Describe("Received Packet Tracker", func() { now := time.Now() tracker.largestObserved = 3 tracker.largestObservedReceivedTime = now.Add(-1 * time.Second) - tracker.ReceivedPacket(5, now, true) + tracker.ReceivedPacket(5, protocol.ECNNon, now, true) Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) Expect(tracker.largestObservedReceivedTime).To(Equal(now)) }) @@ -42,7 +42,7 @@ var _ = Describe("Received Packet Tracker", func() { timestamp := now.Add(-1 * time.Second) tracker.largestObserved = 5 tracker.largestObservedReceivedTime = timestamp - tracker.ReceivedPacket(4, now, true) + tracker.ReceivedPacket(4, protocol.ECNNon, now, true) Expect(tracker.largestObserved).To(Equal(protocol.PacketNumber(5))) Expect(tracker.largestObservedReceivedTime).To(Equal(timestamp)) }) @@ -52,34 +52,51 @@ var _ = Describe("Received Packet Tracker", func() { Context("queueing ACKs", func() { receiveAndAck10Packets := func() { for i := 1; i <= 10; i++ { - tracker.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true) + tracker.ReceivedPacket(protocol.PacketNumber(i), protocol.ECNNon, time.Time{}, true) } Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) Expect(tracker.ackQueued).To(BeFalse()) } It("always queues an ACK for the first packet", func() { - tracker.ReceivedPacket(1, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) Expect(tracker.ackQueued).To(BeTrue()) Expect(tracker.GetAlarmTimeout()).To(BeZero()) Expect(tracker.GetAckFrame(true).DelayTime).To(BeNumerically("~", 0, time.Second)) }) It("works with packet number 0", func() { - tracker.ReceivedPacket(0, time.Now(), true) + tracker.ReceivedPacket(0, protocol.ECNNon, time.Now(), true) Expect(tracker.ackQueued).To(BeTrue()) Expect(tracker.GetAlarmTimeout()).To(BeZero()) Expect(tracker.GetAckFrame(true).DelayTime).To(BeNumerically("~", 0, time.Second)) }) + It("sets ECN flags", func() { + tracker.ReceivedPacket(0, protocol.ECT0, time.Now(), true) + pn := protocol.PacketNumber(1) + for i := 0; i < 2; i++ { + tracker.ReceivedPacket(pn, protocol.ECT1, time.Now(), true) + pn++ + } + for i := 0; i < 3; i++ { + tracker.ReceivedPacket(pn, protocol.ECNCE, time.Now(), true) + pn++ + } + ack := tracker.GetAckFrame(false) + Expect(ack.ECT0).To(BeEquivalentTo(1)) + Expect(ack.ECT1).To(BeEquivalentTo(2)) + Expect(ack.ECNCE).To(BeEquivalentTo(3)) + }) + It("queues an ACK for every second ack-eliciting packet", func() { receiveAndAck10Packets() p := protocol.PacketNumber(11) for i := 0; i <= 20; i++ { - tracker.ReceivedPacket(p, time.Time{}, true) + tracker.ReceivedPacket(p, protocol.ECNNon, time.Time{}, true) Expect(tracker.ackQueued).To(BeFalse()) p++ - tracker.ReceivedPacket(p, time.Time{}, true) + tracker.ReceivedPacket(p, protocol.ECNNon, time.Time{}, true) Expect(tracker.ackQueued).To(BeTrue()) p++ // dequeue the ACK frame @@ -90,47 +107,47 @@ var _ = Describe("Received Packet Tracker", func() { It("resets the counter when a non-queued ACK frame is generated", func() { receiveAndAck10Packets() rcvTime := time.Now() - tracker.ReceivedPacket(11, rcvTime, true) + tracker.ReceivedPacket(11, protocol.ECNNon, rcvTime, true) Expect(tracker.GetAckFrame(false)).ToNot(BeNil()) - tracker.ReceivedPacket(12, rcvTime, true) + tracker.ReceivedPacket(12, protocol.ECNNon, rcvTime, true) Expect(tracker.GetAckFrame(true)).To(BeNil()) - tracker.ReceivedPacket(13, rcvTime, true) + tracker.ReceivedPacket(13, protocol.ECNNon, rcvTime, true) Expect(tracker.GetAckFrame(false)).ToNot(BeNil()) }) It("only sets the timer when receiving a ack-eliciting packets", func() { receiveAndAck10Packets() - tracker.ReceivedPacket(11, time.Now(), false) + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), false) Expect(tracker.ackQueued).To(BeFalse()) Expect(tracker.GetAlarmTimeout()).To(BeZero()) rcvTime := time.Now().Add(10 * time.Millisecond) - tracker.ReceivedPacket(12, rcvTime, true) + tracker.ReceivedPacket(12, protocol.ECNNon, rcvTime, true) Expect(tracker.ackQueued).To(BeFalse()) Expect(tracker.GetAlarmTimeout()).To(Equal(rcvTime.Add(protocol.MaxAckDelay))) }) It("queues an ACK if it was reported missing before", func() { receiveAndAck10Packets() - tracker.ReceivedPacket(11, time.Now(), true) - tracker.ReceivedPacket(13, time.Now(), true) + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(13, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) // ACK: 1-11 and 13, missing: 12 Expect(ack).ToNot(BeNil()) Expect(ack.HasMissingRanges()).To(BeTrue()) Expect(tracker.ackQueued).To(BeFalse()) - tracker.ReceivedPacket(12, time.Now(), true) + tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), true) Expect(tracker.ackQueued).To(BeTrue()) }) It("doesn't queue an ACK if it was reported missing before, but is below the threshold", func() { receiveAndAck10Packets() // 11 is missing - tracker.ReceivedPacket(12, time.Now(), true) - tracker.ReceivedPacket(13, time.Now(), true) + tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(13, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) // ACK: 1-10, 12-13 Expect(ack).ToNot(BeNil()) // now receive 11 tracker.IgnoreBelow(12) - tracker.ReceivedPacket(11, time.Now(), false) + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), false) ack = tracker.GetAckFrame(true) Expect(ack).To(BeNil()) }) @@ -140,7 +157,7 @@ var _ = Describe("Received Packet Tracker", func() { Expect(tracker.lastAck.LargestAcked()).To(Equal(protocol.PacketNumber(10))) Expect(tracker.ackQueued).To(BeFalse()) tracker.IgnoreBelow(11) - tracker.ReceivedPacket(11, time.Now(), true) + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) Expect(tracker.GetAckFrame(true)).To(BeNil()) }) @@ -149,7 +166,7 @@ var _ = Describe("Received Packet Tracker", func() { Expect(tracker.lastAck.LargestAcked()).To(Equal(protocol.PacketNumber(10))) Expect(tracker.ackQueued).To(BeFalse()) tracker.IgnoreBelow(11) - tracker.ReceivedPacket(12, time.Now(), true) + tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(ack.AckRanges).To(Equal([]wire.AckRange{{Smallest: 12, Largest: 12}})) @@ -157,38 +174,38 @@ var _ = Describe("Received Packet Tracker", func() { It("doesn't queue an ACK if for non-ack-eliciting packets arriving out-of-order", func() { receiveAndAck10Packets() - tracker.ReceivedPacket(11, time.Now(), true) + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) Expect(tracker.GetAckFrame(true)).To(BeNil()) - tracker.ReceivedPacket(13, time.Now(), false) // receive a non-ack-eliciting packet out-of-order + tracker.ReceivedPacket(13, protocol.ECNNon, time.Now(), false) // receive a non-ack-eliciting packet out-of-order Expect(tracker.GetAckFrame(true)).To(BeNil()) }) It("doesn't queue an ACK if packets arrive out-of-order, but haven't been acknowledged yet", func() { receiveAndAck10Packets() Expect(tracker.lastAck).ToNot(BeNil()) - tracker.ReceivedPacket(12, time.Now(), false) + tracker.ReceivedPacket(12, protocol.ECNNon, time.Now(), false) Expect(tracker.GetAckFrame(true)).To(BeNil()) // 11 is received out-of-order, but this hasn't been reported in an ACK frame yet - tracker.ReceivedPacket(11, time.Now(), true) + tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true) Expect(tracker.GetAckFrame(true)).To(BeNil()) }) }) Context("ACK generation", func() { It("generates an ACK for an ack-eliciting packet, if no ACK is queued yet", func() { - tracker.ReceivedPacket(1, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) // The first packet is always acknowledged. Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) }) It("doesn't generate ACK for a non-ack-eliciting packet, if no ACK is queued yet", func() { - tracker.ReceivedPacket(1, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) // The first packet is always acknowledged. Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) - tracker.ReceivedPacket(2, time.Now(), false) + tracker.ReceivedPacket(2, protocol.ECNNon, time.Now(), false) Expect(tracker.GetAckFrame(false)).To(BeNil()) - tracker.ReceivedPacket(3, time.Now(), true) + tracker.ReceivedPacket(3, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(false) Expect(ack).ToNot(BeNil()) Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) @@ -201,8 +218,8 @@ var _ = Describe("Received Packet Tracker", func() { }) It("generates a simple ACK frame", func() { - tracker.ReceivedPacket(1, time.Now(), true) - tracker.ReceivedPacket(2, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(2, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) @@ -211,7 +228,7 @@ var _ = Describe("Received Packet Tracker", func() { }) It("generates an ACK for packet number 0", func() { - tracker.ReceivedPacket(0, time.Now(), true) + tracker.ReceivedPacket(0, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(0))) @@ -220,26 +237,26 @@ var _ = Describe("Received Packet Tracker", func() { }) It("sets the delay time", func() { - tracker.ReceivedPacket(1, time.Now(), true) - tracker.ReceivedPacket(2, time.Now().Add(-1337*time.Millisecond), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(2, protocol.ECNNon, time.Now().Add(-1337*time.Millisecond), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(ack.DelayTime).To(BeNumerically("~", 1337*time.Millisecond, 50*time.Millisecond)) }) It("uses a 0 delay time if the delay would be negative", func() { - tracker.ReceivedPacket(0, time.Now().Add(time.Hour), true) + tracker.ReceivedPacket(0, protocol.ECNNon, time.Now().Add(time.Hour), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(ack.DelayTime).To(BeZero()) }) It("saves the last sent ACK", func() { - tracker.ReceivedPacket(1, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(tracker.lastAck).To(Equal(ack)) - tracker.ReceivedPacket(2, time.Now(), true) + tracker.ReceivedPacket(2, protocol.ECNNon, time.Now(), true) tracker.ackQueued = true ack = tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) @@ -247,8 +264,8 @@ var _ = Describe("Received Packet Tracker", func() { }) It("generates an ACK frame with missing packets", func() { - tracker.ReceivedPacket(1, time.Now(), true) - tracker.ReceivedPacket(4, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(4, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4))) @@ -260,9 +277,9 @@ var _ = Describe("Received Packet Tracker", func() { }) It("generates an ACK for packet number 0 and other packets", func() { - tracker.ReceivedPacket(0, time.Now(), true) - tracker.ReceivedPacket(1, time.Now(), true) - tracker.ReceivedPacket(3, time.Now(), true) + tracker.ReceivedPacket(0, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(3, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3))) @@ -275,8 +292,8 @@ var _ = Describe("Received Packet Tracker", func() { It("doesn't add delayed packets to the packetHistory", func() { tracker.IgnoreBelow(7) - tracker.ReceivedPacket(4, time.Now(), true) - tracker.ReceivedPacket(10, time.Now(), true) + tracker.ReceivedPacket(4, protocol.ECNNon, time.Now(), true) + tracker.ReceivedPacket(10, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10))) @@ -285,7 +302,7 @@ var _ = Describe("Received Packet Tracker", func() { It("deletes packets from the packetHistory when a lower limit is set", func() { for i := 1; i <= 12; i++ { - tracker.ReceivedPacket(protocol.PacketNumber(i), time.Now(), true) + tracker.ReceivedPacket(protocol.PacketNumber(i), protocol.ECNNon, time.Now(), true) } tracker.IgnoreBelow(7) // check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame @@ -299,14 +316,14 @@ var _ = Describe("Received Packet Tracker", func() { // TODO: remove this test when dropping support for STOP_WAITINGs It("handles a lower limit of 0", func() { tracker.IgnoreBelow(0) - tracker.ReceivedPacket(1337, time.Now(), true) + tracker.ReceivedPacket(1337, protocol.ECNNon, time.Now(), true) ack := tracker.GetAckFrame(true) Expect(ack).ToNot(BeNil()) Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(1337))) }) It("resets all counters needed for the ACK queueing decision when sending an ACK", func() { - tracker.ReceivedPacket(1, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) tracker.ackAlarm = time.Now().Add(-time.Minute) Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) Expect(tracker.GetAlarmTimeout()).To(BeZero()) @@ -315,21 +332,21 @@ var _ = Describe("Received Packet Tracker", func() { }) It("doesn't generate an ACK when none is queued and the timer is not set", func() { - tracker.ReceivedPacket(1, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) tracker.ackQueued = false tracker.ackAlarm = time.Time{} Expect(tracker.GetAckFrame(true)).To(BeNil()) }) It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() { - tracker.ReceivedPacket(1, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) tracker.ackQueued = false tracker.ackAlarm = time.Now().Add(time.Minute) Expect(tracker.GetAckFrame(true)).To(BeNil()) }) It("generates an ACK when the timer has expired", func() { - tracker.ReceivedPacket(1, time.Now(), true) + tracker.ReceivedPacket(1, protocol.ECNNon, time.Now(), true) tracker.ackQueued = false tracker.ackAlarm = time.Now().Add(-time.Minute) Expect(tracker.GetAckFrame(true)).ToNot(BeNil()) diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index 6362cec1..d0c0f330 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -91,15 +91,15 @@ func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, ar } // ReceivedPacket mocks base method -func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel, arg2 time.Time, arg3 bool) error { +func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.ECN, arg2 protocol.EncryptionLevel, arg3 time.Time, arg4 bool) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) return ret0 } // ReceivedPacket indicates an expected call of ReceivedPacket -func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3, arg4) } diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index bf5333ab..014b371d 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -34,6 +34,15 @@ func (t PacketType) String() string { } } +type ECN uint8 + +const ( + ECNNon ECN = iota + ECT0 + ECT1 + ECNCE +) + // A ByteCount in QUIC type ByteCount uint64 diff --git a/session.go b/session.go index e813510e..1b5bcc74 100644 --- a/session.go +++ b/session.go @@ -59,11 +59,13 @@ type cryptoStreamHandler interface { } type receivedPacket struct { + buffer *packetBuffer + remoteAddr net.Addr rcvTime time.Time data []byte - buffer *packetBuffer + ecn protocol.ECN } func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } @@ -74,6 +76,7 @@ func (p *receivedPacket) Clone() *receivedPacket { rcvTime: p.rcvTime, data: p.data, buffer: p.buffer, + ecn: p.ecn, } } @@ -1067,7 +1070,7 @@ func (s *session) handleUnpackedPacket( } } - return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, packet.encryptionLevel, rcvTime, isAckEliciting) + return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, protocol.ECNNon, packet.encryptionLevel, rcvTime, isAckEliciting) } func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { diff --git a/session_test.go b/session_test.go index 68c271eb..056f9fde 100644 --- a/session_test.go +++ b/session_test.go @@ -766,7 +766,7 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) gomock.InOrder( rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial), - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionInitial, rcvTime, false), + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNNon, protocol.EncryptionInitial, rcvTime, false), ) sess.receivedPacketHandler = rph packet.rcvTime = rcvTime @@ -794,7 +794,7 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) gomock.InOrder( rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT), - rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.Encryption1RTT, rcvTime, true), + rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNNon, protocol.Encryption1RTT, rcvTime, true), ) sess.receivedPacketHandler = rph packet.rcvTime = rcvTime @@ -1213,7 +1213,7 @@ var _ = Describe("Session", func() { sess.handshakeConfirmed = true runSession() packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() - sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) + sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) sess.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() }) @@ -1853,7 +1853,7 @@ var _ = Describe("Session", func() { BeforeEach(func() { sess.config.MaxIdleTimeout = 30 * time.Second sess.config.KeepAlive = true - sess.receivedPacketHandler.ReceivedPacket(0, protocol.EncryptionHandshake, time.Now(), true) + sess.receivedPacketHandler.ReceivedPacket(0, protocol.ECNNon, protocol.EncryptionHandshake, time.Now(), true) }) AfterEach(func() {