Merge pull request #1388 from lucas-clemente/varint-packet-numbers

use varint packet numbers in IETF QUIC headers
This commit is contained in:
Marten Seemann 2018-06-06 16:26:49 +02:00 committed by GitHub
commit 48ef98f641
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 523 additions and 263 deletions

View file

@ -511,6 +511,7 @@ var _ = Describe("Client", func() {
PayloadLen: 1000,
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
@ -530,6 +531,7 @@ var _ = Describe("Client", func() {
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
@ -545,6 +547,7 @@ var _ = Describe("Client", func() {
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveServer, versionIETFFrames)).To(Succeed())
@ -561,7 +564,7 @@ var _ = Describe("Client", func() {
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: 1,
PacketNumberLen: protocol.PacketNumberLen1,
}).Write(buf, protocol.PerspectiveServer, versionGQUICFrames)
Expect(err).ToNot(HaveOccurred())
err = cl.handlePacket(addr, buf.Bytes())
@ -579,7 +582,7 @@ var _ = Describe("Client", func() {
DestConnectionID: connID2,
SrcConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: 1,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}).Write(buf, protocol.PerspectiveServer, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())

View file

@ -74,10 +74,12 @@ type sentPacketHandler struct {
alarm time.Time
logger utils.Logger
version protocol.VersionNumber
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
@ -92,6 +94,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) Se
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
}
}
@ -494,7 +497,7 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
}
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked())
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {

View file

@ -49,7 +49,7 @@ var _ = Describe("SentPacketHandler", func() {
BeforeEach(func() {
rttStats := &congestion.RTTStats{}
handler = NewSentPacketHandler(rttStats, utils.DefaultLogger).(*sentPacketHandler)
handler = NewSentPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*sentPacketHandler)
handler.SetHandshakeComplete()
streamFrame = wire.StreamFrame{
StreamID: 5,

View file

@ -1,8 +1,25 @@
package protocol
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func InferPacketNumber(packetNumberLength PacketNumberLen, lastPacketNumber PacketNumber, wirePacketNumber PacketNumber) PacketNumber {
epochDelta := PacketNumber(1) << (uint8(packetNumberLength) * 8)
func InferPacketNumber(
packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber,
version VersionNumber,
) PacketNumber {
var epochDelta PacketNumber
if version.UsesVarintPacketNumbers() {
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 7
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 14
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 30
}
} else {
epochDelta = PacketNumber(1) << (uint8(packetNumberLength) * 8)
}
epoch := lastPacketNumber & ^(epochDelta - 1)
prevEpochBegin := epoch - epochDelta
nextEpochBegin := epoch + epochDelta
@ -29,9 +46,10 @@ func delta(a, b PacketNumber) PacketNumber {
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber PacketNumber, leastUnacked PacketNumber) PacketNumberLen {
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked)
if diff < (1 << (uint8(PacketNumberLen2)*8 - 1)) {
if version.UsesVarintPacketNumbers() && diff < (1<<(14-1)) ||
!version.UsesVarintPacketNumbers() && diff < (1<<(16-1)) {
return PacketNumberLen2
}
return PacketNumberLen4

View file

@ -11,198 +11,219 @@ import (
// Tests taken and extended from chrome
var _ = Describe("packet number calculation", func() {
Context("infering a packet number", func() {
check := func(length PacketNumberLen, expected, last uint64) {
epoch := uint64(1) << (length * 8)
getEpoch := func(len PacketNumberLen, v VersionNumber) uint64 {
if v.UsesVarintPacketNumbers() {
switch len {
case PacketNumberLen1:
return uint64(1) << 7
case PacketNumberLen2:
return uint64(1) << 14
case PacketNumberLen4:
return uint64(1) << 30
default:
Fail("invalid packet number len")
}
}
return uint64(1) << (len * 8)
}
check := func(length PacketNumberLen, expected, last uint64, v VersionNumber) {
epoch := getEpoch(length, v)
epochMask := epoch - 1
wirePacketNumber := expected & epochMask
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected)))
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber), v)).To(Equal(PacketNumber(expected)))
}
for _, length := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4, PacketNumberLen6} {
Context(fmt.Sprintf("with %d bytes", length), func() {
epoch := uint64(1) << (length * 8)
epochMask := epoch - 1
It("works near epoch start", func() {
// A few quick manual sanity check
check(length, 1, 0)
check(length, epoch+1, epochMask)
check(length, epoch, epochMask)
for _, v := range []VersionNumber{Version39, VersionTLS} {
version := v
// Cases where the last number was close to the start of the range.
for last := uint64(0); last < 10; last++ {
// Small numbers should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, j, last)
}
Context(fmt.Sprintf("using varint packet numbers: %t", version.UsesVarintPacketNumbers()), func() {
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
length := l
// Large numbers should not wrap either (because we're near 0 already).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last)
}
}
})
Context(fmt.Sprintf("with %d bytes", length), func() {
epoch := getEpoch(length, version)
epochMask := epoch - 1
It("works near epoch end", func() {
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := epoch - i
It("works near epoch start", func() {
// A few quick manual sanity check
check(length, 1, 0, version)
check(length, epoch+1, epochMask, version)
check(length, epoch, epochMask, version)
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, epoch+j, last)
}
// Cases where the last number was close to the start of the range.
for last := uint64(0); last < 10; last++ {
// Small numbers should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, j, last, version)
}
// Large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last)
}
}
})
// Large numbers should not wrap either (because we're near 0 already).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, version)
}
}
})
// Next check where we're in a non-zero epoch to verify we handle
// reverse wrapping, too.
It("works near previous epoch", func() {
prevEpoch := 1 * epoch
curEpoch := 2 * epoch
// Cases where the last number was close to the start of the range
for i := uint64(0); i < 10; i++ {
last := curEpoch + i
// Small number should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, curEpoch+j, last)
}
It("works near epoch end", func() {
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := epoch - i
// But large numbers should reverse wrap.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, prevEpoch+num, last)
}
}
})
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, epoch+j, last, version)
}
It("works near next epoch", func() {
curEpoch := 2 * epoch
nextEpoch := 3 * epoch
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := nextEpoch - 1 - i
// Large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, version)
}
}
})
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, nextEpoch+j, last)
}
// Next check where we're in a non-zero epoch to verify we handle
// reverse wrapping, too.
It("works near previous epoch", func() {
prevEpoch := 1 * epoch
curEpoch := 2 * epoch
// Cases where the last number was close to the start of the range
for i := uint64(0); i < 10; i++ {
last := curEpoch + i
// Small number should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, curEpoch+j, last, version)
}
// but large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, curEpoch+num, last)
}
}
})
// But large numbers should reverse wrap.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, prevEpoch+num, last, version)
}
}
})
It("works near next max", func() {
maxNumber := uint64(math.MaxUint64)
maxEpoch := maxNumber & ^epochMask
It("works near next epoch", func() {
curEpoch := 2 * epoch
nextEpoch := 3 * epoch
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := nextEpoch - 1 - i
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
// Subtract 1, because the expected next packet number is 1 more than the
// last packet number.
last := maxNumber - i - 1
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, nextEpoch+j, last, version)
}
// Small numbers should not wrap, because they have nowhere to go.
for j := uint64(0); j < 10; j++ {
check(length, maxEpoch+j, last)
}
// but large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, curEpoch+num, last, version)
}
}
})
// Large numbers should not wrap either.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, maxEpoch+num, last)
}
}
It("works near next max", func() {
maxNumber := uint64(math.MaxUint64)
maxEpoch := maxNumber & ^epochMask
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
// Subtract 1, because the expected next packet number is 1 more than the
// last packet number.
last := maxNumber - i - 1
// Small numbers should not wrap, because they have nowhere to go.
for j := uint64(0); j < 10; j++ {
check(length, maxEpoch+j, last, version)
}
// Large numbers should not wrap either.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, maxEpoch+num, last, version)
}
}
})
})
}
Context("shortening a packet number for the header", func() {
Context("shortening", func() {
It("sends out low packet numbers as 2 byte", func() {
length := GetPacketNumberLengthForHeader(4, 2, version)
Expect(length).To(Equal(PacketNumberLen2))
})
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, version)
Expect(length).To(Equal(PacketNumberLen2))
})
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
length := GetPacketNumberLengthForHeader(40000, 2, version)
Expect(length).To(Equal(PacketNumberLen4))
})
})
Context("self-consistency", func() {
It("works for small packet numbers", func() {
for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
It("works for small packet numbers and increasing ACKed packets", func() {
for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i / 2)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
epochMask := getEpoch(length, version) - 1
wirePacketNumber := uint64(packetNumber) & epochMask
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
It("also works for larger packet numbers", func() {
var increment uint64
for i := uint64(1); i < getEpoch(PacketNumberLen4, version); i += increment {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
epochMask := getEpoch(length, version) - 1
wirePacketNumber := uint64(packetNumber) & epochMask
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
Expect(inferedPacketNumber).To(Equal(packetNumber))
increment = getEpoch(length, version) / 8
}
})
It("works for packet numbers larger than 2^48", func() {
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i - 1000)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
})
})
})
}
})
Context("shortening a packet number for the publicHeader", func() {
Context("shortening", func() {
It("sends out low packet numbers as 2 byte", func() {
length := GetPacketNumberLengthForHeader(4, 2)
Expect(length).To(Equal(PacketNumberLen2))
})
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
length := GetPacketNumberLengthForHeader(0xDEADBEEF, 0xDEADBEEF-1)
Expect(length).To(Equal(PacketNumberLen2))
})
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
length := GetPacketNumberLengthForHeader(40000, 2)
Expect(length).To(Equal(PacketNumberLen4))
})
})
Context("self-consistency", func() {
It("works for small packet numbers", func() {
for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
It("works for small packet numbers and increasing ACKed packets", func() {
for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i / 2)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
It("also works for larger packet numbers", func() {
increment := uint64(1 << (8 - 3))
for i := uint64(1); i < (2 << 31); i += increment {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
Expect(inferedPacketNumber).To(Equal(packetNumber))
switch length {
case PacketNumberLen2:
increment = 1 << (2*8 - 3)
case PacketNumberLen4:
increment = 1 << (4*8 - 3)
}
}
})
It("works for packet numbers larger than 2^48", func() {
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i - 1000)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
})
})
Context("determining the minimum length of a packet number", func() {
It("1 byte", func() {
Expect(GetPacketNumberLength(0xFF)).To(Equal(PacketNumberLen1))

View file

@ -86,6 +86,10 @@ func (vn VersionNumber) UsesStopWaitingFrames() bool {
return vn.isGQUIC()
}
func (vn VersionNumber) UsesVarintPacketNumbers() bool {
return !vn.isGQUIC()
}
// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control
func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool {
if id == vn.CryptoStreamID() {

View file

@ -79,6 +79,13 @@ var _ = Describe("Version", func() {
Expect(VersionTLS.UsesIETFFrameFormat()).To(BeTrue())
})
It("tells if a version uses varint packet numbers", func() {
Expect(Version39.UsesVarintPacketNumbers()).To(BeFalse())
Expect(Version42.UsesVarintPacketNumbers()).To(BeFalse())
Expect(Version43.UsesVarintPacketNumbers()).To(BeFalse())
Expect(VersionTLS.UsesVarintPacketNumbers()).To(BeTrue())
})
It("tells if a version uses the IETF frame types", func() {
Expect(Version39.UsesIETFFrameFormat()).To(BeFalse())
Expect(Version42.UsesIETFFrameFormat()).To(BeFalse())

View file

@ -0,0 +1,50 @@
package utils
import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// ReadVarIntPacketNumber reads a number in the QUIC varint packet number format
func ReadVarIntPacketNumber(b *bytes.Reader) (protocol.PacketNumber, protocol.PacketNumberLen, error) {
b1, err := b.ReadByte()
if err != nil {
return 0, 0, err
}
if b1&0x80 == 0 {
return protocol.PacketNumber(b1), protocol.PacketNumberLen1, nil
}
b2, err := b.ReadByte()
if err != nil {
return 0, 0, err
}
if b1&0x40 == 0 {
return protocol.PacketNumber(uint64(b1&0x3f)<<8 + uint64(b2)), protocol.PacketNumberLen2, nil
}
b3, err := b.ReadByte()
if err != nil {
return 0, 0, err
}
b4, err := b.ReadByte()
if err != nil {
return 0, 0, err
}
return protocol.PacketNumber(uint64(b1&0x3f)<<24 + uint64(b2)<<16 + uint64(b3)<<8 + uint64(b4)), protocol.PacketNumberLen4, nil
}
// WriteVarIntPacketNumber writes a packet number in the QUIC varint packet number format
func WriteVarIntPacketNumber(b *bytes.Buffer, i protocol.PacketNumber, len protocol.PacketNumberLen) error {
switch len {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(i & 0x7f))
case protocol.PacketNumberLen2:
b.Write([]byte{(uint8(i>>8) & 0x3f) | 0x80, uint8(i)})
case protocol.PacketNumberLen4:
b.Write([]byte{(uint8(i>>24) & 0x3f) | 0xc0, uint8(i >> 16), uint8(i >> 8), uint8(i)})
default:
return fmt.Errorf("invalid packet number length: %d", len)
}
return nil
}

View file

@ -0,0 +1,157 @@
package utils
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Varint packet number encoding / decoding", func() {
Context("Decoding", func() {
It("reads a 1 byte number", func() {
b := bytes.NewReader([]byte{0x19}) // 00011001
p, len, err := ReadVarIntPacketNumber(b)
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen1))
Expect(p).To(Equal(protocol.PacketNumber(0x19)))
})
It("errors when given an empty reader", func() {
_, _, err := ReadVarIntPacketNumber(bytes.NewReader(nil))
Expect(err).To(MatchError(io.EOF))
})
It("reads a 2 byte number", func() {
b := bytes.NewReader([]byte{0xb7, 0x19}) // first byte: 10110111
p, len, err := ReadVarIntPacketNumber(b)
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen2))
Expect(p).To(Equal(protocol.PacketNumber(0x3719)))
})
It("errors on EOF when reading a 2 byte number", func() {
b := bytes.NewReader([]byte{0xb7}) // first byte: 10110111
_, _, err := ReadVarIntPacketNumber(b)
Expect(err).To(MatchError(io.EOF))
})
It("reads a 4 byte number", func() {
b := bytes.NewReader([]byte{0xe5, 0x89, 0xfa, 0x19}) // first byte: 11100101
p, len, err := ReadVarIntPacketNumber(b)
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen4))
Expect(p).To(Equal(protocol.PacketNumber(0x2589fa19)))
})
It("errors on EOF after the 3rd byte when reading a 4 byte number", func() {
b := bytes.NewReader([]byte{0xe5, 0x89}) // first byte: 11100101
_, _, err := ReadVarIntPacketNumber(b)
Expect(err).To(MatchError(io.EOF))
})
It("errors on EOF after the 4th byte when reading a 4 byte number", func() {
b := bytes.NewReader([]byte{0xe5, 0x89, 0xfa}) // first byte: 11100101
_, _, err := ReadVarIntPacketNumber(b)
Expect(err).To(MatchError(io.EOF))
})
})
Context("Encoding", func() {
It("writes a 1 byte packet number", func() {
b := &bytes.Buffer{}
err := WriteVarIntPacketNumber(b, 0x42, protocol.PacketNumberLen1)
Expect(err).ToNot(HaveOccurred())
Expect(b.Len()).To(Equal(1))
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen1))
Expect(p).To(Equal(protocol.PacketNumber(0x42)))
})
It("only uses the least significant 7 bits when writing a 1 byte packet number", func() {
b := &bytes.Buffer{}
err := WriteVarIntPacketNumber(b, 0x1234ea, protocol.PacketNumberLen1)
Expect(err).ToNot(HaveOccurred())
Expect(b.Len()).To(Equal(1))
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen1))
Expect(p).To(Equal(protocol.PacketNumber(0x6a)))
})
It("writes a small 2 byte packet number", func() {
b := &bytes.Buffer{}
err := WriteVarIntPacketNumber(b, 0x42, protocol.PacketNumberLen2)
Expect(err).ToNot(HaveOccurred())
Expect(b.Len()).To(Equal(2))
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen2))
Expect(p).To(Equal(protocol.PacketNumber(0x42)))
})
It("writes a 2 byte packet number", func() {
b := &bytes.Buffer{}
err := WriteVarIntPacketNumber(b, 0x1337, protocol.PacketNumberLen2)
Expect(err).ToNot(HaveOccurred())
Expect(b.Len()).To(Equal(2))
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen2))
Expect(p).To(Equal(protocol.PacketNumber(0x1337)))
})
It("only uses the least significant 14 bits when writing a 2 byte packet number", func() {
b := &bytes.Buffer{}
err := WriteVarIntPacketNumber(b, 0x1234ff37, protocol.PacketNumberLen2)
Expect(err).ToNot(HaveOccurred())
Expect(b.Len()).To(Equal(2))
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen2))
Expect(p).To(Equal(protocol.PacketNumber(0x3f37)))
})
It("writes a small 4 byte packet number", func() {
b := &bytes.Buffer{}
err := WriteVarIntPacketNumber(b, 0xbeef, protocol.PacketNumberLen4)
Expect(err).ToNot(HaveOccurred())
Expect(b.Len()).To(Equal(4))
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen4))
Expect(p).To(Equal(protocol.PacketNumber(0xbeef)))
})
It("writes a 4 byte packet number", func() {
b := &bytes.Buffer{}
err := WriteVarIntPacketNumber(b, 0x12beef42, protocol.PacketNumberLen4)
Expect(err).ToNot(HaveOccurred())
Expect(b.Len()).To(Equal(4))
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen4))
Expect(p).To(Equal(protocol.PacketNumber(0x12beef42)))
})
It("only uses the least significant 30 bits when writing a 4 byte packet number", func() {
b := &bytes.Buffer{}
err := WriteVarIntPacketNumber(b, 0x1234deadbeef, protocol.PacketNumberLen4)
Expect(err).ToNot(HaveOccurred())
Expect(b.Len()).To(Equal(4))
p, len, err := ReadVarIntPacketNumber(bytes.NewReader(b.Bytes()))
Expect(err).ToNot(HaveOccurred())
Expect(len).To(Equal(protocol.PacketNumberLen4))
Expect(p).To(Equal(protocol.PacketNumber(0x1eadbeef)))
})
It("errors when encountering invalid packet number lengths", func() {
b := &bytes.Buffer{}
err := WriteVarIntPacketNumber(b, 0x1234deadbeef, 13)
Expect(err).To(MatchError("invalid packet number length: 13"))
})
})
})

