mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
Merge pull request #1388 from lucas-clemente/varint-packet-numbers
use varint packet numbers in IETF QUIC headers
This commit is contained in:
commit
48ef98f641
19 changed files with 523 additions and 263 deletions
|
@ -511,6 +511,7 @@ var _ = Describe("Client", func() {
|
|||
PayloadLen: 1000,
|
||||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
|
@ -530,6 +531,7 @@ var _ = Describe("Client", func() {
|
|||
PayloadLen: 123,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
|
@ -545,6 +547,7 @@ var _ = Describe("Client", func() {
|
|||
PayloadLen: 123,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveServer, versionIETFFrames)).To(Succeed())
|
||||
|
@ -561,7 +564,7 @@ var _ = Describe("Client", func() {
|
|||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}).Write(buf, protocol.PerspectiveServer, versionGQUICFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = cl.handlePacket(addr, buf.Bytes())
|
||||
|
@ -579,7 +582,7 @@ var _ = Describe("Client", func() {
|
|||
DestConnectionID: connID2,
|
||||
SrcConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}).Write(buf, protocol.PerspectiveServer, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -74,10 +74,12 @@ type sentPacketHandler struct {
|
|||
alarm time.Time
|
||||
|
||||
logger utils.Logger
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
// NewSentPacketHandler creates a new sentPacketHandler
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
|
@ -92,6 +94,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) Se
|
|||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -494,7 +497,7 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked())
|
||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
|
|
|
@ -49,7 +49,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
rttStats := &congestion.RTTStats{}
|
||||
handler = NewSentPacketHandler(rttStats, utils.DefaultLogger).(*sentPacketHandler)
|
||||
handler = NewSentPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*sentPacketHandler)
|
||||
handler.SetHandshakeComplete()
|
||||
streamFrame = wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
|
|
|
@ -1,8 +1,25 @@
|
|||
package protocol
|
||||
|
||||
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
|
||||
func InferPacketNumber(packetNumberLength PacketNumberLen, lastPacketNumber PacketNumber, wirePacketNumber PacketNumber) PacketNumber {
|
||||
epochDelta := PacketNumber(1) << (uint8(packetNumberLength) * 8)
|
||||
func InferPacketNumber(
|
||||
packetNumberLength PacketNumberLen,
|
||||
lastPacketNumber PacketNumber,
|
||||
wirePacketNumber PacketNumber,
|
||||
version VersionNumber,
|
||||
) PacketNumber {
|
||||
var epochDelta PacketNumber
|
||||
if version.UsesVarintPacketNumbers() {
|
||||
switch packetNumberLength {
|
||||
case PacketNumberLen1:
|
||||
epochDelta = PacketNumber(1) << 7
|
||||
case PacketNumberLen2:
|
||||
epochDelta = PacketNumber(1) << 14
|
||||
case PacketNumberLen4:
|
||||
epochDelta = PacketNumber(1) << 30
|
||||
}
|
||||
} else {
|
||||
epochDelta = PacketNumber(1) << (uint8(packetNumberLength) * 8)
|
||||
}
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
prevEpochBegin := epoch - epochDelta
|
||||
nextEpochBegin := epoch + epochDelta
|
||||
|
@ -29,9 +46,10 @@ func delta(a, b PacketNumber) PacketNumber {
|
|||
|
||||
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
|
||||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||
func GetPacketNumberLengthForHeader(packetNumber PacketNumber, leastUnacked PacketNumber) PacketNumberLen {
|
||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
|
||||
diff := uint64(packetNumber - leastUnacked)
|
||||
if diff < (1 << (uint8(PacketNumberLen2)*8 - 1)) {
|
||||
if version.UsesVarintPacketNumbers() && diff < (1<<(14-1)) ||
|
||||
!version.UsesVarintPacketNumbers() && diff < (1<<(16-1)) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
return PacketNumberLen4
|
||||
|
|
|
@ -11,198 +11,219 @@ import (
|
|||
// Tests taken and extended from chrome
|
||||
var _ = Describe("packet number calculation", func() {
|
||||
Context("infering a packet number", func() {
|
||||
check := func(length PacketNumberLen, expected, last uint64) {
|
||||
epoch := uint64(1) << (length * 8)
|
||||
getEpoch := func(len PacketNumberLen, v VersionNumber) uint64 {
|
||||
if v.UsesVarintPacketNumbers() {
|
||||
switch len {
|
||||
case PacketNumberLen1:
|
||||
return uint64(1) << 7
|
||||
case PacketNumberLen2:
|
||||
return uint64(1) << 14
|
||||
case PacketNumberLen4:
|
||||
return uint64(1) << 30
|
||||
default:
|
||||
Fail("invalid packet number len")
|
||||
}
|
||||
}
|
||||
return uint64(1) << (len * 8)
|
||||
}
|
||||
check := func(length PacketNumberLen, expected, last uint64, v VersionNumber) {
|
||||
epoch := getEpoch(length, v)
|
||||
epochMask := epoch - 1
|
||||
wirePacketNumber := expected & epochMask
|
||||
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected)))
|
||||
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber), v)).To(Equal(PacketNumber(expected)))
|
||||
}
|
||||
for _, length := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4, PacketNumberLen6} {
|
||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||
epoch := uint64(1) << (length * 8)
|
||||
epochMask := epoch - 1
|
||||
|
||||
It("works near epoch start", func() {
|
||||
// A few quick manual sanity check
|
||||
check(length, 1, 0)
|
||||
check(length, epoch+1, epochMask)
|
||||
check(length, epoch, epochMask)
|
||||
for _, v := range []VersionNumber{Version39, VersionTLS} {
|
||||
version := v
|
||||
|
||||
// Cases where the last number was close to the start of the range.
|
||||
for last := uint64(0); last < 10; last++ {
|
||||
// Small numbers should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, j, last)
|
||||
}
|
||||
Context(fmt.Sprintf("using varint packet numbers: %t", version.UsesVarintPacketNumbers()), func() {
|
||||
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
|
||||
length := l
|
||||
|
||||
// Large numbers should not wrap either (because we're near 0 already).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||
epoch := getEpoch(length, version)
|
||||
epochMask := epoch - 1
|
||||
|
||||
It("works near epoch end", func() {
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := epoch - i
|
||||
It("works near epoch start", func() {
|
||||
// A few quick manual sanity check
|
||||
check(length, 1, 0, version)
|
||||
check(length, epoch+1, epochMask, version)
|
||||
check(length, epoch, epochMask, version)
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch+j, last)
|
||||
}
|
||||
// Cases where the last number was close to the start of the range.
|
||||
for last := uint64(0); last < 10; last++ {
|
||||
// Small numbers should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, j, last, version)
|
||||
}
|
||||
|
||||
// Large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
// Large numbers should not wrap either (because we're near 0 already).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Next check where we're in a non-zero epoch to verify we handle
|
||||
// reverse wrapping, too.
|
||||
It("works near previous epoch", func() {
|
||||
prevEpoch := 1 * epoch
|
||||
curEpoch := 2 * epoch
|
||||
// Cases where the last number was close to the start of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := curEpoch + i
|
||||
// Small number should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, curEpoch+j, last)
|
||||
}
|
||||
It("works near epoch end", func() {
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := epoch - i
|
||||
|
||||
// But large numbers should reverse wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, prevEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch+j, last, version)
|
||||
}
|
||||
|
||||
It("works near next epoch", func() {
|
||||
curEpoch := 2 * epoch
|
||||
nextEpoch := 3 * epoch
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := nextEpoch - 1 - i
|
||||
// Large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, nextEpoch+j, last)
|
||||
}
|
||||
// Next check where we're in a non-zero epoch to verify we handle
|
||||
// reverse wrapping, too.
|
||||
It("works near previous epoch", func() {
|
||||
prevEpoch := 1 * epoch
|
||||
curEpoch := 2 * epoch
|
||||
// Cases where the last number was close to the start of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := curEpoch + i
|
||||
// Small number should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, curEpoch+j, last, version)
|
||||
}
|
||||
|
||||
// but large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, curEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
// But large numbers should reverse wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, prevEpoch+num, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near next max", func() {
|
||||
maxNumber := uint64(math.MaxUint64)
|
||||
maxEpoch := maxNumber & ^epochMask
|
||||
It("works near next epoch", func() {
|
||||
curEpoch := 2 * epoch
|
||||
nextEpoch := 3 * epoch
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := nextEpoch - 1 - i
|
||||
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
// Subtract 1, because the expected next packet number is 1 more than the
|
||||
// last packet number.
|
||||
last := maxNumber - i - 1
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, nextEpoch+j, last, version)
|
||||
}
|
||||
|
||||
// Small numbers should not wrap, because they have nowhere to go.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, maxEpoch+j, last)
|
||||
}
|
||||
// but large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, curEpoch+num, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Large numbers should not wrap either.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, maxEpoch+num, last)
|
||||
}
|
||||
}
|
||||
It("works near next max", func() {
|
||||
maxNumber := uint64(math.MaxUint64)
|
||||
maxEpoch := maxNumber & ^epochMask
|
||||
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
// Subtract 1, because the expected next packet number is 1 more than the
|
||||
// last packet number.
|
||||
last := maxNumber - i - 1
|
||||
|
||||
// Small numbers should not wrap, because they have nowhere to go.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, maxEpoch+j, last, version)
|
||||
}
|
||||
|
||||
// Large numbers should not wrap either.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, maxEpoch+num, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Context("shortening a packet number for the header", func() {
|
||||
Context("shortening", func() {
|
||||
It("sends out low packet numbers as 2 byte", func() {
|
||||
length := GetPacketNumberLengthForHeader(4, 2, version)
|
||||
Expect(length).To(Equal(PacketNumberLen2))
|
||||
})
|
||||
|
||||
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
|
||||
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, version)
|
||||
Expect(length).To(Equal(PacketNumberLen2))
|
||||
})
|
||||
|
||||
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
|
||||
length := GetPacketNumberLengthForHeader(40000, 2, version)
|
||||
Expect(length).To(Equal(PacketNumberLen4))
|
||||
})
|
||||
})
|
||||
|
||||
Context("self-consistency", func() {
|
||||
It("works for small packet numbers", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("works for small packet numbers and increasing ACKed packets", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(i / 2)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
||||
epochMask := getEpoch(length, version) - 1
|
||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("also works for larger packet numbers", func() {
|
||||
var increment uint64
|
||||
for i := uint64(1); i < getEpoch(PacketNumberLen4, version); i += increment {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
||||
epochMask := getEpoch(length, version) - 1
|
||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
|
||||
increment = getEpoch(length, version) / 8
|
||||
}
|
||||
})
|
||||
|
||||
It("works for packet numbers larger than 2^48", func() {
|
||||
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(i - 1000)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
Context("shortening a packet number for the publicHeader", func() {
|
||||
Context("shortening", func() {
|
||||
It("sends out low packet numbers as 2 byte", func() {
|
||||
length := GetPacketNumberLengthForHeader(4, 2)
|
||||
Expect(length).To(Equal(PacketNumberLen2))
|
||||
})
|
||||
|
||||
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
|
||||
length := GetPacketNumberLengthForHeader(0xDEADBEEF, 0xDEADBEEF-1)
|
||||
Expect(length).To(Equal(PacketNumberLen2))
|
||||
})
|
||||
|
||||
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
|
||||
length := GetPacketNumberLengthForHeader(40000, 2)
|
||||
Expect(length).To(Equal(PacketNumberLen4))
|
||||
})
|
||||
})
|
||||
|
||||
Context("self-consistency", func() {
|
||||
It("works for small packet numbers", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("works for small packet numbers and increasing ACKed packets", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(i / 2)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("also works for larger packet numbers", func() {
|
||||
increment := uint64(1 << (8 - 3))
|
||||
for i := uint64(1); i < (2 << 31); i += increment {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
|
||||
switch length {
|
||||
case PacketNumberLen2:
|
||||
increment = 1 << (2*8 - 3)
|
||||
case PacketNumberLen4:
|
||||
increment = 1 << (4*8 - 3)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works for packet numbers larger than 2^48", func() {
|
||||
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(i - 1000)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("determining the minimum length of a packet number", func() {
|
||||
It("1 byte", func() {
|
||||
Expect(GetPacketNumberLength(0xFF)).To(Equal(PacketNumberLen1))
|
||||
|
|
|
@ -86,6 +86,10 @@ func (vn VersionNumber) UsesStopWaitingFrames() bool {
|
|||
return vn.isGQUIC()
|
||||
}
|
||||
|
||||
func (vn VersionNumber) UsesVarintPacketNumbers() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control
|
||||
func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool {
|
||||
if id == vn.CryptoStreamID() {
|
||||
|
|
|
@ -79,6 +79,13 @@ var _ = Describe("Version", func() {
|
|||
Expect(VersionTLS.UsesIETFFrameFormat()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("tells if a version uses varint packet numbers", func() {
|
||||
Expect(Version39.UsesVarintPacketNumbers()).To(BeFalse())
|
||||
Expect(Version42.UsesVarintPacketNumbers()).To(BeFalse())
|
||||
Expect(Version43.UsesVarintPacketNumbers()).To(BeFalse())
|
||||
Expect(VersionTLS.UsesVarintPacketNumbers()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("tells if a version uses the IETF frame types", func() {
|
||||
Expect(Version39.UsesIETFFrameFormat()).To(BeFalse())
|
||||
Expect(Version42.UsesIETFFrameFormat()).To(BeFalse())
|
||||
|
|
50
internal/utils/varint_packetnumber.go
Normal file
50
internal/utils/varint_packetnumber.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// ReadVarIntPacketNumber reads a number in the QUIC varint packet number format
|
||||
func ReadVarIntPacketNumber(b *bytes.Reader) (protocol.PacketNumber, protocol.PacketNumberLen, error) {
|
||||
b1, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if b1&0x80 == 0 {
|
||||
return protocol.PacketNumber(b1), protocol.PacketNumberLen1, nil
|
||||
}
|
||||
b2, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if b1&0x40 == 0 {
|
||||
return protocol.PacketNumber(uint64(b1&0x3f)<<8 + uint64(b2)), protocol.PacketNumberLen2, nil
|
||||
}
|
||||
b3, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
b4, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return protocol.PacketNumber(uint64(b1&0x3f)<<24 + uint64(b2)<<16 + uint64(b3)<<8 + uint64(b4)), protocol.PacketNumberLen4, nil
|
||||
}
|
||||
|
||||
// WriteVarIntPacketNumber writes a packet number in the QUIC varint packet number format
|
||||
func WriteVarIntPacketNumber(b *bytes.Buffer, i protocol.PacketNumber, len protocol.PacketNumberLen) error {
|
||||
switch len {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(i & 0x7f))
|
||||
case protocol.PacketNumberLen2:
|
||||
b.Write([]byte{(uint8(i>>8) & 0x3f) | 0x80, uint8(i)})
|
||||
case protocol.PacketNumberLen4:
|
||||
b.Write([]byte{(uint8(i>>24) & 0x3f) | 0xc0, uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
default:
|
||||
return fmt.Errorf("invalid packet number length: %d", len)
|
||||
}
|
||||
return nil
|
||||
}
|
157
internal/utils/varint_packetnumber_test.go
Normal file
157
internal/utils/varint_packetnumber_test.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Varint packet number encoding / decoding", func() {
|
||||
Context("Decoding", func() {
|
||||
It("reads a 1 byte number", func() {
|
||||
b := bytes.NewReader([]byte{0x19}) // 00011001
|
||||
p, len, err := ReadVarIntPacketNumber(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen1))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x19)))
|
||||
})
|
||||
|
||||
It("errors when given an empty reader", func() {
|
||||
_, _, err := ReadVarIntPacketNumber(bytes.NewReader(nil))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
|
||||
It("reads a 2 byte number", func() {
|
||||
b := bytes.NewReader([]byte{0xb7, 0x19}) // first byte: 10110111
|
||||
p, len, err := ReadVarIntPacketNumber(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen2))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x3719)))
|
||||
})
|
||||
|
||||
It("errors on EOF when reading a 2 byte number", func() {
|
||||
b := bytes.NewReader([]byte{0xb7}) // first byte: 10110111
|
||||
_, _, err := ReadVarIntPacketNumber(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
|
||||
It("reads a 4 byte number", func() {
|
||||
b := bytes.NewReader([]byte{0xe5, 0x89, 0xfa, 0x19}) // first byte: 11100101
|
||||
p, len, err := ReadVarIntPacketNumber(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x2589fa19)))
|
||||
})
|
||||
|
||||
It("errors on EOF after the 3rd byte when reading a 4 byte number", func() {
|
||||
b := bytes.NewReader([]byte{0xe5, 0x89}) // first byte: 11100101
|
||||
_, _, err := ReadVarIntPacketNumber(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
|
||||
It("errors on EOF after the 4th byte when reading a 4 byte number", func() {
|
||||
b := bytes.NewReader([]byte{0xe5, 0x89, 0xfa}) // first byte: 11100101
|
||||
_, _, err := ReadVarIntPacketNumber(b)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Encoding", func() {
|
||||
It("writes a 1 byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteVarIntPacketNumber(b, 0x42, protocol.PacketNumberLen1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Len()).To(Equal(1))
|
||||
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen1))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x42)))
|
||||
})
|
||||
|
||||
It("only uses the least significant 7 bits when writing a 1 byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteVarIntPacketNumber(b, 0x1234ea, protocol.PacketNumberLen1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Len()).To(Equal(1))
|
||||
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen1))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x6a)))
|
||||
})
|
||||
|
||||
It("writes a small 2 byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteVarIntPacketNumber(b, 0x42, protocol.PacketNumberLen2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Len()).To(Equal(2))
|
||||
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen2))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x42)))
|
||||
})
|
||||
|
||||
It("writes a 2 byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteVarIntPacketNumber(b, 0x1337, protocol.PacketNumberLen2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Len()).To(Equal(2))
|
||||
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen2))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x1337)))
|
||||
})
|
||||
|
||||
It("only uses the least significant 14 bits when writing a 2 byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteVarIntPacketNumber(b, 0x1234ff37, protocol.PacketNumberLen2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Len()).To(Equal(2))
|
||||
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen2))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x3f37)))
|
||||
})
|
||||
|
||||
It("writes a small 4 byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteVarIntPacketNumber(b, 0xbeef, protocol.PacketNumberLen4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Len()).To(Equal(4))
|
||||
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0xbeef)))
|
||||
})
|
||||
|
||||
It("writes a 4 byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteVarIntPacketNumber(b, 0x12beef42, protocol.PacketNumberLen4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Len()).To(Equal(4))
|
||||
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x12beef42)))
|
||||
})
|
||||
|
||||
It("only uses the least significant 30 bits when writing a 4 byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteVarIntPacketNumber(b, 0x1234deadbeef, protocol.PacketNumberLen4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Len()).To(Equal(4))
|
||||
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(p).To(Equal(protocol.PacketNumber(0x1eadbeef)))
|
||||
})
|
||||
|
||||
It("errors when encountering invalid packet number lengths", func() {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteVarIntPacketNumber(b, 0x1234deadbeef, 13)
|
||||
Expect(err).To(MatchError("invalid packet number length: 13"))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -46,6 +46,7 @@ var _ = Describe("Header", func() {
|
|||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
Type: protocol.PacketType0RTT,
|
||||
PacketNumber: 0x42,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
Version: 0x1234,
|
||||
}).writeHeader(buf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
|
@ -72,12 +72,12 @@ func parseLongHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
|
|||
return nil, err
|
||||
}
|
||||
h.PayloadLen = protocol.ByteCount(pl)
|
||||
pn, err := utils.BigEndian.ReadUint32(b)
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PacketNumber = protocol.PacketNumber(pn)
|
||||
h.PacketNumberLen = protocol.PacketNumberLen4
|
||||
h.PacketNumber = pn
|
||||
h.PacketNumberLen = pnLen
|
||||
h.Type = protocol.PacketType(typeByte & 0x7f)
|
||||
|
||||
if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
|
||||
|
@ -98,25 +98,14 @@ func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
|
|||
if typeByte&0x38 != 0x30 {
|
||||
return nil, errors.New("invalid bits 3, 4 and 5")
|
||||
}
|
||||
var pnLen protocol.PacketNumberLen
|
||||
switch typeByte & 0x3 {
|
||||
case 0x0:
|
||||
pnLen = protocol.PacketNumberLen1
|
||||
case 0x1:
|
||||
pnLen = protocol.PacketNumberLen2
|
||||
case 0x2:
|
||||
pnLen = protocol.PacketNumberLen4
|
||||
default:
|
||||
return nil, errors.New("invalid short header type")
|
||||
}
|
||||
pn, err := utils.BigEndian.ReadUintN(b, uint8(pnLen))
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Header{
|
||||
KeyPhase: int(typeByte&0x40) >> 6,
|
||||
DestConnectionID: connID,
|
||||
PacketNumber: protocol.PacketNumber(pn),
|
||||
PacketNumber: pn,
|
||||
PacketNumberLen: pnLen,
|
||||
}, nil
|
||||
}
|
||||
|
@ -144,39 +133,21 @@ func (h *Header) writeLongHeader(b *bytes.Buffer) error {
|
|||
b.Write(h.DestConnectionID.Bytes())
|
||||
b.Write(h.SrcConnectionID.Bytes())
|
||||
utils.WriteVarInt(b, uint64(h.PayloadLen))
|
||||
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
|
||||
return nil
|
||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||
}
|
||||
|
||||
func (h *Header) writeShortHeader(b *bytes.Buffer) error {
|
||||
typeByte := byte(0x30)
|
||||
typeByte |= byte(h.KeyPhase << 6)
|
||||
switch h.PacketNumberLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
case protocol.PacketNumberLen2:
|
||||
typeByte |= 0x1
|
||||
case protocol.PacketNumberLen4:
|
||||
typeByte |= 0x2
|
||||
default:
|
||||
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
|
||||
}
|
||||
b.WriteByte(typeByte)
|
||||
|
||||
b.Write(h.DestConnectionID.Bytes())
|
||||
switch h.PacketNumberLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(h.PacketNumber))
|
||||
case protocol.PacketNumberLen2:
|
||||
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
|
||||
case protocol.PacketNumberLen4:
|
||||
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
|
||||
}
|
||||
return nil
|
||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||
}
|
||||
|
||||
func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
|
||||
if h.IsLongHeader {
|
||||
return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + 4 /* packet number */, nil
|
||||
return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + protocol.ByteCount(h.PacketNumberLen), nil
|
||||
}
|
||||
|
||||
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
||||
|
@ -192,7 +163,7 @@ func (h *Header) logHeader(logger utils.Logger) {
|
|||
if h.Version == 0 {
|
||||
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
|
||||
} else {
|
||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PayloadLen, h.Version)
|
||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
|
||||
}
|
||||
} else {
|
||||
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
|
||||
|
|
|
@ -17,6 +17,12 @@ import (
|
|||
var _ = Describe("IETF QUIC Header", func() {
|
||||
srcConnID := protocol.ConnectionID(bytes.Repeat([]byte{'f'}, protocol.ConnectionIDLen))
|
||||
|
||||
appendPacketNumber := func(data []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
utils.WriteVarIntPacketNumber(buf, pn, pnLen)
|
||||
return append(data, buf.Bytes()...)
|
||||
}
|
||||
|
||||
Context("parsing", func() {
|
||||
Context("Version Negotiation Packets", func() {
|
||||
It("parses", func() {
|
||||
|
@ -58,21 +64,19 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
})
|
||||
|
||||
Context("long headers", func() {
|
||||
generatePacket := func(t protocol.PacketType) []byte {
|
||||
It("parses a long header", func() {
|
||||
data := []byte{
|
||||
0x80 ^ uint8(t),
|
||||
0x80 ^ uint8(protocol.PacketTypeInitial),
|
||||
0x1, 0x2, 0x3, 0x4, // version number
|
||||
0x55, // connection ID lengths
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // destination connection ID
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // source connection ID
|
||||
}
|
||||
data = append(data, encodeVarInt(0x1337)...) // payload length
|
||||
data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) // packet number
|
||||
return data
|
||||
}
|
||||
data = append(data, encodeVarInt(0x1337)...) // payload length
|
||||
// packet number
|
||||
data = appendPacketNumber(data, 0xbeef, protocol.PacketNumberLen4)
|
||||
|
||||
It("parses a long header", func() {
|
||||
b := bytes.NewReader(generatePacket(protocol.PacketTypeInitial))
|
||||
b := bytes.NewReader(data)
|
||||
h, err := parseHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(h.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
|
@ -81,7 +85,7 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}))
|
||||
Expect(h.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}))
|
||||
Expect(h.PayloadLen).To(Equal(protocol.ByteCount(0x1337)))
|
||||
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad)))
|
||||
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef)))
|
||||
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(h.Version).To(Equal(protocol.VersionNumber(0x1020304)))
|
||||
Expect(h.IsVersionNegotiation).To(BeFalse())
|
||||
|
@ -104,6 +108,21 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
Expect(h.DestConnectionID).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("parses a long header with a 2 byte connection ID", func() {
|
||||
data := []byte{
|
||||
0x80 ^ uint8(protocol.PacketTypeInitial),
|
||||
0x1, 0x2, 0x3, 0x4, // version number
|
||||
0x0, // connection ID lengths
|
||||
}
|
||||
data = append(data, encodeVarInt(0x42)...) // payload length
|
||||
data = appendPacketNumber(data, 0x123, protocol.PacketNumberLen2)
|
||||
b := bytes.NewReader(data)
|
||||
h, err := parseHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x123)))
|
||||
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen2))
|
||||
})
|
||||
|
||||
It("parses a long header without a source connection ID", func() {
|
||||
data := []byte{
|
||||
0x80 ^ uint8(protocol.PacketTypeInitial),
|
||||
|
@ -127,6 +146,8 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
Type: 42,
|
||||
SrcConnectionID: srcConnID,
|
||||
Version: 0x10203040,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}).Write(buf, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := bytes.NewReader(buf.Bytes())
|
||||
|
@ -141,8 +162,8 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
0x55, // connection ID lengths
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // destination connection ID
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // source connection ID
|
||||
0xde, 0xca, 0xfb, 0xad, // packet number
|
||||
}
|
||||
data = appendPacketNumber(data, 0x1337, protocol.PacketNumberLen4)
|
||||
for i := 0; i < len(data); i++ {
|
||||
_, err := parseHeader(bytes.NewReader(data[:i]))
|
||||
Expect(err).To(Equal(io.EOF))
|
||||
|
@ -155,8 +176,8 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
data := []byte{
|
||||
0x30, // 1 byte packet number
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
|
||||
0x42, // packet number
|
||||
}
|
||||
data = appendPacketNumber(data, 0x42, protocol.PacketNumberLen1)
|
||||
b := bytes.NewReader(data)
|
||||
h, err := parseHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -174,8 +195,8 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
data := []byte{
|
||||
0x30 ^ 0x40,
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
|
||||
0x11,
|
||||
}
|
||||
data = appendPacketNumber(data, 11, protocol.PacketNumberLen1)
|
||||
b := bytes.NewReader(data)
|
||||
h, err := parseHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -188,8 +209,8 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
data := []byte{
|
||||
0x30 ^ 0x40 ^ 0x1,
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
|
||||
0x13, 0x37, // packet number
|
||||
}
|
||||
data = appendPacketNumber(data, 0x1337, protocol.PacketNumberLen2)
|
||||
b := bytes.NewReader(data)
|
||||
h, err := parseHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -203,34 +224,23 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
data := []byte{
|
||||
0x30 ^ 0x40 ^ 0x2,
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
|
||||
0xde, 0xad, 0xbe, 0xef, // packet number
|
||||
}
|
||||
data = appendPacketNumber(data, 0x99beef, protocol.PacketNumberLen4)
|
||||
b := bytes.NewReader(data)
|
||||
h, err := parseHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(h.IsLongHeader).To(BeFalse())
|
||||
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xdeadbeef)))
|
||||
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef)))
|
||||
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(b.Len()).To(BeZero())
|
||||
})
|
||||
|
||||
It("rejects headers that have an invalid type", func() {
|
||||
data := []byte{
|
||||
0x30 ^ 0x40 ^ 0x3,
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
|
||||
0xde, 0xad, 0xbe, 0xef, // packet number
|
||||
}
|
||||
b := bytes.NewReader(data)
|
||||
_, err := parseHeader(b)
|
||||
Expect(err).To(MatchError("invalid short header type"))
|
||||
})
|
||||
|
||||
It("rejects headers that have bit 3,4 and 5 set incorrectly", func() {
|
||||
data := []byte{
|
||||
0x38 ^ 0x2,
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
|
||||
0xde, 0xca, 0xfb, 0xad, // packet number
|
||||
}
|
||||
data = appendPacketNumber(data, 1234, protocol.PacketNumberLen2)
|
||||
b := bytes.NewReader(data)
|
||||
_, err := parseHeader(b)
|
||||
Expect(err).To(MatchError("invalid bits 3, 4 and 5"))
|
||||
|
@ -240,8 +250,8 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
data := []byte{
|
||||
0x30 ^ 0x2,
|
||||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
|
||||
0xde, 0xca, 0xfb, 0xad, // packet number
|
||||
}
|
||||
data = appendPacketNumber(data, 1234, protocol.PacketNumberLen2)
|
||||
for i := 0; i < len(data); i++ {
|
||||
_, err := parseHeader(bytes.NewReader(data[:i]))
|
||||
Expect(err).To(Equal(io.EOF))
|
||||
|
@ -265,7 +275,8 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe},
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37},
|
||||
PayloadLen: 0xcafe,
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumber: 0xdecaf,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
Version: 0x1020304,
|
||||
}).writeHeader(buf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -276,8 +287,8 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // dest connection ID
|
||||
0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37, // source connection ID
|
||||
}
|
||||
expected = append(expected, encodeVarInt(0xcafe)...) // payload length
|
||||
expected = append(expected, []byte{0xde, 0xca, 0xfb, 0xad}...) // packet number
|
||||
expected = append(expected, encodeVarInt(0xcafe)...) // payload length
|
||||
expected = appendPacketNumber(expected, 0xdecaf, protocol.PacketNumberLen4)
|
||||
Expect(buf.Bytes()).To(Equal(expected))
|
||||
})
|
||||
|
||||
|
@ -288,6 +299,7 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
SrcConnectionID: srcConnID,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3}, // connection IDs must be at least 4 bytes long
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
Version: 0x1020304,
|
||||
}).writeHeader(buf)
|
||||
Expect(err).To(MatchError("invalid connection ID length: 3 bytes"))
|
||||
|
@ -300,6 +312,7 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
SrcConnectionID: srcConnID,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, // connection IDs must be at most 18 bytes long
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
Version: 0x1020304,
|
||||
}).writeHeader(buf)
|
||||
Expect(err).To(MatchError("invalid connection ID length: 19 bytes"))
|
||||
|
@ -312,6 +325,7 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
SrcConnectionID: srcConnID,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, // connection IDs must be at most 18 bytes long
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
Version: 0x1020304,
|
||||
}).writeHeader(buf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -350,26 +364,24 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
err := (&Header{
|
||||
OmitConnectionID: true,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
PacketNumber: 0x1337,
|
||||
PacketNumber: 0x765,
|
||||
}).writeHeader(buf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(buf.Bytes()).To(Equal([]byte{
|
||||
0x30 | 0x1,
|
||||
0x13, 0x37, // packet number
|
||||
}))
|
||||
expected := []byte{0x30}
|
||||
expected = appendPacketNumber(expected, 0x765, protocol.PacketNumberLen2)
|
||||
Expect(buf.Bytes()).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("writes a header with a 4 byte packet number", func() {
|
||||
err := (&Header{
|
||||
OmitConnectionID: true,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumber: 0x123456,
|
||||
}).writeHeader(buf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(buf.Bytes()).To(Equal([]byte{
|
||||
0x30 | 0x2,
|
||||
0xde, 0xca, 0xfb, 0xad, // packet number
|
||||
}))
|
||||
expected := []byte{0x30}
|
||||
expected = appendPacketNumber(expected, 0x123456, protocol.PacketNumberLen4)
|
||||
Expect(buf.Bytes()).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("errors when given an invalid packet number length", func() {
|
||||
|
@ -410,8 +422,9 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
PayloadLen: 1,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 1 /* short payload len */ + 4 /* packet number */
|
||||
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 1 /* short payload len */ + 1 /* packet number */
|
||||
Expect(h.getHeaderLength()).To(BeEquivalentTo(expectedLen))
|
||||
err := h.writeHeader(buf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -424,8 +437,9 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
PayloadLen: 1500,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 2 /* long payload len */ + 4 /* packet number */
|
||||
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 2 /* long payload len */ + 2 /* packet number */
|
||||
Expect(h.getHeaderLength()).To(BeEquivalentTo(expectedLen))
|
||||
err := h.writeHeader(buf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -518,12 +532,13 @@ var _ = Describe("IETF QUIC Header", func() {
|
|||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
PacketNumber: 0x1337,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
PayloadLen: 54321,
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37},
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37},
|
||||
Version: 0xfeed,
|
||||
}).logHeader(logger)
|
||||
Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: 0xdeadbeefcafe1337, SrcConnectionID: 0xdecafbad13371337, PacketNumber: 0x1337, PayloadLen: 54321, Version: 0xfeed}"))
|
||||
Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: 0xdeadbeefcafe1337, SrcConnectionID: 0xdecafbad13371337, PacketNumber: 0x1337, PacketNumberLen: 2, PayloadLen: 54321, Version: 0xfeed}"))
|
||||
})
|
||||
|
||||
It("logs Short Headers containing a connection ID", func() {
|
||||
|
|
|
@ -23,6 +23,7 @@ var _ = Describe("Packing and unpacking Initial packets", func() {
|
|||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
PacketNumber: 0x42,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
Version: ver,
|
||||
|
|
|
@ -458,7 +458,6 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
|
|||
}
|
||||
|
||||
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
|
||||
header.PacketNumberLen = protocol.PacketNumberLen4
|
||||
header.IsLongHeader = true
|
||||
// Set the payload len to maximum size.
|
||||
// Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
|
||||
|
|
|
@ -216,7 +216,6 @@ var _ = Describe("Packet packer", func() {
|
|||
It("uses the Long Header format for non-forward-secure packets", func() {
|
||||
h := packer.getHeader(protocol.EncryptionSecure)
|
||||
Expect(h.IsLongHeader).To(BeTrue())
|
||||
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(h.Version).To(Equal(versionIETFHeader))
|
||||
})
|
||||
|
||||
|
|
|
@ -311,6 +311,7 @@ var _ = Describe("Server", func() {
|
|||
PayloadLen: 1000,
|
||||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
|
@ -332,6 +333,7 @@ var _ = Describe("Server", func() {
|
|||
PayloadLen: 123,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
|
@ -349,6 +351,7 @@ var _ = Describe("Server", func() {
|
|||
PayloadLen: 123,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
|
@ -507,6 +510,7 @@ var _ = Describe("Server", func() {
|
|||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
PacketNumber: 0x55,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: 0x1234,
|
||||
PayloadLen: protocol.MinInitialPacketSize,
|
||||
}
|
||||
|
@ -549,6 +553,7 @@ var _ = Describe("Server", func() {
|
|||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
PacketNumber: 0x55,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
|
@ -571,6 +576,7 @@ var _ = Describe("Server", func() {
|
|||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
PacketNumber: 0x55,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
|
|
|
@ -125,6 +125,7 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea
|
|||
SrcConnectionID: clientHdr.DestConnectionID,
|
||||
DestConnectionID: clientHdr.SrcConnectionID,
|
||||
PacketNumber: 1, // random packet number
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: clientHdr.Version,
|
||||
}
|
||||
data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer, s.logger)
|
||||
|
@ -193,6 +194,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
|
|||
SrcConnectionID: hdr.DestConnectionID,
|
||||
PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()),
|
||||
PacketNumber: hdr.PacketNumber, // echo the client's packet number
|
||||
PacketNumberLen: hdr.PacketNumberLen,
|
||||
Version: version,
|
||||
}
|
||||
data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger)
|
||||
|
|
|
@ -52,6 +52,7 @@ var _ = Describe("Stateless TLS handling", func() {
|
|||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
SrcConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
err := hdr.Write(hdrBuf, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
|
@ -85,6 +86,7 @@ var _ = Describe("Stateless TLS handling", func() {
|
|||
hdr := &wire.Header{
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: 0x1337,
|
||||
}
|
||||
server.HandleInitial(nil, hdr, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize))
|
||||
|
|
|
@ -407,7 +407,7 @@ var newTLSClientSession = func(
|
|||
|
||||
func (s *session) preSetup() {
|
||||
s.rttStats = &congestion.RTTStats{}
|
||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
|
||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version)
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
protocol.ReceiveConnectionFlowControlWindow,
|
||||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||
|
@ -621,6 +621,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
|
|||
hdr.PacketNumberLen,
|
||||
s.largestRcvdPacketNumber,
|
||||
hdr.PacketNumber,
|
||||
s.version,
|
||||
)
|
||||
|
||||
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue