From 12c9427f43c5e56692c8ec3ed5a7fe88fba8632a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 7 Sep 2020 16:52:56 +0700 Subject: [PATCH] save dummy packets in the packet history when skipping packet numbers --- internal/ackhandler/interfaces.go | 1 + .../ackhandler/packet_number_generator.go | 21 +--- .../packet_number_generator_test.go | 42 -------- internal/ackhandler/sent_packet_handler.go | 16 ++-- .../ackhandler/sent_packet_handler_test.go | 16 ++-- internal/ackhandler/sent_packet_history.go | 40 +++++--- .../ackhandler/sent_packet_history_test.go | 95 +++++++++++++++---- internal/protocol/params.go | 3 - 8 files changed, 124 insertions(+), 110 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 42d83d98..1465a654 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -19,6 +19,7 @@ type Packet struct { includedInBytesInFlight bool declaredLost bool + skippedPacket bool } // SentPacketHandler handles ACKs received for outgoing packets diff --git a/internal/ackhandler/packet_number_generator.go b/internal/ackhandler/packet_number_generator.go index 56fbf3d8..4569be9a 100644 --- a/internal/ackhandler/packet_number_generator.go +++ b/internal/ackhandler/packet_number_generator.go @@ -5,19 +5,16 @@ import ( "math" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" ) // The packetNumberGenerator generates the packet number for the next packet -// it randomly skips a packet number every averagePeriod packets (on average) -// it is guarantued to never skip two consecutive packet numbers +// it randomly skips a packet number every averagePeriod packets (on average). +// It is guaranteed to never skip two consecutive packet numbers. type packetNumberGenerator struct { averagePeriod protocol.PacketNumber next protocol.PacketNumber nextToSkip protocol.PacketNumber - - history []protocol.PacketNumber } func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator { @@ -40,14 +37,9 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber { p.next++ if p.next == p.nextToSkip { - if len(p.history)+1 > protocol.MaxTrackedSkippedPackets { - p.history = p.history[1:] - } - p.history = append(p.history, p.next) p.next++ p.generateNewSkip() } - return next } @@ -67,12 +59,3 @@ func (p *packetNumberGenerator) getRandomNumber() uint16 { num := uint16(b[0])<<8 + uint16(b[1]) return num } - -func (p *packetNumberGenerator) Validate(ack *wire.AckFrame) bool { - for _, pn := range p.history { - if ack.AcksPacket(pn) { - return false - } - } - return true -} diff --git a/internal/ackhandler/packet_number_generator_test.go b/internal/ackhandler/packet_number_generator_test.go index 9465a833..73bd914f 100644 --- a/internal/ackhandler/packet_number_generator_test.go +++ b/internal/ackhandler/packet_number_generator_test.go @@ -4,7 +4,6 @@ import ( "math" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -98,45 +97,4 @@ var _ = Describe("Packet Number Generator", func() { Expect(largest).To(BeNumerically(">", math.MaxUint16-300)) Expect(sum / uint64(rep)).To(BeNumerically("==", uint64(math.MaxUint16/2), 1000)) }) - - It("validates ACK frames", func() { - var skipped []protocol.PacketNumber - var lastPN protocol.PacketNumber - for len(skipped) < 3 { - if png.Peek() > lastPN+1 { - skipped = append(skipped, lastPN+1) - } - lastPN = png.Pop() - } - invalidACK := &wire.AckFrame{ - AckRanges: []wire.AckRange{{Smallest: 1, Largest: lastPN}}, - } - Expect(png.Validate(invalidACK)).To(BeFalse()) - validACK1 := &wire.AckFrame{ - AckRanges: []wire.AckRange{{Smallest: 1, Largest: skipped[0] - 1}}, - } - Expect(png.Validate(validACK1)).To(BeTrue()) - validACK2 := &wire.AckFrame{ - AckRanges: []wire.AckRange{ - {Smallest: 1, Largest: skipped[0] - 1}, - {Smallest: skipped[0] + 1, Largest: skipped[1] - 1}, - {Smallest: skipped[1] + 1, Largest: skipped[2] - 1}, - {Smallest: skipped[2] + 1, Largest: skipped[2] + 100}, - }, - } - Expect(png.Validate(validACK2)).To(BeTrue()) - }) - - It("tracks a maximum number of protocol.MaxTrackedSkippedPackets packets", func() { - var skipped []protocol.PacketNumber - var lastPN protocol.PacketNumber - for len(skipped) < protocol.MaxTrackedSkippedPackets+3 { - if png.Peek() > lastPN+1 { - skipped = append(skipped, lastPN+1) - } - lastPN = png.Pop() - Expect(len(png.history)).To(BeNumerically("<=", protocol.MaxTrackedSkippedPackets)) - } - Expect(len(png.history)).To(Equal(protocol.MaxTrackedSkippedPackets)) - }) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 97358d40..6b441443 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -161,6 +161,9 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { case protocol.Encryption0RTT: // TODO(#2067): invalidate sent data h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } if p.EncryptionLevel != protocol.Encryption0RTT { return false, nil } @@ -266,10 +269,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En pnSpace.largestAcked = utils.MaxPacketNumber(pnSpace.largestAcked, largestAcked) - if !pnSpace.pns.Validate(ack) { - return qerr.NewError(qerr.ProtocolViolation, "Received an ACK for a skipped packet number") - } - // Servers complete address validation when a protected packet is received. if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && (encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) { @@ -310,6 +309,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) } for _, p := range ackedPackets { + if p.skippedPacket { + return fmt.Errorf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel) + } if p.includedInBytesInFlight && !p.declaredLost { h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) } @@ -518,7 +520,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if packet.PacketNumber > pnSpace.largestAcked { return false, nil } - if packet.declaredLost { + if packet.declaredLost || packet.skippedPacket { return true, nil } @@ -763,7 +765,7 @@ func (h *sentPacketHandler) ResetForRetry() error { if firstPacketSendTime.IsZero() { firstPacketSendTime = p.SendTime } - if p.declaredLost { + if p.declaredLost || p.skippedPacket { return true, nil } h.queueFramesForRetransmission(p) @@ -772,7 +774,7 @@ func (h *sentPacketHandler) ResetForRetry() error { // All application data packets sent at this point are 0-RTT packets. // In the case of a Retry, we can assume that the server dropped all of them. h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { - if !p.declaredLost { + if !p.declaredLost && !p.skippedPacket { h.queueFramesForRetransmission(p) } return true, nil diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index c5837f01..9b428809 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -88,7 +88,7 @@ var _ = Describe("SentPacketHandler", func() { pnSpace := handler.getPacketNumberSpace(encLevel) var length int pnSpace.history.Iterate(func(p *Packet) (bool, error) { - if !p.declaredLost { + if !p.declaredLost && !p.skippedPacket { length++ } return true, nil @@ -550,11 +550,10 @@ var _ = Describe("SentPacketHandler", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes() cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - for i := protocol.PacketNumber(1); i < protocol.MaxOutstandingSentPackets; i++ { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) + for i := protocol.PacketNumber(0); i < protocol.MaxOutstandingSentPackets; i++ { Expect(handler.SendMode()).To(Equal(SendAny)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) } - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: protocol.MaxOutstandingSentPackets})) Expect(handler.SendMode()).To(Equal(SendAck)) }) @@ -747,7 +746,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode()).To(Equal(SendPTOInitial)) handler.SentPacket(initialPacket(&Packet{PacketNumber: 3})) Expect(handler.SendMode()).To(Equal(SendPTOInitial)) - handler.SentPacket(initialPacket(&Packet{PacketNumber: 3})) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 4})) Expect(handler.SendMode()).To(Equal(SendAny)) }) @@ -1038,6 +1037,9 @@ var _ = Describe("SentPacketHandler", func() { // TODO(#2067): invalidate 0-RTT data when 0-RTT is rejected It("retransmits 0-RTT packets when 0-RTT keys are dropped", func() { for i := protocol.PacketNumber(0); i < 6; i++ { + if i == 3 { + continue + } handler.SentPacket(ackElicitingPacket(&Packet{ PacketNumber: i, EncryptionLevel: protocol.Encryption0RTT, @@ -1046,9 +1048,9 @@ var _ = Describe("SentPacketHandler", func() { for i := protocol.PacketNumber(6); i < 12; i++ { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i})) } - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(12))) + Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11))) handler.DropPackets(protocol.Encryption0RTT) - Expect(lostPackets).To(Equal([]protocol.PacketNumber{0, 1, 2, 3, 4, 5})) + Expect(lostPackets).To(Equal([]protocol.PacketNumber{0, 1, 2, 4, 5})) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) }) diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go index 055d92d9..6ed44ad0 100644 --- a/internal/ackhandler/sent_packet_history.go +++ b/internal/ackhandler/sent_packet_history.go @@ -9,25 +9,41 @@ import ( ) type sentPacketHistory struct { - rttStats *utils.RTTStats - packetList *PacketList - packetMap map[protocol.PacketNumber]*PacketElement + rttStats *utils.RTTStats + packetList *PacketList + packetMap map[protocol.PacketNumber]*PacketElement + highestSent protocol.PacketNumber } func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory { return &sentPacketHistory{ - rttStats: rttStats, - packetList: NewPacketList(), - packetMap: make(map[protocol.PacketNumber]*PacketElement), + rttStats: rttStats, + packetList: NewPacketList(), + packetMap: make(map[protocol.PacketNumber]*PacketElement), + highestSent: protocol.InvalidPacketNumber, } } func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { - if !isAckEliciting { - return + if p.PacketNumber <= h.highestSent { + panic("non-sequential packet number use") + } + // Skipped packet numbers. + for pn := h.highestSent + 1; pn < p.PacketNumber; pn++ { + el := h.packetList.PushBack(Packet{ + PacketNumber: pn, + EncryptionLevel: p.EncryptionLevel, + SendTime: p.SendTime, + skippedPacket: true, + }) + h.packetMap[pn] = el + } + h.highestSent = p.PacketNumber + + if isAckEliciting { + el := h.packetList.PushBack(*p) + h.packetMap[p.PacketNumber] = el } - el := h.packetList.PushBack(*p) - h.packetMap[p.PacketNumber] = el } // Iterate iterates through all packets. @@ -48,7 +64,7 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err // FirstOutStanding returns the first outstanding packet. func (h *sentPacketHistory) FirstOutstanding() *Packet { for el := h.packetList.Front(); el != nil; el = el.Next() { - if !el.Value.declaredLost { + if !el.Value.declaredLost && !el.Value.skippedPacket { return &el.Value } } @@ -82,7 +98,7 @@ func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { if p.SendTime.After(now.Add(-maxAge)) { break } - if !p.declaredLost { // should only happen in the case of drastic RTT changes + if !p.skippedPacket && !p.declaredLost { // should only happen in the case of drastic RTT changes continue } delete(h.packetMap, p.PacketNumber) diff --git a/internal/ackhandler/sent_packet_history_test.go b/internal/ackhandler/sent_packet_history_test.go index bce01936..f691954d 100644 --- a/internal/ackhandler/sent_packet_history_test.go +++ b/internal/ackhandler/sent_packet_history_test.go @@ -18,10 +18,25 @@ var _ = Describe("SentPacketHistory", func() { ) expectInHistory := func(packetNumbers []protocol.PacketNumber) { - ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers))) - ExpectWithOffset(1, hist.packetList.Len()).To(Equal(len(packetNumbers))) + var mapLen int + for _, el := range hist.packetMap { + if !el.Value.skippedPacket { + mapLen++ + } + } + var listLen int + for el := hist.packetList.Front(); el != nil; el = el.Next() { + if !el.Value.skippedPacket { + listLen++ + } + } + ExpectWithOffset(1, mapLen).To(Equal(len(packetNumbers))) + ExpectWithOffset(1, listLen).To(Equal(len(packetNumbers))) i := 0 err := hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } pn := packetNumbers[i] ExpectWithOffset(1, p.PacketNumber).To(Equal(pn)) ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn)) @@ -48,12 +63,16 @@ var _ = Describe("SentPacketHistory", func() { hist.SentPacket(&Packet{PacketNumber: 3}, false) hist.SentPacket(&Packet{PacketNumber: 4}, true) expectInHistory([]protocol.PacketNumber{1, 4}) + for el := hist.packetList.Front(); el != nil; el = el.Next() { + Expect(el.Value.PacketNumber).ToNot(Equal(protocol.PacketNumber(3))) + } }) It("gets the length", func() { + hist.SentPacket(&Packet{PacketNumber: 0}, true) hist.SentPacket(&Packet{PacketNumber: 1}, true) - hist.SentPacket(&Packet{PacketNumber: 10}, true) - Expect(hist.Len()).To(Equal(2)) + hist.SentPacket(&Packet{PacketNumber: 2}, true) + Expect(hist.Len()).To(Equal(3)) }) Context("getting the first outstanding packet", func() { @@ -87,53 +106,79 @@ var _ = Describe("SentPacketHistory", func() { Context("iterating", func() { BeforeEach(func() { - hist.SentPacket(&Packet{PacketNumber: 10}, true) - hist.SentPacket(&Packet{PacketNumber: 14}, true) - hist.SentPacket(&Packet{PacketNumber: 18}, true) + hist.SentPacket(&Packet{PacketNumber: 1}, true) + hist.SentPacket(&Packet{PacketNumber: 4}, true) + hist.SentPacket(&Packet{PacketNumber: 8}, true) }) It("iterates over all packets", func() { var iterations []protocol.PacketNumber Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } iterations = append(iterations, p.PacketNumber) return true, nil })).To(Succeed()) - Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14, 18})) + Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4, 8})) + }) + + It("also iterates over skipped packets", func() { + var packets, skippedPackets []protocol.PacketNumber + Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + skippedPackets = append(skippedPackets, p.PacketNumber) + } else { + packets = append(packets, p.PacketNumber) + } + return true, nil + })).To(Succeed()) + Expect(packets).To(Equal([]protocol.PacketNumber{1, 4, 8})) + Expect(skippedPackets).To(Equal([]protocol.PacketNumber{0, 2, 3, 5, 6, 7})) }) It("stops iterating", func() { var iterations []protocol.PacketNumber Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } iterations = append(iterations, p.PacketNumber) - return p.PacketNumber != 14, nil + return p.PacketNumber != 4, nil })).To(Succeed()) - Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14})) + Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4})) }) It("returns the error", func() { testErr := errors.New("test error") var iterations []protocol.PacketNumber Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } iterations = append(iterations, p.PacketNumber) - if p.PacketNumber == 14 { + if p.PacketNumber == 4 { return false, testErr } return true, nil })).To(MatchError(testErr)) - Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14})) + Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4})) }) It("allows deletions", func() { var iterations []protocol.PacketNumber Expect(hist.Iterate(func(p *Packet) (bool, error) { + if p.skippedPacket { + return true, nil + } iterations = append(iterations, p.PacketNumber) - if p.PacketNumber == 14 { - Expect(hist.Remove(14)).To(Succeed()) + if p.PacketNumber == 4 { + Expect(hist.Remove(4)).To(Succeed()) } return true, nil })).To(Succeed()) - expectInHistory([]protocol.PacketNumber{10, 18}) - Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14, 18})) + expectInHistory([]protocol.PacketNumber{1, 8}) + Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4, 8})) }) }) @@ -181,19 +226,29 @@ var _ = Describe("SentPacketHistory", func() { It("deletes old packets after 3 PTOs", func() { now := time.Now() hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}, true) - Expect(hist.Len()).To(Equal(1)) + expectInHistory([]protocol.PacketNumber{10}) hist.DeleteOldPackets(now.Add(-time.Nanosecond)) - Expect(hist.Len()).To(Equal(1)) + expectInHistory([]protocol.PacketNumber{10}) hist.DeleteOldPackets(now) - Expect(hist.Len()).To(BeZero()) + expectInHistory([]protocol.PacketNumber{}) }) It("doesn't delete a packet if it hasn't been declared lost yet", func() { now := time.Now() hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}, true) hist.SentPacket(&Packet{PacketNumber: 11, SendTime: now.Add(-3 * pto), declaredLost: false}, true) - Expect(hist.Len()).To(Equal(2)) + expectInHistory([]protocol.PacketNumber{10, 11}) hist.DeleteOldPackets(now) + expectInHistory([]protocol.PacketNumber{11}) + }) + + It("deletes skipped packets", func() { + now := time.Now() + hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto)}, true) + expectInHistory([]protocol.PacketNumber{10}) + Expect(hist.Len()).To(Equal(11)) + hist.DeleteOldPackets(now) + expectInHistory([]protocol.PacketNumber{10}) // the packet was not declared lost Expect(hist.Len()).To(Equal(1)) }) }) diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 6bdfd0d3..1fdfab82 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -48,9 +48,6 @@ const MaxSessionUnprocessedPackets = 256 // SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack const SkipPacketAveragePeriodLength PacketNumber = 500 -// MaxTrackedSkippedPackets is the maximum number of skipped packet numbers the SentPacketHandler keep track of for Optimistic ACK attack mitigation -const MaxTrackedSkippedPackets = 10 - // MaxAcceptQueueSize is the maximum number of sessions that the server queues for accepting. // If the queue is full, new connection attempts will be rejected. const MaxAcceptQueueSize = 32