refactor header writing to append to a byte slice (#3646)

This avoids having to allocate a bytes.Buffer.
This commit is contained in:
Marten Seemann 2023-01-17 01:56:06 -08:00 committed by GitHub
parent 3d4bbc28ba
commit c24fbb094c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 282 additions and 279 deletions

View file

@ -571,8 +571,8 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().Close()
streamManager.EXPECT().CloseWithError(gomock.Any())
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
buf := &bytes.Buffer{}
Expect(wire.WriteShortHeader(buf, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed())
b, err := wire.AppendShortHeader(nil, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version)
@ -589,7 +589,7 @@ var _ = Describe("Connection", func() {
rcvTime: time.Now(),
remoteAddr: &net.UDPAddr{},
buffer: getPacketBuffer(),
data: buf.Bytes(),
data: b,
})
// Consistently(pack).ShouldNot(Receive())
Eventually(conn.Context().Done()).Should(BeClosed())
@ -649,20 +649,20 @@ var _ = Describe("Connection", func() {
})
getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) *receivedPacket {
buf := &bytes.Buffer{}
Expect(wire.WriteShortHeader(buf, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed())
b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
return &receivedPacket{
data: append(buf.Bytes(), data...),
data: append(b, data...),
buffer: getPacketBuffer(),
rcvTime: time.Now(),
}
}
getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket {
buf := &bytes.Buffer{}
Expect(extHdr.Write(buf, conn.version)).To(Succeed())
b, err := extHdr.Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred())
return &receivedPacket{
data: append(buf.Bytes(), data...),
data: append(b, data...),
buffer: getPacketBuffer(),
rcvTime: time.Now(),
}
@ -2365,10 +2365,10 @@ var _ = Describe("Client Connection", func() {
destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket {
buf := &bytes.Buffer{}
Expect(hdr.Write(buf, conn.version)).To(Succeed())
b, err := hdr.Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred())
return &receivedPacket{
data: append(buf.Bytes(), data...),
data: append(b, data...),
buffer: getPacketBuffer(),
}
}
@ -2643,9 +2643,9 @@ var _ = Describe("Client Connection", func() {
})
getRetryTag := func(hdr *wire.ExtendedHeader) []byte {
buf := &bytes.Buffer{}
hdr.Write(buf, conn.version)
return handshake.GetRetryIntegrityTag(buf.Bytes(), origDestConnID, hdr.Version)[:]
b, err := hdr.Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred())
return handshake.GetRetryIntegrityTag(b, origDestConnID, hdr.Version)[:]
}
It("handles Retry packets", func() {
@ -2858,10 +2858,10 @@ var _ = Describe("Client Connection", func() {
var unpacker *MockUnpacker
getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket {
buf := &bytes.Buffer{}
Expect(extHdr.Write(buf, conn.version)).To(Succeed())
b, err := extHdr.Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred())
return &receivedPacket{
data: append(buf.Bytes(), data...),
data: append(b, data...),
buffer: getPacketBuffer(),
}
}

View file

@ -1,7 +1,6 @@
package main
import (
"bytes"
"log"
"math/rand"
@ -85,28 +84,28 @@ func main() {
PacketNumberLen: protocol.PacketNumberLen(rand.Intn(4) + 1),
PacketNumber: protocol.PacketNumber(rand.Uint64()),
}
b := &bytes.Buffer{}
if err := extHdr.Write(b, version); err != nil {
b, err := extHdr.Append(nil, version)
if err != nil {
log.Fatal(err)
}
if h.Type == protocol.PacketTypeRetry {
b.Write([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})
b = append(b, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}...)
}
if h.Length > 0 {
b.Write(make([]byte, h.Length))
b = append(b, make([]byte, h.Length)...)
}
if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil {
if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil {
log.Fatal(err)
}
}
// short header
b := &bytes.Buffer{}
if err := wire.WriteShortHeader(b, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne); err != nil {
b, err := wire.AppendShortHeader(nil, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
if err != nil {
log.Fatal(err)
}
if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil {
if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil {
log.Fatal(err)
}

View file

@ -64,8 +64,8 @@ func Fuzz(data []byte) int {
if hdr.Length > 16383 {
return 1
}
b := &bytes.Buffer{}
if err := extHdr.Write(b, version); err != nil {
b, err := extHdr.Append(nil, version)
if err != nil {
// We are able to parse packets with connection IDs longer than 20 bytes,
// but in QUIC version 1, we don't write headers with longer connection IDs.
if hdr.DestConnectionID.Len() <= protocol.MaxConnIDLen &&
@ -76,8 +76,8 @@ func Fuzz(data []byte) int {
}
// GetLength is not implemented for Retry packets
if hdr.Type != protocol.PacketTypeRetry {
if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(b.Len()) {
panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, b.Len()))
if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(len(b)) {
panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, len(b)))
}
}
return 1

View file

@ -1,7 +1,6 @@
package self_test
import (
"bytes"
"context"
"errors"
"fmt"
@ -116,12 +115,12 @@ var _ = Describe("MITM test", func() {
for i := 0; i < numPackets; i++ {
payloadLen := mrand.Int31n(100)
replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1))
buf := &bytes.Buffer{}
Expect(replyHdr.Write(buf, version)).To(Succeed())
b := make([]byte, payloadLen)
mrand.Read(b)
buf.Write(b)
if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
b, err := replyHdr.Append(nil, version)
Expect(err).ToNot(HaveOccurred())
r := make([]byte, payloadLen)
mrand.Read(r)
b = append(b, r...)
if _, err := conn.WriteTo(b, remoteAddr); err != nil {
return
}
<-ticker.C
@ -134,13 +133,13 @@ var _ = Describe("MITM test", func() {
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
}
for i := 0; i < numPackets; i++ {
buf := &bytes.Buffer{}
Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))).To(Succeed())
b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))
Expect(err).ToNot(HaveOccurred())
payloadLen := mrand.Int31n(100)
b := make([]byte, payloadLen)
mrand.Read(b)
buf.Write(b)
if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
r := make([]byte, payloadLen)
mrand.Read(r)
b = append(b, r...)
if _, err := conn.WriteTo(b, remoteAddr); err != nil {
return
}
<-ticker.C

View file

@ -28,7 +28,6 @@ func isProxyRunning() bool {
var _ = Describe("QUIC Proxy", func() {
makePacket := func(p protocol.PacketNumber, payload []byte) []byte {
b := &bytes.Buffer{}
hdr := wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
@ -40,10 +39,10 @@ var _ = Describe("QUIC Proxy", func() {
PacketNumber: p,
PacketNumberLen: protocol.PacketNumberLen4,
}
Expect(hdr.Write(b, protocol.VersionWhatever)).To(Succeed())
raw := b.Bytes()
raw = append(raw, payload...)
return raw
b, err := hdr.Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
b = append(b, payload...)
return b
}
readPacketNumber := func(b []byte) protocol.PacketNumber {

View file

@ -1,7 +1,7 @@
package testutils
import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -13,9 +13,11 @@ import (
// writePacket returns a new raw packet with the specified header and payload
func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte {
buf := &bytes.Buffer{}
hdr.Write(buf, hdr.Version)
return append(buf.Bytes(), data...)
b, err := hdr.Append(nil, hdr.Version)
if err != nil {
panic(fmt.Sprintf("failed to write header: %s", err))
}
return append(b, data...)
}
// packRawPayload returns a new raw payload containing given frames

View file

@ -2,6 +2,7 @@ package wire
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
@ -93,20 +94,17 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
return nil
}
// Write writes the Header.
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
// Append appends the Header.
func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) {
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
}
if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
}
return h.writeLongHeader(b, ver)
}
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error {
var packetType uint8
if version == protocol.Version2 {
if v == protocol.Version2 {
//nolint:exhaustive
switch h.Type {
case protocol.PacketTypeInitial:
@ -137,24 +135,25 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.Versi
firstByte |= uint8(h.PacketNumberLen - 1)
}
b.WriteByte(firstByte)
utils.BigEndian.WriteUint32(b, uint32(h.Version))
b.WriteByte(uint8(h.DestConnectionID.Len()))
b.Write(h.DestConnectionID.Bytes())
b.WriteByte(uint8(h.SrcConnectionID.Len()))
b.Write(h.SrcConnectionID.Bytes())
b = append(b, firstByte)
b = append(b, make([]byte, 4)...)
binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version))
b = append(b, uint8(h.DestConnectionID.Len()))
b = append(b, h.DestConnectionID.Bytes()...)
b = append(b, uint8(h.SrcConnectionID.Len()))
b = append(b, h.SrcConnectionID.Bytes()...)
//nolint:exhaustive
switch h.Type {
case protocol.PacketTypeRetry:
b.Write(h.Token)
return nil
b = append(b, h.Token...)
return b, nil
case protocol.PacketTypeInitial:
quicvarint.Write(b, uint64(len(h.Token)))
b.Write(h.Token)
b = quicvarint.Append(b, uint64(len(h.Token)))
b = append(b, h.Token...)
}
quicvarint.WriteWithLen(b, uint64(h.Length), 2)
return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen)
b = quicvarint.AppendWithLen(b, uint64(h.Length), 2)
return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
@ -188,18 +187,24 @@ func (h *ExtendedHeader) Log(logger utils.Logger) {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
}
func writePacketNumber(b *bytes.Buffer, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) error {
func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) {
switch pnLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(pn))
b = append(b, uint8(pn))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(pn))
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(pn))
b = append(b, buf...)
case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(pn))
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(pn))
b = append(b, buf[1:]...)
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(pn))
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(pn))
b = append(b, buf...)
default:
return fmt.Errorf("invalid packet number length: %d", pnLen)
return nil, fmt.Errorf("invalid packet number length: %d", pnLen)
}
return nil
return b, nil
}

View file

@ -14,20 +14,12 @@ import (
)
var _ = Describe("Header", func() {
const versionIETFHeader = protocol.Version1
Context("Writing", func() {
var buf *bytes.Buffer
BeforeEach(func() {
buf = &bytes.Buffer{}
})
Context("Long Header", func() {
Context("Long Header, version 1", func() {
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
It("writes", func() {
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}),
@ -37,7 +29,8 @@ var _ = Describe("Header", func() {
},
PacketNumber: 0xdecaf,
PacketNumberLen: protocol.PacketNumberLen3,
}).Write(buf, versionIETFHeader)).To(Succeed())
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
expected := []byte{
0xc0 | 0x2<<4 | 0x2,
0x1, 0x2, 0x3, 0x4, // version number
@ -48,11 +41,11 @@ var _ = Describe("Header", func() {
}
expected = append(expected, encodeVarInt(protocol.InitialPacketSizeIPv4)...) // length
expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number
Expect(buf.Bytes()).To(Equal(expected))
Expect(b).To(Equal(expected))
})
It("writes a header with a 20 byte connection ID", func() {
err := (&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
SrcConnectionID: srcConnID,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}), // connection IDs must be at most 20 bytes long
@ -61,14 +54,14 @@ var _ = Describe("Header", func() {
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, versionIETFHeader)
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20})))
Expect(b).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20})))
})
It("writes an Initial containing a token", func() {
token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Version: 0x1020304,
Type: protocol.PacketTypeInitial,
@ -76,14 +69,15 @@ var _ = Describe("Header", func() {
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Bytes()[0]>>4&0b11 == 0)
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b[0]>>4&0b11 == 0)
expectedSubstring := append(encodeVarInt(uint64(len(token))), token...)
Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring)))
Expect(b).To(ContainSubstring(string(expectedSubstring)))
})
It("uses a 2-byte encoding for the length on Initial packets", func() {
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Version: 0x1020304,
Type: protocol.PacketTypeInitial,
@ -91,89 +85,88 @@ var _ = Describe("Header", func() {
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, versionIETFHeader)).To(Succeed())
b := &bytes.Buffer{}
quicvarint.WriteWithLen(b, 37, 2)
Expect(buf.Bytes()[buf.Len()-6 : buf.Len()-4]).To(Equal(b.Bytes()))
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
lengthEncoded := quicvarint.AppendWithLen(nil, 37, 2)
Expect(b[len(b)-6 : len(b)-4]).To(Equal(lengthEncoded))
})
It("writes a Retry packet", func() {
token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")
Expect((&ExtendedHeader{Header: Header{
b, err := (&ExtendedHeader{Header: Header{
Version: protocol.Version1,
Type: protocol.PacketTypeRetry,
Token: token,
}}).Write(buf, versionIETFHeader)).To(Succeed())
}}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0xc0 | 0b11<<4}
expected = appendVersion(expected, protocol.Version1)
expected = append(expected, 0x0) // dest connection ID length
expected = append(expected, 0x0) // src connection ID length
expected = append(expected, token...)
Expect(buf.Bytes()).To(Equal(expected))
Expect(b).To(Equal(expected))
})
})
Context("long header, version 2", func() {
It("writes an Initial", func() {
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Version: protocol.Version2,
Type: protocol.PacketTypeInitial,
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, protocol.Version2)).To(Succeed())
Expect(buf.Bytes()[0]>>4&0b11 == 0b01)
}).Append(nil, protocol.Version2)
Expect(err).ToNot(HaveOccurred())
Expect(b[0]>>4&0b11 == 0b01)
})
It("writes a Retry packet", func() {
token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")
Expect((&ExtendedHeader{Header: Header{
b, err := (&ExtendedHeader{Header: Header{
Version: protocol.Version2,
Type: protocol.PacketTypeRetry,
Token: token,
}}).Write(buf, versionIETFHeader)).To(Succeed())
expected := []byte{0xc0 | 0b11<<4}
}}).Append(nil, protocol.Version2)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0xc0 | 0b00<<4}
expected = appendVersion(expected, protocol.Version2)
expected = append(expected, 0x0) // dest connection ID length
expected = append(expected, 0x0) // src connection ID length
expected = append(expected, token...)
Expect(buf.Bytes()).To(Equal(expected))
Expect(b).To(Equal(expected))
})
It("writes a Handshake Packet", func() {
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Version: protocol.Version2,
Type: protocol.PacketTypeHandshake,
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, protocol.Version2)).To(Succeed())
Expect(buf.Bytes()[0]>>4&0b11 == 0b11)
}).Append(nil, protocol.Version2)
Expect(err).ToNot(HaveOccurred())
Expect(b[0]>>4&0b11 == 0b11)
})
It("writes a 0-RTT Packet", func() {
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Version: protocol.Version2,
Type: protocol.PacketType0RTT,
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, protocol.Version2)).To(Succeed())
Expect(buf.Bytes()[0]>>4&0b11 == 0b10)
}).Append(nil, protocol.Version2)
Expect(err).ToNot(HaveOccurred())
Expect(b[0]>>4&0b11 == 0b10)
})
})
})
Context("getting the length", func() {
var buf *bytes.Buffer
BeforeEach(func() {
buf = &bytes.Buffer{}
})
It("has the right length for the Long Header, for a short length", func() {
h := &ExtendedHeader{
Header: Header{
@ -185,9 +178,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen1,
}
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* length */ + 1 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(expectedLen))
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
b, err := h.Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
})
It("has the right length for the Long Header, for a long length", func() {
@ -201,9 +195,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen2,
}
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(expectedLen))
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
b, err := h.Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
})
It("has the right length for an Initial that has a short length", func() {
@ -217,9 +212,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen2,
}
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(expectedLen))
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
b, err := h.Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
})
It("has the right length for an Initial not containing a Token", func() {
@ -233,9 +229,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen2,
}
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(expectedLen))
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
b, err := h.Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
})
It("has the right length for an Initial containing a Token", func() {
@ -250,9 +247,10 @@ var _ = Describe("Header", func() {
PacketNumberLen: protocol.PacketNumberLen2,
}
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(expectedLen))
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
b, err := h.Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(expectedLen))
})
})

