From 1c5380c49b9d74fe8e02fb3d7cd4e371182ccf55 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Fri, 24 Mar 2017 18:28:59 +0100 Subject: [PATCH] Implement loss recovery from the current WG draft Fixes #498 and will hopefully go a long way towards fixing the many flaky tests. --- ackhandler/ackhandler_suite_test.go | 2 +- ackhandler/interfaces.go | 12 +- ackhandler/packet.go | 2 - ackhandler/sent_packet_handler.go | 329 +++++++++++--------- ackhandler/sent_packet_handler_test.go | 405 ++++++++++--------------- protocol/protocol.go | 14 +- protocol/server_parameters.go | 3 - session.go | 17 +- session_test.go | 26 +- 9 files changed, 376 insertions(+), 434 deletions(-) diff --git a/ackhandler/ackhandler_suite_test.go b/ackhandler/ackhandler_suite_test.go index 94589bef..53108c19 100644 --- a/ackhandler/ackhandler_suite_test.go +++ b/ackhandler/ackhandler_suite_test.go @@ -9,5 +9,5 @@ import ( func TestCrypto(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "AckHandler (New) Suite") + RunSpecs(t, "AckHandler Suite") } diff --git a/ackhandler/interfaces.go b/ackhandler/interfaces.go index 1d1b4be4..d44a2e0c 100644 --- a/ackhandler/interfaces.go +++ b/ackhandler/interfaces.go @@ -12,18 +12,16 @@ type SentPacketHandler interface { SentPacket(packet *Packet) error ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error + SendingAllowed() bool GetStopWaitingFrame(force bool) *frames.StopWaitingFrame - - MaybeQueueRTOs() DequeuePacketForRetransmission() (packet *Packet) - - BytesInFlight() protocol.ByteCount GetLeastUnacked() protocol.PacketNumber - SendingAllowed() bool - CheckForError() error + GetAlarmTimeout() time.Time + OnAlarm() - TimeOfFirstRTO() time.Time + // TODO(lclemente): Remove this now that the logic is simpler + CheckForError() error } // ReceivedPacketHandler handles ACKs needed to send for incoming packets diff --git a/ackhandler/packet.go b/ackhandler/packet.go index 17218854..e9dbf6ab 100644 --- a/ackhandler/packet.go +++ b/ackhandler/packet.go @@ -15,8 +15,6 @@ type Packet struct { Length protocol.ByteCount EncryptionLevel protocol.EncryptionLevel - MissingReports uint8 - SendTime time.Time } diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 686fd9d7..a6013a98 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -12,6 +12,18 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) +const ( + // Maximum reordering in time space before time based loss detection considers a packet lost. + // In fraction of an RTT. + timeReorderingFraction = 1.0 / 8 + // defaultRTOTimeout is the RTO time on new connections + defaultRTOTimeout = 500 * time.Millisecond + // Minimum time in the future an RTO alarm may be set for. + minRTOTimeout = 200 * time.Millisecond + // maxRTOTimeout is the maximum RTO time + maxRTOTimeout = 60 * time.Second +) + var ( // ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK") @@ -22,11 +34,10 @@ var ( errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package") ) -var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number.") +var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number") type sentPacketHandler struct { lastSentPacketNumber protocol.PacketNumber - lastSentPacketTime time.Time skippedPackets []protocol.PacketNumber LargestAcked protocol.PacketNumber @@ -40,10 +51,17 @@ type sentPacketHandler struct { bytesInFlight protocol.ByteCount - rttStats *congestion.RTTStats congestion congestion.SendAlgorithm + rttStats *congestion.RTTStats - consecutiveRTOCount uint32 + // The number of times an RTO has been sent without receiving an ack. + rtoCount uint32 + + // The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time. + lossTime time.Time + + // The alarm timeout + alarm time.Time } // NewSentPacketHandler creates a new sentPacketHandler @@ -64,40 +82,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler { } } -func (h *sentPacketHandler) ackPacket(packetElement *PacketElement) { - packet := &packetElement.Value - h.bytesInFlight -= packet.Length - h.packetHistory.Remove(packetElement) -} - -// nackPacket NACKs a packet -// it returns true if a FastRetransmissions was triggered -func (h *sentPacketHandler) nackPacket(packetElement *PacketElement) bool { - packet := &packetElement.Value - - packet.MissingReports++ - - if packet.MissingReports > protocol.RetransmissionThreshold { - utils.Debugf("\tQueueing packet 0x%x for retransmission (fast)", packet.PacketNumber) - h.queuePacketForRetransmission(packetElement) - return true - } - return false -} - -// does NOT set packet.Retransmitted. This variable is not needed anymore -func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) { - packet := &packetElement.Value - h.bytesInFlight -= packet.Length - h.retransmissionQueue = append(h.retransmissionQueue, packet) - - h.packetHistory.Remove(packetElement) - - // strictly speaking, this is only necessary for RTO retransmissions - // this is because FastRetransmissions are triggered by missing ranges in ACKs, and then the LargestAcked will already be higher than the packet number of the retransmitted packet - h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber) -} - func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber { if f := h.packetHistory.Front(); f != nil { return f.Value.PacketNumber - 1 @@ -119,7 +103,6 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { } now := time.Now() - h.lastSentPacketTime = now packet.SendTime = now if packet.Length == 0 { return errors.New("SentPacketHandler: packet cannot be empty") @@ -131,12 +114,14 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { h.congestion.OnPacketSent( now, - h.BytesInFlight(), + h.bytesInFlight, packet.PacketNumber, packet.Length, true, /* TODO: is retransmittable */ ) + h.updateLossDetectionAlarm() + return nil } @@ -149,54 +134,58 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum if withPacketNumber <= h.largestReceivedPacketWithAck { return ErrDuplicateOrOutOfOrderAck } - h.largestReceivedPacketWithAck = withPacketNumber // ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK) if ackFrame.LargestAcked <= h.largestInOrderAcked() { return nil } - - // check if it acks any packets that were skipped - for _, p := range h.skippedPackets { - if ackFrame.AcksPacket(p) { - return ErrAckForSkippedPacket - } - } - h.LargestAcked = ackFrame.LargestAcked - var ackedPackets congestion.PacketVector - var lostPackets congestion.PacketVector - ackRangeIndex := 0 - rttUpdated := false + if h.skippedPacketsAcked(ackFrame) { + return ErrAckForSkippedPacket + } - var el, elNext *PacketElement - for el = h.packetHistory.Front(); el != nil; el = elNext { - // determine the next list element right at the beginning, because el.Next() is not avaible anymore, when the list element is deleted (i.e. when the packet is ACKed) - elNext = el.Next() + rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime) + + ackedPackets, err := h.determineNewlyAckedPackets(ackFrame) + if err != nil { + return err + } + + if len(ackedPackets) > 0 { + var ackedPacketsCongestion congestion.PacketVector + for _, p := range ackedPackets { + h.onPacketAcked(p) + ackedPacketsCongestion = append(ackedPacketsCongestion, congestion.PacketInfo{ + Number: p.Value.PacketNumber, + Length: p.Value.Length, + }) + } + h.congestion.OnCongestionEvent(rttUpdated, h.bytesInFlight, ackedPacketsCongestion, nil) + } + + h.detectLostPackets(rttUpdated) + h.updateLossDetectionAlarm() + + h.garbageCollectSkippedPackets() + h.stopWaitingManager.ReceivedAck(ackFrame) + + return nil +} + +func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame) ([]*PacketElement, error) { + var ackedPackets []*PacketElement + ackRangeIndex := 0 + for el := h.packetHistory.Front(); el != nil; el = el.Next() { packet := el.Value packetNumber := packet.PacketNumber - // NACK packets below the LowestAcked + // Ignore packets below the LowestAcked if packetNumber < ackFrame.LowestAcked { - retransmitted := h.nackPacket(el) - if retransmitted { - lostPackets = append(lostPackets, congestion.PacketInfo{Number: packetNumber, Length: packet.Length}) - } continue } - - // Update the RTT - if packetNumber == h.LargestAcked { - rttUpdated = true - timeDelta := rcvTime.Sub(packet.SendTime) - h.rttStats.UpdateRTT(timeDelta, ackFrame.DelayTime, rcvTime) - if utils.Debug() { - utils.Debugf("\tEstimated RTT: %dms", h.rttStats.SmoothedRTT()/time.Millisecond) - } - } - + // Break after LargestAcked is reached if packetNumber > ackFrame.LargestAcked { break } @@ -211,59 +200,124 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum if packetNumber >= ackRange.FirstPacketNumber { // packet i contained in ACK range if packetNumber > ackRange.LastPacketNumber { - return fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.FirstPacketNumber, ackRange.LastPacketNumber) - } - h.ackPacket(el) - ackedPackets = append(ackedPackets, congestion.PacketInfo{Number: packetNumber, Length: packet.Length}) - } else { - retransmitted := h.nackPacket(el) - if retransmitted { - lostPackets = append(lostPackets, congestion.PacketInfo{Number: packetNumber, Length: packet.Length}) + return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.FirstPacketNumber, ackRange.LastPacketNumber) } + ackedPackets = append(ackedPackets, el) } } else { - h.ackPacket(el) - ackedPackets = append(ackedPackets, congestion.PacketInfo{Number: packetNumber, Length: packet.Length}) + ackedPackets = append(ackedPackets, el) } } - if rttUpdated { - // Reset counter if a new packet was acked - h.consecutiveRTOCount = 0 + return ackedPackets, nil +} + +func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool { + for el := h.packetHistory.Front(); el != nil; el = el.Next() { + packet := el.Value + if packet.PacketNumber == largestAcked { + h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now()) + return true + } + // Packets are sorted by number, so we can stop searching + if packet.PacketNumber > largestAcked { + break + } + } + return false +} + +func (h *sentPacketHandler) updateLossDetectionAlarm() { + // Cancel the alarm if no packets are outstanding + if h.packetHistory.Len() == 0 { + h.alarm = time.Time{} + return } - h.garbageCollectSkippedPackets() + // TODO(#496): Handle handshake packets separately + // TODO(#497): TLP + if !h.lossTime.IsZero() { + // Early retransmit timer or time loss detection. + h.alarm = h.lossTime + } else { + // RTO + h.alarm = time.Now().Add(h.computeRTOTimeout()) + } +} - h.stopWaitingManager.ReceivedAck(ackFrame) +// TODO(lucas-clemente): Introducing congestion.MaybeExitSlowStart() would allow us to call through for each packet and eliminate both the rttUpdated param and the packet slices passed to the congestion +func (h *sentPacketHandler) detectLostPackets(rttUpdated bool) { + h.lossTime = time.Time{} + now := time.Now() - h.congestion.OnCongestionEvent( - rttUpdated, - h.BytesInFlight(), - ackedPackets, - lostPackets, - ) + maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) + delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT) - return nil + var lostPackets []*PacketElement + for el := h.packetHistory.Front(); el != nil; el = el.Next() { + packet := el.Value + + if packet.PacketNumber > h.LargestAcked { + break + } + + timeSinceSent := now.Sub(packet.SendTime) + if timeSinceSent > delayUntilLost { + lostPackets = append(lostPackets, el) + } else if h.lossTime.IsZero() { + // Note: This conditional is only entered once per call + h.lossTime = now.Add(delayUntilLost - timeSinceSent) + } + } + + if len(lostPackets) > 0 { + var lostPacketsCongestion congestion.PacketVector + for _, p := range lostPackets { + h.queuePacketForRetransmission(p) + lostPacketsCongestion = append(lostPacketsCongestion, congestion.PacketInfo{ + Number: p.Value.PacketNumber, + Length: p.Value.Length, + }) + } + h.congestion.OnCongestionEvent(rttUpdated, h.bytesInFlight, nil, lostPacketsCongestion) + } +} + +func (h *sentPacketHandler) OnAlarm() { + // TODO(#496): Handle handshake packets separately + // TODO(#497): TLP + if !h.lossTime.IsZero() { + // Early retransmit or time loss detection + h.detectLostPackets(false /* rttUpdated */) + } else { + // RTO + h.retransmitOldestTwoPackets() + h.rtoCount++ + } + + h.updateLossDetectionAlarm() +} + +func (h *sentPacketHandler) GetAlarmTimeout() time.Time { + return h.alarm +} + +func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) { + h.bytesInFlight -= packetElement.Value.Length + h.rtoCount = 0 + // TODO(#497): h.tlpCount = 0 + h.packetHistory.Remove(packetElement) } func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { if len(h.retransmissionQueue) == 0 { return nil } - - if len(h.retransmissionQueue) > 0 { - queueLen := len(h.retransmissionQueue) - // packets are usually NACKed in descending order. So use the slice as a stack - packet := h.retransmissionQueue[queueLen-1] - h.retransmissionQueue = h.retransmissionQueue[:queueLen-1] - return packet - } - - return nil -} - -func (h *sentPacketHandler) BytesInFlight() protocol.ByteCount { - return h.bytesInFlight + queueLen := len(h.retransmissionQueue) + // packets are usually NACKed in descending order. So use the slice as a stack + packet := h.retransmissionQueue[queueLen-1] + h.retransmissionQueue = h.retransmissionQueue[:queueLen-1] + return packet } func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber { @@ -275,7 +329,7 @@ func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingF } func (h *sentPacketHandler) SendingAllowed() bool { - congestionLimited := h.BytesInFlight() > h.congestion.GetCongestionWindow() + congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow() maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets return !(congestionLimited || maxTrackedLimited) } @@ -288,22 +342,13 @@ func (h *sentPacketHandler) CheckForError() error { return nil } -func (h *sentPacketHandler) MaybeQueueRTOs() { - if time.Now().Before(h.TimeOfFirstRTO()) { - return +func (h *sentPacketHandler) retransmitOldestTwoPackets() { + if p := h.packetHistory.Front(); p != nil { + h.queueRTO(p) } - - // Always queue the two oldest packets - if h.packetHistory.Front() != nil { - h.queueRTO(h.packetHistory.Front()) + if p := h.packetHistory.Front(); p != nil { + h.queueRTO(p) } - if h.packetHistory.Front() != nil { - h.queueRTO(h.packetHistory.Front()) - } - - // Reset the RTO timer here, since it's not clear that this packet contained any retransmittable frames - h.lastSentPacketTime = time.Now() - h.consecutiveRTOCount++ } func (h *sentPacketHandler) queueRTO(el *PacketElement) { @@ -312,28 +357,42 @@ func (h *sentPacketHandler) queueRTO(el *PacketElement) { Number: packet.PacketNumber, Length: packet.Length, }} - h.congestion.OnCongestionEvent(false, h.BytesInFlight(), nil, packetsLost) - h.congestion.OnRetransmissionTimeout(true) utils.Debugf("\tQueueing packet 0x%x for retransmission (RTO)", packet.PacketNumber) h.queuePacketForRetransmission(el) + h.congestion.OnCongestionEvent(false, h.bytesInFlight, nil, packetsLost) + h.congestion.OnRetransmissionTimeout(true) } -func (h *sentPacketHandler) getRTO() time.Duration { +func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) { + packet := &packetElement.Value + h.bytesInFlight -= packet.Length + h.retransmissionQueue = append(h.retransmissionQueue, packet) + + h.packetHistory.Remove(packetElement) + + // strictly speaking, this is only necessary for RTO retransmissions + // this is because FastRetransmissions are triggered by missing ranges in ACKs, and then the LargestAcked will already be higher than the packet number of the retransmitted packet + h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber) +} + +func (h *sentPacketHandler) computeRTOTimeout() time.Duration { rto := h.congestion.RetransmissionDelay() if rto == 0 { - rto = protocol.DefaultRetransmissionTime + rto = defaultRTOTimeout } - rto = utils.MaxDuration(rto, protocol.MinRetransmissionTime) + rto = utils.MaxDuration(rto, minRTOTimeout) // Exponential backoff - rto *= 1 << h.consecutiveRTOCount - return utils.MinDuration(rto, protocol.MaxRetransmissionTime) + rto = rto << h.rtoCount + return utils.MinDuration(rto, maxRTOTimeout) } -func (h *sentPacketHandler) TimeOfFirstRTO() time.Time { - if h.lastSentPacketTime.IsZero() { - return time.Time{} +func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool { + for _, p := range h.skippedPackets { + if ackFrame.AcksPacket(p) { + return true + } } - return h.lastSentPacketTime.Add(h.getRTO()) + return false } func (h *sentPacketHandler) garbageCollectSkippedPackets() { diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index 00642b60..de2bea35 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -43,7 +43,7 @@ func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) { } func (m *mockCongestion) RetransmissionDelay() time.Duration { - return protocol.DefaultRetransmissionTime + return defaultRTOTimeout } func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not implemented") } @@ -90,7 +90,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(2))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(handler.packetHistory.Back().Value.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(3))) Expect(handler.skippedPackets).To(BeEmpty()) }) @@ -103,7 +103,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).To(MatchError(errPacketNumberNotIncreasing)) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(1))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) Expect(handler.skippedPackets).To(BeEmpty()) }) @@ -116,7 +116,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).To(MatchError(errPacketNumberNotIncreasing)) Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(2))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(1))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) }) It("stores the sent time", func() { @@ -126,13 +126,6 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.packetHistory.Front().Value.SendTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) }) - It("updates the last sent time", func() { - packet := Packet{PacketNumber: 1, Frames: []frames.Frame{&streamFrame}, Length: 1} - err := handler.SentPacket(&packet) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.lastSentPacketTime.Unix()).To(BeNumerically("~", time.Now().Unix(), 1)) - }) - Context("skipped packet numbers", func() { It("works with non-consecutive packet numbers", func() { packet1 := Packet{PacketNumber: 1, Frames: []frames.Frame{&streamFrame}, Length: 1} @@ -146,7 +139,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(3))) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(3))) Expect(handler.skippedPackets).To(HaveLen(1)) Expect(handler.skippedPackets[0]).To(Equal(protocol.PacketNumber(2))) }) @@ -244,9 +237,12 @@ var _ = Describe("SentPacketHandler", func() { {PacketNumber: 12, Frames: []frames.Frame{&streamFrame}, Length: 1}, } for _, packet := range packets { - handler.SentPacket(packet) + err := handler.SentPacket(packet) + Expect(err).NotTo(HaveOccurred()) } - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets)))) + // Increase RTT, because the tests would be flaky otherwise + handler.rttStats.UpdateRTT(time.Hour, 0, time.Now()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets)))) }) Context("ACK validation", func() { @@ -258,10 +254,10 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.ReceivedAck(&ack, 1337, time.Now()) Expect(err).ToNot(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 3))) err = handler.ReceivedAck(&ack, 1337, time.Now()) Expect(err).To(MatchError(ErrDuplicateOrOutOfOrderAck)) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 3))) }) It("rejects out of order ACKs", func() { @@ -270,11 +266,11 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.ReceivedAck(&ack, 1337, time.Now()) Expect(err).ToNot(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 3))) err = handler.ReceivedAck(&ack, 1337-1, time.Now()) Expect(err).To(MatchError(ErrDuplicateOrOutOfOrderAck)) Expect(handler.LargestAcked).To(Equal(protocol.PacketNumber(3))) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 3))) }) It("rejects ACKs with a too high LargestAcked packet number", func() { @@ -283,7 +279,7 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.ReceivedAck(&ack, 1, time.Now()) Expect(err).To(MatchError(errAckForUnsentPacket)) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets)))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets)))) }) It("ignores repeated ACKs", func() { @@ -293,11 +289,11 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.ReceivedAck(&ack, 1337, time.Now()) Expect(err).ToNot(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 3))) err = handler.ReceivedAck(&ack, 1337+1, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.LargestAcked).To(Equal(protocol.PacketNumber(3))) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 3))) }) It("rejects ACKs for skipped packets", func() { @@ -307,7 +303,6 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.ReceivedAck(&ack, 1337, time.Now()) Expect(err).To(MatchError(ErrAckForSkippedPacket)) - Expect(handler.LargestAcked).To(BeZero()) }) It("accepts an ACK that correctly nacks a skipped packet", func() { @@ -337,7 +332,6 @@ var _ = Describe("SentPacketHandler", func() { el := handler.packetHistory.Front() for i := 6; i <= 10; i++ { Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(i))) - Expect(el.Value.MissingReports).To(BeZero()) el = el.Next() } Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(12))) @@ -354,10 +348,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(9))) - Expect(el.Value.MissingReports).To(BeZero()) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(10))) - Expect(el.Value.MissingReports).To(BeZero()) Expect(el.Next().Value.PacketNumber).To(Equal(protocol.PacketNumber(12))) }) @@ -376,17 +368,14 @@ var _ = Describe("SentPacketHandler", func() { Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(4))) - Expect(el.Value.MissingReports).To(Equal(uint8(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(5))) - Expect(el.Value.MissingReports).To(Equal(uint8(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(10))) - Expect(el.Value.MissingReports).To(BeZero()) Expect(el.Next().Value.PacketNumber).To(Equal(protocol.PacketNumber(12))) }) - It("NACKs packets below the LowestAcked", func() { + It("Does not ack packets below the LowestAcked", func() { ack := frames.AckFrame{ LargestAcked: 8, LowestAcked: 3, @@ -395,10 +384,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) el := handler.packetHistory.Front() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(el.Value.MissingReports).To(Equal(uint8(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(el.Value.MissingReports).To(Equal(uint8(1))) Expect(el.Next().Value.PacketNumber).To(Equal(protocol.PacketNumber(9))) }) @@ -417,19 +404,14 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) el := handler.packetHistory.Front() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(el.Value.MissingReports).To(Equal(uint8(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(4))) - Expect(el.Value.MissingReports).To(Equal(uint8(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(5))) - Expect(el.Value.MissingReports).To(Equal(uint8(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(8))) - Expect(el.Value.MissingReports).To(Equal(uint8(1))) el = el.Next() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(10))) - Expect(el.Value.MissingReports).To(BeZero()) }) It("processes an ACK frame that would be sent after a late arrival of a packet", func() { @@ -444,7 +426,7 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.ReceivedAck(&ack1, 1, time.Now()) Expect(err).ToNot(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 5))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 5))) el := handler.packetHistory.Front() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(3))) ack2 := frames.AckFrame{ @@ -453,7 +435,7 @@ var _ = Describe("SentPacketHandler", func() { } err = handler.ReceivedAck(&ack2, 2, time.Now()) Expect(err).ToNot(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 6))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 6))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(7))) }) @@ -468,7 +450,7 @@ var _ = Describe("SentPacketHandler", func() { } err := handler.ReceivedAck(&ack1, 1, time.Now()) Expect(err).ToNot(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 5))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 5))) el := handler.packetHistory.Front() Expect(el.Value.PacketNumber).To(Equal(protocol.PacketNumber(3))) ack2 := frames.AckFrame{ @@ -477,7 +459,7 @@ var _ = Describe("SentPacketHandler", func() { } err = handler.ReceivedAck(&ack2, 2, time.Now()) Expect(err).ToNot(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 7))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 7))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(8))) }) @@ -489,7 +471,7 @@ var _ = Describe("SentPacketHandler", func() { err := handler.ReceivedAck(&ack1, 1, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(7))) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 6))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 6))) ack2 := frames.AckFrame{ LargestAcked: 10, LowestAcked: 1, @@ -501,14 +483,14 @@ var _ = Describe("SentPacketHandler", func() { } err = handler.ReceivedAck(&ack2, 2, time.Now()) Expect(err).ToNot(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(len(packets) - 6 - 3))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(len(packets) - 6 - 3))) Expect(handler.packetHistory.Front().Value.PacketNumber).To(Equal(protocol.PacketNumber(7))) Expect(handler.packetHistory.Back().Value.PacketNumber).To(Equal(protocol.PacketNumber(12))) }) }) Context("calculating RTT", func() { - It("calculates the RTT", func() { + It("computes the RTT", func() { now := time.Now() // First, fake the sent times of the first, second and last packet getPacketElement(1).Value.SendTime = now.Add(-10 * time.Minute) @@ -536,7 +518,7 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("Retransmission handler", func() { + Context("Retransmission handling", func() { var packets []*Packet BeforeEach(func() { @@ -552,67 +534,33 @@ var _ = Describe("SentPacketHandler", func() { for _, packet := range packets { handler.SentPacket(packet) } - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(7))) + // Increase RTT, because the tests would be flaky otherwise + handler.rttStats.UpdateRTT(time.Minute, 0, time.Now()) + // Ack a single packet so that we have non-RTO timings + handler.ReceivedAck(&frames.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, time.Now()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) }) - It("does not dequeue a packet if no packet has been nacked", func() { - for i := uint8(0); i < protocol.RetransmissionThreshold; i++ { - el := getPacketElement(2) - Expect(el).ToNot(BeNil()) - handler.nackPacket(el) - } - Expect(getPacketElement(2)).ToNot(BeNil()) - handler.MaybeQueueRTOs() + It("does not dequeue a packet if no ack has been received", func() { Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) }) - It("queues a packet for retransmission", func() { - for i := uint8(0); i < protocol.RetransmissionThreshold+1; i++ { - el := getPacketElement(2) - Expect(el).ToNot(BeNil()) - handler.nackPacket(el) - } - Expect(getPacketElement(2)).To(BeNil()) - handler.MaybeQueueRTOs() - Expect(handler.retransmissionQueue).To(HaveLen(1)) - Expect(handler.retransmissionQueue[0].PacketNumber).To(Equal(protocol.PacketNumber(2))) - }) - It("dequeues a packet for retransmission", func() { - for i := uint8(0); i < protocol.RetransmissionThreshold+1; i++ { - el := getPacketElement(3) - Expect(el).ToNot(BeNil()) - handler.nackPacket(el) - } + getPacketElement(1).Value.SendTime = time.Now().Add(-time.Hour) + handler.OnAlarm() + Expect(getPacketElement(1)).To(BeNil()) + Expect(handler.retransmissionQueue).To(HaveLen(1)) + Expect(handler.retransmissionQueue[0].PacketNumber).To(Equal(protocol.PacketNumber(1))) packet := handler.DequeuePacketForRetransmission() Expect(packet).ToNot(BeNil()) - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(3))) + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) }) - It("keeps the packets in the right order", func() { - for i := uint8(0); i < protocol.RetransmissionThreshold+1; i++ { - el := getPacketElement(4) - Expect(el).ToNot(BeNil()) - handler.nackPacket(el) - } - for i := uint8(0); i < protocol.RetransmissionThreshold+1; i++ { - el := getPacketElement(2) - Expect(el).ToNot(BeNil()) - handler.nackPacket(el) - } - packet := handler.DequeuePacketForRetransmission() - Expect(packet).ToNot(BeNil()) - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(2))) - packet = handler.DequeuePacketForRetransmission() - Expect(packet).ToNot(BeNil()) - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(4))) - }) - Context("StopWaitings", func() { It("gets a StopWaitingFrame", func() { ack := frames.AckFrame{LargestAcked: 5, LowestAcked: 5} - err := handler.ReceivedAck(&ack, 1, time.Now()) + err := handler.ReceivedAck(&ack, 2, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.GetStopWaitingFrame(false)).To(Equal(&frames.StopWaitingFrame{LeastUnacked: 6})) }) @@ -624,55 +572,40 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("calculating bytes in flight", func() { - It("works in a typical retransmission scenarios", func() { - packet1 := Packet{PacketNumber: 1, Frames: []frames.Frame{&streamFrame}, Length: 1} - packet2 := Packet{PacketNumber: 2, Frames: []frames.Frame{&streamFrame}, Length: 2} - packet3 := Packet{PacketNumber: 3, Frames: []frames.Frame{&streamFrame}, Length: 3} - err := handler.SentPacket(&packet1) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(&packet2) - Expect(err).NotTo(HaveOccurred()) - err = handler.SentPacket(&packet3) - Expect(err).NotTo(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(1 + 2 + 3))) + It("calculates bytes in flight", func() { + packet1 := Packet{PacketNumber: 1, Frames: []frames.Frame{&streamFrame}, Length: 1} + packet2 := Packet{PacketNumber: 2, Frames: []frames.Frame{&streamFrame}, Length: 2} + packet3 := Packet{PacketNumber: 3, Frames: []frames.Frame{&streamFrame}, Length: 3} + err := handler.SentPacket(&packet1) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) + err = handler.SentPacket(&packet2) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1 + 2))) + err = handler.SentPacket(&packet3) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1 + 2 + 3))) - // ACK 1 and 3, NACK 2 - ack := frames.AckFrame{ - LargestAcked: 3, - LowestAcked: 1, - AckRanges: []frames.AckRange{ - {FirstPacketNumber: 3, LastPacketNumber: 3}, - {FirstPacketNumber: 1, LastPacketNumber: 1}, - }, - } - err = handler.ReceivedAck(&ack, 1, time.Now()) - Expect(err).NotTo(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(2))) + // Increase RTT, because the tests would be flaky otherwise + handler.rttStats.UpdateRTT(time.Minute, 0, time.Now()) - // Simulate protocol.RetransmissionThreshold more NACKs - for i := uint8(0); i < protocol.RetransmissionThreshold; i++ { - el := getPacketElement(2) - Expect(el).ToNot(BeNil()) - handler.nackPacket(el) - } - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(0))) + // ACK 1 and 3, NACK 2 + ack := frames.AckFrame{ + LargestAcked: 3, + LowestAcked: 1, + AckRanges: []frames.AckRange{ + {FirstPacketNumber: 3, LastPacketNumber: 3}, + {FirstPacketNumber: 1, LastPacketNumber: 1}, + }, + } + err = handler.ReceivedAck(&ack, 1, time.Now()) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - // Retransmission - packet4 := Packet{PacketNumber: 4, Length: 2} - err = handler.SentPacket(&packet4) - Expect(err).NotTo(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(2))) + handler.packetHistory.Front().Value.SendTime = time.Now().Add(-time.Hour) + handler.OnAlarm() - // ACK - ack = frames.AckFrame{ - LargestAcked: 4, - LowestAcked: 1, - } - err = handler.ReceivedAck(&ack, 2, time.Now()) - Expect(err).NotTo(HaveOccurred()) - Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(0))) - }) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(0))) }) Context("congestion", func() { @@ -688,7 +621,6 @@ var _ = Describe("SentPacketHandler", func() { It("should call OnSent", func() { p := &Packet{ PacketNumber: 1, - Frames: []frames.Frame{&frames.StreamFrame{StreamID: 5}}, Length: 42, } err := handler.SentPacket(p) @@ -700,46 +632,31 @@ var _ = Describe("SentPacketHandler", func() { Expect(cong.argsOnPacketSent[4]).To(BeTrue()) }) - It("should call OnCongestionEvent", func() { + It("should call OnCongestionEvent for ACKs", func() { handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) - handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 2}) - handler.SentPacket(&Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 3}) - ack := frames.AckFrame{ - LargestAcked: 3, - LowestAcked: 1, - AckRanges: []frames.AckRange{ - {FirstPacketNumber: 3, LastPacketNumber: 3}, - {FirstPacketNumber: 1, LastPacketNumber: 1}, - }, - } - err := handler.ReceivedAck(&ack, 1, time.Now()) + handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 1}) + Expect(cong.nCalls).To(Equal(2)) + err := handler.ReceivedAck(&frames.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, time.Now()) Expect(err).NotTo(HaveOccurred()) - Expect(cong.nCalls).To(Equal(4)) // 3 * SentPacket + 1 * ReceivedAck - // rttUpdated, bytesInFlight, ackedPackets, lostPackets + Expect(cong.nCalls).To(Equal(3)) Expect(cong.argsOnCongestionEvent[0]).To(BeTrue()) - Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(2))) - Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{Number: 1, Length: 1}, {Number: 3, Length: 3}})) + Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(1))) + Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{Number: 1, Length: 1}})) Expect(cong.argsOnCongestionEvent[3]).To(BeEmpty()) + }) - // Loose the packet - var packetNumber protocol.PacketNumber - for i := uint8(0); i < protocol.RetransmissionThreshold; i++ { - packetNumber = protocol.PacketNumber(4 + i) - handler.SentPacket(&Packet{PacketNumber: packetNumber, Frames: []frames.Frame{}, Length: protocol.ByteCount(packetNumber)}) - ack := frames.AckFrame{ - LargestAcked: packetNumber, - LowestAcked: 1, - AckRanges: []frames.AckRange{ - {FirstPacketNumber: 3, LastPacketNumber: packetNumber}, - {FirstPacketNumber: 1, LastPacketNumber: 1}, - }, - } - err = handler.ReceivedAck(&ack, protocol.PacketNumber(2+i), time.Now()) - Expect(err).NotTo(HaveOccurred()) - } - - Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{Number: packetNumber, Length: protocol.ByteCount(packetNumber)}})) - Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{Number: 2, Length: 2}})) + It("should call OnCongestionEvent for losses", func() { + handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) + handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 1}) + handler.SentPacket(&Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 1}) + Expect(cong.nCalls).To(Equal(3)) + handler.OnAlarm() // RTO, meaning 2 lost packets + Expect(cong.nCalls).To(Equal(3 + 4 /* 2* (OnCongestionEvent+OnRTO)*/)) + Expect(cong.onRetransmissionTimeout).To(BeTrue()) + Expect(cong.argsOnCongestionEvent[0]).To(BeFalse()) + Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(1))) + Expect(cong.argsOnCongestionEvent[2]).To(BeEmpty()) + Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{Number: 2, Length: 1}})) }) It("allows or denies sending based on congestion", func() { @@ -754,111 +671,101 @@ var _ = Describe("SentPacketHandler", func() { handler.retransmissionQueue = make([]*Packet, protocol.MaxTrackedSentPackets) Expect(handler.SendingAllowed()).To(BeFalse()) }) - - It("should call OnRetransmissionTimeout", func() { - err := handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) - Expect(err).NotTo(HaveOccurred()) - handler.lastSentPacketTime = time.Now().Add(-time.Second) - handler.MaybeQueueRTOs() - Expect(cong.nCalls).To(Equal(3)) - // rttUpdated, bytesInFlight, ackedPackets, lostPackets - Expect(cong.argsOnCongestionEvent[0]).To(BeFalse()) - Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(1))) - Expect(cong.argsOnCongestionEvent[2]).To(BeEmpty()) - Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{Number: 1, Length: 1}})) - Expect(cong.onRetransmissionTimeout).To(BeTrue()) - }) }) Context("calculating RTO", func() { It("uses default RTO", func() { - Expect(handler.getRTO()).To(Equal(protocol.DefaultRetransmissionTime)) + Expect(handler.computeRTOTimeout()).To(Equal(defaultRTOTimeout)) }) It("uses RTO from rttStats", func() { rtt := time.Second expected := rtt + rtt/2*4 handler.rttStats.UpdateRTT(rtt, 0, time.Now()) - Expect(handler.getRTO()).To(Equal(expected)) + Expect(handler.computeRTOTimeout()).To(Equal(expected)) }) It("limits RTO min", func() { rtt := time.Millisecond handler.rttStats.UpdateRTT(rtt, 0, time.Now()) - Expect(handler.getRTO()).To(Equal(protocol.MinRetransmissionTime)) + Expect(handler.computeRTOTimeout()).To(Equal(minRTOTimeout)) }) It("limits RTO max", func() { rtt := time.Hour handler.rttStats.UpdateRTT(rtt, 0, time.Now()) - Expect(handler.getRTO()).To(Equal(protocol.MaxRetransmissionTime)) + Expect(handler.computeRTOTimeout()).To(Equal(maxRTOTimeout)) }) It("implements exponential backoff", func() { - handler.consecutiveRTOCount = 0 - Expect(handler.getRTO()).To(Equal(protocol.DefaultRetransmissionTime)) - handler.consecutiveRTOCount = 1 - Expect(handler.getRTO()).To(Equal(2 * protocol.DefaultRetransmissionTime)) - handler.consecutiveRTOCount = 2 - Expect(handler.getRTO()).To(Equal(4 * protocol.DefaultRetransmissionTime)) + handler.rtoCount = 0 + Expect(handler.computeRTOTimeout()).To(Equal(defaultRTOTimeout)) + handler.rtoCount = 1 + Expect(handler.computeRTOTimeout()).To(Equal(2 * defaultRTOTimeout)) + handler.rtoCount = 2 + Expect(handler.computeRTOTimeout()).To(Equal(4 * defaultRTOTimeout)) + }) + }) + + Context("Delay-based loss detection", func() { + It("detects a packet as lost", func() { + err := handler.SentPacket(&Packet{PacketNumber: 1, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + err = handler.SentPacket(&Packet{PacketNumber: 2, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.lossTime.IsZero()).To(BeTrue()) + + err = handler.ReceivedAck(&frames.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, time.Now().Add(time.Hour)) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.lossTime.IsZero()).To(BeFalse()) + + // RTT is around 1h now. + // The formula is (1+1/8) * RTT, so this should be around that number + Expect(handler.lossTime.Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute)) + Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", time.Hour*9/8, time.Minute)) + + handler.packetHistory.Front().Value.SendTime = time.Now().Add(-2 * time.Hour) + handler.OnAlarm() + Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil()) + }) + + It("does not detect packets as lost without ACKs", func() { + err := handler.SentPacket(&Packet{PacketNumber: 1, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + err = handler.SentPacket(&Packet{PacketNumber: 2, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + err = handler.SentPacket(&Packet{PacketNumber: 3, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.lossTime.IsZero()).To(BeTrue()) + + err = handler.ReceivedAck(&frames.AckFrame{LargestAcked: 1, LowestAcked: 1}, 1, time.Now().Add(time.Hour)) + Expect(err).NotTo(HaveOccurred()) + Expect(handler.lossTime.IsZero()).To(BeTrue()) + Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", handler.computeRTOTimeout(), time.Minute)) + + // This means RTO, so both packets should be lost + handler.OnAlarm() + Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil()) + Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil()) }) }) Context("RTO retransmission", func() { - Context("calculating the time to first RTO", func() { - It("defaults to zero", func() { - Expect(handler.TimeOfFirstRTO().IsZero()).To(BeTrue()) - }) - - It("returns time to RTO", func() { - err := handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) - Expect(err).NotTo(HaveOccurred()) - Expect(handler.TimeOfFirstRTO().Sub(time.Now())).To(BeNumerically("~", protocol.DefaultRetransmissionTime, time.Millisecond)) - }) - }) - - Context("queuing packets due to RTO", func() { - It("does nothing if not required", func() { - err := handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1}) - Expect(err).NotTo(HaveOccurred()) - handler.MaybeQueueRTOs() - Expect(handler.retransmissionQueue).To(BeEmpty()) - }) - - It("queues a packet if RTO expired", func() { - p := &Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1} - err := handler.SentPacket(p) - Expect(err).NotTo(HaveOccurred()) - handler.lastSentPacketTime = time.Now().Add(-time.Second) - handler.MaybeQueueRTOs() - Expect(handler.retransmissionQueue).To(HaveLen(1)) - Expect(handler.retransmissionQueue[0].PacketNumber).To(Equal(p.PacketNumber)) - Expect(time.Now().Sub(handler.lastSentPacketTime)).To(BeNumerically("<", time.Second/2)) - }) - - It("queues two packets if RTO expired", func() { - for i := 1; i < 4; i++ { - p := &Packet{PacketNumber: protocol.PacketNumber(i), Length: 1} - err := handler.SentPacket(p) - Expect(err).NotTo(HaveOccurred()) - } - handler.lastSentPacketTime = time.Now().Add(-time.Second) - handler.MaybeQueueRTOs() - Expect(handler.retransmissionQueue).To(HaveLen(2)) - Expect(handler.retransmissionQueue[0].PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(handler.retransmissionQueue[1].PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(time.Now().Sub(handler.lastSentPacketTime)).To(BeNumerically("<", time.Second/2)) - Expect(handler.consecutiveRTOCount).To(Equal(uint32(1))) - }) - }) - - It("works with DequeuePacketForRetransmission", func() { - p := &Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1} - err := handler.SentPacket(p) + It("queues two packets if RTO expires", func() { + err := handler.SentPacket(&Packet{PacketNumber: 1, Length: 1}) Expect(err).NotTo(HaveOccurred()) - handler.lastSentPacketTime = time.Now().Add(-time.Second) - handler.MaybeQueueRTOs() - Expect(handler.DequeuePacketForRetransmission().PacketNumber).To(Equal(p.PacketNumber)) + err = handler.SentPacket(&Packet{PacketNumber: 2, Length: 1}) + Expect(err).NotTo(HaveOccurred()) + + handler.rttStats.UpdateRTT(time.Hour, 0, time.Now()) + Expect(handler.lossTime.IsZero()).To(BeTrue()) + Expect(handler.GetAlarmTimeout().Sub(time.Now())).To(BeNumerically("~", handler.computeRTOTimeout(), time.Minute)) + + handler.OnAlarm() + Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil()) + Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil()) + + Expect(handler.rtoCount).To(BeEquivalentTo(1)) }) }) }) diff --git a/protocol/protocol.go b/protocol/protocol.go index 93c66948..a7a82d79 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -1,9 +1,6 @@ package protocol -import ( - "math" - "time" -) +import "math" // A PacketNumber in QUIC type PacketNumber uint64 @@ -55,15 +52,6 @@ const InitialStreamFlowControlWindow ByteCount = (1 << 14) // 16 kB // InitialConnectionFlowControlWindow is the initial connection-level flow control window for sending const InitialConnectionFlowControlWindow ByteCount = (1 << 14) // 16 kB -// DefaultRetransmissionTime is the RTO time on new connections -const DefaultRetransmissionTime = 500 * time.Millisecond - -// MinRetransmissionTime is the minimum RTO time -const MinRetransmissionTime = 200 * time.Millisecond - -// MaxRetransmissionTime is the maximum RTO time -const MaxRetransmissionTime = 60 * time.Second - // ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have. const ClientHelloMinimumSize = 1024 diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index 54f04bf2..334938a6 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -78,9 +78,6 @@ const MaxNewStreamIDDelta = 4 * MaxStreamsPerConnection // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. const MaxSessionUnprocessedPackets = DefaultMaxCongestionWindow -// RetransmissionThreshold + 1 is the number of times a packet has to be NACKed so that it gets retransmitted -const RetransmissionThreshold = 3 - // SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack const SkipPacketAveragePeriodLength PacketNumber = 500 diff --git a/session.go b/session.go index de01586e..56b53796 100644 --- a/session.go +++ b/session.go @@ -32,7 +32,7 @@ type receivedPacket struct { var ( errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream") - errSessionAlreadyClosed = errors.New("Cannot close session. It was already closed before.") + errSessionAlreadyClosed = errors.New("cannot close session; it was already closed before") ) // cryptoChangeCallback is called every time the encryption level changes @@ -251,10 +251,16 @@ runLoop: s.close(err) } + now := time.Now() + if s.sentPacketHandler.GetAlarmTimeout().Before(now) { + // This could cause packets to be retransmitted, so check it before trying + // to send packets. + s.sentPacketHandler.OnAlarm() + } + if err := s.sendPacket(); err != nil { s.close(err) } - now := time.Now() if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) { s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) } @@ -277,8 +283,8 @@ func (s *session) maybeResetTimer() { if !s.nextAckScheduledTime.IsZero() { nextDeadline = utils.MinTime(nextDeadline, s.nextAckScheduledTime) } - if rtoTime := s.sentPacketHandler.TimeOfFirstRTO(); !rtoTime.IsZero() { - nextDeadline = utils.MinTime(nextDeadline, rtoTime) + if lossTime := s.sentPacketHandler.GetAlarmTimeout(); !lossTime.IsZero() { + nextDeadline = utils.MinTime(nextDeadline, lossTime) } if !s.cryptoSetup.HandshakeComplete() { handshakeDeadline := s.sessionCreationTime.Add(protocol.MaxTimeForCryptoHandshake) @@ -566,9 +572,6 @@ func (s *session) sendPacket() error { return err } - // Do this before checking the congestion, since we might de-congestionize here :) - s.sentPacketHandler.MaybeQueueRTOs() - if !s.sentPacketHandler.SendingAllowed() { return nil } diff --git a/session_test.go b/session_test.go index 23db05f2..43e9287b 100644 --- a/session_test.go +++ b/session_test.go @@ -62,7 +62,6 @@ type mockSentPacketHandler struct { retransmissionQueue []*ackhandler.Packet sentPackets []*ackhandler.Packet congestionLimited bool - maybeQueueRTOsCalled bool requestedStopWaiting bool } @@ -70,22 +69,22 @@ func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error { h.sentPackets = append(h.sentPackets, packet) return nil } + func (h *mockSentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error { return nil } -func (h *mockSentPacketHandler) BytesInFlight() protocol.ByteCount { return 0 } + func (h *mockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { return 1 } +func (h *mockSentPacketHandler) CheckForError() error { return nil } + +func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") } +func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") } +func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited } + func (h *mockSentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame { h.requestedStopWaiting = true return &frames.StopWaitingFrame{LeastUnacked: 0x1337} } -func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited } -func (h *mockSentPacketHandler) CheckForError() error { return nil } -func (h *mockSentPacketHandler) TimeOfFirstRTO() time.Time { panic("not implemented") } - -func (h *mockSentPacketHandler) MaybeQueueRTOs() { - h.maybeQueueRTOsCalled = true -} func (h *mockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet { if len(h.retransmissionQueue) > 0 { @@ -433,6 +432,7 @@ var _ = Describe("Session", func() { It("doesn't queue a RST_STREAM for a stream that it already sent a FIN on", func() { str, err := sess.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) str.(*stream).sentFin() str.Close() err = sess.handleRstStreamFrame(&frames.RstStreamFrame{ @@ -973,14 +973,6 @@ var _ = Describe("Session", func() { Expect(ok).To(BeTrue()) }) - It("calls MaybeQueueRTOs even if congestion blocked, so that bytesInFlight is updated", func() { - sph.congestionLimited = true - sess.sentPacketHandler = sph - err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sph.maybeQueueRTOsCalled).To(BeTrue()) - }) - It("retransmits a WindowUpdates if it hasn't already sent a WindowUpdate with a higher ByteOffset", func() { _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred())