save dummy packets in the packet history when skipping packet numbers

This commit is contained in:
Marten Seemann 2020-09-07 16:52:56 +07:00
parent 0044a69d73
commit 12c9427f43
8 changed files with 124 additions and 110 deletions

View file

@ -19,6 +19,7 @@ type Packet struct {
includedInBytesInFlight bool
declaredLost bool
skippedPacket bool
}
// SentPacketHandler handles ACKs received for outgoing packets

View file

@ -5,19 +5,16 @@ import (
"math"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// The packetNumberGenerator generates the packet number for the next packet
// it randomly skips a packet number every averagePeriod packets (on average)
// it is guarantued to never skip two consecutive packet numbers
// it randomly skips a packet number every averagePeriod packets (on average).
// It is guaranteed to never skip two consecutive packet numbers.
type packetNumberGenerator struct {
averagePeriod protocol.PacketNumber
next protocol.PacketNumber
nextToSkip protocol.PacketNumber
history []protocol.PacketNumber
}
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
@ -40,14 +37,9 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
p.next++
if p.next == p.nextToSkip {
if len(p.history)+1 > protocol.MaxTrackedSkippedPackets {
p.history = p.history[1:]
}
p.history = append(p.history, p.next)
p.next++
p.generateNewSkip()
}
return next
}
@ -67,12 +59,3 @@ func (p *packetNumberGenerator) getRandomNumber() uint16 {
num := uint16(b[0])<<8 + uint16(b[1])
return num
}
func (p *packetNumberGenerator) Validate(ack *wire.AckFrame) bool {
for _, pn := range p.history {
if ack.AcksPacket(pn) {
return false
}
}
return true
}

View file

@ -4,7 +4,6 @@ import (
"math"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@ -98,45 +97,4 @@ var _ = Describe("Packet Number Generator", func() {
Expect(largest).To(BeNumerically(">", math.MaxUint16-300))
Expect(sum / uint64(rep)).To(BeNumerically("==", uint64(math.MaxUint16/2), 1000))
})
It("validates ACK frames", func() {
var skipped []protocol.PacketNumber
var lastPN protocol.PacketNumber
for len(skipped) < 3 {
if png.Peek() > lastPN+1 {
skipped = append(skipped, lastPN+1)
}
lastPN = png.Pop()
}
invalidACK := &wire.AckFrame{
AckRanges: []wire.AckRange{{Smallest: 1, Largest: lastPN}},
}
Expect(png.Validate(invalidACK)).To(BeFalse())
validACK1 := &wire.AckFrame{
AckRanges: []wire.AckRange{{Smallest: 1, Largest: skipped[0] - 1}},
}
Expect(png.Validate(validACK1)).To(BeTrue())
validACK2 := &wire.AckFrame{
AckRanges: []wire.AckRange{
{Smallest: 1, Largest: skipped[0] - 1},
{Smallest: skipped[0] + 1, Largest: skipped[1] - 1},
{Smallest: skipped[1] + 1, Largest: skipped[2] - 1},
{Smallest: skipped[2] + 1, Largest: skipped[2] + 100},
},
}
Expect(png.Validate(validACK2)).To(BeTrue())
})
It("tracks a maximum number of protocol.MaxTrackedSkippedPackets packets", func() {
var skipped []protocol.PacketNumber
var lastPN protocol.PacketNumber
for len(skipped) < protocol.MaxTrackedSkippedPackets+3 {
if png.Peek() > lastPN+1 {
skipped = append(skipped, lastPN+1)
}
lastPN = png.Pop()
Expect(len(png.history)).To(BeNumerically("<=", protocol.MaxTrackedSkippedPackets))
}
Expect(len(png.history)).To(Equal(protocol.MaxTrackedSkippedPackets))
})
})

View file

@ -161,6 +161,9 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
case protocol.Encryption0RTT:
// TODO(#2067): invalidate sent data
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) {
if p.skippedPacket {
return true, nil
}
if p.EncryptionLevel != protocol.Encryption0RTT {
return false, nil
}
@ -266,10 +269,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
pnSpace.largestAcked = utils.MaxPacketNumber(pnSpace.largestAcked, largestAcked)
if !pnSpace.pns.Validate(ack) {
return qerr.NewError(qerr.ProtocolViolation, "Received an ACK for a skipped packet number")
}
// Servers complete address validation when a protected packet is received.
if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation &&
(encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) {
@ -310,6 +309,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
for _, p := range ackedPackets {
if p.skippedPacket {
return fmt.Errorf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel)
}
if p.includedInBytesInFlight && !p.declaredLost {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
@ -518,7 +520,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
if packet.PacketNumber > pnSpace.largestAcked {
return false, nil
}
if packet.declaredLost {
if packet.declaredLost || packet.skippedPacket {
return true, nil
}
@ -763,7 +765,7 @@ func (h *sentPacketHandler) ResetForRetry() error {
if firstPacketSendTime.IsZero() {
firstPacketSendTime = p.SendTime
}
if p.declaredLost {
if p.declaredLost || p.skippedPacket {
return true, nil
}
h.queueFramesForRetransmission(p)
@ -772,7 +774,7 @@ func (h *sentPacketHandler) ResetForRetry() error {
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) {
if !p.declaredLost {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
return true, nil

View file

@ -88,7 +88,7 @@ var _ = Describe("SentPacketHandler", func() {
pnSpace := handler.getPacketNumberSpace(encLevel)
var length int
pnSpace.history.Iterate(func(p *Packet) (bool, error) {
if !p.declaredLost {
if !p.declaredLost && !p.skippedPacket {
length++
}
return true, nil
@ -550,11 +550,10 @@ var _ = Describe("SentPacketHandler", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes()
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
for i := protocol.PacketNumber(1); i < protocol.MaxOutstandingSentPackets; i++ {
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i}))
for i := protocol.PacketNumber(0); i < protocol.MaxOutstandingSentPackets; i++ {
Expect(handler.SendMode()).To(Equal(SendAny))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i}))
}
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: protocol.MaxOutstandingSentPackets}))
Expect(handler.SendMode()).To(Equal(SendAck))
})
@ -747,7 +746,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.SendMode()).To(Equal(SendPTOInitial))
handler.SentPacket(initialPacket(&Packet{PacketNumber: 3}))
Expect(handler.SendMode()).To(Equal(SendPTOInitial))
handler.SentPacket(initialPacket(&Packet{PacketNumber: 3}))
handler.SentPacket(initialPacket(&Packet{PacketNumber: 4}))
Expect(handler.SendMode()).To(Equal(SendAny))
})
@ -1038,6 +1037,9 @@ var _ = Describe("SentPacketHandler", func() {
// TODO(#2067): invalidate 0-RTT data when 0-RTT is rejected
It("retransmits 0-RTT packets when 0-RTT keys are dropped", func() {
for i := protocol.PacketNumber(0); i < 6; i++ {
if i == 3 {
continue
}
handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: i,
EncryptionLevel: protocol.Encryption0RTT,
@ -1046,9 +1048,9 @@ var _ = Describe("SentPacketHandler", func() {
for i := protocol.PacketNumber(6); i < 12; i++ {
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: i}))
}
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(12)))
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11)))
handler.DropPackets(protocol.Encryption0RTT)
Expect(lostPackets).To(Equal([]protocol.PacketNumber{0, 1, 2, 3, 4, 5}))
Expect(lostPackets).To(Equal([]protocol.PacketNumber{0, 1, 2, 4, 5}))
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6)))
})

View file

@ -9,25 +9,41 @@ import (
)
type sentPacketHistory struct {
rttStats *utils.RTTStats
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
rttStats *utils.RTTStats
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
highestSent protocol.PacketNumber
}
func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory {
return &sentPacketHistory{
rttStats: rttStats,
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
rttStats: rttStats,
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
highestSent: protocol.InvalidPacketNumber,
}
}
func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) {
if !isAckEliciting {
return
if p.PacketNumber <= h.highestSent {
panic("non-sequential packet number use")
}
// Skipped packet numbers.
for pn := h.highestSent + 1; pn < p.PacketNumber; pn++ {
el := h.packetList.PushBack(Packet{
PacketNumber: pn,
EncryptionLevel: p.EncryptionLevel,
SendTime: p.SendTime,
skippedPacket: true,
})
h.packetMap[pn] = el
}
h.highestSent = p.PacketNumber
if isAckEliciting {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
}
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
}
// Iterate iterates through all packets.
@ -48,7 +64,7 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err
// FirstOutStanding returns the first outstanding packet.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
for el := h.packetList.Front(); el != nil; el = el.Next() {
if !el.Value.declaredLost {
if !el.Value.declaredLost && !el.Value.skippedPacket {
return &el.Value
}
}
@ -82,7 +98,7 @@ func (h *sentPacketHistory) DeleteOldPackets(now time.Time) {
if p.SendTime.After(now.Add(-maxAge)) {
break
}
if !p.declaredLost { // should only happen in the case of drastic RTT changes
if !p.skippedPacket && !p.declaredLost { // should only happen in the case of drastic RTT changes
continue
}
delete(h.packetMap, p.PacketNumber)

View file

@ -18,10 +18,25 @@ var _ = Describe("SentPacketHistory", func() {
)
expectInHistory := func(packetNumbers []protocol.PacketNumber) {
ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers)))
ExpectWithOffset(1, hist.packetList.Len()).To(Equal(len(packetNumbers)))
var mapLen int
for _, el := range hist.packetMap {
if !el.Value.skippedPacket {
mapLen++
}
}
var listLen int
for el := hist.packetList.Front(); el != nil; el = el.Next() {
if !el.Value.skippedPacket {
listLen++
}
}
ExpectWithOffset(1, mapLen).To(Equal(len(packetNumbers)))
ExpectWithOffset(1, listLen).To(Equal(len(packetNumbers)))
i := 0
err := hist.Iterate(func(p *Packet) (bool, error) {
if p.skippedPacket {
return true, nil
}
pn := packetNumbers[i]
ExpectWithOffset(1, p.PacketNumber).To(Equal(pn))
ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn))
@ -48,12 +63,16 @@ var _ = Describe("SentPacketHistory", func() {
hist.SentPacket(&Packet{PacketNumber: 3}, false)
hist.SentPacket(&Packet{PacketNumber: 4}, true)
expectInHistory([]protocol.PacketNumber{1, 4})
for el := hist.packetList.Front(); el != nil; el = el.Next() {
Expect(el.Value.PacketNumber).ToNot(Equal(protocol.PacketNumber(3)))
}
})
It("gets the length", func() {
hist.SentPacket(&Packet{PacketNumber: 0}, true)
hist.SentPacket(&Packet{PacketNumber: 1}, true)
hist.SentPacket(&Packet{PacketNumber: 10}, true)
Expect(hist.Len()).To(Equal(2))
hist.SentPacket(&Packet{PacketNumber: 2}, true)
Expect(hist.Len()).To(Equal(3))
})
Context("getting the first outstanding packet", func() {
@ -87,53 +106,79 @@ var _ = Describe("SentPacketHistory", func() {
Context("iterating", func() {
BeforeEach(func() {
hist.SentPacket(&Packet{PacketNumber: 10}, true)
hist.SentPacket(&Packet{PacketNumber: 14}, true)
hist.SentPacket(&Packet{PacketNumber: 18}, true)
hist.SentPacket(&Packet{PacketNumber: 1}, true)
hist.SentPacket(&Packet{PacketNumber: 4}, true)
hist.SentPacket(&Packet{PacketNumber: 8}, true)
})
It("iterates over all packets", func() {
var iterations []protocol.PacketNumber
Expect(hist.Iterate(func(p *Packet) (bool, error) {
if p.skippedPacket {
return true, nil
}
iterations = append(iterations, p.PacketNumber)
return true, nil
})).To(Succeed())
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14, 18}))
Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4, 8}))
})
It("also iterates over skipped packets", func() {
var packets, skippedPackets []protocol.PacketNumber
Expect(hist.Iterate(func(p *Packet) (bool, error) {
if p.skippedPacket {
skippedPackets = append(skippedPackets, p.PacketNumber)
} else {
packets = append(packets, p.PacketNumber)
}
return true, nil
})).To(Succeed())
Expect(packets).To(Equal([]protocol.PacketNumber{1, 4, 8}))
Expect(skippedPackets).To(Equal([]protocol.PacketNumber{0, 2, 3, 5, 6, 7}))
})
It("stops iterating", func() {
var iterations []protocol.PacketNumber
Expect(hist.Iterate(func(p *Packet) (bool, error) {
if p.skippedPacket {
return true, nil
}
iterations = append(iterations, p.PacketNumber)
return p.PacketNumber != 14, nil
return p.PacketNumber != 4, nil
})).To(Succeed())
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4}))
})
It("returns the error", func() {
testErr := errors.New("test error")
var iterations []protocol.PacketNumber
Expect(hist.Iterate(func(p *Packet) (bool, error) {
if p.skippedPacket {
return true, nil
}
iterations = append(iterations, p.PacketNumber)
if p.PacketNumber == 14 {
if p.PacketNumber == 4 {
return false, testErr
}
return true, nil
})).To(MatchError(testErr))
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4}))
})
It("allows deletions", func() {
var iterations []protocol.PacketNumber
Expect(hist.Iterate(func(p *Packet) (bool, error) {
if p.skippedPacket {
return true, nil
}
iterations = append(iterations, p.PacketNumber)
if p.PacketNumber == 14 {
Expect(hist.Remove(14)).To(Succeed())
if p.PacketNumber == 4 {
Expect(hist.Remove(4)).To(Succeed())
}
return true, nil
})).To(Succeed())
expectInHistory([]protocol.PacketNumber{10, 18})
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14, 18}))
expectInHistory([]protocol.PacketNumber{1, 8})
Expect(iterations).To(Equal([]protocol.PacketNumber{1, 4, 8}))
})
})
@ -181,19 +226,29 @@ var _ = Describe("SentPacketHistory", func() {
It("deletes old packets after 3 PTOs", func() {
now := time.Now()
hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}, true)
Expect(hist.Len()).To(Equal(1))
expectInHistory([]protocol.PacketNumber{10})
hist.DeleteOldPackets(now.Add(-time.Nanosecond))
Expect(hist.Len()).To(Equal(1))
expectInHistory([]protocol.PacketNumber{10})
hist.DeleteOldPackets(now)
Expect(hist.Len()).To(BeZero())
expectInHistory([]protocol.PacketNumber{})
})
It("doesn't delete a packet if it hasn't been declared lost yet", func() {
now := time.Now()
hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto), declaredLost: true}, true)
hist.SentPacket(&Packet{PacketNumber: 11, SendTime: now.Add(-3 * pto), declaredLost: false}, true)
Expect(hist.Len()).To(Equal(2))
expectInHistory([]protocol.PacketNumber{10, 11})
hist.DeleteOldPackets(now)
expectInHistory([]protocol.PacketNumber{11})
})
It("deletes skipped packets", func() {
now := time.Now()
hist.SentPacket(&Packet{PacketNumber: 10, SendTime: now.Add(-3 * pto)}, true)
expectInHistory([]protocol.PacketNumber{10})
Expect(hist.Len()).To(Equal(11))
hist.DeleteOldPackets(now)
expectInHistory([]protocol.PacketNumber{10}) // the packet was not declared lost
Expect(hist.Len()).To(Equal(1))
})
})

View file

@ -48,9 +48,6 @@ const MaxSessionUnprocessedPackets = 256
// SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack
const SkipPacketAveragePeriodLength PacketNumber = 500
// MaxTrackedSkippedPackets is the maximum number of skipped packet numbers the SentPacketHandler keep track of for Optimistic ACK attack mitigation
const MaxTrackedSkippedPackets = 10
// MaxAcceptQueueSize is the maximum number of sessions that the server queues for accepting.
// If the queue is full, new connection attempts will be rejected.
const MaxAcceptQueueSize = 32