View file

@ -15,8 +15,7 @@ import (
var _ = Describe("Header Parsing", func() {
Context("Parsing the Connection ID", func() {
It("parses the connection ID of a long header packet", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
@ -24,15 +23,15 @@ var _ = Describe("Header Parsing", func() {
Version: protocol.Version1,
},
PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed())
connID, err := ParseConnectionID(buf.Bytes(), 8)
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
connID, err := ParseConnectionID(b, 8)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
})
It("errors on EOF, for long header packets", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}),
@ -40,9 +39,10 @@ var _ = Describe("Header Parsing", func() {
Version: protocol.Version1,
},
PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed())
data := buf.Bytes()[:buf.Len()-2] // cut the packet number
_, err := ParseConnectionID(data, 8)
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
data := b[:len(b)-2] // cut the packet number
_, err = ParseConnectionID(data, 8)
Expect(err).ToNot(HaveOccurred())
for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ {
b := make([]byte, i)
@ -419,22 +419,22 @@ var _ = Describe("Header Parsing", func() {
Context("coalesced packets", func() {
It("cuts packets", func() {
buf := &bytes.Buffer{}
hdr := Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
Length: 2 + 6,
Version: protocol.Version1,
}
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: hdr,
PacketNumber: 0x1337,
PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed())
hdrRaw := append([]byte{}, buf.Bytes()...)
buf.Write([]byte("foobar")) // payload of the first packet
buf.Write([]byte("raboof")) // second packet
parsedHdr, data, rest, err := ParsePacket(buf.Bytes())
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
hdrRaw := append([]byte{}, b...)
b = append(b, []byte("foobar")...) // payload of the first packet
b = append(b, []byte("raboof")...) // second packet
parsedHdr, data, rest, err := ParsePacket(b)
Expect(err).ToNot(HaveOccurred())
Expect(parsedHdr.Type).To(Equal(hdr.Type))
Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID))
@ -443,8 +443,7 @@ var _ = Describe("Header Parsing", func() {
})
It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
@ -453,15 +452,15 @@ var _ = Describe("Header Parsing", func() {
},
PacketNumber: 0x1337,
PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed())
_, _, _, err := ParsePacket(buf.Bytes())
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
_, _, _, err = ParsePacket(b)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)"))
})
It("errors on packets that are smaller than the length in the packet header, for too small payload", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
b, err := (&ExtendedHeader{
Header: Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
@ -470,9 +469,10 @@ var _ = Describe("Header Parsing", func() {
},
PacketNumber: 0x1337,
PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed())
buf.Write(make([]byte, 500-2 /* for packet number length */))
_, _, _, err := ParsePacket(buf.Bytes())
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
b = append(b, make([]byte, 500-2 /* for packet number length */)...)
_, _, _, err = ParsePacket(b)
Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
})
})

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
@ -54,15 +53,15 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet
return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err
}
// WriteShortHeader writes a short header.
func WriteShortHeader(b *bytes.Buffer, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) error {
// AppendShortHeader writes a short header.
func AppendShortHeader(b []byte, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) ([]byte, error) {
typeByte := 0x40 | uint8(pnLen-1)
if kp == protocol.KeyPhaseOne {
typeByte |= byte(1 << 2)
}
b.WriteByte(typeByte)
b.Write(connID.Bytes())
return writePacketNumber(b, pn, pnLen)
b = append(b, typeByte)
b = append(b, connID.Bytes()...)
return appendPacketNumber(b, pn, pnLen)
}
func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount {

View file

@ -5,6 +5,7 @@ import (
"io"
"log"
"os"
"testing"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -77,15 +78,15 @@ var _ = Describe("Short Header", func() {
Context("writing", func() {
It("writes a short header packet", func() {
b := &bytes.Buffer{}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
Expect(WriteShortHeader(b, connID, 1337, 4, protocol.KeyPhaseOne)).To(Succeed())
l, pn, pnLen, kp, err := ParseShortHeader(b.Bytes(), 4)
b, err := AppendShortHeader(nil, connID, 1337, 4, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
l, pn, pnLen, kp, err := ParseShortHeader(b, 4)
Expect(err).ToNot(HaveOccurred())
Expect(pn).To(Equal(protocol.PacketNumber(1337)))
Expect(pnLen).To(Equal(protocol.PacketNumberLen4))
Expect(kp).To(Equal(protocol.KeyPhaseOne))
Expect(l).To(Equal(b.Len()))
Expect(l).To(Equal(len(b)))
})
})
@ -113,3 +114,17 @@ var _ = Describe("Short Header", func() {
})
})
})
func BenchmarkWriteShortHeader(b *testing.B) {
b.ReportAllocs()
buf := make([]byte, 100)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
for i := 0; i < b.N; i++ {
var err error
buf, err = AppendShortHeader(buf, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
if err != nil {
b.Fatalf("failed to write short header: %s", err)
}
buf = buf[:0]
}
}

View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"crypto/rand"
"errors"
"net"
@ -14,7 +13,6 @@ import (
"github.com/lucas-clemente/quic-go/logging"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
@ -37,8 +35,7 @@ var _ = Describe("Packet Handler Map", func() {
)
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
b, err := (&wire.ExtendedHeader{
Header: wire.Header{
Type: t,
DestConnectionID: connID,
@ -46,8 +43,9 @@ var _ = Describe("Packet Handler Map", func() {
Version: protocol.VersionTLS,
},
PacketNumberLen: protocol.PacketNumberLen2,
}).Write(buf, protocol.VersionWhatever)).To(Succeed())
return buf.Bytes()
}).Append(nil, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
return b
}
getPacket := func(connID protocol.ConnectionID) []byte {

View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"errors"
"fmt"
"net"
@ -800,13 +799,12 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
paddingLen += padding
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen
raw := buffer.Data[len(buffer.Data):]
buf := bytes.NewBuffer(buffer.Data)
startLen := buf.Len()
if err := header.Write(buf, p.version); err != nil {
startLen := len(buffer.Data)
raw := buffer.Data[startLen:]
raw, err := header.Append(raw, p.version)
if err != nil {
return nil, err
}
raw = raw[:buf.Len()-startLen]
payloadOffset := protocol.ByteCount(len(raw))
pn := p.pnManager.PopPacketNumber(encLevel)
@ -814,7 +812,7 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err := p.appendPacketPayload(raw, payload, paddingLen)
raw, err = p.appendPacketPayload(raw, payload, paddingLen)
if err != nil {
return nil, err
}
@ -846,20 +844,19 @@ func (p *packetPacker) appendShortHeaderPacket(
}
paddingLen += padding
raw := buffer.Data[len(buffer.Data):]
buf := bytes.NewBuffer(buffer.Data)
startLen := buf.Len()
if err := wire.WriteShortHeader(buf, connID, pn, pnLen, kp); err != nil {
startLen := len(buffer.Data)
raw := buffer.Data[startLen:]
raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp)
if err != nil {
return nil, err
}
raw = raw[:buf.Len()-startLen]
payloadOffset := protocol.ByteCount(len(raw))
if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err := p.appendPacketPayload(raw, payload, paddingLen)
raw, err = p.appendPacketPayload(raw, payload, paddingLen)
if err != nil {
return nil, err
}

View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"errors"
"time"
@ -12,7 +11,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
@ -28,21 +26,22 @@ var _ = Describe("Packet Unpacker", func() {
)
getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
buf := &bytes.Buffer{}
ExpectWithOffset(1, extHdr.Write(buf, version)).To(Succeed())
hdrLen := buf.Len()
if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) {
buf.Write(make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen)))
}
hdr, _, _, err := wire.ParsePacket(buf.Bytes())
b, err := extHdr.Append(nil, version)
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, err).ToNot(HaveOccurred())
return hdr, buf.Bytes()[:hdrLen]
hdrLen := len(b)
if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) {
b = append(b, make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))...)
}
hdr, _, _, err := wire.ParsePacket(b)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
return hdr, b[:hdrLen]
}
getShortHeader := func(connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) []byte {
buf := &bytes.Buffer{}
Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, kp)).To(Succeed())
return buf.Bytes()
b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, kp)
Expect(err).ToNot(HaveOccurred())
return b
}
BeforeEach(func() {
@ -72,12 +71,12 @@ var _ = Describe("Packet Unpacker", func() {
})
It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() {
buf := &bytes.Buffer{}
Expect(wire.WriteShortHeader(buf, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed())
data := append(buf.Bytes(), make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
b, err := wire.AppendShortHeader(nil, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
data := append(b, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), data)
_, _, _, _, err = unpacker.UnpackShortHeader(time.Now(), data)
Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19"))
})

View file

@ -107,32 +107,32 @@ func Append(b []byte, i uint64) []byte {
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
}
// WriteWithLen writes i in the QUIC varint format with the desired length to w.
func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) {
// AppendWithLen append i in the QUIC varint format with the desired length.
func AppendWithLen(b []byte, i uint64, length protocol.ByteCount) []byte {
if length != 1 && length != 2 && length != 4 && length != 8 {
panic("invalid varint length")
}
l := Len(i)
if l == length {
Write(w, i)
return
return Append(b, i)
}
if l > length {
panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length))
}
if length == 2 {
w.WriteByte(0b01000000)
b = append(b, 0b01000000)
} else if length == 4 {
w.WriteByte(0b10000000)
b = append(b, 0b10000000)
} else if length == 8 {
w.WriteByte(0b11000000)
b = append(b, 0b11000000)
}
for j := protocol.ByteCount(1); j < length-l; j++ {
w.WriteByte(0)
b = append(b, 0)
}
for j := protocol.ByteCount(0); j < l; j++ {
w.WriteByte(uint8(i >> (8 * (l - 1 - j))))
b = append(b, uint8(i>>(8*(l-1-j))))
}
return b
}
// Len determines the number of bytes that will be needed to write the number i.

View file

@ -142,54 +142,47 @@ var _ = Describe("Varint encoding / decoding", func() {
Context("with fixed length", func() {
It("panics when given an invalid length", func() {
Expect(func() { WriteWithLen(&bytes.Buffer{}, 25, 3) }).Should(Panic())
Expect(func() { AppendWithLen(nil, 25, 3) }).Should(Panic())
})
It("panics when given a too short length", func() {
Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt1+1, 1) }).Should(Panic())
Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt2+1, 2) }).Should(Panic())
Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt4+1, 4) }).Should(Panic())
Expect(func() { AppendWithLen(nil, maxVarInt1+1, 1) }).Should(Panic())
Expect(func() { AppendWithLen(nil, maxVarInt2+1, 2) }).Should(Panic())
Expect(func() { AppendWithLen(nil, maxVarInt4+1, 4) }).Should(Panic())
})
It("writes a 1-byte number in minimal encoding", func() {
b := &bytes.Buffer{}
WriteWithLen(b, 37, 1)
Expect(b.Bytes()).To(Equal([]byte{0x25}))
Expect(AppendWithLen(nil, 37, 1)).To(Equal([]byte{0x25}))
})
It("writes a 1-byte number in 2 bytes", func() {
b := &bytes.Buffer{}
WriteWithLen(b, 37, 2)
Expect(b.Bytes()).To(Equal([]byte{0b01000000, 0x25}))
Expect(Read(b)).To(BeEquivalentTo(37))
b := AppendWithLen(nil, 37, 2)
Expect(b).To(Equal([]byte{0b01000000, 0x25}))
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37))
})
It("writes a 1-byte number in 4 bytes", func() {
b := &bytes.Buffer{}
WriteWithLen(b, 37, 4)
Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0, 0x25}))
Expect(Read(b)).To(BeEquivalentTo(37))
b := AppendWithLen(nil, 37, 4)
Expect(b).To(Equal([]byte{0b10000000, 0, 0, 0x25}))
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37))
})
It("writes a 1-byte number in 8 bytes", func() {
b := &bytes.Buffer{}
WriteWithLen(b, 37, 8)
Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25}))
Expect(Read(b)).To(BeEquivalentTo(37))
b := AppendWithLen(nil, 37, 8)
Expect(b).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25}))
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37))
})
It("writes a 2-byte number in 4 bytes", func() {
b := &bytes.Buffer{}
WriteWithLen(b, 15293, 4)
Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd}))
Expect(Read(b)).To(BeEquivalentTo(15293))
b := AppendWithLen(nil, 15293, 4)
Expect(b).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd}))
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(15293))
})
It("write a 4-byte number in 8 bytes", func() {
b := &bytes.Buffer{}
WriteWithLen(b, 494878333, 8)
Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d}))
Expect(Read(b)).To(BeEquivalentTo(494878333))
b := AppendWithLen(nil, 494878333, 8)
Expect(b).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d}))
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(494878333))
})
})

