diff --git a/internal/ackhandler/packet_number_generator.go b/internal/ackhandler/packet_number_generator.go index 4569be9a..f203892a 100644 --- a/internal/ackhandler/packet_number_generator.go +++ b/internal/ackhandler/packet_number_generator.go @@ -2,7 +2,8 @@ package ackhandler import ( "crypto/rand" - "math" + "encoding/binary" + mrand "math/rand" "github.com/lucas-clemente/quic-go/internal/protocol" ) @@ -11,6 +12,7 @@ import ( // it randomly skips a packet number every averagePeriod packets (on average). // It is guaranteed to never skip two consecutive packet numbers. type packetNumberGenerator struct { + rand *mrand.Rand averagePeriod protocol.PacketNumber next protocol.PacketNumber @@ -18,7 +20,10 @@ type packetNumberGenerator struct { } func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator { + b := make([]byte, 8) + rand.Read(b) // it's not the end of the world if we don't get perfect random here g := &packetNumberGenerator{ + rand: mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(b)))), next: initial, averagePeriod: averagePeriod, } @@ -44,18 +49,6 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber { } func (p *packetNumberGenerator) generateNewSkip() { - num := p.getRandomNumber() - skip := protocol.PacketNumber(num) * (p.averagePeriod - 1) / (math.MaxUint16 / 2) // make sure that there are never two consecutive packet numbers that are skipped - p.nextToSkip = p.next + 2 + skip -} - -// getRandomNumber() generates a cryptographically secure random number between 0 and MaxUint16 (= 65535) -// The expectation value is 65535/2 -func (p *packetNumberGenerator) getRandomNumber() uint16 { - b := make([]byte, 2) - rand.Read(b) // ignore the error here - - num := uint16(b[0])<<8 + uint16(b[1]) - return num + p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rand.Int31n(int32(2*p.averagePeriod))) } diff --git a/internal/ackhandler/packet_number_generator_test.go b/internal/ackhandler/packet_number_generator_test.go index 73bd914f..a862a9c4 100644 --- a/internal/ackhandler/packet_number_generator_test.go +++ b/internal/ackhandler/packet_number_generator_test.go @@ -1,8 +1,6 @@ package ackhandler import ( - "math" - "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -10,9 +8,10 @@ import ( var _ = Describe("Packet Number Generator", func() { var png *packetNumberGenerator + const initialPN protocol.PacketNumber = 8 BeforeEach(func() { - png = newPacketNumberGenerator(1, 100) + png = newPacketNumberGenerator(initialPN, 100) }) It("can be initialized to return any first packet number", func() { @@ -20,19 +19,18 @@ var _ = Describe("Packet Number Generator", func() { Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345))) }) - It("gets 1 as the first packet number", func() { + It("gets the first packet number", func() { num := png.Pop() - Expect(num).To(Equal(protocol.PacketNumber(1))) + Expect(num).To(Equal(initialPN)) }) It("allows peeking", func() { png.nextToSkip = 1000 - Expect(png.Peek()).To(Equal(protocol.PacketNumber(1))) - Expect(png.Peek()).To(Equal(protocol.PacketNumber(1))) - num := png.Pop() - Expect(num).To(Equal(protocol.PacketNumber(1))) - Expect(png.Peek()).To(Equal(protocol.PacketNumber(2))) - Expect(png.Peek()).To(Equal(protocol.PacketNumber(2))) + Expect(png.Peek()).To(Equal(initialPN)) + Expect(png.Peek()).To(Equal(initialPN)) + Expect(png.Pop()).To(Equal(initialPN)) + Expect(png.Peek()).To(Equal(initialPN + 1)) + Expect(png.Peek()).To(Equal(initialPN + 1)) }) It("skips a packet number", func() { @@ -50,51 +48,34 @@ var _ = Describe("Packet Number Generator", func() { }) It("skips a specific packet number", func() { - png.nextToSkip = 2 - num := png.Pop() - Expect(num).To(Equal(protocol.PacketNumber(1))) - Expect(png.Peek()).To(Equal(protocol.PacketNumber(3))) - num = png.Pop() - Expect(num).To(Equal(protocol.PacketNumber(3))) + png.nextToSkip = initialPN + 1 + Expect(png.Pop()).To(Equal(initialPN)) + Expect(png.Peek()).To(Equal(initialPN + 2)) + Expect(png.Pop()).To(Equal(initialPN + 2)) }) It("generates a new packet number to skip", func() { - png.next = 100 - png.averagePeriod = 100 + const averagePeriod = 25 + png.averagePeriod = averagePeriod - rep := 5000 - var sum protocol.PacketNumber - - for i := 0; i < rep; i++ { - png.generateNewSkip() - Expect(png.nextToSkip).ToNot(Equal(protocol.PacketNumber(101))) - sum += png.nextToSkip + periods := make([]protocol.PacketNumber, 0, 500) + last := initialPN + var lastSkip protocol.PacketNumber + for len(periods) < cap(periods) { + next := png.Pop() + if next > last+1 { + skipped := next - 1 + Expect(skipped).To(BeNumerically(">", lastSkip+1)) + periods = append(periods, skipped-lastSkip-1) + lastSkip = skipped + } + last = next } - average := sum / protocol.PacketNumber(rep) - Expect(average).To(BeNumerically("==", protocol.PacketNumber(200), 4)) - }) - - It("uses random numbers", func() { - var smallest uint16 = math.MaxUint16 - var largest uint16 - var sum uint64 - - rep := 10000 - - for i := 0; i < rep; i++ { - num := png.getRandomNumber() - sum += uint64(num) - if num > largest { - largest = num - } - if num < smallest { - smallest = num - } + var average float64 + for _, p := range periods { + average += float64(p) / float64(len(periods)) } - - Expect(smallest).To(BeNumerically("<", 300)) - Expect(largest).To(BeNumerically(">", math.MaxUint16-300)) - Expect(sum / uint64(rep)).To(BeNumerically("==", uint64(math.MaxUint16/2), 1000)) + Expect(average).To(BeNumerically("~", averagePeriod+1 /* we never skip two packet numbers at the same time */, 5)) }) })