diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 7a30f7ed..b84f0dcb 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -756,7 +756,7 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) pnSpace := h.getPacketNumberSpace(encLevel) pn := pnSpace.pns.Peek() // See section 17.1 of RFC 9000. - return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked) + return pn, protocol.PacketNumberLengthForHeader(pn, pnSpace.largestAcked) } func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber { diff --git a/internal/protocol/packet_number.go b/internal/protocol/packet_number.go index bd340161..9422db92 100644 --- a/internal/protocol/packet_number.go +++ b/internal/protocol/packet_number.go @@ -21,58 +21,36 @@ const ( PacketNumberLen4 PacketNumberLen = 4 ) -// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number -func DecodePacketNumber( - packetNumberLength PacketNumberLen, - lastPacketNumber PacketNumber, - wirePacketNumber PacketNumber, -) PacketNumber { - var epochDelta PacketNumber - switch packetNumberLength { - case PacketNumberLen1: - epochDelta = PacketNumber(1) << 8 - case PacketNumberLen2: - epochDelta = PacketNumber(1) << 16 - case PacketNumberLen3: - epochDelta = PacketNumber(1) << 24 - case PacketNumberLen4: - epochDelta = PacketNumber(1) << 32 +// DecodePacketNumber calculates the packet number based its length and the last seen packet number +// This function is taken from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3. +func DecodePacketNumber(length PacketNumberLen, largest PacketNumber, truncated PacketNumber) PacketNumber { + expected := largest + 1 + win := PacketNumber(1 << (length * 8)) + hwin := win / 2 + mask := win - 1 + candidate := (expected & ^mask) | truncated + if candidate <= expected-hwin && candidate < 1<<62-win { + return candidate + win } - epoch := lastPacketNumber & ^(epochDelta - 1) - var prevEpochBegin PacketNumber - if epoch > epochDelta { - prevEpochBegin = epoch - epochDelta + if candidate > expected+hwin && candidate >= win { + return candidate - win } - nextEpochBegin := epoch + epochDelta - return closestTo( - lastPacketNumber+1, - epoch+wirePacketNumber, - closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber), - ) + return candidate } -func closestTo(target, a, b PacketNumber) PacketNumber { - if delta(target, a) < delta(target, b) { - return a - } - return b -} - -func delta(a, b PacketNumber) PacketNumber { - if a < b { - return b - a - } - return a - b -} - -// GetPacketNumberLengthForHeader gets the length of the packet number for the public header +// PacketNumberLengthForHeader gets the length of the packet number for the public header // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances -func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen { - diff := uint64(packetNumber - leastUnacked) - if diff < (1 << (16 - 1)) { +func PacketNumberLengthForHeader(pn, largestAcked PacketNumber) PacketNumberLen { + var numUnacked PacketNumber + if largestAcked == InvalidPacketNumber { + numUnacked = pn + 1 + } else { + numUnacked = pn - largestAcked + } + if numUnacked < 1<<(16-1) { return PacketNumberLen2 } - if diff < (1 << (24 - 1)) { + if numUnacked < 1<<(24-1) { return PacketNumberLen3 } return PacketNumberLen4 diff --git a/internal/protocol/packet_number_test.go b/internal/protocol/packet_number_test.go index 01f881e4..4aa20a0d 100644 --- a/internal/protocol/packet_number_test.go +++ b/internal/protocol/packet_number_test.go @@ -1,204 +1,48 @@ package protocol import ( - "fmt" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -// Tests taken and extended from chrome -var _ = Describe("packet number calculation", func() { +var _ = Describe("Packet Number", func() { It("InvalidPacketNumber is smaller than all valid packet numbers", func() { Expect(InvalidPacketNumber).To(BeNumerically("<", 0)) }) - It("works with the example from the draft", func() { + It("PacketNumberLen has the correct value", func() { + Expect(PacketNumberLen1).To(BeEquivalentTo(1)) + Expect(PacketNumberLen2).To(BeEquivalentTo(2)) + Expect(PacketNumberLen3).To(BeEquivalentTo(3)) + Expect(PacketNumberLen4).To(BeEquivalentTo(4)) + }) + + It("decodes the packet number", func() { + Expect(DecodePacketNumber(PacketNumberLen1, 10, 255)).To(Equal(PacketNumber(255))) + Expect(DecodePacketNumber(PacketNumberLen1, 10, 0)).To(Equal(PacketNumber(0))) + Expect(DecodePacketNumber(PacketNumberLen1, 127, 0)).To(Equal(PacketNumber(256))) + Expect(DecodePacketNumber(PacketNumberLen1, 128, 0)).To(Equal(PacketNumber(256))) + Expect(DecodePacketNumber(PacketNumberLen1, 256+126, 0)).To(Equal(PacketNumber(256))) + Expect(DecodePacketNumber(PacketNumberLen1, 256+127, 0)).To(Equal(PacketNumber(512))) + Expect(DecodePacketNumber(PacketNumberLen2, 0xffff, 0xffff)).To(Equal(PacketNumber(0xffff))) + Expect(DecodePacketNumber(PacketNumberLen2, 0xffff+1, 0xffff)).To(Equal(PacketNumber(0xffff))) + + // example from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3 Expect(DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)).To(Equal(PacketNumber(0xa82f9b32))) }) - It("works with the examples from the draft", func() { - Expect(GetPacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2)) - Expect(GetPacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3)) + It("encodes the packet number, with the examples from the RFC", func() { + Expect(PacketNumberLengthForHeader(1, InvalidPacketNumber)).To(Equal(PacketNumberLen2)) + Expect(PacketNumberLengthForHeader(1<<15-2, InvalidPacketNumber)).To(Equal(PacketNumberLen2)) + Expect(PacketNumberLengthForHeader(1<<15-1, InvalidPacketNumber)).To(Equal(PacketNumberLen3)) + Expect(PacketNumberLengthForHeader(1<<23-2, InvalidPacketNumber)).To(Equal(PacketNumberLen3)) + Expect(PacketNumberLengthForHeader(1<<23-1, InvalidPacketNumber)).To(Equal(PacketNumberLen4)) + Expect(PacketNumberLengthForHeader(1<<15+9, 10)).To(Equal(PacketNumberLen2)) + Expect(PacketNumberLengthForHeader(1<<15+10, 10)).To(Equal(PacketNumberLen3)) + Expect(PacketNumberLengthForHeader(1<<23+99, 100)).To(Equal(PacketNumberLen3)) + Expect(PacketNumberLengthForHeader(1<<23+100, 100)).To(Equal(PacketNumberLen4)) + // examples from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.2 + Expect(PacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2)) + Expect(PacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3)) }) - - getEpoch := func(len PacketNumberLen) uint64 { - if len > 4 { - Fail("invalid packet number len") - } - return uint64(1) << (len * 8) - } - - check := func(length PacketNumberLen, expected, last uint64) { - epoch := getEpoch(length) - epochMask := epoch - 1 - wirePacketNumber := expected & epochMask - ExpectWithOffset(1, DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected))) - } - - for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen3, PacketNumberLen4} { - length := l - - Context(fmt.Sprintf("with %d bytes", length), func() { - epoch := getEpoch(length) - epochMask := epoch - 1 - - It("works near epoch start", func() { - // A few quick manual sanity check - check(length, 1, 0) - check(length, epoch+1, epochMask) - check(length, epoch, epochMask) - - // Cases where the last number was close to the start of the range. - for last := uint64(0); last < 10; last++ { - // Small numbers should not wrap (even if they're out of order). - for j := uint64(0); j < 10; j++ { - check(length, j, last) - } - - // Large numbers should not wrap either (because we're near 0 already). - for j := uint64(0); j < 10; j++ { - check(length, epoch-1-j, last) - } - } - }) - - It("works near epoch end", func() { - // Cases where the last number was close to the end of the range - for i := uint64(0); i < 10; i++ { - last := epoch - i - - // Small numbers should wrap. - for j := uint64(0); j < 10; j++ { - check(length, epoch+j, last) - } - - // Large numbers should not (even if they're out of order). - for j := uint64(0); j < 10; j++ { - check(length, epoch-1-j, last) - } - } - }) - - // Next check where we're in a non-zero epoch to verify we handle - // reverse wrapping, too. - It("works near previous epoch", func() { - prevEpoch := 1 * epoch - curEpoch := 2 * epoch - // Cases where the last number was close to the start of the range - for i := uint64(0); i < 10; i++ { - last := curEpoch + i - // Small number should not wrap (even if they're out of order). - for j := uint64(0); j < 10; j++ { - check(length, curEpoch+j, last) - } - - // But large numbers should reverse wrap. - for j := uint64(0); j < 10; j++ { - num := epoch - 1 - j - check(length, prevEpoch+num, last) - } - } - }) - - It("works near next epoch", func() { - curEpoch := 2 * epoch - nextEpoch := 3 * epoch - // Cases where the last number was close to the end of the range - for i := uint64(0); i < 10; i++ { - last := nextEpoch - 1 - i - - // Small numbers should wrap. - for j := uint64(0); j < 10; j++ { - check(length, nextEpoch+j, last) - } - - // but large numbers should not (even if they're out of order). - for j := uint64(0); j < 10; j++ { - num := epoch - 1 - j - check(length, curEpoch+num, last) - } - } - }) - - Context("shortening a packet number for the header", func() { - Context("shortening", func() { - It("sends out low packet numbers as 2 byte", func() { - length := GetPacketNumberLengthForHeader(4, 2) - Expect(length).To(Equal(PacketNumberLen2)) - }) - - It("sends out high packet numbers as 2 byte, if all ACKs are received", func() { - length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1) - Expect(length).To(Equal(PacketNumberLen2)) - }) - - It("sends out higher packet numbers as 3 bytes, if a lot of ACKs are missing", func() { - length := GetPacketNumberLengthForHeader(40000, 2) - Expect(length).To(Equal(PacketNumberLen3)) - }) - - It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() { - length := GetPacketNumberLengthForHeader(40000000, 2) - Expect(length).To(Equal(PacketNumberLen4)) - }) - }) - - Context("self-consistency", func() { - It("works for small packet numbers", func() { - for i := uint64(1); i < 10000; i++ { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - } - }) - - It("works for small packet numbers and increasing ACKed packets", func() { - for i := uint64(1); i < 10000; i++ { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(i / 2) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - epochMask := getEpoch(length) - 1 - wirePacketNumber := uint64(packetNumber) & epochMask - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - } - }) - - It("also works for larger packet numbers", func() { - var increment uint64 - for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - epochMask := getEpoch(length) - 1 - wirePacketNumber := uint64(packetNumber) & epochMask - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - - increment = getEpoch(length) / 8 - } - }) - - It("works for packet numbers larger than 2^48", func() { - for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) { - packetNumber := PacketNumber(i) - leastUnacked := PacketNumber(i - 1000) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) - wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - - decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) - Expect(decodedPacketNumber).To(Equal(packetNumber)) - } - }) - }) - }) - }) - } })