View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
@ -577,19 +576,19 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
replyHdr.Log(s.logger)
}
packetBuffer := getPacketBuffer()
defer packetBuffer.Release()
buf := bytes.NewBuffer(packetBuffer.Data)
if err := replyHdr.Write(buf, hdr.Version); err != nil {
buf := getPacketBuffer()
defer buf.Release()
buf.Data, err = replyHdr.Append(buf.Data, hdr.Version)
if err != nil {
return err
}
// append the Retry integrity tag
tag := handshake.GetRetryIntegrityTag(buf.Bytes(), hdr.DestConnectionID, hdr.Version)
buf.Write(tag[:])
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
buf.Data = append(buf.Data, tag[:]...)
if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil)
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
}
_, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info.OOB())
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB())
return err
}
@ -627,9 +626,8 @@ func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header
// sendError sends the error as a response to the packet received with header hdr
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error {
packetBuffer := getPacketBuffer()
defer packetBuffer.Release()
buf := bytes.NewBuffer(packetBuffer.Data)
b := getPacketBuffer()
defer b.Release()
ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)}
@ -640,33 +638,34 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
replyHdr.DestConnectionID = hdr.SrcConnectionID
replyHdr.PacketNumberLen = protocol.PacketNumberLen4
replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead())
if err := replyHdr.Write(buf, hdr.Version); err != nil {
var err error
b.Data, err = replyHdr.Append(b.Data, hdr.Version)
if err != nil {
return err
}
payloadOffset := buf.Len()
payloadOffset := len(b.Data)
raw := buf.Bytes()
raw, err := ccf.Append(raw, hdr.Version)
b.Data, err = ccf.Append(b.Data, hdr.Version)
if err != nil {
return err
}
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset])
raw = raw[0 : len(raw)+sealer.Overhead()]
_ = sealer.Seal(b.Data[payloadOffset:payloadOffset], b.Data[payloadOffset:], replyHdr.PacketNumber, b.Data[:payloadOffset])
b.Data = b.Data[0 : len(b.Data)+sealer.Overhead()]
pnOffset := payloadOffset - int(replyHdr.PacketNumberLen)
sealer.EncryptHeader(
raw[pnOffset+4:pnOffset+4+16],
&raw[0],
raw[pnOffset:payloadOffset],
b.Data[pnOffset+4:pnOffset+4+16],
&b.Data[0],
b.Data[pnOffset:payloadOffset],
)
replyHdr.Log(s.logger)
wire.LogFrame(s.logger, ccf, true)
if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf})
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
}
_, err = s.conn.WritePacket(raw, remoteAddr, info.OOB())
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB())
return err
}

View file

@ -42,17 +42,18 @@ var _ = Describe("Server", func() {
)
getPacket := func(hdr *wire.Header, p []byte) *receivedPacket {
buffer := getPacketBuffer()
buf := bytes.NewBuffer(buffer.Data)
buf := getPacketBuffer()
hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
Expect((&wire.ExtendedHeader{
var err error
buf.Data, err = (&wire.ExtendedHeader{
Header: *hdr,
PacketNumber: 0x42,
PacketNumberLen: protocol.PacketNumberLen4,
}).Write(buf, protocol.VersionTLS)).To(Succeed())
n := buf.Len()
buf.Write(p)
data := buffer.Data[:buf.Len()]
}).Append(buf.Data, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
n := len(buf.Data)
buf.Data = append(buf.Data, p...)
data := buf.Data
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
data = data[:len(data)+16]
@ -60,7 +61,7 @@ var _ = Describe("Server", func() {
return &receivedPacket{
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
data: data,
buffer: buffer,
buffer: buf,
}
}