From c24fbb094cde74cbc590ebe13f3deec95887d08a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 17 Jan 2023 01:56:06 -0800 Subject: [PATCH] refactor header writing to append to a byte slice (#3646) This avoids having to allocate a bytes.Buffer. --- connection_test.go | 36 +++---- fuzzing/header/cmd/corpus.go | 17 ++- fuzzing/header/fuzz.go | 8 +- integrationtests/self/mitm_test.go | 25 +++-- integrationtests/tools/proxy/proxy_test.go | 9 +- internal/testutils/testutils.go | 10 +- internal/wire/extended_header.go | 59 +++++----- internal/wire/extended_header_test.go | 120 ++++++++++----------- internal/wire/header_test.go | 50 ++++----- internal/wire/short_header.go | 11 +- internal/wire/short_header_test.go | 23 +++- packet_handler_map_test.go | 10 +- packet_packer.go | 23 ++-- packet_unpacker_test.go | 33 +++--- quicvarint/varint.go | 18 ++-- quicvarint/varint_test.go | 47 ++++---- server.go | 45 ++++---- server_test.go | 17 +-- 18 files changed, 282 insertions(+), 279 deletions(-) diff --git a/connection_test.go b/connection_test.go index 532cc609..320f123b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -571,8 +571,8 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - buf := &bytes.Buffer{} - Expect(wire.WriteShortHeader(buf, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed()) + b, err := wire.AppendShortHeader(nil, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version) @@ -589,7 +589,7 @@ var _ = Describe("Connection", func() { rcvTime: time.Now(), remoteAddr: &net.UDPAddr{}, buffer: getPacketBuffer(), - data: buf.Bytes(), + data: b, }) // Consistently(pack).ShouldNot(Receive()) Eventually(conn.Context().Done()).Should(BeClosed()) @@ -649,20 +649,20 @@ var _ = Describe("Connection", func() { }) getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) *receivedPacket { - buf := &bytes.Buffer{} - Expect(wire.WriteShortHeader(buf, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed()) + b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) return &receivedPacket{ - data: append(buf.Bytes(), data...), + data: append(b, data...), buffer: getPacketBuffer(), rcvTime: time.Now(), } } getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { - buf := &bytes.Buffer{} - Expect(extHdr.Write(buf, conn.version)).To(Succeed()) + b, err := extHdr.Append(nil, conn.version) + Expect(err).ToNot(HaveOccurred()) return &receivedPacket{ - data: append(buf.Bytes(), data...), + data: append(b, data...), buffer: getPacketBuffer(), rcvTime: time.Now(), } @@ -2365,10 +2365,10 @@ var _ = Describe("Client Connection", func() { destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { - buf := &bytes.Buffer{} - Expect(hdr.Write(buf, conn.version)).To(Succeed()) + b, err := hdr.Append(nil, conn.version) + Expect(err).ToNot(HaveOccurred()) return &receivedPacket{ - data: append(buf.Bytes(), data...), + data: append(b, data...), buffer: getPacketBuffer(), } } @@ -2643,9 +2643,9 @@ var _ = Describe("Client Connection", func() { }) getRetryTag := func(hdr *wire.ExtendedHeader) []byte { - buf := &bytes.Buffer{} - hdr.Write(buf, conn.version) - return handshake.GetRetryIntegrityTag(buf.Bytes(), origDestConnID, hdr.Version)[:] + b, err := hdr.Append(nil, conn.version) + Expect(err).ToNot(HaveOccurred()) + return handshake.GetRetryIntegrityTag(b, origDestConnID, hdr.Version)[:] } It("handles Retry packets", func() { @@ -2858,10 +2858,10 @@ var _ = Describe("Client Connection", func() { var unpacker *MockUnpacker getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { - buf := &bytes.Buffer{} - Expect(extHdr.Write(buf, conn.version)).To(Succeed()) + b, err := extHdr.Append(nil, conn.version) + Expect(err).ToNot(HaveOccurred()) return &receivedPacket{ - data: append(buf.Bytes(), data...), + data: append(b, data...), buffer: getPacketBuffer(), } } diff --git a/fuzzing/header/cmd/corpus.go b/fuzzing/header/cmd/corpus.go index 9d0b8307..cc966dcf 100644 --- a/fuzzing/header/cmd/corpus.go +++ b/fuzzing/header/cmd/corpus.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "log" "math/rand" @@ -85,28 +84,28 @@ func main() { PacketNumberLen: protocol.PacketNumberLen(rand.Intn(4) + 1), PacketNumber: protocol.PacketNumber(rand.Uint64()), } - b := &bytes.Buffer{} - if err := extHdr.Write(b, version); err != nil { + b, err := extHdr.Append(nil, version) + if err != nil { log.Fatal(err) } if h.Type == protocol.PacketTypeRetry { - b.Write([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) + b = append(b, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}...) } if h.Length > 0 { - b.Write(make([]byte, h.Length)) + b = append(b, make([]byte, h.Length)...) } - if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil { + if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil { log.Fatal(err) } } // short header - b := &bytes.Buffer{} - if err := wire.WriteShortHeader(b, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne); err != nil { + b, err := wire.AppendShortHeader(nil, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne) + if err != nil { log.Fatal(err) } - if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil { + if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil { log.Fatal(err) } diff --git a/fuzzing/header/fuzz.go b/fuzzing/header/fuzz.go index 2c8b2abe..34cdedfd 100644 --- a/fuzzing/header/fuzz.go +++ b/fuzzing/header/fuzz.go @@ -64,8 +64,8 @@ func Fuzz(data []byte) int { if hdr.Length > 16383 { return 1 } - b := &bytes.Buffer{} - if err := extHdr.Write(b, version); err != nil { + b, err := extHdr.Append(nil, version) + if err != nil { // We are able to parse packets with connection IDs longer than 20 bytes, // but in QUIC version 1, we don't write headers with longer connection IDs. if hdr.DestConnectionID.Len() <= protocol.MaxConnIDLen && @@ -76,8 +76,8 @@ func Fuzz(data []byte) int { } // GetLength is not implemented for Retry packets if hdr.Type != protocol.PacketTypeRetry { - if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(b.Len()) { - panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, b.Len())) + if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(len(b)) { + panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, len(b))) } } return 1 diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 45e8dbde..dcc0761d 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -1,7 +1,6 @@ package self_test import ( - "bytes" "context" "errors" "fmt" @@ -116,12 +115,12 @@ var _ = Describe("MITM test", func() { for i := 0; i < numPackets; i++ { payloadLen := mrand.Int31n(100) replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1)) - buf := &bytes.Buffer{} - Expect(replyHdr.Write(buf, version)).To(Succeed()) - b := make([]byte, payloadLen) - mrand.Read(b) - buf.Write(b) - if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { + b, err := replyHdr.Append(nil, version) + Expect(err).ToNot(HaveOccurred()) + r := make([]byte, payloadLen) + mrand.Read(r) + b = append(b, r...) + if _, err := conn.WriteTo(b, remoteAddr); err != nil { return } <-ticker.C @@ -134,13 +133,13 @@ var _ = Describe("MITM test", func() { Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) } for i := 0; i < numPackets; i++ { - buf := &bytes.Buffer{} - Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))).To(Succeed()) + b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2))) + Expect(err).ToNot(HaveOccurred()) payloadLen := mrand.Int31n(100) - b := make([]byte, payloadLen) - mrand.Read(b) - buf.Write(b) - if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { + r := make([]byte, payloadLen) + mrand.Read(r) + b = append(b, r...) + if _, err := conn.WriteTo(b, remoteAddr); err != nil { return } <-ticker.C diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index 9136a9ae..bd6eb8fd 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -28,7 +28,6 @@ func isProxyRunning() bool { var _ = Describe("QUIC Proxy", func() { makePacket := func(p protocol.PacketNumber, payload []byte) []byte { - b := &bytes.Buffer{} hdr := wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeInitial, @@ -40,10 +39,10 @@ var _ = Describe("QUIC Proxy", func() { PacketNumber: p, PacketNumberLen: protocol.PacketNumberLen4, } - Expect(hdr.Write(b, protocol.VersionWhatever)).To(Succeed()) - raw := b.Bytes() - raw = append(raw, payload...) - return raw + b, err := hdr.Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + b = append(b, payload...) + return b } readPacketNumber := func(b []byte) protocol.PacketNumber { diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go index e0d8b8ed..cd5994eb 100644 --- a/internal/testutils/testutils.go +++ b/internal/testutils/testutils.go @@ -1,7 +1,7 @@ package testutils import ( - "bytes" + "fmt" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -13,9 +13,11 @@ import ( // writePacket returns a new raw packet with the specified header and payload func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte { - buf := &bytes.Buffer{} - hdr.Write(buf, hdr.Version) - return append(buf.Bytes(), data...) + b, err := hdr.Append(nil, hdr.Version) + if err != nil { + panic(fmt.Sprintf("failed to write header: %s", err)) + } + return append(b, data...) } // packRawPayload returns a new raw payload containing given frames diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 02527255..6febae02 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "encoding/binary" "errors" "fmt" "io" @@ -93,20 +94,17 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { return nil } -// Write writes the Header. -func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { +// Append appends the Header. +func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) { if h.DestConnectionID.Len() > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) + return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) } if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) + return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) } - return h.writeLongHeader(b, ver) -} -func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error { var packetType uint8 - if version == protocol.Version2 { + if v == protocol.Version2 { //nolint:exhaustive switch h.Type { case protocol.PacketTypeInitial: @@ -137,24 +135,25 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.Versi firstByte |= uint8(h.PacketNumberLen - 1) } - b.WriteByte(firstByte) - utils.BigEndian.WriteUint32(b, uint32(h.Version)) - b.WriteByte(uint8(h.DestConnectionID.Len())) - b.Write(h.DestConnectionID.Bytes()) - b.WriteByte(uint8(h.SrcConnectionID.Len())) - b.Write(h.SrcConnectionID.Bytes()) + b = append(b, firstByte) + b = append(b, make([]byte, 4)...) + binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version)) + b = append(b, uint8(h.DestConnectionID.Len())) + b = append(b, h.DestConnectionID.Bytes()...) + b = append(b, uint8(h.SrcConnectionID.Len())) + b = append(b, h.SrcConnectionID.Bytes()...) //nolint:exhaustive switch h.Type { case protocol.PacketTypeRetry: - b.Write(h.Token) - return nil + b = append(b, h.Token...) + return b, nil case protocol.PacketTypeInitial: - quicvarint.Write(b, uint64(len(h.Token))) - b.Write(h.Token) + b = quicvarint.Append(b, uint64(len(h.Token))) + b = append(b, h.Token...) } - quicvarint.WriteWithLen(b, uint64(h.Length), 2) - return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen) + b = quicvarint.AppendWithLen(b, uint64(h.Length), 2) + return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen) } // ParsedLen returns the number of bytes that were consumed when parsing the header @@ -188,18 +187,24 @@ func (h *ExtendedHeader) Log(logger utils.Logger) { logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) } -func writePacketNumber(b *bytes.Buffer, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) error { +func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) { switch pnLen { case protocol.PacketNumberLen1: - b.WriteByte(uint8(pn)) + b = append(b, uint8(pn)) case protocol.PacketNumberLen2: - utils.BigEndian.WriteUint16(b, uint16(pn)) + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, uint16(pn)) + b = append(b, buf...) case protocol.PacketNumberLen3: - utils.BigEndian.WriteUint24(b, uint32(pn)) + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(pn)) + b = append(b, buf[1:]...) case protocol.PacketNumberLen4: - utils.BigEndian.WriteUint32(b, uint32(pn)) + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(pn)) + b = append(b, buf...) default: - return fmt.Errorf("invalid packet number length: %d", pnLen) + return nil, fmt.Errorf("invalid packet number length: %d", pnLen) } - return nil + return b, nil } diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go index 44e5ff7b..c13a0395 100644 --- a/internal/wire/extended_header_test.go +++ b/internal/wire/extended_header_test.go @@ -14,20 +14,12 @@ import ( ) var _ = Describe("Header", func() { - const versionIETFHeader = protocol.Version1 - Context("Writing", func() { - var buf *bytes.Buffer - - BeforeEach(func() { - buf = &bytes.Buffer{} - }) - - Context("Long Header", func() { + Context("Long Header, version 1", func() { srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) It("writes", func() { - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}), @@ -37,7 +29,8 @@ var _ = Describe("Header", func() { }, PacketNumber: 0xdecaf, PacketNumberLen: protocol.PacketNumberLen3, - }).Write(buf, versionIETFHeader)).To(Succeed()) + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{ 0xc0 | 0x2<<4 | 0x2, 0x1, 0x2, 0x3, 0x4, // version number @@ -48,11 +41,11 @@ var _ = Describe("Header", func() { } expected = append(expected, encodeVarInt(protocol.InitialPacketSizeIPv4)...) // length expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number - Expect(buf.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a header with a 20 byte connection ID", func() { - err := (&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ SrcConnectionID: srcConnID, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}), // connection IDs must be at most 20 bytes long @@ -61,14 +54,14 @@ var _ = Describe("Header", func() { }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader) + }).Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}))) + Expect(b).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}))) }) It("writes an Initial containing a token", func() { token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Version: 0x1020304, Type: protocol.PacketTypeInitial, @@ -76,14 +69,15 @@ var _ = Describe("Header", func() { }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0) + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b[0]>>4&0b11 == 0) expectedSubstring := append(encodeVarInt(uint64(len(token))), token...) - Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring))) + Expect(b).To(ContainSubstring(string(expectedSubstring))) }) It("uses a 2-byte encoding for the length on Initial packets", func() { - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Version: 0x1020304, Type: protocol.PacketTypeInitial, @@ -91,89 +85,88 @@ var _ = Describe("Header", func() { }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader)).To(Succeed()) - b := &bytes.Buffer{} - quicvarint.WriteWithLen(b, 37, 2) - Expect(buf.Bytes()[buf.Len()-6 : buf.Len()-4]).To(Equal(b.Bytes())) + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + lengthEncoded := quicvarint.AppendWithLen(nil, 37, 2) + Expect(b[len(b)-6 : len(b)-4]).To(Equal(lengthEncoded)) }) It("writes a Retry packet", func() { token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") - Expect((&ExtendedHeader{Header: Header{ + b, err := (&ExtendedHeader{Header: Header{ Version: protocol.Version1, Type: protocol.PacketTypeRetry, Token: token, - }}).Write(buf, versionIETFHeader)).To(Succeed()) + }}).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0xc0 | 0b11<<4} expected = appendVersion(expected, protocol.Version1) expected = append(expected, 0x0) // dest connection ID length expected = append(expected, 0x0) // src connection ID length expected = append(expected, token...) - Expect(buf.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) }) Context("long header, version 2", func() { It("writes an Initial", func() { - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Version: protocol.Version2, Type: protocol.PacketTypeInitial, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.Version2)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0b01) + }).Append(nil, protocol.Version2) + Expect(err).ToNot(HaveOccurred()) + Expect(b[0]>>4&0b11 == 0b01) }) It("writes a Retry packet", func() { token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") - Expect((&ExtendedHeader{Header: Header{ + b, err := (&ExtendedHeader{Header: Header{ Version: protocol.Version2, Type: protocol.PacketTypeRetry, Token: token, - }}).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0xc0 | 0b11<<4} + }}).Append(nil, protocol.Version2) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0xc0 | 0b00<<4} expected = appendVersion(expected, protocol.Version2) expected = append(expected, 0x0) // dest connection ID length expected = append(expected, 0x0) // src connection ID length expected = append(expected, token...) - Expect(buf.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a Handshake Packet", func() { - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Version: protocol.Version2, Type: protocol.PacketTypeHandshake, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.Version2)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0b11) + }).Append(nil, protocol.Version2) + Expect(err).ToNot(HaveOccurred()) + Expect(b[0]>>4&0b11 == 0b11) }) It("writes a 0-RTT Packet", func() { - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Version: protocol.Version2, Type: protocol.PacketType0RTT, }, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.Version2)).To(Succeed()) - Expect(buf.Bytes()[0]>>4&0b11 == 0b10) + }).Append(nil, protocol.Version2) + Expect(err).ToNot(HaveOccurred()) + Expect(b[0]>>4&0b11 == 0b10) }) }) }) Context("getting the length", func() { - var buf *bytes.Buffer - - BeforeEach(func() { - buf = &bytes.Buffer{} - }) - It("has the right length for the Long Header, for a short length", func() { h := &ExtendedHeader{ Header: Header{ @@ -185,9 +178,10 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen1, } expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* length */ + 1 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) + Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen)) + b, err := h.Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(expectedLen)) }) It("has the right length for the Long Header, for a long length", func() { @@ -201,9 +195,10 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen2, } expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) + Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen)) + b, err := h.Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(expectedLen)) }) It("has the right length for an Initial that has a short length", func() { @@ -217,9 +212,10 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen2, } expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) + Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen)) + b, err := h.Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(expectedLen)) }) It("has the right length for an Initial not containing a Token", func() { @@ -233,9 +229,10 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen2, } expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) + Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen)) + b, err := h.Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(expectedLen)) }) It("has the right length for an Initial containing a Token", func() { @@ -250,9 +247,10 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen2, } expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */ - Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen)) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(expectedLen)) + Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen)) + b, err := h.Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(expectedLen)) }) }) diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index a58365c9..8468303b 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -15,8 +15,7 @@ import ( var _ = Describe("Header Parsing", func() { Context("Parsing the Connection ID", func() { It("parses the connection ID of a long header packet", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), @@ -24,15 +23,15 @@ var _ = Describe("Header Parsing", func() { Version: protocol.Version1, }, PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - connID, err := ParseConnectionID(buf.Bytes(), 8) + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + connID, err := ParseConnectionID(b, 8) Expect(err).ToNot(HaveOccurred()) Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) }) It("errors on EOF, for long header packets", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}), @@ -40,9 +39,10 @@ var _ = Describe("Header Parsing", func() { Version: protocol.Version1, }, PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - data := buf.Bytes()[:buf.Len()-2] // cut the packet number - _, err := ParseConnectionID(data, 8) + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + data := b[:len(b)-2] // cut the packet number + _, err = ParseConnectionID(data, 8) Expect(err).ToNot(HaveOccurred()) for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ { b := make([]byte, i) @@ -419,22 +419,22 @@ var _ = Describe("Header Parsing", func() { Context("coalesced packets", func() { It("cuts packets", func() { - buf := &bytes.Buffer{} hdr := Header{ Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 2 + 6, Version: protocol.Version1, } - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: hdr, PacketNumber: 0x1337, PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - hdrRaw := append([]byte{}, buf.Bytes()...) - buf.Write([]byte("foobar")) // payload of the first packet - buf.Write([]byte("raboof")) // second packet - parsedHdr, data, rest, err := ParsePacket(buf.Bytes()) + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + hdrRaw := append([]byte{}, b...) + b = append(b, []byte("foobar")...) // payload of the first packet + b = append(b, []byte("raboof")...) // second packet + parsedHdr, data, rest, err := ParsePacket(b) Expect(err).ToNot(HaveOccurred()) Expect(parsedHdr.Type).To(Equal(hdr.Type)) Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID)) @@ -443,8 +443,7 @@ var _ = Describe("Header Parsing", func() { }) It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), @@ -453,15 +452,15 @@ var _ = Describe("Header Parsing", func() { }, PacketNumber: 0x1337, PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - _, _, _, err := ParsePacket(buf.Bytes()) + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + _, _, _, err = ParsePacket(b) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)")) }) It("errors on packets that are smaller than the length in the packet header, for too small payload", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ + b, err := (&ExtendedHeader{ Header: Header{ Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), @@ -470,9 +469,10 @@ var _ = Describe("Header Parsing", func() { }, PacketNumber: 0x1337, PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - buf.Write(make([]byte, 500-2 /* for packet number length */)) - _, _, _, err := ParsePacket(buf.Bytes()) + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + b = append(b, make([]byte, 500-2 /* for packet number length */)...) + _, _, _, err = ParsePacket(b) Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) }) }) diff --git a/internal/wire/short_header.go b/internal/wire/short_header.go index 77308ad2..dfe95739 100644 --- a/internal/wire/short_header.go +++ b/internal/wire/short_header.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "errors" "fmt" "io" @@ -54,15 +53,15 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err } -// WriteShortHeader writes a short header. -func WriteShortHeader(b *bytes.Buffer, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) error { +// AppendShortHeader writes a short header. +func AppendShortHeader(b []byte, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) ([]byte, error) { typeByte := 0x40 | uint8(pnLen-1) if kp == protocol.KeyPhaseOne { typeByte |= byte(1 << 2) } - b.WriteByte(typeByte) - b.Write(connID.Bytes()) - return writePacketNumber(b, pn, pnLen) + b = append(b, typeByte) + b = append(b, connID.Bytes()...) + return appendPacketNumber(b, pn, pnLen) } func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount { diff --git a/internal/wire/short_header_test.go b/internal/wire/short_header_test.go index 74dd55da..63fe5ca5 100644 --- a/internal/wire/short_header_test.go +++ b/internal/wire/short_header_test.go @@ -5,6 +5,7 @@ import ( "io" "log" "os" + "testing" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -77,15 +78,15 @@ var _ = Describe("Short Header", func() { Context("writing", func() { It("writes a short header packet", func() { - b := &bytes.Buffer{} connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) - Expect(WriteShortHeader(b, connID, 1337, 4, protocol.KeyPhaseOne)).To(Succeed()) - l, pn, pnLen, kp, err := ParseShortHeader(b.Bytes(), 4) + b, err := AppendShortHeader(nil, connID, 1337, 4, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) + l, pn, pnLen, kp, err := ParseShortHeader(b, 4) Expect(err).ToNot(HaveOccurred()) Expect(pn).To(Equal(protocol.PacketNumber(1337))) Expect(pnLen).To(Equal(protocol.PacketNumberLen4)) Expect(kp).To(Equal(protocol.KeyPhaseOne)) - Expect(l).To(Equal(b.Len())) + Expect(l).To(Equal(len(b))) }) }) @@ -113,3 +114,17 @@ var _ = Describe("Short Header", func() { }) }) }) + +func BenchmarkWriteShortHeader(b *testing.B) { + b.ReportAllocs() + buf := make([]byte, 100) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) + for i := 0; i < b.N; i++ { + var err error + buf, err = AppendShortHeader(buf, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne) + if err != nil { + b.Fatalf("failed to write short header: %s", err) + } + buf = buf[:0] + } +} diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 6c63cd80..6ab1abdc 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "crypto/rand" "errors" "net" @@ -14,7 +13,6 @@ import ( "github.com/lucas-clemente/quic-go/logging" "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -37,8 +35,7 @@ var _ = Describe("Packet Handler Map", func() { ) getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { - buf := &bytes.Buffer{} - Expect((&wire.ExtendedHeader{ + b, err := (&wire.ExtendedHeader{ Header: wire.Header{ Type: t, DestConnectionID: connID, @@ -46,8 +43,9 @@ var _ = Describe("Packet Handler Map", func() { Version: protocol.VersionTLS, }, PacketNumberLen: protocol.PacketNumberLen2, - }).Write(buf, protocol.VersionWhatever)).To(Succeed()) - return buf.Bytes() + }).Append(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + return b } getPacket := func(connID protocol.ConnectionID) []byte { diff --git a/packet_packer.go b/packet_packer.go index 3a0c2e2c..8c707838 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "errors" "fmt" "net" @@ -800,13 +799,12 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire paddingLen += padding header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen - raw := buffer.Data[len(buffer.Data):] - buf := bytes.NewBuffer(buffer.Data) - startLen := buf.Len() - if err := header.Write(buf, p.version); err != nil { + startLen := len(buffer.Data) + raw := buffer.Data[startLen:] + raw, err := header.Append(raw, p.version) + if err != nil { return nil, err } - raw = raw[:buf.Len()-startLen] payloadOffset := protocol.ByteCount(len(raw)) pn := p.pnManager.PopPacketNumber(encLevel) @@ -814,7 +812,7 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - raw, err := p.appendPacketPayload(raw, payload, paddingLen) + raw, err = p.appendPacketPayload(raw, payload, paddingLen) if err != nil { return nil, err } @@ -846,20 +844,19 @@ func (p *packetPacker) appendShortHeaderPacket( } paddingLen += padding - raw := buffer.Data[len(buffer.Data):] - buf := bytes.NewBuffer(buffer.Data) - startLen := buf.Len() - if err := wire.WriteShortHeader(buf, connID, pn, pnLen, kp); err != nil { + startLen := len(buffer.Data) + raw := buffer.Data[startLen:] + raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp) + if err != nil { return nil, err } - raw = raw[:buf.Len()-startLen] payloadOffset := protocol.ByteCount(len(raw)) if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) { return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - raw, err := p.appendPacketPayload(raw, payload, paddingLen) + raw, err = p.appendPacketPayload(raw, payload, paddingLen) if err != nil { return nil, err } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 02bfe553..2b93a41e 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "errors" "time" @@ -12,7 +11,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -28,21 +26,22 @@ var _ = Describe("Packet Unpacker", func() { ) getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { - buf := &bytes.Buffer{} - ExpectWithOffset(1, extHdr.Write(buf, version)).To(Succeed()) - hdrLen := buf.Len() - if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) { - buf.Write(make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))) - } - hdr, _, _, err := wire.ParsePacket(buf.Bytes()) + b, err := extHdr.Append(nil, version) + Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred()) - return hdr, buf.Bytes()[:hdrLen] + hdrLen := len(b) + if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) { + b = append(b, make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))...) + } + hdr, _, _, err := wire.ParsePacket(b) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + return hdr, b[:hdrLen] } getShortHeader := func(connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) []byte { - buf := &bytes.Buffer{} - Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, kp)).To(Succeed()) - return buf.Bytes() + b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, kp) + Expect(err).ToNot(HaveOccurred()) + return b } BeforeEach(func() { @@ -72,12 +71,12 @@ var _ = Describe("Packet Unpacker", func() { }) It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() { - buf := &bytes.Buffer{} - Expect(wire.WriteShortHeader(buf, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed()) - data := append(buf.Bytes(), make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) + b, err := wire.AppendShortHeader(nil, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) + data := append(b, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) opener := mocks.NewMockShortHeaderOpener(mockCtrl) cs.EXPECT().Get1RTTOpener().Return(opener, nil) - _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), data) + _, _, _, _, err = unpacker.UnpackShortHeader(time.Now(), data) Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19")) }) diff --git a/quicvarint/varint.go b/quicvarint/varint.go index 28dcaa93..99eb23be 100644 --- a/quicvarint/varint.go +++ b/quicvarint/varint.go @@ -107,32 +107,32 @@ func Append(b []byte, i uint64) []byte { panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) } -// WriteWithLen writes i in the QUIC varint format with the desired length to w. -func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) { +// AppendWithLen append i in the QUIC varint format with the desired length. +func AppendWithLen(b []byte, i uint64, length protocol.ByteCount) []byte { if length != 1 && length != 2 && length != 4 && length != 8 { panic("invalid varint length") } l := Len(i) if l == length { - Write(w, i) - return + return Append(b, i) } if l > length { panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length)) } if length == 2 { - w.WriteByte(0b01000000) + b = append(b, 0b01000000) } else if length == 4 { - w.WriteByte(0b10000000) + b = append(b, 0b10000000) } else if length == 8 { - w.WriteByte(0b11000000) + b = append(b, 0b11000000) } for j := protocol.ByteCount(1); j < length-l; j++ { - w.WriteByte(0) + b = append(b, 0) } for j := protocol.ByteCount(0); j < l; j++ { - w.WriteByte(uint8(i >> (8 * (l - 1 - j)))) + b = append(b, uint8(i>>(8*(l-1-j)))) } + return b } // Len determines the number of bytes that will be needed to write the number i. diff --git a/quicvarint/varint_test.go b/quicvarint/varint_test.go index aa90809d..155cf283 100644 --- a/quicvarint/varint_test.go +++ b/quicvarint/varint_test.go @@ -142,54 +142,47 @@ var _ = Describe("Varint encoding / decoding", func() { Context("with fixed length", func() { It("panics when given an invalid length", func() { - Expect(func() { WriteWithLen(&bytes.Buffer{}, 25, 3) }).Should(Panic()) + Expect(func() { AppendWithLen(nil, 25, 3) }).Should(Panic()) }) It("panics when given a too short length", func() { - Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt1+1, 1) }).Should(Panic()) - Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt2+1, 2) }).Should(Panic()) - Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt4+1, 4) }).Should(Panic()) + Expect(func() { AppendWithLen(nil, maxVarInt1+1, 1) }).Should(Panic()) + Expect(func() { AppendWithLen(nil, maxVarInt2+1, 2) }).Should(Panic()) + Expect(func() { AppendWithLen(nil, maxVarInt4+1, 4) }).Should(Panic()) }) It("writes a 1-byte number in minimal encoding", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 37, 1) - Expect(b.Bytes()).To(Equal([]byte{0x25})) + Expect(AppendWithLen(nil, 37, 1)).To(Equal([]byte{0x25})) }) It("writes a 1-byte number in 2 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 37, 2) - Expect(b.Bytes()).To(Equal([]byte{0b01000000, 0x25})) - Expect(Read(b)).To(BeEquivalentTo(37)) + b := AppendWithLen(nil, 37, 2) + Expect(b).To(Equal([]byte{0b01000000, 0x25})) + Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37)) }) It("writes a 1-byte number in 4 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 37, 4) - Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0, 0x25})) - Expect(Read(b)).To(BeEquivalentTo(37)) + b := AppendWithLen(nil, 37, 4) + Expect(b).To(Equal([]byte{0b10000000, 0, 0, 0x25})) + Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37)) }) It("writes a 1-byte number in 8 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 37, 8) - Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25})) - Expect(Read(b)).To(BeEquivalentTo(37)) + b := AppendWithLen(nil, 37, 8) + Expect(b).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25})) + Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37)) }) It("writes a 2-byte number in 4 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 15293, 4) - Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd})) - Expect(Read(b)).To(BeEquivalentTo(15293)) + b := AppendWithLen(nil, 15293, 4) + Expect(b).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd})) + Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(15293)) }) It("write a 4-byte number in 8 bytes", func() { - b := &bytes.Buffer{} - WriteWithLen(b, 494878333, 8) - Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d})) - Expect(Read(b)).To(BeEquivalentTo(494878333)) + b := AppendWithLen(nil, 494878333, 8) + Expect(b).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d})) + Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(494878333)) }) }) diff --git a/server.go b/server.go index bf062a95..c4d3a05f 100644 --- a/server.go +++ b/server.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "context" "crypto/rand" "crypto/tls" @@ -577,19 +576,19 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack replyHdr.Log(s.logger) } - packetBuffer := getPacketBuffer() - defer packetBuffer.Release() - buf := bytes.NewBuffer(packetBuffer.Data) - if err := replyHdr.Write(buf, hdr.Version); err != nil { + buf := getPacketBuffer() + defer buf.Release() + buf.Data, err = replyHdr.Append(buf.Data, hdr.Version) + if err != nil { return err } // append the Retry integrity tag - tag := handshake.GetRetryIntegrityTag(buf.Bytes(), hdr.DestConnectionID, hdr.Version) - buf.Write(tag[:]) + tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) + buf.Data = append(buf.Data, tag[:]...) if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil) + s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) return err } @@ -627,9 +626,8 @@ func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header // sendError sends the error as a response to the packet received with header hdr func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { - packetBuffer := getPacketBuffer() - defer packetBuffer.Release() - buf := bytes.NewBuffer(packetBuffer.Data) + b := getPacketBuffer() + defer b.Release() ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)} @@ -640,33 +638,34 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han replyHdr.DestConnectionID = hdr.SrcConnectionID replyHdr.PacketNumberLen = protocol.PacketNumberLen4 replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead()) - if err := replyHdr.Write(buf, hdr.Version); err != nil { + var err error + b.Data, err = replyHdr.Append(b.Data, hdr.Version) + if err != nil { return err } - payloadOffset := buf.Len() + payloadOffset := len(b.Data) - raw := buf.Bytes() - raw, err := ccf.Append(raw, hdr.Version) + b.Data, err = ccf.Append(b.Data, hdr.Version) if err != nil { return err } - _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset]) - raw = raw[0 : len(raw)+sealer.Overhead()] + _ = sealer.Seal(b.Data[payloadOffset:payloadOffset], b.Data[payloadOffset:], replyHdr.PacketNumber, b.Data[:payloadOffset]) + b.Data = b.Data[0 : len(b.Data)+sealer.Overhead()] pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) sealer.EncryptHeader( - raw[pnOffset+4:pnOffset+4+16], - &raw[0], - raw[pnOffset:payloadOffset], + b.Data[pnOffset+4:pnOffset+4+16], + &b.Data[0], + b.Data[pnOffset:payloadOffset], ) replyHdr.Log(s.logger) wire.LogFrame(s.logger, ccf, true) if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf}) + s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(raw, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) return err } diff --git a/server_test.go b/server_test.go index 2d8fb3c4..76903637 100644 --- a/server_test.go +++ b/server_test.go @@ -42,17 +42,18 @@ var _ = Describe("Server", func() { ) getPacket := func(hdr *wire.Header, p []byte) *receivedPacket { - buffer := getPacketBuffer() - buf := bytes.NewBuffer(buffer.Data) + buf := getPacketBuffer() hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 - Expect((&wire.ExtendedHeader{ + var err error + buf.Data, err = (&wire.ExtendedHeader{ Header: *hdr, PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, protocol.VersionTLS)).To(Succeed()) - n := buf.Len() - buf.Write(p) - data := buffer.Data[:buf.Len()] + }).Append(buf.Data, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + n := len(buf.Data) + buf.Data = append(buf.Data, p...) + data := buf.Data sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version) _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) data = data[:len(data)+16] @@ -60,7 +61,7 @@ var _ = Describe("Server", func() { return &receivedPacket{ remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, data: data, - buffer: buffer, + buffer: buf, } }