diff --git a/internal/ackhandler/packet_number_generator.go b/internal/ackhandler/packet_number_generator.go index 4569be9a..e93ef967 100644 --- a/internal/ackhandler/packet_number_generator.go +++ b/internal/ackhandler/packet_number_generator.go @@ -2,40 +2,72 @@ package ackhandler import ( "crypto/rand" - "math" + "encoding/binary" + mrand "math/rand" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) -// The packetNumberGenerator generates the packet number for the next packet +type packetNumberGenerator interface { + Peek() protocol.PacketNumber + Pop() protocol.PacketNumber +} + +type sequentialPacketNumberGenerator struct { + next protocol.PacketNumber +} + +var _ packetNumberGenerator = &sequentialPacketNumberGenerator{} + +func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator { + return &sequentialPacketNumberGenerator{next: initial} +} + +func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber { + return p.next +} + +func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { + next := p.next + p.next++ + return next +} + +// The skippingPacketNumberGenerator generates the packet number for the next packet // it randomly skips a packet number every averagePeriod packets (on average). // It is guaranteed to never skip two consecutive packet numbers. -type packetNumberGenerator struct { - averagePeriod protocol.PacketNumber +type skippingPacketNumberGenerator struct { + rand *mrand.Rand + period protocol.PacketNumber + maxPeriod protocol.PacketNumber next protocol.PacketNumber nextToSkip protocol.PacketNumber } -func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator { - g := &packetNumberGenerator{ - next: initial, - averagePeriod: averagePeriod, +var _ packetNumberGenerator = &skippingPacketNumberGenerator{} + +func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator { + b := make([]byte, 8) + rand.Read(b) // it's not the end of the world if we don't get perfect random here + g := &skippingPacketNumberGenerator{ + rand: mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(b)))), + next: initial, + period: initialPeriod, + maxPeriod: maxPeriod, } g.generateNewSkip() return g } -func (p *packetNumberGenerator) Peek() protocol.PacketNumber { +func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { return p.next } -func (p *packetNumberGenerator) Pop() protocol.PacketNumber { +func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { next := p.next - - // generate a new packet number for the next packet - p.next++ - + p.next++ // generate a new packet number for the next packet if p.next == p.nextToSkip { p.next++ p.generateNewSkip() @@ -43,19 +75,8 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber { return next } -func (p *packetNumberGenerator) generateNewSkip() { - num := p.getRandomNumber() - skip := protocol.PacketNumber(num) * (p.averagePeriod - 1) / (math.MaxUint16 / 2) +func (p *skippingPacketNumberGenerator) generateNewSkip() { // make sure that there are never two consecutive packet numbers that are skipped - p.nextToSkip = p.next + 2 + skip -} - -// getRandomNumber() generates a cryptographically secure random number between 0 and MaxUint16 (= 65535) -// The expectation value is 65535/2 -func (p *packetNumberGenerator) getRandomNumber() uint16 { - b := make([]byte, 2) - rand.Read(b) // ignore the error here - - num := uint16(b[0])<<8 + uint16(b[1]) - return num + p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rand.Int63n(int64(2*p.period))) + p.period = utils.MinPacketNumber(2*p.period, p.maxPeriod) } diff --git a/internal/ackhandler/packet_number_generator_test.go b/internal/ackhandler/packet_number_generator_test.go index 73bd914f..94678256 100644 --- a/internal/ackhandler/packet_number_generator_test.go +++ b/internal/ackhandler/packet_number_generator_test.go @@ -1,41 +1,49 @@ package ackhandler import ( - "math" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("Packet Number Generator", func() { - var png *packetNumberGenerator +var _ = Describe("Sequential Packet Number Generator", func() { + It("generates sequential packet numbers", func() { + const initialPN protocol.PacketNumber = 123 + png := newSequentialPacketNumberGenerator(initialPN) - BeforeEach(func() { - png = newPacketNumberGenerator(1, 100) + for i := initialPN; i < initialPN+1000; i++ { + Expect(png.Peek()).To(Equal(i)) + Expect(png.Peek()).To(Equal(i)) + Expect(png.Pop()).To(Equal(i)) + } }) +}) + +var _ = Describe("Skipping Packet Number Generator", func() { + const initialPN protocol.PacketNumber = 8 + const initialPeriod protocol.PacketNumber = 25 + const maxPeriod protocol.PacketNumber = 300 It("can be initialized to return any first packet number", func() { - png = newPacketNumberGenerator(12345, 100) + png := newSkippingPacketNumberGenerator(12345, initialPeriod, maxPeriod) Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345))) }) - It("gets 1 as the first packet number", func() { - num := png.Pop() - Expect(num).To(Equal(protocol.PacketNumber(1))) - }) - It("allows peeking", func() { + png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod).(*skippingPacketNumberGenerator) png.nextToSkip = 1000 - Expect(png.Peek()).To(Equal(protocol.PacketNumber(1))) - Expect(png.Peek()).To(Equal(protocol.PacketNumber(1))) - num := png.Pop() - Expect(num).To(Equal(protocol.PacketNumber(1))) - Expect(png.Peek()).To(Equal(protocol.PacketNumber(2))) - Expect(png.Peek()).To(Equal(protocol.PacketNumber(2))) + Expect(png.Peek()).To(Equal(initialPN)) + Expect(png.Peek()).To(Equal(initialPN)) + Expect(png.Pop()).To(Equal(initialPN)) + Expect(png.Peek()).To(Equal(initialPN + 1)) + Expect(png.Peek()).To(Equal(initialPN + 1)) }) It("skips a packet number", func() { + png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) var last protocol.PacketNumber var skipped bool for i := 0; i < 1000; i++ { @@ -49,52 +57,38 @@ var _ = Describe("Packet Number Generator", func() { Expect(skipped).To(BeTrue()) }) - It("skips a specific packet number", func() { - png.nextToSkip = 2 - num := png.Pop() - Expect(num).To(Equal(protocol.PacketNumber(1))) - Expect(png.Peek()).To(Equal(protocol.PacketNumber(3))) - num = png.Pop() - Expect(num).To(Equal(protocol.PacketNumber(3))) - }) - It("generates a new packet number to skip", func() { - png.next = 100 - png.averagePeriod = 100 - - rep := 5000 - var sum protocol.PacketNumber + const rep = 500 + periods := make([][]protocol.PacketNumber, rep) + expectedPeriods := []protocol.PacketNumber{25, 50, 100, 200, 300, 300, 300} for i := 0; i < rep; i++ { - png.generateNewSkip() - Expect(png.nextToSkip).ToNot(Equal(protocol.PacketNumber(101))) - sum += png.nextToSkip - } - - average := sum / protocol.PacketNumber(rep) - Expect(average).To(BeNumerically("==", protocol.PacketNumber(200), 4)) - }) - - It("uses random numbers", func() { - var smallest uint16 = math.MaxUint16 - var largest uint16 - var sum uint64 - - rep := 10000 - - for i := 0; i < rep; i++ { - num := png.getRandomNumber() - sum += uint64(num) - if num > largest { - largest = num - } - if num < smallest { - smallest = num + png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) + last := initialPN + lastSkip := initialPN + for len(periods[i]) < len(expectedPeriods) { + next := png.Pop() + if next > last+1 { + skipped := next - 1 + Expect(skipped).To(BeNumerically(">", lastSkip+1)) + periods[i] = append(periods[i], skipped-lastSkip-1) + lastSkip = skipped + } + last = next } } - Expect(smallest).To(BeNumerically("<", 300)) - Expect(largest).To(BeNumerically(">", math.MaxUint16-300)) - Expect(sum / uint64(rep)).To(BeNumerically("==", uint64(math.MaxUint16/2), 1000)) + for j := 0; j < len(expectedPeriods); j++ { + var average float64 + for i := 0; i < rep; i++ { + average += float64(periods[i][j]) / float64(len(periods)) + } + fmt.Fprintf(GinkgoWriter, "Period %d: %.2f (expected %d)\n", j, average, expectedPeriods[j]) + tolerance := protocol.PacketNumber(5) + if t := expectedPeriods[j] / 10; t > tolerance { + tolerance = t + } + Expect(average).To(BeNumerically("~", expectedPeriods[j]+1 /* we never skip two packet numbers at the same time */, tolerance)) + } }) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 2d791a1f..99a51eb8 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -25,7 +25,7 @@ const ( type packetNumberSpace struct { history *sentPacketHistory - pns *packetNumberGenerator + pns packetNumberGenerator lossTime time.Time lastAckElicitingPacketTime time.Time @@ -34,10 +34,16 @@ type packetNumberSpace struct { largestSent protocol.PacketNumber } -func newPacketNumberSpace(initialPN protocol.PacketNumber, rttStats *utils.RTTStats) *packetNumberSpace { +func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace { + var pns packetNumberGenerator + if skipPNs { + pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod) + } else { + pns = newSequentialPacketNumberGenerator(initialPN) + } return &packetNumberSpace{ history: newSentPacketHistory(rttStats), - pns: newPacketNumberGenerator(initialPN, protocol.SkipPacketAveragePeriodLength), + pns: pns, largestSent: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, } @@ -94,7 +100,7 @@ var ( ) func newSentPacketHandler( - initialPacketNumber protocol.PacketNumber, + initialPN protocol.PacketNumber, rttStats *utils.RTTStats, pers protocol.Perspective, tracer logging.ConnectionTracer, @@ -110,9 +116,9 @@ func newSentPacketHandler( return &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerAddressValidated: pers == protocol.PerspectiveClient, - initialPackets: newPacketNumberSpace(initialPacketNumber, rttStats), - handshakePackets: newPacketNumberSpace(0, rttStats), - appDataPackets: newPacketNumberSpace(0, rttStats), + initialPackets: newPacketNumberSpace(initialPN, false, rttStats), + handshakePackets: newPacketNumberSpace(0, false, rttStats), + appDataPackets: newPacketNumberSpace(0, true, rttStats), rttStats: rttStats, congestion: congestion, perspective: pers, @@ -765,8 +771,8 @@ func (h *sentPacketHandler) ResetForRetry() error { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } - h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), h.rttStats) - h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), h.rttStats) + h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats) + h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats) oldAlarm := h.alarm h.alarm = time.Time{} if h.tracer != nil { diff --git a/internal/protocol/params.go b/internal/protocol/params.go index ab2e22fb..12567fa7 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -48,8 +48,12 @@ const MaxServerUnprocessedPackets = 1024 // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. const MaxSessionUnprocessedPackets = 256 -// SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack -const SkipPacketAveragePeriodLength PacketNumber = 500 +// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack. +// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod. +const SkipPacketInitialPeriod PacketNumber = 256 + +// SkipPacketMaxPeriod is the maximum period length used for packet number skipping. +const SkipPacketMaxPeriod PacketNumber = 128 * 1024 // MaxAcceptQueueSize is the maximum number of sessions that the server queues for accepting. // If the queue is full, new connection attempts will be rejected.