refactor header writing to append to a byte slice (#3646)

This avoids having to allocate a bytes.Buffer.
This commit is contained in:
Marten Seemann 2023-01-17 01:56:06 -08:00 committed by GitHub
parent 3d4bbc28ba
commit c24fbb094c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 282 additions and 279 deletions

View file

@ -571,8 +571,8 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
buf := &bytes.Buffer{} b, err := wire.AppendShortHeader(nil, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
Expect(wire.WriteShortHeader(buf, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed()) Expect(err).ToNot(HaveOccurred())
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version) b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version)
@ -589,7 +589,7 @@ var _ = Describe("Connection", func() {
rcvTime: time.Now(), rcvTime: time.Now(),
remoteAddr: &net.UDPAddr{}, remoteAddr: &net.UDPAddr{},
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
data: buf.Bytes(), data: b,
}) })
// Consistently(pack).ShouldNot(Receive()) // Consistently(pack).ShouldNot(Receive())
Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed())
@ -649,20 +649,20 @@ var _ = Describe("Connection", func() {
}) })
getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) *receivedPacket { getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) *receivedPacket {
buf := &bytes.Buffer{} b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
Expect(wire.WriteShortHeader(buf, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed()) Expect(err).ToNot(HaveOccurred())
return &receivedPacket{ return &receivedPacket{
data: append(buf.Bytes(), data...), data: append(b, data...),
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
rcvTime: time.Now(), rcvTime: time.Now(),
} }
} }
getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket {
buf := &bytes.Buffer{} b, err := extHdr.Append(nil, conn.version)
Expect(extHdr.Write(buf, conn.version)).To(Succeed()) Expect(err).ToNot(HaveOccurred())
return &receivedPacket{ return &receivedPacket{
data: append(buf.Bytes(), data...), data: append(b, data...),
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
rcvTime: time.Now(), rcvTime: time.Now(),
} }
@ -2365,10 +2365,10 @@ var _ = Describe("Client Connection", func() {
destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket {
buf := &bytes.Buffer{} b, err := hdr.Append(nil, conn.version)
Expect(hdr.Write(buf, conn.version)).To(Succeed()) Expect(err).ToNot(HaveOccurred())
return &receivedPacket{ return &receivedPacket{
data: append(buf.Bytes(), data...), data: append(b, data...),
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
} }
} }
@ -2643,9 +2643,9 @@ var _ = Describe("Client Connection", func() {
}) })
getRetryTag := func(hdr *wire.ExtendedHeader) []byte { getRetryTag := func(hdr *wire.ExtendedHeader) []byte {
buf := &bytes.Buffer{} b, err := hdr.Append(nil, conn.version)
hdr.Write(buf, conn.version) Expect(err).ToNot(HaveOccurred())
return handshake.GetRetryIntegrityTag(buf.Bytes(), origDestConnID, hdr.Version)[:] return handshake.GetRetryIntegrityTag(b, origDestConnID, hdr.Version)[:]
} }
It("handles Retry packets", func() { It("handles Retry packets", func() {
@ -2858,10 +2858,10 @@ var _ = Describe("Client Connection", func() {
var unpacker *MockUnpacker var unpacker *MockUnpacker
getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket {
buf := &bytes.Buffer{} b, err := extHdr.Append(nil, conn.version)
Expect(extHdr.Write(buf, conn.version)).To(Succeed()) Expect(err).ToNot(HaveOccurred())
return &receivedPacket{ return &receivedPacket{
data: append(buf.Bytes(), data...), data: append(b, data...),
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
} }
} }

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"bytes"
"log" "log"
"math/rand" "math/rand"
@ -85,28 +84,28 @@ func main() {
PacketNumberLen: protocol.PacketNumberLen(rand.Intn(4) + 1), PacketNumberLen: protocol.PacketNumberLen(rand.Intn(4) + 1),
PacketNumber: protocol.PacketNumber(rand.Uint64()), PacketNumber: protocol.PacketNumber(rand.Uint64()),
} }
b := &bytes.Buffer{} b, err := extHdr.Append(nil, version)
if err := extHdr.Write(b, version); err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if h.Type == protocol.PacketTypeRetry { if h.Type == protocol.PacketTypeRetry {
b.Write([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) b = append(b, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}...)
} }
if h.Length > 0 { if h.Length > 0 {
b.Write(make([]byte, h.Length)) b = append(b, make([]byte, h.Length)...)
} }
if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil { if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
// short header // short header
b := &bytes.Buffer{} b, err := wire.AppendShortHeader(nil, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
if err := wire.WriteShortHeader(b, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne); err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil { if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil {
log.Fatal(err) log.Fatal(err)
} }

View file

@ -64,8 +64,8 @@ func Fuzz(data []byte) int {
if hdr.Length > 16383 { if hdr.Length > 16383 {
return 1 return 1
} }
b := &bytes.Buffer{} b, err := extHdr.Append(nil, version)
if err := extHdr.Write(b, version); err != nil { if err != nil {
// We are able to parse packets with connection IDs longer than 20 bytes, // We are able to parse packets with connection IDs longer than 20 bytes,
// but in QUIC version 1, we don't write headers with longer connection IDs. // but in QUIC version 1, we don't write headers with longer connection IDs.
if hdr.DestConnectionID.Len() <= protocol.MaxConnIDLen && if hdr.DestConnectionID.Len() <= protocol.MaxConnIDLen &&
@ -76,8 +76,8 @@ func Fuzz(data []byte) int {
} }
// GetLength is not implemented for Retry packets // GetLength is not implemented for Retry packets
if hdr.Type != protocol.PacketTypeRetry { if hdr.Type != protocol.PacketTypeRetry {
if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(b.Len()) { if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(len(b)) {
panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, b.Len())) panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, len(b)))
} }
} }
return 1 return 1

View file

@ -1,7 +1,6 @@
package self_test package self_test
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -116,12 +115,12 @@ var _ = Describe("MITM test", func() {
for i := 0; i < numPackets; i++ { for i := 0; i < numPackets; i++ {
payloadLen := mrand.Int31n(100) payloadLen := mrand.Int31n(100)
replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1)) replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1))
buf := &bytes.Buffer{} b, err := replyHdr.Append(nil, version)
Expect(replyHdr.Write(buf, version)).To(Succeed()) Expect(err).ToNot(HaveOccurred())
b := make([]byte, payloadLen) r := make([]byte, payloadLen)
mrand.Read(b) mrand.Read(r)
buf.Write(b) b = append(b, r...)
if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { if _, err := conn.WriteTo(b, remoteAddr); err != nil {
return return
} }
<-ticker.C <-ticker.C
@ -134,13 +133,13 @@ var _ = Describe("MITM test", func() {
Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
} }
for i := 0; i < numPackets; i++ { for i := 0; i < numPackets; i++ {
buf := &bytes.Buffer{} b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))
Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))).To(Succeed()) Expect(err).ToNot(HaveOccurred())
payloadLen := mrand.Int31n(100) payloadLen := mrand.Int31n(100)
b := make([]byte, payloadLen) r := make([]byte, payloadLen)
mrand.Read(b) mrand.Read(r)
buf.Write(b) b = append(b, r...)
if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { if _, err := conn.WriteTo(b, remoteAddr); err != nil {
return return
} }
<-ticker.C <-ticker.C

View file

@ -28,7 +28,6 @@ func isProxyRunning() bool {
var _ = Describe("QUIC Proxy", func() { var _ = Describe("QUIC Proxy", func() {
makePacket := func(p protocol.PacketNumber, payload []byte) []byte { makePacket := func(p protocol.PacketNumber, payload []byte) []byte {
b := &bytes.Buffer{}
hdr := wire.ExtendedHeader{ hdr := wire.ExtendedHeader{
Header: wire.Header{ Header: wire.Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -40,10 +39,10 @@ var _ = Describe("QUIC Proxy", func() {
PacketNumber: p, PacketNumber: p,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
} }
Expect(hdr.Write(b, protocol.VersionWhatever)).To(Succeed()) b, err := hdr.Append(nil, protocol.Version1)
raw := b.Bytes() Expect(err).ToNot(HaveOccurred())
raw = append(raw, payload...) b = append(b, payload...)
return raw return b
} }
readPacketNumber := func(b []byte) protocol.PacketNumber { readPacketNumber := func(b []byte) protocol.PacketNumber {

View file

@ -1,7 +1,7 @@
package testutils package testutils
import ( import (
"bytes" "fmt"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
@ -13,9 +13,11 @@ import (
// writePacket returns a new raw packet with the specified header and payload // writePacket returns a new raw packet with the specified header and payload
func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte { func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte {
buf := &bytes.Buffer{} b, err := hdr.Append(nil, hdr.Version)
hdr.Write(buf, hdr.Version) if err != nil {
return append(buf.Bytes(), data...) panic(fmt.Sprintf("failed to write header: %s", err))
}
return append(b, data...)
} }
// packRawPayload returns a new raw payload containing given frames // packRawPayload returns a new raw payload containing given frames

View file

@ -2,6 +2,7 @@ package wire
import ( import (
"bytes" "bytes"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -93,20 +94,17 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
return nil return nil
} }
// Write writes the Header. // Append appends the Header.
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) {
if h.DestConnectionID.Len() > protocol.MaxConnIDLen { if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
} }
if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
}
return h.writeLongHeader(b, ver)
} }
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error {
var packetType uint8 var packetType uint8
if version == protocol.Version2 { if v == protocol.Version2 {
//nolint:exhaustive //nolint:exhaustive
switch h.Type { switch h.Type {
case protocol.PacketTypeInitial: case protocol.PacketTypeInitial:
@ -137,24 +135,25 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.Versi
firstByte |= uint8(h.PacketNumberLen - 1) firstByte |= uint8(h.PacketNumberLen - 1)
} }
b.WriteByte(firstByte) b = append(b, firstByte)
utils.BigEndian.WriteUint32(b, uint32(h.Version)) b = append(b, make([]byte, 4)...)
b.WriteByte(uint8(h.DestConnectionID.Len())) binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version))
b.Write(h.DestConnectionID.Bytes()) b = append(b, uint8(h.DestConnectionID.Len()))
b.WriteByte(uint8(h.SrcConnectionID.Len())) b = append(b, h.DestConnectionID.Bytes()...)
b.Write(h.SrcConnectionID.Bytes()) b = append(b, uint8(h.SrcConnectionID.Len()))
b = append(b, h.SrcConnectionID.Bytes()...)
//nolint:exhaustive //nolint:exhaustive
switch h.Type { switch h.Type {
case protocol.PacketTypeRetry: case protocol.PacketTypeRetry:
b.Write(h.Token) b = append(b, h.Token...)
return nil return b, nil
case protocol.PacketTypeInitial: case protocol.PacketTypeInitial:
quicvarint.Write(b, uint64(len(h.Token))) b = quicvarint.Append(b, uint64(len(h.Token)))
b.Write(h.Token) b = append(b, h.Token...)
} }
quicvarint.WriteWithLen(b, uint64(h.Length), 2) b = quicvarint.AppendWithLen(b, uint64(h.Length), 2)
return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen) return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
} }
// ParsedLen returns the number of bytes that were consumed when parsing the header // ParsedLen returns the number of bytes that were consumed when parsing the header
@ -188,18 +187,24 @@ func (h *ExtendedHeader) Log(logger utils.Logger) {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
} }
func writePacketNumber(b *bytes.Buffer, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) error { func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) {
switch pnLen { switch pnLen {
case protocol.PacketNumberLen1: case protocol.PacketNumberLen1:
b.WriteByte(uint8(pn)) b = append(b, uint8(pn))
case protocol.PacketNumberLen2: case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(pn)) buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(pn))
b = append(b, buf...)
case protocol.PacketNumberLen3: case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(pn)) buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(pn))
b = append(b, buf[1:]...)
case protocol.PacketNumberLen4: case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(pn)) buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(pn))
b = append(b, buf...)
default: default:
return fmt.Errorf("invalid packet number length: %d", pnLen) return nil, fmt.Errorf("invalid packet number length: %d", pnLen)
} }
return nil return b, nil
} }

View file

@ -14,20 +14,12 @@ import (
) )
var _ = Describe("Header", func() { var _ = Describe("Header", func() {
const versionIETFHeader = protocol.Version1
Context("Writing", func() { Context("Writing", func() {
var buf *bytes.Buffer Context("Long Header, version 1", func() {
BeforeEach(func() {
buf = &bytes.Buffer{}
})
Context("Long Header", func() {
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
It("writes", func() { It("writes", func() {
Expect((&ExtendedHeader{ b, err := (&ExtendedHeader{
Header: Header{ Header: Header{
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}), DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}),
@ -37,7 +29,8 @@ var _ = Describe("Header", func() {
}, },
PacketNumber: 0xdecaf, PacketNumber: 0xdecaf,
PacketNumberLen: protocol.PacketNumberLen3, PacketNumberLen: protocol.PacketNumberLen3,
}).Write(buf, versionIETFHeader)).To(Succeed()) }).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
expected := []byte{ expected := []byte{
0xc0 | 0x2<<4 | 0x2, 0xc0 | 0x2<<4 | 0x2,
0x1, 0x2, 0x3, 0x4, // version number 0x1, 0x2, 0x3, 0x4, // version number
@ -48,11 +41,11 @@ var _ = Describe("Header", func() {
} }
expected = append(expected, encodeVarInt(protocol.InitialPacketSizeIPv4)...) // length expected = append(expected, encodeVarInt(protocol.InitialPacketSizeIPv4)...) // length
expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number
Expect(buf.Bytes()).To(Equal(expected)) Expect(b).To(Equal(expected))
}) })
It("writes a header with a 20 byte connection ID", func() { It("writes a header with a 20 byte connection ID", func() {
err := (&ExtendedHeader{ b, err := (&ExtendedHeader{
Header: Header{ Header: Header{
SrcConnectionID: srcConnID, SrcConnectionID: srcConnID,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}), // connection IDs must be at most 20 bytes long DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}), // connection IDs must be at most 20 bytes long
@ -61,14 +54,14 @@ var _ = Describe("Header", func() {
}, },
PacketNumber: 0xdecafbad, PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, versionIETFHeader) }).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}))) Expect(b).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20})))
}) })
It("writes an Initial containing a token", func() { It("writes an Initial containing a token", func() {
token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
Expect((&ExtendedHeader{ b, err := (&ExtendedHeader{
Header: Header{ Header: Header{
Version: 0x1020304, Version: 0x1020304,
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -76,14 +69,15 @@ var _ = Describe("Header", func() {
}, },
PacketNumber: 0xdecafbad, PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, versionIETFHeader)).To(Succeed()) }).Append(nil, protocol.Version1)
Expect(buf.Bytes()[0]>>4&0b11 == 0) Expect(err).ToNot(HaveOccurred())
Expect(b[0]>>4&0b11 == 0)
expectedSubstring := append(encodeVarInt(uint64(len(token))), token...) expectedSubstring := append(encodeVarInt(uint64(len(token))), token...)
Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring))) Expect(b).To(ContainSubstring(string(expectedSubstring)))
}) })
It("uses a 2-byte encoding for the length on Initial packets", func() { It("uses a 2-byte encoding for the length on Initial packets", func() {
Expect((&ExtendedHeader{ b, err := (&ExtendedHeader{
Header: Header{ Header: Header{
Version: 0x1020304, Version: 0x1020304,
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -91,89 +85,88 @@ var _ = Describe("Header", func() {
}, },
PacketNumber: 0xdecafbad, PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, versionIETFHeader)).To(Succeed()) }).Append(nil, protocol.Version1)
b := &bytes.Buffer{} Expect(err).ToNot(HaveOccurred())
quicvarint.WriteWithLen(b, 37, 2) lengthEncoded := quicvarint.AppendWithLen(nil, 37, 2)
Expect(buf.Bytes()[buf.Len()-6 : buf.Len()-4]).To(Equal(b.Bytes())) Expect(b[len(b)-6 : len(b)-4]).To(Equal(lengthEncoded))
}) })
It("writes a Retry packet", func() { It("writes a Retry packet", func() {
token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")
Expect((&ExtendedHeader{Header: Header{ b, err := (&ExtendedHeader{Header: Header{
Version: protocol.Version1, Version: protocol.Version1,
Type: protocol.PacketTypeRetry, Type: protocol.PacketTypeRetry,
Token: token, Token: token,
}}).Write(buf, versionIETFHeader)).To(Succeed()) }}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0xc0 | 0b11<<4} expected := []byte{0xc0 | 0b11<<4}
expected = appendVersion(expected, protocol.Version1) expected = appendVersion(expected, protocol.Version1)
expected = append(expected, 0x0) // dest connection ID length expected = append(expected, 0x0) // dest connection ID length
expected = append(expected, 0x0) // src connection ID length expected = append(expected, 0x0) // src connection ID length
expected = append(expected, token...) expected = append(expected, token...)
Expect(buf.Bytes()).To(Equal(expected)) Expect(b).To(Equal(expected))
}) })
}) })
Context("long header, version 2", func() { Context("long header, version 2", func() {
It("writes an Initial", func() { It("writes an Initial", func() {
Expect((&ExtendedHeader{ b, err := (&ExtendedHeader{
Header: Header{ Header: Header{
Version: protocol.Version2, Version: protocol.Version2,
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
}, },
PacketNumber: 0xdecafbad, PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, protocol.Version2)).To(Succeed()) }).Append(nil, protocol.Version2)
Expect(buf.Bytes()[0]>>4&0b11 == 0b01) Expect(err).ToNot(HaveOccurred())
Expect(b[0]>>4&0b11 == 0b01)
}) })
It("writes a Retry packet", func() { It("writes a Retry packet", func() {
token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")
Expect((&ExtendedHeader{Header: Header{ b, err := (&ExtendedHeader{Header: Header{
Version: protocol.Version2, Version: protocol.Version2,
Type: protocol.PacketTypeRetry, Type: protocol.PacketTypeRetry,
Token: token, Token: token,
}}).Write(buf, versionIETFHeader)).To(Succeed()) }}).Append(nil, protocol.Version2)
expected := []byte{0xc0 | 0b11<<4} Expect(err).ToNot(HaveOccurred())
expected := []byte{0xc0 | 0b00<<4}
expected = appendVersion(expected, protocol.Version2) expected = appendVersion(expected, protocol.Version2)
expected = append(expected, 0x0) // dest connection ID length expected = append(expected, 0x0) // dest connection ID length
expected = append(expected, 0x0) // src connection ID length expected = append(expected, 0x0) // src connection ID length
expected = append(expected, token...) expected = append(expected, token...)
Expect(buf.Bytes()).To(Equal(expected)) Expect(b).To(Equal(expected))
}) })
It("writes a Handshake Packet", func() { It("writes a Handshake Packet", func() {
Expect((&ExtendedHeader{ b, err := (&ExtendedHeader{
Header: Header{ Header: Header{
Version: protocol.Version2, Version: protocol.Version2,
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
}, },
PacketNumber: 0xdecafbad, PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, protocol.Version2)).To(Succeed()) }).Append(nil, protocol.Version2)
Expect(buf.Bytes()[0]>>4&0b11 == 0b11) Expect(err).ToNot(HaveOccurred())
Expect(b[0]>>4&0b11 == 0b11)
}) })
It("writes a 0-RTT Packet", func() { It("writes a 0-RTT Packet", func() {
Expect((&ExtendedHeader{ b, err := (&ExtendedHeader{
Header: Header{ Header: Header{
Version: protocol.Version2, Version: protocol.Version2,
Type: protocol.PacketType0RTT, Type: protocol.PacketType0RTT,
}, },
PacketNumber: 0xdecafbad, PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, protocol.Version2)).To(Succeed()) }).Append(nil, protocol.Version2)
Expect(buf.Bytes()[0]>>4&0b11 == 0b10) Expect(err).ToNot(HaveOccurred())
Expect(b[0]>>4&0b11 == 0b10)
}) })
}) })
}) })
Context("getting the length", func() { Context("getting the length", func() {
var buf *bytes.Buffer
BeforeEach(func() {
buf = &bytes.Buffer{}
})
It("has the right length for the Long Header, for a short length", func() { It("has the right length for the Long Header, for a short length", func() {
h := &ExtendedHeader{ h := &ExtendedHeader{
Header: Header{ Header: Header{
@ -185,9 +178,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* length */ + 1 /* packet number */ expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* length */ + 1 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) b, err := h.Append(nil, protocol.Version1)
Expect(buf.Len()).To(Equal(expectedLen)) Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
}) })
It("has the right length for the Long Header, for a long length", func() { It("has the right length for the Long Header, for a long length", func() {
@ -201,9 +195,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen2, PacketNumberLen: protocol.PacketNumberLen2,
} }
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */ expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) b, err := h.Append(nil, protocol.Version1)
Expect(buf.Len()).To(Equal(expectedLen)) Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
}) })
It("has the right length for an Initial that has a short length", func() { It("has the right length for an Initial that has a short length", func() {
@ -217,9 +212,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen2, PacketNumberLen: protocol.PacketNumberLen2,
} }
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) b, err := h.Append(nil, protocol.Version1)
Expect(buf.Len()).To(Equal(expectedLen)) Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
}) })
It("has the right length for an Initial not containing a Token", func() { It("has the right length for an Initial not containing a Token", func() {
@ -233,9 +229,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen2, PacketNumberLen: protocol.PacketNumberLen2,
} }
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) b, err := h.Append(nil, protocol.Version1)
Expect(buf.Len()).To(Equal(expectedLen)) Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
}) })
It("has the right length for an Initial containing a Token", func() { It("has the right length for an Initial containing a Token", func() {
@ -250,9 +247,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen2, PacketNumberLen: protocol.PacketNumberLen2,
} }
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */ expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) b, err := h.Append(nil, protocol.Version1)
Expect(buf.Len()).To(Equal(expectedLen)) Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
}) })
}) })

