diff --git a/internal/ackhandler/packet_number_generator.go b/internal/ackhandler/packet_number_generator.go index 84ba2532..e93ef967 100644 --- a/internal/ackhandler/packet_number_generator.go +++ b/internal/ackhandler/packet_number_generator.go @@ -6,6 +6,7 @@ import ( mrand "math/rand" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) type packetNumberGenerator interface { @@ -37,8 +38,9 @@ func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { // it randomly skips a packet number every averagePeriod packets (on average). // It is guaranteed to never skip two consecutive packet numbers. type skippingPacketNumberGenerator struct { - rand *mrand.Rand - averagePeriod protocol.PacketNumber + rand *mrand.Rand + period protocol.PacketNumber + maxPeriod protocol.PacketNumber next protocol.PacketNumber nextToSkip protocol.PacketNumber @@ -46,13 +48,14 @@ type skippingPacketNumberGenerator struct { var _ packetNumberGenerator = &skippingPacketNumberGenerator{} -func newSkippingPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) packetNumberGenerator { +func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator { b := make([]byte, 8) rand.Read(b) // it's not the end of the world if we don't get perfect random here g := &skippingPacketNumberGenerator{ - rand: mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(b)))), - next: initial, - averagePeriod: averagePeriod, + rand: mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(b)))), + next: initial, + period: initialPeriod, + maxPeriod: maxPeriod, } g.generateNewSkip() return g @@ -74,5 +77,6 @@ func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { func (p *skippingPacketNumberGenerator) generateNewSkip() { // make sure that there are never two consecutive packet numbers that are skipped - p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rand.Int31n(int32(2*p.averagePeriod))) + p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rand.Int63n(int64(2*p.period))) + p.period = utils.MinPacketNumber(2*p.period, p.maxPeriod) } diff --git a/internal/ackhandler/packet_number_generator_test.go b/internal/ackhandler/packet_number_generator_test.go index 4321284f..94678256 100644 --- a/internal/ackhandler/packet_number_generator_test.go +++ b/internal/ackhandler/packet_number_generator_test.go @@ -1,7 +1,10 @@ package ackhandler import ( + "fmt" + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -20,24 +23,17 @@ var _ = Describe("Sequential Packet Number Generator", func() { }) var _ = Describe("Skipping Packet Number Generator", func() { - var png *skippingPacketNumberGenerator const initialPN protocol.PacketNumber = 8 - - BeforeEach(func() { - png = newSkippingPacketNumberGenerator(initialPN, 100).(*skippingPacketNumberGenerator) - }) + const initialPeriod protocol.PacketNumber = 25 + const maxPeriod protocol.PacketNumber = 300 It("can be initialized to return any first packet number", func() { - png = newSkippingPacketNumberGenerator(12345, 100).(*skippingPacketNumberGenerator) + png := newSkippingPacketNumberGenerator(12345, initialPeriod, maxPeriod) Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345))) }) - It("gets the first packet number", func() { - num := png.Pop() - Expect(num).To(Equal(initialPN)) - }) - It("allows peeking", func() { + png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod).(*skippingPacketNumberGenerator) png.nextToSkip = 1000 Expect(png.Peek()).To(Equal(initialPN)) Expect(png.Peek()).To(Equal(initialPN)) @@ -47,6 +43,7 @@ var _ = Describe("Skipping Packet Number Generator", func() { }) It("skips a packet number", func() { + png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) var last protocol.PacketNumber var skipped bool for i := 0; i < 1000; i++ { @@ -60,35 +57,38 @@ var _ = Describe("Skipping Packet Number Generator", func() { Expect(skipped).To(BeTrue()) }) - It("skips a specific packet number", func() { - png.nextToSkip = initialPN + 1 - Expect(png.Pop()).To(Equal(initialPN)) - Expect(png.Peek()).To(Equal(initialPN + 2)) - Expect(png.Pop()).To(Equal(initialPN + 2)) - }) - It("generates a new packet number to skip", func() { - const averagePeriod = 25 - png.averagePeriod = averagePeriod + const rep = 500 + periods := make([][]protocol.PacketNumber, rep) + expectedPeriods := []protocol.PacketNumber{25, 50, 100, 200, 300, 300, 300} - periods := make([]protocol.PacketNumber, 0, 500) - last := initialPN - var lastSkip protocol.PacketNumber - for len(periods) < cap(periods) { - next := png.Pop() - if next > last+1 { - skipped := next - 1 - Expect(skipped).To(BeNumerically(">", lastSkip+1)) - periods = append(periods, skipped-lastSkip-1) - lastSkip = skipped + for i := 0; i < rep; i++ { + png := newSkippingPacketNumberGenerator(initialPN, initialPeriod, maxPeriod) + last := initialPN + lastSkip := initialPN + for len(periods[i]) < len(expectedPeriods) { + next := png.Pop() + if next > last+1 { + skipped := next - 1 + Expect(skipped).To(BeNumerically(">", lastSkip+1)) + periods[i] = append(periods[i], skipped-lastSkip-1) + lastSkip = skipped + } + last = next } - last = next } - var average float64 - for _, p := range periods { - average += float64(p) / float64(len(periods)) + for j := 0; j < len(expectedPeriods); j++ { + var average float64 + for i := 0; i < rep; i++ { + average += float64(periods[i][j]) / float64(len(periods)) + } + fmt.Fprintf(GinkgoWriter, "Period %d: %.2f (expected %d)\n", j, average, expectedPeriods[j]) + tolerance := protocol.PacketNumber(5) + if t := expectedPeriods[j] / 10; t > tolerance { + tolerance = t + } + Expect(average).To(BeNumerically("~", expectedPeriods[j]+1 /* we never skip two packet numbers at the same time */, tolerance)) } - Expect(average).To(BeNumerically("~", averagePeriod+1 /* we never skip two packet numbers at the same time */, 5)) }) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index c9dd1280..99a51eb8 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -37,7 +37,7 @@ type packetNumberSpace struct { func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace { var pns packetNumberGenerator if skipPNs { - pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketAveragePeriodLength) + pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod) } else { pns = newSequentialPacketNumberGenerator(initialPN) } diff --git a/internal/protocol/params.go b/internal/protocol/params.go index a124f3aa..73b5cdbd 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -48,8 +48,12 @@ const MaxServerUnprocessedPackets = 1024 // MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. const MaxSessionUnprocessedPackets = 256 -// SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack -const SkipPacketAveragePeriodLength PacketNumber = 500 +// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack. +// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod. +const SkipPacketInitialPeriod PacketNumber = 256 + +// SkipPacketMaxPeriod is the maximum period length used for packet number skipping. +const SkipPacketMaxPeriod PacketNumber = 128 * 1024 // MaxAcceptQueueSize is the maximum number of sessions that the server queues for accepting. // If the queue is full, new connection attempts will be rejected.