View file

@ -46,6 +46,7 @@ var _ = Describe("Header", func() {
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
Type: protocol.PacketType0RTT,
PacketNumber: 0x42,
PacketNumberLen: protocol.PacketNumberLen2,
Version: 0x1234,
}).writeHeader(buf)
Expect(err).ToNot(HaveOccurred())

View file

@ -72,12 +72,12 @@ func parseLongHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
return nil, err
}
h.PayloadLen = protocol.ByteCount(pl)
pn, err := utils.BigEndian.ReadUint32(b)
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
h.PacketNumber = protocol.PacketNumber(pn)
h.PacketNumberLen = protocol.PacketNumberLen4
h.PacketNumber = pn
h.PacketNumberLen = pnLen
h.Type = protocol.PacketType(typeByte & 0x7f)
if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
@ -98,25 +98,14 @@ func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
if typeByte&0x38 != 0x30 {
return nil, errors.New("invalid bits 3, 4 and 5")
}
var pnLen protocol.PacketNumberLen
switch typeByte & 0x3 {
case 0x0:
pnLen = protocol.PacketNumberLen1
case 0x1:
pnLen = protocol.PacketNumberLen2
case 0x2:
pnLen = protocol.PacketNumberLen4
default:
return nil, errors.New("invalid short header type")
}
pn, err := utils.BigEndian.ReadUintN(b, uint8(pnLen))
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
return &Header{
KeyPhase: int(typeByte&0x40) >> 6,
DestConnectionID: connID,
PacketNumber: protocol.PacketNumber(pn),
PacketNumber: pn,
PacketNumberLen: pnLen,
}, nil
}
@ -144,39 +133,21 @@ func (h *Header) writeLongHeader(b *bytes.Buffer) error {
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
utils.WriteVarInt(b, uint64(h.PayloadLen))
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
return nil
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
func (h *Header) writeShortHeader(b *bytes.Buffer) error {
typeByte := byte(0x30)
typeByte |= byte(h.KeyPhase << 6)
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
case protocol.PacketNumberLen2:
typeByte |= 0x1
case protocol.PacketNumberLen4:
typeByte |= 0x2
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
}
return nil
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
if h.IsLongHeader {
return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + 4 /* packet number */, nil
return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + protocol.ByteCount(h.PacketNumberLen), nil
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
@ -192,7 +163,7 @@ func (h *Header) logHeader(logger utils.Logger) {
if h.Version == 0 {
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
} else {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PayloadLen, h.Version)
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
}
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)

View file

@ -17,6 +17,12 @@ import (
var _ = Describe("IETF QUIC Header", func() {
srcConnID := protocol.ConnectionID(bytes.Repeat([]byte{'f'}, protocol.ConnectionIDLen))
appendPacketNumber := func(data []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) []byte {
buf := &bytes.Buffer{}
utils.WriteVarIntPacketNumber(buf, pn, pnLen)
return append(data, buf.Bytes()...)
}
Context("parsing", func() {
Context("Version Negotiation Packets", func() {
It("parses", func() {
@ -58,21 +64,19 @@ var _ = Describe("IETF QUIC Header", func() {
})
Context("long headers", func() {
generatePacket := func(t protocol.PacketType) []byte {
It("parses a long header", func() {
data := []byte{
0x80 ^ uint8(t),
0x80 ^ uint8(protocol.PacketTypeInitial),
0x1, 0x2, 0x3, 0x4, // version number
0x55, // connection ID lengths
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // destination connection ID
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // source connection ID
}
data = append(data, encodeVarInt(0x1337)...) // payload length
data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) // packet number
return data
}
data = append(data, encodeVarInt(0x1337)...) // payload length
// packet number
data = appendPacketNumber(data, 0xbeef, protocol.PacketNumberLen4)
It("parses a long header", func() {
b := bytes.NewReader(generatePacket(protocol.PacketTypeInitial))
b := bytes.NewReader(data)
h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred())
Expect(h.Type).To(Equal(protocol.PacketTypeInitial))
@ -81,7 +85,7 @@ var _ = Describe("IETF QUIC Header", func() {
Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}))
Expect(h.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}))
Expect(h.PayloadLen).To(Equal(protocol.ByteCount(0x1337)))
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad)))
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef)))
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
Expect(h.Version).To(Equal(protocol.VersionNumber(0x1020304)))
Expect(h.IsVersionNegotiation).To(BeFalse())
@ -104,6 +108,21 @@ var _ = Describe("IETF QUIC Header", func() {
Expect(h.DestConnectionID).To(BeEmpty())
})
It("parses a long header with a 2 byte connection ID", func() {
data := []byte{
0x80 ^ uint8(protocol.PacketTypeInitial),
0x1, 0x2, 0x3, 0x4, // version number
0x0, // connection ID lengths
}
data = append(data, encodeVarInt(0x42)...) // payload length
data = appendPacketNumber(data, 0x123, protocol.PacketNumberLen2)
b := bytes.NewReader(data)
h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred())
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x123)))
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen2))
})
It("parses a long header without a source connection ID", func() {
data := []byte{
0x80 ^ uint8(protocol.PacketTypeInitial),
@ -127,6 +146,8 @@ var _ = Describe("IETF QUIC Header", func() {
Type: 42,
SrcConnectionID: srcConnID,
Version: 0x10203040,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen1,
}).Write(buf, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(buf.Bytes())
@ -141,8 +162,8 @@ var _ = Describe("IETF QUIC Header", func() {
0x55, // connection ID lengths
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // destination connection ID
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // source connection ID
0xde, 0xca, 0xfb, 0xad, // packet number
}
data = appendPacketNumber(data, 0x1337, protocol.PacketNumberLen4)
for i := 0; i < len(data); i++ {
_, err := parseHeader(bytes.NewReader(data[:i]))
Expect(err).To(Equal(io.EOF))
@ -155,8 +176,8 @@ var _ = Describe("IETF QUIC Header", func() {
data := []byte{
0x30, // 1 byte packet number
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0x42, // packet number
}
data = appendPacketNumber(data, 0x42, protocol.PacketNumberLen1)
b := bytes.NewReader(data)
h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred())
@ -174,8 +195,8 @@ var _ = Describe("IETF QUIC Header", func() {
data := []byte{
0x30 ^ 0x40,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0x11,
}
data = appendPacketNumber(data, 11, protocol.PacketNumberLen1)
b := bytes.NewReader(data)
h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred())
@ -188,8 +209,8 @@ var _ = Describe("IETF QUIC Header", func() {
data := []byte{
0x30 ^ 0x40 ^ 0x1,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0x13, 0x37, // packet number
}
data = appendPacketNumber(data, 0x1337, protocol.PacketNumberLen2)
b := bytes.NewReader(data)
h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred())
@ -203,34 +224,23 @@ var _ = Describe("IETF QUIC Header", func() {
data := []byte{
0x30 ^ 0x40 ^ 0x2,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0xde, 0xad, 0xbe, 0xef, // packet number
}
data = appendPacketNumber(data, 0x99beef, protocol.PacketNumberLen4)
b := bytes.NewReader(data)
h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred())
Expect(h.IsLongHeader).To(BeFalse())
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xdeadbeef)))
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef)))
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
Expect(b.Len()).To(BeZero())
})
It("rejects headers that have an invalid type", func() {
data := []byte{
0x30 ^ 0x40 ^ 0x3,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0xde, 0xad, 0xbe, 0xef, // packet number
}
b := bytes.NewReader(data)
_, err := parseHeader(b)
Expect(err).To(MatchError("invalid short header type"))
})
It("rejects headers that have bit 3,4 and 5 set incorrectly", func() {
data := []byte{
0x38 ^ 0x2,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0xde, 0xca, 0xfb, 0xad, // packet number
}
data = appendPacketNumber(data, 1234, protocol.PacketNumberLen2)
b := bytes.NewReader(data)
_, err := parseHeader(b)
Expect(err).To(MatchError("invalid bits 3, 4 and 5"))
@ -240,8 +250,8 @@ var _ = Describe("IETF QUIC Header", func() {
data := []byte{
0x30 ^ 0x2,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0xde, 0xca, 0xfb, 0xad, // packet number
}
data = appendPacketNumber(data, 1234, protocol.PacketNumberLen2)
for i := 0; i < len(data); i++ {
_, err := parseHeader(bytes.NewReader(data[:i]))
Expect(err).To(Equal(io.EOF))
@ -265,7 +275,8 @@ var _ = Describe("IETF QUIC Header", func() {
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe},
SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37},
PayloadLen: 0xcafe,
PacketNumber: 0xdecafbad,
PacketNumber: 0xdecaf,
PacketNumberLen: protocol.PacketNumberLen4,
Version: 0x1020304,
}).writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
@ -276,8 +287,8 @@ var _ = Describe("IETF QUIC Header", func() {
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // dest connection ID
0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37, // source connection ID
}
expected = append(expected, encodeVarInt(0xcafe)...) // payload length
expected = append(expected, []byte{0xde, 0xca, 0xfb, 0xad}...) // packet number
expected = append(expected, encodeVarInt(0xcafe)...) // payload length
expected = appendPacketNumber(expected, 0xdecaf, protocol.PacketNumberLen4)
Expect(buf.Bytes()).To(Equal(expected))
})
@ -288,6 +299,7 @@ var _ = Describe("IETF QUIC Header", func() {
SrcConnectionID: srcConnID,
DestConnectionID: protocol.ConnectionID{1, 2, 3}, // connection IDs must be at least 4 bytes long
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
Version: 0x1020304,
}).writeHeader(buf)
Expect(err).To(MatchError("invalid connection ID length: 3 bytes"))
@ -300,6 +312,7 @@ var _ = Describe("IETF QUIC Header", func() {
SrcConnectionID: srcConnID,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, // connection IDs must be at most 18 bytes long
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
Version: 0x1020304,
}).writeHeader(buf)
Expect(err).To(MatchError("invalid connection ID length: 19 bytes"))
@ -312,6 +325,7 @@ var _ = Describe("IETF QUIC Header", func() {
SrcConnectionID: srcConnID,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, // connection IDs must be at most 18 bytes long
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
Version: 0x1020304,
}).writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
@ -350,26 +364,24 @@ var _ = Describe("IETF QUIC Header", func() {
err := (&Header{
OmitConnectionID: true,
PacketNumberLen: protocol.PacketNumberLen2,
PacketNumber: 0x1337,
PacketNumber: 0x765,
}).writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
Expect(buf.Bytes()).To(Equal([]byte{
0x30 | 0x1,
0x13, 0x37, // packet number
}))
expected := []byte{0x30}
expected = appendPacketNumber(expected, 0x765, protocol.PacketNumberLen2)
Expect(buf.Bytes()).To(Equal(expected))
})
It("writes a header with a 4 byte packet number", func() {
err := (&Header{
OmitConnectionID: true,
PacketNumberLen: protocol.PacketNumberLen4,
PacketNumber: 0xdecafbad,
PacketNumber: 0x123456,
}).writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
Expect(buf.Bytes()).To(Equal([]byte{
0x30 | 0x2,
0xde, 0xca, 0xfb, 0xad, // packet number
}))
expected := []byte{0x30}
expected = appendPacketNumber(expected, 0x123456, protocol.PacketNumberLen4)
Expect(buf.Bytes()).To(Equal(expected))
})
It("errors when given an invalid packet number length", func() {
@ -410,8 +422,9 @@ var _ = Describe("IETF QUIC Header", func() {
PayloadLen: 1,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
PacketNumberLen: protocol.PacketNumberLen1,
}
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 1 /* short payload len */ + 4 /* packet number */
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 1 /* short payload len */ + 1 /* packet number */
Expect(h.getHeaderLength()).To(BeEquivalentTo(expectedLen))
err := h.writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
@ -424,8 +437,9 @@ var _ = Describe("IETF QUIC Header", func() {
PayloadLen: 1500,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
PacketNumberLen: protocol.PacketNumberLen2,
}
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 2 /* long payload len */ + 4 /* packet number */
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* conn ID len */ + 8 /* dest conn id */ + 8 /* src conn id */ + 2 /* long payload len */ + 2 /* packet number */
Expect(h.getHeaderLength()).To(BeEquivalentTo(expectedLen))
err := h.writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
@ -518,12 +532,13 @@ var _ = Describe("IETF QUIC Header", func() {
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PacketNumber: 0x1337,
PacketNumberLen: protocol.PacketNumberLen2,
PayloadLen: 54321,
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37},
SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37},
Version: 0xfeed,
}).logHeader(logger)
Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: 0xdeadbeefcafe1337, SrcConnectionID: 0xdecafbad13371337, PacketNumber: 0x1337, PayloadLen: 54321, Version: 0xfeed}"))
Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, DestConnectionID: 0xdeadbeefcafe1337, SrcConnectionID: 0xdecafbad13371337, PacketNumber: 0x1337, PacketNumberLen: 2, PayloadLen: 54321, Version: 0xfeed}"))
})
It("logs Short Headers containing a connection ID", func() {

View file

@ -23,6 +23,7 @@ var _ = Describe("Packing and unpacking Initial packets", func() {
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
PacketNumber: 0x42,
PacketNumberLen: protocol.PacketNumberLen1,
DestConnectionID: connID,
SrcConnectionID: connID,
Version: ver,

View file

@ -458,7 +458,6 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
}
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
header.PacketNumberLen = protocol.PacketNumberLen4
header.IsLongHeader = true
// Set the payload len to maximum size.
// Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.

View file

@ -216,7 +216,6 @@ var _ = Describe("Packet packer", func() {
It("uses the Long Header format for non-forward-secure packets", func() {
h := packer.getHeader(protocol.EncryptionSecure)
Expect(h.IsLongHeader).To(BeTrue())
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
Expect(h.Version).To(Equal(versionIETFHeader))
})

View file

@ -311,6 +311,7 @@ var _ = Describe("Server", func() {
PayloadLen: 1000,
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
@ -332,6 +333,7 @@ var _ = Describe("Server", func() {
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
@ -349,6 +351,7 @@ var _ = Describe("Server", func() {
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
@ -507,6 +510,7 @@ var _ = Describe("Server", func() {
DestConnectionID: connID,
SrcConnectionID: connID,
PacketNumber: 0x55,
PacketNumberLen: protocol.PacketNumberLen1,
Version: 0x1234,
PayloadLen: protocol.MinInitialPacketSize,
}
@ -549,6 +553,7 @@ var _ = Describe("Server", func() {
DestConnectionID: connID,
SrcConnectionID: connID,
PacketNumber: 0x55,
PacketNumberLen: protocol.PacketNumberLen1,
Version: protocol.VersionTLS,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
@ -571,6 +576,7 @@ var _ = Describe("Server", func() {
DestConnectionID: connID,
SrcConnectionID: connID,
PacketNumber: 0x55,
PacketNumberLen: protocol.PacketNumberLen1,
Version: protocol.VersionTLS,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)

View file

@ -125,6 +125,7 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea
SrcConnectionID: clientHdr.DestConnectionID,
DestConnectionID: clientHdr.SrcConnectionID,
PacketNumber: 1, // random packet number
PacketNumberLen: protocol.PacketNumberLen1,
Version: clientHdr.Version,
}
data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer, s.logger)
@ -193,6 +194,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
SrcConnectionID: hdr.DestConnectionID,
PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()),
PacketNumber: hdr.PacketNumber, // echo the client's packet number
PacketNumberLen: hdr.PacketNumberLen,
Version: version,
}
data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger)

View file

@ -52,6 +52,7 @@ var _ = Describe("Stateless TLS handling", func() {
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
SrcConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen1,
Version: protocol.VersionTLS,
}
err := hdr.Write(hdrBuf, protocol.PerspectiveClient, protocol.VersionTLS)
@ -85,6 +86,7 @@ var _ = Describe("Stateless TLS handling", func() {
hdr := &wire.Header{
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
PacketNumberLen: protocol.PacketNumberLen1,
Version: 0x1337,
}
server.HandleInitial(nil, hdr, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize))

View file

@ -407,7 +407,7 @@ var newTLSClientSession = func(
func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ReceiveConnectionFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
@ -621,6 +621,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
hdr.PacketNumberLen,
s.largestRcvdPacketNumber,
hdr.PacketNumber,
s.version,
)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)