View file

@ -15,8 +15,7 @@ import (
var _ = Describe("Header Parsing", func() { var _ = Describe("Header Parsing", func() {
Context("Parsing the Connection ID", func() { Context("Parsing the Connection ID", func() {
It("parses the connection ID of a long header packet", func() { It("parses the connection ID of a long header packet", func() {
buf := &bytes.Buffer{} b, err := (&ExtendedHeader{
Expect((&ExtendedHeader{
Header: Header{ Header: Header{
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
@ -24,15 +23,15 @@ var _ = Describe("Header Parsing", func() {
Version: protocol.Version1, Version: protocol.Version1,
}, },
PacketNumberLen: 2, PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed()) }).Append(nil, protocol.Version1)
connID, err := ParseConnectionID(buf.Bytes(), 8) Expect(err).ToNot(HaveOccurred())
connID, err := ParseConnectionID(b, 8)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
}) })
It("errors on EOF, for long header packets", func() { It("errors on EOF, for long header packets", func() {
buf := &bytes.Buffer{} b, err := (&ExtendedHeader{
Expect((&ExtendedHeader{
Header: Header{ Header: Header{
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}), DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}),
@ -40,9 +39,10 @@ var _ = Describe("Header Parsing", func() {
Version: protocol.Version1, Version: protocol.Version1,
}, },
PacketNumberLen: 2, PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed()) }).Append(nil, protocol.Version1)
data := buf.Bytes()[:buf.Len()-2] // cut the packet number Expect(err).ToNot(HaveOccurred())
_, err := ParseConnectionID(data, 8) data := b[:len(b)-2] // cut the packet number
_, err = ParseConnectionID(data, 8)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ { for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ {
b := make([]byte, i) b := make([]byte, i)
@ -419,22 +419,22 @@ var _ = Describe("Header Parsing", func() {
Context("coalesced packets", func() { Context("coalesced packets", func() {
It("cuts packets", func() { It("cuts packets", func() {
buf := &bytes.Buffer{}
hdr := Header{ hdr := Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
Length: 2 + 6, Length: 2 + 6,
Version: protocol.Version1, Version: protocol.Version1,
} }
Expect((&ExtendedHeader{ b, err := (&ExtendedHeader{
Header: hdr, Header: hdr,
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed()) }).Append(nil, protocol.Version1)
hdrRaw := append([]byte{}, buf.Bytes()...) Expect(err).ToNot(HaveOccurred())
buf.Write([]byte("foobar")) // payload of the first packet hdrRaw := append([]byte{}, b...)
buf.Write([]byte("raboof")) // second packet b = append(b, []byte("foobar")...) // payload of the first packet
parsedHdr, data, rest, err := ParsePacket(buf.Bytes()) b = append(b, []byte("raboof")...) // second packet
parsedHdr, data, rest, err := ParsePacket(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(parsedHdr.Type).To(Equal(hdr.Type)) Expect(parsedHdr.Type).To(Equal(hdr.Type))
Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID)) Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID))
@ -443,8 +443,7 @@ var _ = Describe("Header Parsing", func() {
}) })
It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() { It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() {
buf := &bytes.Buffer{} b, err := (&ExtendedHeader{
Expect((&ExtendedHeader{
Header: Header{ Header: Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
@ -453,15 +452,15 @@ var _ = Describe("Header Parsing", func() {
}, },
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed()) }).Append(nil, protocol.Version1)
_, _, _, err := ParsePacket(buf.Bytes()) Expect(err).ToNot(HaveOccurred())
_, _, _, err = ParsePacket(b)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)")) Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)"))
}) })
It("errors on packets that are smaller than the length in the packet header, for too small payload", func() { It("errors on packets that are smaller than the length in the packet header, for too small payload", func() {
buf := &bytes.Buffer{} b, err := (&ExtendedHeader{
Expect((&ExtendedHeader{
Header: Header{ Header: Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
@ -470,9 +469,10 @@ var _ = Describe("Header Parsing", func() {
}, },
PacketNumber: 0x1337, PacketNumber: 0x1337,
PacketNumberLen: 2, PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed()) }).Append(nil, protocol.Version1)
buf.Write(make([]byte, 500-2 /* for packet number length */)) Expect(err).ToNot(HaveOccurred())
_, _, _, err := ParsePacket(buf.Bytes()) b = append(b, make([]byte, 500-2 /* for packet number length */)...)
_, _, _, err = ParsePacket(b)
Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
}) })
}) })

View file

@ -1,7 +1,6 @@
package wire package wire
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -54,15 +53,15 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet
return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err
} }
// WriteShortHeader writes a short header. // AppendShortHeader writes a short header.
func WriteShortHeader(b *bytes.Buffer, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) error { func AppendShortHeader(b []byte, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) ([]byte, error) {
typeByte := 0x40 | uint8(pnLen-1) typeByte := 0x40 | uint8(pnLen-1)
if kp == protocol.KeyPhaseOne { if kp == protocol.KeyPhaseOne {
typeByte |= byte(1 << 2) typeByte |= byte(1 << 2)
} }
b.WriteByte(typeByte) b = append(b, typeByte)
b.Write(connID.Bytes()) b = append(b, connID.Bytes()...)
return writePacketNumber(b, pn, pnLen) return appendPacketNumber(b, pn, pnLen)
} }
func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount { func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount {

View file

@ -5,6 +5,7 @@ import (
"io" "io"
"log" "log"
"os" "os"
"testing"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
@ -77,15 +78,15 @@ var _ = Describe("Short Header", func() {
Context("writing", func() { Context("writing", func() {
It("writes a short header packet", func() { It("writes a short header packet", func() {
b := &bytes.Buffer{}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
Expect(WriteShortHeader(b, connID, 1337, 4, protocol.KeyPhaseOne)).To(Succeed()) b, err := AppendShortHeader(nil, connID, 1337, 4, protocol.KeyPhaseOne)
l, pn, pnLen, kp, err := ParseShortHeader(b.Bytes(), 4) Expect(err).ToNot(HaveOccurred())
l, pn, pnLen, kp, err := ParseShortHeader(b, 4)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(pn).To(Equal(protocol.PacketNumber(1337))) Expect(pn).To(Equal(protocol.PacketNumber(1337)))
Expect(pnLen).To(Equal(protocol.PacketNumberLen4)) Expect(pnLen).To(Equal(protocol.PacketNumberLen4))
Expect(kp).To(Equal(protocol.KeyPhaseOne)) Expect(kp).To(Equal(protocol.KeyPhaseOne))
Expect(l).To(Equal(b.Len())) Expect(l).To(Equal(len(b)))
}) })
}) })
@ -113,3 +114,17 @@ var _ = Describe("Short Header", func() {
}) })
}) })
}) })
func BenchmarkWriteShortHeader(b *testing.B) {
b.ReportAllocs()
buf := make([]byte, 100)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
for i := 0; i < b.N; i++ {
var err error
buf, err = AppendShortHeader(buf, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
if err != nil {
b.Fatalf("failed to write short header: %s", err)
}
buf = buf[:0]
}
}

View file

@ -1,7 +1,6 @@
package quic package quic
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"errors" "errors"
"net" "net"
@ -14,7 +13,6 @@ import (
"github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/logging"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -37,8 +35,7 @@ var _ = Describe("Packet Handler Map", func() {
) )
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
buf := &bytes.Buffer{} b, err := (&wire.ExtendedHeader{
Expect((&wire.ExtendedHeader{
Header: wire.Header{ Header: wire.Header{
Type: t, Type: t,
DestConnectionID: connID, DestConnectionID: connID,
@ -46,8 +43,9 @@ var _ = Describe("Packet Handler Map", func() {
Version: protocol.VersionTLS, Version: protocol.VersionTLS,
}, },
PacketNumberLen: protocol.PacketNumberLen2, PacketNumberLen: protocol.PacketNumberLen2,
}).Write(buf, protocol.VersionWhatever)).To(Succeed()) }).Append(nil, protocol.VersionWhatever)
return buf.Bytes() Expect(err).ToNot(HaveOccurred())
return b
} }
getPacket := func(connID protocol.ConnectionID) []byte { getPacket := func(connID protocol.ConnectionID) []byte {

View file

@ -1,7 +1,6 @@
package quic package quic
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -800,13 +799,12 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
paddingLen += padding paddingLen += padding
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen
raw := buffer.Data[len(buffer.Data):] startLen := len(buffer.Data)
buf := bytes.NewBuffer(buffer.Data) raw := buffer.Data[startLen:]
startLen := buf.Len() raw, err := header.Append(raw, p.version)
if err := header.Write(buf, p.version); err != nil { if err != nil {
return nil, err return nil, err
} }
raw = raw[:buf.Len()-startLen]
payloadOffset := protocol.ByteCount(len(raw)) payloadOffset := protocol.ByteCount(len(raw))
pn := p.pnManager.PopPacketNumber(encLevel) pn := p.pnManager.PopPacketNumber(encLevel)
@ -814,7 +812,7 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
} }
raw, err := p.appendPacketPayload(raw, payload, paddingLen) raw, err = p.appendPacketPayload(raw, payload, paddingLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -846,20 +844,19 @@ func (p *packetPacker) appendShortHeaderPacket(
} }
paddingLen += padding paddingLen += padding
raw := buffer.Data[len(buffer.Data):] startLen := len(buffer.Data)
buf := bytes.NewBuffer(buffer.Data) raw := buffer.Data[startLen:]
startLen := buf.Len() raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp)
if err := wire.WriteShortHeader(buf, connID, pn, pnLen, kp); err != nil { if err != nil {
return nil, err return nil, err
} }
raw = raw[:buf.Len()-startLen]
payloadOffset := protocol.ByteCount(len(raw)) payloadOffset := protocol.ByteCount(len(raw))
if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) { if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
} }
raw, err := p.appendPacketPayload(raw, payload, paddingLen) raw, err = p.appendPacketPayload(raw, payload, paddingLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,7 +1,6 @@
package quic package quic
import ( import (
"bytes"
"errors" "errors"
"time" "time"
@ -12,7 +11,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -28,21 +26,22 @@ var _ = Describe("Packet Unpacker", func() {
) )
getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
buf := &bytes.Buffer{} b, err := extHdr.Append(nil, version)
ExpectWithOffset(1, extHdr.Write(buf, version)).To(Succeed()) Expect(err).ToNot(HaveOccurred())
hdrLen := buf.Len()
if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) {
buf.Write(make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen)))
}
hdr, _, _, err := wire.ParsePacket(buf.Bytes())
ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
return hdr, buf.Bytes()[:hdrLen] hdrLen := len(b)
if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) {
b = append(b, make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))...)
}
hdr, _, _, err := wire.ParsePacket(b)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
return hdr, b[:hdrLen]
} }
getShortHeader := func(connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) []byte { getShortHeader := func(connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) []byte {
buf := &bytes.Buffer{} b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, kp)
Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, kp)).To(Succeed()) Expect(err).ToNot(HaveOccurred())
return buf.Bytes() return b
} }
BeforeEach(func() { BeforeEach(func() {
@ -72,12 +71,12 @@ var _ = Describe("Packet Unpacker", func() {
}) })
It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() { It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() {
buf := &bytes.Buffer{} b, err := wire.AppendShortHeader(nil, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
Expect(wire.WriteShortHeader(buf, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed()) Expect(err).ToNot(HaveOccurred())
data := append(buf.Bytes(), make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) data := append(b, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil) cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), data) _, _, _, _, err = unpacker.UnpackShortHeader(time.Now(), data)
Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19")) Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19"))
}) })

View file

@ -107,32 +107,32 @@ func Append(b []byte, i uint64) []byte {
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
} }
// WriteWithLen writes i in the QUIC varint format with the desired length to w. // AppendWithLen append i in the QUIC varint format with the desired length.
func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) { func AppendWithLen(b []byte, i uint64, length protocol.ByteCount) []byte {
if length != 1 && length != 2 && length != 4 && length != 8 { if length != 1 && length != 2 && length != 4 && length != 8 {
panic("invalid varint length") panic("invalid varint length")
} }
l := Len(i) l := Len(i)
if l == length { if l == length {
Write(w, i) return Append(b, i)
return
} }
if l > length { if l > length {
panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length)) panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length))
} }
if length == 2 { if length == 2 {
w.WriteByte(0b01000000) b = append(b, 0b01000000)
} else if length == 4 { } else if length == 4 {
w.WriteByte(0b10000000) b = append(b, 0b10000000)
} else if length == 8 { } else if length == 8 {
w.WriteByte(0b11000000) b = append(b, 0b11000000)
} }
for j := protocol.ByteCount(1); j < length-l; j++ { for j := protocol.ByteCount(1); j < length-l; j++ {
w.WriteByte(0) b = append(b, 0)
} }
for j := protocol.ByteCount(0); j < l; j++ { for j := protocol.ByteCount(0); j < l; j++ {
w.WriteByte(uint8(i >> (8 * (l - 1 - j)))) b = append(b, uint8(i>>(8*(l-1-j))))
} }
return b
} }
// Len determines the number of bytes that will be needed to write the number i. // Len determines the number of bytes that will be needed to write the number i.

View file

@ -142,54 +142,47 @@ var _ = Describe("Varint encoding / decoding", func() {
Context("with fixed length", func() { Context("with fixed length", func() {
It("panics when given an invalid length", func() { It("panics when given an invalid length", func() {
Expect(func() { WriteWithLen(&bytes.Buffer{}, 25, 3) }).Should(Panic()) Expect(func() { AppendWithLen(nil, 25, 3) }).Should(Panic())
}) })
It("panics when given a too short length", func() { It("panics when given a too short length", func() {
Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt1+1, 1) }).Should(Panic()) Expect(func() { AppendWithLen(nil, maxVarInt1+1, 1) }).Should(Panic())
Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt2+1, 2) }).Should(Panic()) Expect(func() { AppendWithLen(nil, maxVarInt2+1, 2) }).Should(Panic())
Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt4+1, 4) }).Should(Panic()) Expect(func() { AppendWithLen(nil, maxVarInt4+1, 4) }).Should(Panic())
}) })
It("writes a 1-byte number in minimal encoding", func() { It("writes a 1-byte number in minimal encoding", func() {
b := &bytes.Buffer{} Expect(AppendWithLen(nil, 37, 1)).To(Equal([]byte{0x25}))
WriteWithLen(b, 37, 1)
Expect(b.Bytes()).To(Equal([]byte{0x25}))
}) })
It("writes a 1-byte number in 2 bytes", func() { It("writes a 1-byte number in 2 bytes", func() {
b := &bytes.Buffer{} b := AppendWithLen(nil, 37, 2)
WriteWithLen(b, 37, 2) Expect(b).To(Equal([]byte{0b01000000, 0x25}))
Expect(b.Bytes()).To(Equal([]byte{0b01000000, 0x25})) Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37))
Expect(Read(b)).To(BeEquivalentTo(37))
}) })
It("writes a 1-byte number in 4 bytes", func() { It("writes a 1-byte number in 4 bytes", func() {
b := &bytes.Buffer{} b := AppendWithLen(nil, 37, 4)
WriteWithLen(b, 37, 4) Expect(b).To(Equal([]byte{0b10000000, 0, 0, 0x25}))
Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0, 0x25})) Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37))
Expect(Read(b)).To(BeEquivalentTo(37))
}) })
It("writes a 1-byte number in 8 bytes", func() { It("writes a 1-byte number in 8 bytes", func() {
b := &bytes.Buffer{} b := AppendWithLen(nil, 37, 8)
WriteWithLen(b, 37, 8) Expect(b).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25}))
Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25})) Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37))
Expect(Read(b)).To(BeEquivalentTo(37))
}) })
It("writes a 2-byte number in 4 bytes", func() { It("writes a 2-byte number in 4 bytes", func() {
b := &bytes.Buffer{} b := AppendWithLen(nil, 15293, 4)
WriteWithLen(b, 15293, 4) Expect(b).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd}))
Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd})) Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(15293))
Expect(Read(b)).To(BeEquivalentTo(15293))
}) })
It("write a 4-byte number in 8 bytes", func() { It("write a 4-byte number in 8 bytes", func() {
b := &bytes.Buffer{} b := AppendWithLen(nil, 494878333, 8)
WriteWithLen(b, 494878333, 8) Expect(b).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d}))
Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d})) Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(494878333))
Expect(Read(b)).To(BeEquivalentTo(494878333))
}) })
}) })

View file

@ -1,7 +1,6 @@
package quic package quic
import ( import (
"bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
@ -577,19 +576,19 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
replyHdr.Log(s.logger) replyHdr.Log(s.logger)
} }
packetBuffer := getPacketBuffer() buf := getPacketBuffer()
defer packetBuffer.Release() defer buf.Release()
buf := bytes.NewBuffer(packetBuffer.Data) buf.Data, err = replyHdr.Append(buf.Data, hdr.Version)
if err := replyHdr.Write(buf, hdr.Version); err != nil { if err != nil {
return err return err
} }
// append the Retry integrity tag // append the Retry integrity tag
tag := handshake.GetRetryIntegrityTag(buf.Bytes(), hdr.DestConnectionID, hdr.Version) tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
buf.Write(tag[:]) buf.Data = append(buf.Data, tag[:]...)
if s.config.Tracer != nil { if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil) s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
} }
_, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info.OOB()) _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB())
return err return err
} }
@ -627,9 +626,8 @@ func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header
// sendError sends the error as a response to the packet received with header hdr // sendError sends the error as a response to the packet received with header hdr
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error {
packetBuffer := getPacketBuffer() b := getPacketBuffer()
defer packetBuffer.Release() defer b.Release()
buf := bytes.NewBuffer(packetBuffer.Data)
ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)} ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)}
@ -640,33 +638,34 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
replyHdr.DestConnectionID = hdr.SrcConnectionID replyHdr.DestConnectionID = hdr.SrcConnectionID
replyHdr.PacketNumberLen = protocol.PacketNumberLen4 replyHdr.PacketNumberLen = protocol.PacketNumberLen4
replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead()) replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead())
if err := replyHdr.Write(buf, hdr.Version); err != nil { var err error
b.Data, err = replyHdr.Append(b.Data, hdr.Version)
if err != nil {
return err return err
} }
payloadOffset := buf.Len() payloadOffset := len(b.Data)
raw := buf.Bytes() b.Data, err = ccf.Append(b.Data, hdr.Version)
raw, err := ccf.Append(raw, hdr.Version)
if err != nil { if err != nil {
return err return err
} }
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset]) _ = sealer.Seal(b.Data[payloadOffset:payloadOffset], b.Data[payloadOffset:], replyHdr.PacketNumber, b.Data[:payloadOffset])
raw = raw[0 : len(raw)+sealer.Overhead()] b.Data = b.Data[0 : len(b.Data)+sealer.Overhead()]
pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) pnOffset := payloadOffset - int(replyHdr.PacketNumberLen)
sealer.EncryptHeader( sealer.EncryptHeader(
raw[pnOffset+4:pnOffset+4+16], b.Data[pnOffset+4:pnOffset+4+16],
&raw[0], &b.Data[0],
raw[pnOffset:payloadOffset], b.Data[pnOffset:payloadOffset],
) )
replyHdr.Log(s.logger) replyHdr.Log(s.logger)
wire.LogFrame(s.logger, ccf, true) wire.LogFrame(s.logger, ccf, true)
if s.config.Tracer != nil { if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf}) s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
} }
_, err = s.conn.WritePacket(raw, remoteAddr, info.OOB()) _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB())
return err return err
} }

View file

@ -42,17 +42,18 @@ var _ = Describe("Server", func() {
) )
getPacket := func(hdr *wire.Header, p []byte) *receivedPacket { getPacket := func(hdr *wire.Header, p []byte) *receivedPacket {
buffer := getPacketBuffer() buf := getPacketBuffer()
buf := bytes.NewBuffer(buffer.Data)
hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
Expect((&wire.ExtendedHeader{ var err error
buf.Data, err = (&wire.ExtendedHeader{
Header: *hdr, Header: *hdr,
PacketNumber: 0x42, PacketNumber: 0x42,
PacketNumberLen: protocol.PacketNumberLen4, PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, protocol.VersionTLS)).To(Succeed()) }).Append(buf.Data, protocol.VersionTLS)
n := buf.Len() Expect(err).ToNot(HaveOccurred())
buf.Write(p) n := len(buf.Data)
data := buffer.Data[:buf.Len()] buf.Data = append(buf.Data, p...)
data := buf.Data
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
data = data[:len(data)+16] data = data[:len(data)+16]
@ -60,7 +61,7 @@ var _ = Describe("Server", func() {
return &receivedPacket{ return &receivedPacket{
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
data: data, data: data,
buffer: buffer, buffer: buf,
} }
} }