mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
protocol: use the packet number decoding / encoding logic from the RFC (#4661)
This commit is contained in:
parent
135b8c0e75
commit
4be70e0146
3 changed files with 56 additions and 234 deletions
|
@ -756,7 +756,7 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel)
|
|||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
pn := pnSpace.pns.Peek()
|
||||
// See section 17.1 of RFC 9000.
|
||||
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
|
||||
return pn, protocol.PacketNumberLengthForHeader(pn, pnSpace.largestAcked)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
|
||||
|
|
|
@ -21,58 +21,36 @@ const (
|
|||
PacketNumberLen4 PacketNumberLen = 4
|
||||
)
|
||||
|
||||
// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
|
||||
func DecodePacketNumber(
|
||||
packetNumberLength PacketNumberLen,
|
||||
lastPacketNumber PacketNumber,
|
||||
wirePacketNumber PacketNumber,
|
||||
) PacketNumber {
|
||||
var epochDelta PacketNumber
|
||||
switch packetNumberLength {
|
||||
case PacketNumberLen1:
|
||||
epochDelta = PacketNumber(1) << 8
|
||||
case PacketNumberLen2:
|
||||
epochDelta = PacketNumber(1) << 16
|
||||
case PacketNumberLen3:
|
||||
epochDelta = PacketNumber(1) << 24
|
||||
case PacketNumberLen4:
|
||||
epochDelta = PacketNumber(1) << 32
|
||||
// DecodePacketNumber calculates the packet number based its length and the last seen packet number
|
||||
// This function is taken from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3.
|
||||
func DecodePacketNumber(length PacketNumberLen, largest PacketNumber, truncated PacketNumber) PacketNumber {
|
||||
expected := largest + 1
|
||||
win := PacketNumber(1 << (length * 8))
|
||||
hwin := win / 2
|
||||
mask := win - 1
|
||||
candidate := (expected & ^mask) | truncated
|
||||
if candidate <= expected-hwin && candidate < 1<<62-win {
|
||||
return candidate + win
|
||||
}
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
var prevEpochBegin PacketNumber
|
||||
if epoch > epochDelta {
|
||||
prevEpochBegin = epoch - epochDelta
|
||||
if candidate > expected+hwin && candidate >= win {
|
||||
return candidate - win
|
||||
}
|
||||
nextEpochBegin := epoch + epochDelta
|
||||
return closestTo(
|
||||
lastPacketNumber+1,
|
||||
epoch+wirePacketNumber,
|
||||
closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
|
||||
)
|
||||
return candidate
|
||||
}
|
||||
|
||||
func closestTo(target, a, b PacketNumber) PacketNumber {
|
||||
if delta(target, a) < delta(target, b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func delta(a, b PacketNumber) PacketNumber {
|
||||
if a < b {
|
||||
return b - a
|
||||
}
|
||||
return a - b
|
||||
}
|
||||
|
||||
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
|
||||
// PacketNumberLengthForHeader gets the length of the packet number for the public header
|
||||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
|
||||
diff := uint64(packetNumber - leastUnacked)
|
||||
if diff < (1 << (16 - 1)) {
|
||||
func PacketNumberLengthForHeader(pn, largestAcked PacketNumber) PacketNumberLen {
|
||||
var numUnacked PacketNumber
|
||||
if largestAcked == InvalidPacketNumber {
|
||||
numUnacked = pn + 1
|
||||
} else {
|
||||
numUnacked = pn - largestAcked
|
||||
}
|
||||
if numUnacked < 1<<(16-1) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
if diff < (1 << (24 - 1)) {
|
||||
if numUnacked < 1<<(24-1) {
|
||||
return PacketNumberLen3
|
||||
}
|
||||
return PacketNumberLen4
|
||||
|
|
|
@ -1,204 +1,48 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Tests taken and extended from chrome
|
||||
var _ = Describe("packet number calculation", func() {
|
||||
var _ = Describe("Packet Number", func() {
|
||||
It("InvalidPacketNumber is smaller than all valid packet numbers", func() {
|
||||
Expect(InvalidPacketNumber).To(BeNumerically("<", 0))
|
||||
})
|
||||
|
||||
It("works with the example from the draft", func() {
|
||||
It("PacketNumberLen has the correct value", func() {
|
||||
Expect(PacketNumberLen1).To(BeEquivalentTo(1))
|
||||
Expect(PacketNumberLen2).To(BeEquivalentTo(2))
|
||||
Expect(PacketNumberLen3).To(BeEquivalentTo(3))
|
||||
Expect(PacketNumberLen4).To(BeEquivalentTo(4))
|
||||
})
|
||||
|
||||
It("decodes the packet number", func() {
|
||||
Expect(DecodePacketNumber(PacketNumberLen1, 10, 255)).To(Equal(PacketNumber(255)))
|
||||
Expect(DecodePacketNumber(PacketNumberLen1, 10, 0)).To(Equal(PacketNumber(0)))
|
||||
Expect(DecodePacketNumber(PacketNumberLen1, 127, 0)).To(Equal(PacketNumber(256)))
|
||||
Expect(DecodePacketNumber(PacketNumberLen1, 128, 0)).To(Equal(PacketNumber(256)))
|
||||
Expect(DecodePacketNumber(PacketNumberLen1, 256+126, 0)).To(Equal(PacketNumber(256)))
|
||||
Expect(DecodePacketNumber(PacketNumberLen1, 256+127, 0)).To(Equal(PacketNumber(512)))
|
||||
Expect(DecodePacketNumber(PacketNumberLen2, 0xffff, 0xffff)).To(Equal(PacketNumber(0xffff)))
|
||||
Expect(DecodePacketNumber(PacketNumberLen2, 0xffff+1, 0xffff)).To(Equal(PacketNumber(0xffff)))
|
||||
|
||||
// example from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3
|
||||
Expect(DecodePacketNumber(PacketNumberLen2, 0xa82f30ea, 0x9b32)).To(Equal(PacketNumber(0xa82f9b32)))
|
||||
})
|
||||
|
||||
It("works with the examples from the draft", func() {
|
||||
Expect(GetPacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2))
|
||||
Expect(GetPacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3))
|
||||
It("encodes the packet number, with the examples from the RFC", func() {
|
||||
Expect(PacketNumberLengthForHeader(1, InvalidPacketNumber)).To(Equal(PacketNumberLen2))
|
||||
Expect(PacketNumberLengthForHeader(1<<15-2, InvalidPacketNumber)).To(Equal(PacketNumberLen2))
|
||||
Expect(PacketNumberLengthForHeader(1<<15-1, InvalidPacketNumber)).To(Equal(PacketNumberLen3))
|
||||
Expect(PacketNumberLengthForHeader(1<<23-2, InvalidPacketNumber)).To(Equal(PacketNumberLen3))
|
||||
Expect(PacketNumberLengthForHeader(1<<23-1, InvalidPacketNumber)).To(Equal(PacketNumberLen4))
|
||||
Expect(PacketNumberLengthForHeader(1<<15+9, 10)).To(Equal(PacketNumberLen2))
|
||||
Expect(PacketNumberLengthForHeader(1<<15+10, 10)).To(Equal(PacketNumberLen3))
|
||||
Expect(PacketNumberLengthForHeader(1<<23+99, 100)).To(Equal(PacketNumberLen3))
|
||||
Expect(PacketNumberLengthForHeader(1<<23+100, 100)).To(Equal(PacketNumberLen4))
|
||||
// examples from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.2
|
||||
Expect(PacketNumberLengthForHeader(0xac5c02, 0xabe8b3)).To(Equal(PacketNumberLen2))
|
||||
Expect(PacketNumberLengthForHeader(0xace8fe, 0xabe8b3)).To(Equal(PacketNumberLen3))
|
||||
})
|
||||
|
||||
getEpoch := func(len PacketNumberLen) uint64 {
|
||||
if len > 4 {
|
||||
Fail("invalid packet number len")
|
||||
}
|
||||
return uint64(1) << (len * 8)
|
||||
}
|
||||
|
||||
check := func(length PacketNumberLen, expected, last uint64) {
|
||||
epoch := getEpoch(length)
|
||||
epochMask := epoch - 1
|
||||
wirePacketNumber := expected & epochMask
|
||||
ExpectWithOffset(1, DecodePacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected)))
|
||||
}
|
||||
|
||||
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen3, PacketNumberLen4} {
|
||||
length := l
|
||||
|
||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||
epoch := getEpoch(length)
|
||||
epochMask := epoch - 1
|
||||
|
||||
It("works near epoch start", func() {
|
||||
// A few quick manual sanity check
|
||||
check(length, 1, 0)
|
||||
check(length, epoch+1, epochMask)
|
||||
check(length, epoch, epochMask)
|
||||
|
||||
// Cases where the last number was close to the start of the range.
|
||||
for last := uint64(0); last < 10; last++ {
|
||||
// Small numbers should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not wrap either (because we're near 0 already).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near epoch end", func() {
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := epoch - i
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch+j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Next check where we're in a non-zero epoch to verify we handle
|
||||
// reverse wrapping, too.
|
||||
It("works near previous epoch", func() {
|
||||
prevEpoch := 1 * epoch
|
||||
curEpoch := 2 * epoch
|
||||
// Cases where the last number was close to the start of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := curEpoch + i
|
||||
// Small number should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, curEpoch+j, last)
|
||||
}
|
||||
|
||||
// But large numbers should reverse wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, prevEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near next epoch", func() {
|
||||
curEpoch := 2 * epoch
|
||||
nextEpoch := 3 * epoch
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := nextEpoch - 1 - i
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, nextEpoch+j, last)
|
||||
}
|
||||
|
||||
// but large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, curEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
Context("shortening a packet number for the header", func() {
|
||||
Context("shortening", func() {
|
||||
It("sends out low packet numbers as 2 byte", func() {
|
||||
length := GetPacketNumberLengthForHeader(4, 2)
|
||||
Expect(length).To(Equal(PacketNumberLen2))
|
||||
})
|
||||
|
||||
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
|
||||
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1)
|
||||
Expect(length).To(Equal(PacketNumberLen2))
|
||||
})
|
||||
|
||||
It("sends out higher packet numbers as 3 bytes, if a lot of ACKs are missing", func() {
|
||||
length := GetPacketNumberLengthForHeader(40000, 2)
|
||||
Expect(length).To(Equal(PacketNumberLen3))
|
||||
})
|
||||
|
||||
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
|
||||
length := GetPacketNumberLengthForHeader(40000000, 2)
|
||||
Expect(length).To(Equal(PacketNumberLen4))
|
||||
})
|
||||
})
|
||||
|
||||
Context("self-consistency", func() {
|
||||
It("works for small packet numbers", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||
Expect(decodedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("works for small packet numbers and increasing ACKed packets", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(i / 2)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||
epochMask := getEpoch(length) - 1
|
||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
||||
|
||||
decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||
Expect(decodedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("also works for larger packet numbers", func() {
|
||||
var increment uint64
|
||||
for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||
epochMask := getEpoch(length) - 1
|
||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
||||
|
||||
decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||
Expect(decodedPacketNumber).To(Equal(packetNumber))
|
||||
|
||||
increment = getEpoch(length) / 8
|
||||
}
|
||||
})
|
||||
|
||||
It("works for packet numbers larger than 2^48", func() {
|
||||
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(i - 1000)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
decodedPacketNumber := DecodePacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||
Expect(decodedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue