diff --git a/connection_test.go b/connection_test.go index f7608def..e13fb049 100644 --- a/connection_test.go +++ b/connection_test.go @@ -561,11 +561,11 @@ var _ = Describe("Connection", func() { } Expect(hdr.Write(buf, conn.version)).To(Succeed()) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { - buf := &bytes.Buffer{} - Expect((&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Write(buf, conn.version)).To(Succeed()) + b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Write(nil, conn.version) + Expect(err).ToNot(HaveOccurred()) return &unpackedPacket{ hdr: hdr, - data: buf.Bytes(), + data: b, encryptionLevel: protocol.Encryption1RTT, }, nil }) @@ -754,15 +754,15 @@ var _ = Describe("Connection", func() { PacketNumberLen: protocol.PacketNumberLen1, } rcvTime := time.Now().Add(-10 * time.Second) - buf := &bytes.Buffer{} - Expect((&wire.PingFrame{}).Write(buf, conn.version)).To(Succeed()) + b, err := (&wire.PingFrame{}).Write(nil, conn.version) + Expect(err).ToNot(HaveOccurred()) packet := getPacket(hdr, nil) packet.ecn = protocol.ECT1 unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ packetNumber: 0x1337, encryptionLevel: protocol.Encryption1RTT, hdr: hdr, - data: buf.Bytes(), + data: b, }, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) gomock.InOrder( diff --git a/fuzzing/frames/cmd/corpus.go b/fuzzing/frames/cmd/corpus.go index 78ef754f..55aed6ed 100644 --- a/fuzzing/frames/cmd/corpus.go +++ b/fuzzing/frames/cmd/corpus.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "log" "math/rand" "time" @@ -253,11 +252,11 @@ func getFrames() []wire.Frame { func main() { for _, f := range getFrames() { - b := &bytes.Buffer{} - if err := f.Write(b, version); err != nil { + b, err := f.Write(nil, version) + if err != nil { log.Fatal(err) } - if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), 1); err != nil { + if err := helper.WriteCorpusFileWithPrefix("corpus", b, 1); err != nil { log.Fatal(err) } } @@ -265,20 +264,22 @@ func main() { for i := 0; i < 30; i++ { frames := getFrames() - b := &bytes.Buffer{} + var b []byte for j := 0; j < rand.Intn(30)+2; j++ { if rand.Intn(10) == 0 { // write a PADDING frame - b.WriteByte(0x0) + b = append(b, 0) } f := frames[rand.Intn(len(frames))] - if err := f.Write(b, version); err != nil { + var err error + b, err = f.Write(b, version) + if err != nil { log.Fatal(err) } if rand.Intn(10) == 0 { // write a PADDING frame - b.WriteByte(0x0) + b = append(b, 0) } } - if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), 1); err != nil { + if err := helper.WriteCorpusFileWithPrefix("corpus", b, 1); err != nil { log.Fatal(err) } } diff --git a/fuzzing/frames/fuzz.go b/fuzzing/frames/fuzz.go index cf1f0e53..bf5217f2 100644 --- a/fuzzing/frames/fuzz.go +++ b/fuzzing/frames/fuzz.go @@ -55,10 +55,10 @@ func Fuzz(data []byte) int { return 0 } - b := &bytes.Buffer{} + var b []byte for _, f := range frames { if f == nil { // PADDING frame - b.WriteByte(0x0) + b = append(b, 0) continue } // We accept empty STREAM frames, but we don't write them. @@ -68,11 +68,12 @@ func Fuzz(data []byte) int { continue } } - lenBefore := b.Len() - if err := f.Write(b, version); err != nil { + lenBefore := len(b) + b, err := f.Write(b, version) + if err != nil { panic(fmt.Sprintf("Error writing frame %#v: %s", f, err)) } - frameLen := b.Len() - lenBefore + frameLen := len(b) - lenBefore if f.Length(version) != protocol.ByteCount(frameLen) { panic(fmt.Sprintf("Inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version))) } @@ -80,8 +81,8 @@ func Fuzz(data []byte) int { sf.PutBack() } } - if b.Len() > parsedLen { - panic(fmt.Sprintf("Serialized length (%d) is longer than parsed length (%d)", b.Len(), parsedLen)) + if len(b) > parsedLen { + panic(fmt.Sprintf("Serialized length (%d) is longer than parsed length (%d)", len(b), parsedLen)) } return 1 } diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go index 87e2b975..84db6161 100644 --- a/internal/testutils/testutils.go +++ b/internal/testutils/testutils.go @@ -20,11 +20,15 @@ func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte { // packRawPayload returns a new raw payload containing given frames func packRawPayload(version protocol.VersionNumber, frames []wire.Frame) []byte { - buf := new(bytes.Buffer) + var b []byte for _, cf := range frames { - cf.Write(buf, version) + var err error + b, err = cf.Write(b, version) + if err != nil { + panic(err) + } } - return buf.Bytes() + return b } // ComposeInitialPacket returns an Initial packet encrypted under key diff --git a/internal/wire/ack_frame.go b/internal/wire/ack_frame.go index e5280a73..3c7afad0 100644 --- a/internal/wire/ack_frame.go +++ b/internal/wire/ack_frame.go @@ -107,40 +107,40 @@ func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNu } // Write writes an ACK frame. -func (f *AckFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { +func (f *AckFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 if hasECN { - b.WriteByte(0x3) + b = append(b, 0b11) } else { - b.WriteByte(0x2) + b = append(b, 0b10) } - quicvarint.Write(b, uint64(f.LargestAcked())) - quicvarint.Write(b, encodeAckDelay(f.DelayTime)) + b = quicvarint.Append(b, uint64(f.LargestAcked())) + b = quicvarint.Append(b, encodeAckDelay(f.DelayTime)) numRanges := f.numEncodableAckRanges() - quicvarint.Write(b, uint64(numRanges-1)) + b = quicvarint.Append(b, uint64(numRanges-1)) // write the first range _, firstRange := f.encodeAckRange(0) - quicvarint.Write(b, firstRange) + b = quicvarint.Append(b, firstRange) // write all the other range for i := 1; i < numRanges; i++ { gap, len := f.encodeAckRange(i) - quicvarint.Write(b, gap) - quicvarint.Write(b, len) + b = quicvarint.Append(b, gap) + b = quicvarint.Append(b, len) } if hasECN { - quicvarint.Write(b, f.ECT0) - quicvarint.Write(b, f.ECT1) - quicvarint.Write(b, f.ECNCE) + b = quicvarint.Append(b, f.ECT0) + b = quicvarint.Append(b, f.ECT1) + b = quicvarint.Append(b, f.ECNCE) } - return nil + return b, nil } // Length of a written frame -func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { largestAcked := f.AckRanges[0].Largest numRanges := f.numEncodableAckRanges() diff --git a/internal/wire/ack_frame_test.go b/internal/wire/ack_frame_test.go index aa00f92a..4bd2a9e4 100644 --- a/internal/wire/ack_frame_test.go +++ b/internal/wire/ack_frame_test.go @@ -116,15 +116,15 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { It("uses the ack delay exponent", func() { const delayTime = 1 << 10 * time.Millisecond - buf := &bytes.Buffer{} f := &AckFrame{ AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, DelayTime: delayTime, } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) for i := uint8(0); i < 8; i++ { - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent+i, protocol.Version1) + r := bytes.NewReader(b) + frame, err := parseAckFrame(r, protocol.AckDelayExponent+i, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.DelayTime).To(Equal(delayTime * (1 << i))) } @@ -202,29 +202,29 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { Context("when writing", func() { It("writes a simple frame", func() { - buf := &bytes.Buffer{} f := &AckFrame{ AckRanges: []AckRange{{Smallest: 100, Largest: 1337}}, } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x2} expected = append(expected, encodeVarInt(1337)...) // largest acked expected = append(expected, 0) // delay expected = append(expected, encodeVarInt(0)...) // num ranges expected = append(expected, encodeVarInt(1337-100)...) - Expect(buf.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes an ACK-ECN frame", func() { - buf := &bytes.Buffer{} f := &AckFrame{ AckRanges: []AckRange{{Smallest: 10, Largest: 2000}}, ECT0: 13, ECT1: 37, ECNCE: 12345, } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) expected := []byte{0x3} expected = append(expected, encodeVarInt(2000)...) // largest acked expected = append(expected, 0) // delay @@ -233,43 +233,42 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { expected = append(expected, encodeVarInt(13)...) expected = append(expected, encodeVarInt(37)...) expected = append(expected, encodeVarInt(12345)...) - Expect(buf.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a frame that acks a single packet", func() { - buf := &bytes.Buffer{} f := &AckFrame{ AckRanges: []AckRange{{Smallest: 0x2eadbeef, Largest: 0x2eadbeef}}, DelayTime: 18 * time.Millisecond, } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) + r := bytes.NewReader(b) + frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) Expect(frame.DelayTime).To(Equal(f.DelayTime)) - Expect(b.Len()).To(BeZero()) + Expect(r.Len()).To(BeZero()) }) It("writes a frame that acks many packets", func() { - buf := &bytes.Buffer{} f := &AckFrame{ AckRanges: []AckRange{{Smallest: 0x1337, Largest: 0x2eadbeef}}, } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) + r := bytes.NewReader(b) + frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) - Expect(b.Len()).To(BeZero()) + Expect(r.Len()).To(BeZero()) }) It("writes a frame with a a single gap", func() { - buf := &bytes.Buffer{} f := &AckFrame{ AckRanges: []AckRange{ {Smallest: 400, Largest: 1000}, @@ -277,19 +276,18 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { }, } Expect(f.validateAckRanges()).To(BeTrue()) - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) + r := bytes.NewReader(b) + frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(b.Len()).To(BeZero()) + Expect(r.Len()).To(BeZero()) }) It("writes a frame with multiple ranges", func() { - buf := &bytes.Buffer{} f := &AckFrame{ AckRanges: []AckRange{ {Smallest: 10, Largest: 10}, @@ -299,18 +297,18 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { }, } Expect(f.validateAckRanges()).To(BeTrue()) - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) + r := bytes.NewReader(b) + frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(b.Len()).To(BeZero()) + Expect(r.Len()).To(BeZero()) }) It("limits the maximum size of the ACK frame", func() { - buf := &bytes.Buffer{} const numRanges = 1000 ackRanges := make([]AckRange, numRanges) for i := protocol.PacketNumber(1); i <= numRanges; i++ { @@ -318,16 +316,17 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { } f := &AckFrame{AckRanges: ackRanges} Expect(f.validateAckRanges()).To(BeTrue()) - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) // make sure the ACK frame is *a little bit* smaller than the MaxAckFrameSize - Expect(buf.Len()).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) - Expect(buf.Len()).To(BeNumerically("<=", protocol.MaxAckFrameSize)) - b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) + Expect(len(b)).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) + Expect(len(b)).To(BeNumerically("<=", protocol.MaxAckFrameSize)) + r := bytes.NewReader(b) + frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.HasMissingRanges()).To(BeTrue()) - Expect(b.Len()).To(BeZero()) + Expect(r.Len()).To(BeZero()) Expect(len(frame.AckRanges)).To(BeNumerically("<", numRanges)) // make sure we dropped some ranges }) }) diff --git a/internal/wire/connection_close_frame.go b/internal/wire/connection_close_frame.go index 4ce49af6..391eca27 100644 --- a/internal/wire/connection_close_frame.go +++ b/internal/wire/connection_close_frame.go @@ -66,18 +66,18 @@ func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount return length } -func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +func (f *ConnectionCloseFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { if f.IsApplicationError { - b.WriteByte(0x1d) + b = append(b, 0x1d) } else { - b.WriteByte(0x1c) + b = append(b, 0x1c) } - quicvarint.Write(b, f.ErrorCode) + b = quicvarint.Append(b, f.ErrorCode) if !f.IsApplicationError { - quicvarint.Write(b, f.FrameType) + b = quicvarint.Append(b, f.FrameType) } - quicvarint.Write(b, uint64(len(f.ReasonPhrase))) - b.WriteString(f.ReasonPhrase) - return nil + b = quicvarint.Append(b, uint64(len(f.ReasonPhrase))) + b = append(b, []byte(f.ReasonPhrase)...) + return b, nil } diff --git a/internal/wire/connection_close_frame_test.go b/internal/wire/connection_close_frame_test.go index 9c5e6661..cca9831c 100644 --- a/internal/wire/connection_close_frame_test.go +++ b/internal/wire/connection_close_frame_test.go @@ -84,70 +84,69 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { Context("when writing", func() { It("writes a frame without a reason phrase", func() { - b := &bytes.Buffer{} frame := &ConnectionCloseFrame{ ErrorCode: 0xbeef, FrameType: 0x12345, } - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + b, err := frame.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x1c} expected = append(expected, encodeVarInt(0xbeef)...) expected = append(expected, encodeVarInt(0x12345)...) // frame type expected = append(expected, encodeVarInt(0)...) // reason phrase length - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a frame with a reason phrase", func() { - b := &bytes.Buffer{} frame := &ConnectionCloseFrame{ ErrorCode: 0xdead, ReasonPhrase: "foobar", } - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + b, err := frame.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x1c} expected = append(expected, encodeVarInt(0xdead)...) expected = append(expected, encodeVarInt(0)...) // frame type expected = append(expected, encodeVarInt(6)...) // reason phrase length expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a frame with an application error code", func() { - b := &bytes.Buffer{} frame := &ConnectionCloseFrame{ IsApplicationError: true, ErrorCode: 0xdead, ReasonPhrase: "foobar", } - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + b, err := frame.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x1d} expected = append(expected, encodeVarInt(0xdead)...) expected = append(expected, encodeVarInt(6)...) // reason phrase length expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("has proper min length, for a frame containing a QUIC error code", func() { - b := &bytes.Buffer{} f := &ConnectionCloseFrame{ ErrorCode: 0xcafe, FrameType: 0xdeadbeef, ReasonPhrase: "foobar", } - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) }) It("has proper min length, for a frame containing an application error code", func() { - b := &bytes.Buffer{} f := &ConnectionCloseFrame{ IsApplicationError: true, ErrorCode: 0xcafe, ReasonPhrase: "foobar", } - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) + Expect(b).To(HaveLen(int(f.Length(protocol.Version1)))) }) }) }) diff --git a/internal/wire/crypto_frame.go b/internal/wire/crypto_frame.go index 6301c878..7a886309 100644 --- a/internal/wire/crypto_frame.go +++ b/internal/wire/crypto_frame.go @@ -42,12 +42,12 @@ func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, return frame, nil } -func (f *CryptoFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x6) - quicvarint.Write(b, uint64(f.Offset)) - quicvarint.Write(b, uint64(len(f.Data))) - b.Write(f.Data) - return nil +func (f *CryptoFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x6) + b = quicvarint.Append(b, uint64(f.Offset)) + b = quicvarint.Append(b, uint64(len(f.Data))) + b = append(b, f.Data...) + return b, nil } // Length of a written frame diff --git a/internal/wire/crypto_frame_test.go b/internal/wire/crypto_frame_test.go index c3798101..0369e676 100644 --- a/internal/wire/crypto_frame_test.go +++ b/internal/wire/crypto_frame_test.go @@ -45,14 +45,13 @@ var _ = Describe("CRYPTO frame", func() { Offset: 0x123456, Data: []byte("foobar"), } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x6} expected = append(expected, encodeVarInt(0x123456)...) // offset expected = append(expected, encodeVarInt(6)...) // length expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) }) @@ -64,31 +63,29 @@ var _ = Describe("CRYPTO frame", func() { f := &CryptoFrame{ Offset: 0xdeadbeef, } - b := &bytes.Buffer{} var frameOneByteTooSmallCounter int for i := 1; i < maxSize; i++ { - b.Reset() f.Data = nil maxDataLen := f.MaxDataLen(protocol.ByteCount(i)) if maxDataLen == 0 { // 0 means that no valid CRYTPO frame can be written // check that writing a minimal size CRYPTO frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeNumerically(">", i)) + Expect(len(b)).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes // In that case, it's impossible to create a STREAM frame of the desired size - if b.Len() == i-1 { + if len(b) == i-1 { frameOneByteTooSmallCounter++ continue } - Expect(b.Len()).To(Equal(i)) + Expect(len(b)).To(Equal(i)) } Expect(frameOneByteTooSmallCounter).To(Equal(1)) }) diff --git a/internal/wire/data_blocked_frame.go b/internal/wire/data_blocked_frame.go index 459f04d1..59b94214 100644 --- a/internal/wire/data_blocked_frame.go +++ b/internal/wire/data_blocked_frame.go @@ -25,11 +25,10 @@ func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBloc }, nil } -func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - typeByte := uint8(0x14) - b.WriteByte(typeByte) - quicvarint.Write(b, uint64(f.MaximumData)) - return nil +func (f *DataBlockedFrame) Write(b []byte, version protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x14) + b = quicvarint.Append(b, uint64(f.MaximumData)) + return b, nil } // Length of a written frame diff --git a/internal/wire/data_blocked_frame_test.go b/internal/wire/data_blocked_frame_test.go index 57ffd9a8..531c70d9 100644 --- a/internal/wire/data_blocked_frame_test.go +++ b/internal/wire/data_blocked_frame_test.go @@ -37,13 +37,12 @@ var _ = Describe("DATA_BLOCKED frame", func() { Context("when writing", func() { It("writes a sample frame", func() { - b := &bytes.Buffer{} frame := DataBlockedFrame{MaximumData: 0xdeadbeef} - err := frame.Write(b, protocol.VersionWhatever) + b, err := frame.Write(nil, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x14} expected = append(expected, encodeVarInt(0xdeadbeef)...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("has the correct min length", func() { diff --git a/internal/wire/datagram_frame.go b/internal/wire/datagram_frame.go index 9d6e55cb..2e3c7b0e 100644 --- a/internal/wire/datagram_frame.go +++ b/internal/wire/datagram_frame.go @@ -44,17 +44,17 @@ func parseDatagramFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DatagramFra return f, nil } -func (f *DatagramFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { +func (f *DatagramFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { typeByte := uint8(0x30) if f.DataLenPresent { - typeByte ^= 0x1 + typeByte ^= 0b1 } - b.WriteByte(typeByte) + b = append(b, typeByte) if f.DataLenPresent { - quicvarint.Write(b, uint64(len(f.Data))) + b = quicvarint.Append(b, uint64(len(f.Data))) } - b.Write(f.Data) - return nil + b = append(b, f.Data...) + return b, nil } // MaxDataLen returns the maximum data length diff --git a/internal/wire/datagram_frame_test.go b/internal/wire/datagram_frame_test.go index 4431eb8b..32790a94 100644 --- a/internal/wire/datagram_frame_test.go +++ b/internal/wire/datagram_frame_test.go @@ -64,21 +64,21 @@ var _ = Describe("STREAM frame", func() { DataLenPresent: true, Data: []byte("foobar"), } - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x30 ^ 0x1} expected = append(expected, encodeVarInt(0x6)...) expected = append(expected, []byte("foobar")...) - Expect(buf.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a frame without length", func() { f := &DatagramFrame{Data: []byte("Lorem ipsum")} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x30} expected = append(expected, []byte("Lorem ipsum")...) - Expect(buf.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) }) @@ -111,42 +111,44 @@ var _ = Describe("STREAM frame", func() { if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(b.Len()).To(BeNumerically(">", i)) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(len(b)).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(b.Len()).To(Equal(i)) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(i)) } }) It("always returns a data length such that the resulting frame has the right size, if data length is present", func() { data := make([]byte, maxSize) f := &DatagramFrame{DataLenPresent: true} - b := &bytes.Buffer{} var frameOneByteTooSmallCounter int for i := 1; i < 3000; i++ { - b.Reset() f.Data = nil maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - Expect(f.Write(b, protocol.Version1)).To(Succeed()) - Expect(b.Len()).To(BeNumerically(">", i)) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(len(b)).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - Expect(f.Write(b, protocol.Version1)).To(Succeed()) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes // In that case, it's impossible to create a STREAM frame of the desired size - if b.Len() == i-1 { + if len(b) == i-1 { frameOneByteTooSmallCounter++ continue } - Expect(b.Len()).To(Equal(i)) + Expect(b).To(HaveLen(i)) } Expect(frameOneByteTooSmallCounter).To(Equal(1)) }) diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 8d52ebc7..1f8fd605 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -11,13 +11,9 @@ import ( ) var _ = Describe("Frame parsing", func() { - var ( - buf *bytes.Buffer - parser FrameParser - ) + var parser FrameParser BeforeEach(func() { - buf = &bytes.Buffer{} parser = NewFrameParser(true, protocol.Version1) }) @@ -28,9 +24,10 @@ var _ = Describe("Frame parsing", func() { }) It("skips PADDING frames", func() { - buf.Write([]byte{0}) // PADDING frame - (&PingFrame{}).Write(buf, protocol.Version1) - f, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + b := []byte{0} // PADDING frame + b, err := (&PingFrame{}).Write(b, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + f, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(f).To(Equal(&PingFrame{})) }) @@ -45,9 +42,9 @@ var _ = Describe("Frame parsing", func() { It("unpacks ACK frames", func() { f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -60,8 +57,9 @@ var _ = Describe("Frame parsing", func() { AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, DelayTime: time.Second, } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) // The ACK frame is always written using the protocol.AckDelayExponent. // That's why we expect a different value when parsing. @@ -74,8 +72,9 @@ var _ = Describe("Frame parsing", func() { AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, DelayTime: time.Second, } - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.EncryptionHandshake) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.EncryptionHandshake) Expect(err).ToNot(HaveOccurred()) Expect(frame.(*AckFrame).DelayTime).To(Equal(time.Second)) }) @@ -86,19 +85,18 @@ var _ = Describe("Frame parsing", func() { FinalSize: 0xdecafbad1234, ErrorCode: 0x1337, } - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("unpacks STOP_SENDING frames", func() { f := &StopSendingFrame{StreamID: 0x42} - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -108,9 +106,9 @@ var _ = Describe("Frame parsing", func() { Offset: 0x1337, Data: []byte("lorem ipsum"), } - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -118,9 +116,9 @@ var _ = Describe("Frame parsing", func() { It("unpacks NEW_TOKEN frames", func() { f := &NewTokenFrame{Token: []byte("foobar")} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -133,9 +131,9 @@ var _ = Describe("Frame parsing", func() { Fin: true, Data: []byte("foobar"), } - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(Equal(f)) @@ -145,10 +143,9 @@ var _ = Describe("Frame parsing", func() { f := &MaxDataFrame{ MaximumData: 0xcafe, } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -158,10 +155,9 @@ var _ = Describe("Frame parsing", func() { StreamID: 0xdeadbeef, MaximumStreamData: 0xdecafbad, } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -171,20 +167,18 @@ var _ = Describe("Frame parsing", func() { Type: protocol.StreamTypeBidi, MaxStreamNum: 0x1337, } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("unpacks DATA_BLOCKED frames", func() { f := &DataBlockedFrame{MaximumData: 0x1234} - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -194,9 +188,9 @@ var _ = Describe("Frame parsing", func() { StreamID: 0xdeadbeef, MaximumStreamData: 0xdead, } - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -206,10 +200,9 @@ var _ = Describe("Frame parsing", func() { Type: protocol.StreamTypeBidi, StreamLimit: 0x1234567, } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -220,27 +213,27 @@ var _ = Describe("Frame parsing", func() { ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, } - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("unpacks RETIRE_CONNECTION_ID frames", func() { f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("unpacks PATH_CHALLENGE frames", func() { f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -249,9 +242,9 @@ var _ = Describe("Frame parsing", func() { It("unpacks PATH_RESPONSE frames", func() { f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).ToNot(BeNil()) Expect(frame).To(BeAssignableToTypeOf(f)) @@ -263,28 +256,27 @@ var _ = Describe("Frame parsing", func() { IsApplicationError: true, ReasonPhrase: "foobar", } - buf := &bytes.Buffer{} - err := f.Write(buf, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("unpacks HANDSHAKE_DONE frames", func() { f := &HandshakeDoneFrame{} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("unpacks DATAGRAM frames", func() { f := &DatagramFrame{Data: []byte("foobar")} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + frame, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -292,9 +284,9 @@ var _ = Describe("Frame parsing", func() { It("errors when DATAGRAM frames are not supported", func() { parser = NewFrameParser(false, protocol.Version1) f := &DatagramFrame{Data: []byte("foobar")} - buf := &bytes.Buffer{} - Expect(f.Write(buf, protocol.Version1)).To(Succeed()) - _, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + _, err = parser.ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) Expect(err).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, FrameType: 0x30, @@ -316,9 +308,9 @@ var _ = Describe("Frame parsing", func() { StreamID: 0x1337, MaximumStreamData: 0xdeadbeef, } - b := &bytes.Buffer{} - f.Write(b, protocol.Version1) - _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2]), protocol.Encryption1RTT) + b, err := f.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + _, err = parser.ParseNext(bytes.NewReader(b[:len(b)-2]), protocol.Encryption1RTT) Expect(err).To(HaveOccurred()) Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) }) @@ -352,9 +344,9 @@ var _ = Describe("Frame parsing", func() { BeforeEach(func() { framesSerialized = nil for _, frame := range frames { - buf := &bytes.Buffer{} - Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) - framesSerialized = append(framesSerialized, buf.Bytes()) + b, err := frame.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + framesSerialized = append(framesSerialized, b) } }) diff --git a/internal/wire/handshake_done_frame.go b/internal/wire/handshake_done_frame.go index 158d659f..91826201 100644 --- a/internal/wire/handshake_done_frame.go +++ b/internal/wire/handshake_done_frame.go @@ -17,9 +17,8 @@ func parseHandshakeDoneFrame(r *bytes.Reader, _ protocol.VersionNumber) (*Handsh return &HandshakeDoneFrame{}, nil } -func (f *HandshakeDoneFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x1e) - return nil +func (f *HandshakeDoneFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + return append(b, 0x1e), nil } // Length of a written frame diff --git a/internal/wire/handshake_done_frame_test.go b/internal/wire/handshake_done_frame_test.go new file mode 100644 index 00000000..c66c63dd --- /dev/null +++ b/internal/wire/handshake_done_frame_test.go @@ -0,0 +1,39 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("HANDSHAKE_DONE frame", func() { + Context("when parsing", func() { + It("accepts sample frame", func() { + b := bytes.NewReader([]byte{0x1e}) + _, err := parseHandshakeDoneFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + _, err := parseHandshakeDoneFrame(bytes.NewReader(nil), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when writing", func() { + It("writes a sample frame", func() { + frame := HandshakeDoneFrame{} + b, err := frame.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte{0x1e})) + }) + + It("has the correct min length", func() { + frame := HandshakeDoneFrame{} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(1))) + }) + }) +}) diff --git a/internal/wire/interface.go b/internal/wire/interface.go index 99fdc80f..7175cde2 100644 --- a/internal/wire/interface.go +++ b/internal/wire/interface.go @@ -8,7 +8,7 @@ import ( // A Frame in QUIC type Frame interface { - Write(b *bytes.Buffer, version protocol.VersionNumber) error + Write(b []byte, version protocol.VersionNumber) ([]byte, error) Length(version protocol.VersionNumber) protocol.ByteCount } diff --git a/internal/wire/max_data_frame.go b/internal/wire/max_data_frame.go index a9a09248..02a7aaba 100644 --- a/internal/wire/max_data_frame.go +++ b/internal/wire/max_data_frame.go @@ -28,13 +28,13 @@ func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame } // Write writes a MAX_STREAM_DATA frame -func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x10) - quicvarint.Write(b, uint64(f.MaximumData)) - return nil +func (f *MaxDataFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x10) + b = quicvarint.Append(b, uint64(f.MaximumData)) + return b, nil } // Length of a written frame -func (f *MaxDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *MaxDataFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.MaximumData)) } diff --git a/internal/wire/max_data_frame_test.go b/internal/wire/max_data_frame_test.go index a5ee0222..afeff439 100644 --- a/internal/wire/max_data_frame_test.go +++ b/internal/wire/max_data_frame_test.go @@ -35,7 +35,7 @@ var _ = Describe("MAX_DATA frame", func() { }) Context("writing", func() { - It("has proper min length", func() { + It("has proper length", func() { f := &MaxDataFrame{ MaximumData: 0xdeadbeef, } @@ -43,15 +43,14 @@ var _ = Describe("MAX_DATA frame", func() { }) It("writes a MAX_DATA frame", func() { - b := &bytes.Buffer{} f := &MaxDataFrame{ MaximumData: 0xdeadbeefcafe, } - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x10} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) }) }) diff --git a/internal/wire/max_stream_data_frame.go b/internal/wire/max_stream_data_frame.go index 728ecbe8..a6cfa0df 100644 --- a/internal/wire/max_stream_data_frame.go +++ b/internal/wire/max_stream_data_frame.go @@ -33,11 +33,11 @@ func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStr }, nil } -func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x11) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.MaximumStreamData)) - return nil +func (f *MaxStreamDataFrame) Write(b []byte, version protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x11) + b = quicvarint.Append(b, uint64(f.StreamID)) + b = quicvarint.Append(b, uint64(f.MaximumStreamData)) + return b, nil } // Length of a written frame diff --git a/internal/wire/max_stream_data_frame_test.go b/internal/wire/max_stream_data_frame_test.go index f12aac87..43ecd297 100644 --- a/internal/wire/max_stream_data_frame_test.go +++ b/internal/wire/max_stream_data_frame_test.go @@ -38,7 +38,7 @@ var _ = Describe("MAX_STREAM_DATA frame", func() { }) Context("writing", func() { - It("has proper min length", func() { + It("has proper length", func() { f := &MaxStreamDataFrame{ StreamID: 0x1337, MaximumStreamData: 0xdeadbeef, @@ -47,7 +47,6 @@ var _ = Describe("MAX_STREAM_DATA frame", func() { }) It("writes a sample frame", func() { - b := &bytes.Buffer{} f := &MaxStreamDataFrame{ StreamID: 0xdecafbad, MaximumStreamData: 0xdeadbeefcafe42, @@ -55,9 +54,9 @@ var _ = Describe("MAX_STREAM_DATA frame", func() { expected := []byte{0x11} expected = append(expected, encodeVarInt(0xdecafbad)...) expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...) - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) }) }) diff --git a/internal/wire/max_streams_frame.go b/internal/wire/max_streams_frame.go index 73d7e13e..841a2ee6 100644 --- a/internal/wire/max_streams_frame.go +++ b/internal/wire/max_streams_frame.go @@ -38,15 +38,15 @@ func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStream return f, nil } -func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { +func (f *MaxStreamsFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: - b.WriteByte(0x12) + b = append(b, 0x12) case protocol.StreamTypeUni: - b.WriteByte(0x13) + b = append(b, 0x13) } - quicvarint.Write(b, uint64(f.MaxStreamNum)) - return nil + b = quicvarint.Append(b, uint64(f.MaxStreamNum)) + return b, nil } // Length of a written frame diff --git a/internal/wire/max_streams_frame_test.go b/internal/wire/max_streams_frame_test.go index d05e9448..7e344506 100644 --- a/internal/wire/max_streams_frame_test.go +++ b/internal/wire/max_streams_frame_test.go @@ -54,9 +54,9 @@ var _ = Describe("MAX_STREAMS frame", func() { Type: streamType, MaxStreamNum: protocol.MaxStreamCount, } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - frame, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + b, err := f.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + frame, err := parseMaxStreamsFrame(bytes.NewReader(b), protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -66,9 +66,9 @@ var _ = Describe("MAX_STREAMS frame", func() { Type: streamType, MaxStreamNum: protocol.MaxStreamCount + 1, } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - _, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + b, err := f.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + _, err = parseMaxStreamsFrame(bytes.NewReader(b), protocol.VersionWhatever) Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) }) } @@ -80,11 +80,11 @@ var _ = Describe("MAX_STREAMS frame", func() { Type: protocol.StreamTypeBidi, MaxStreamNum: 0xdeadbeef, } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + b, err := f.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x12} expected = append(expected, encodeVarInt(0xdeadbeef)...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("for a unidirectional stream", func() { @@ -92,14 +92,14 @@ var _ = Describe("MAX_STREAMS frame", func() { Type: protocol.StreamTypeUni, MaxStreamNum: 0xdecafbad, } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + b, err := f.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x13} expected = append(expected, encodeVarInt(0xdecafbad)...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) - It("has the correct min length", func() { + It("has the correct length", func() { frame := MaxStreamsFrame{MaxStreamNum: 0x1337} Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + quicvarint.Len(0x1337))) }) diff --git a/internal/wire/new_connection_id_frame.go b/internal/wire/new_connection_id_frame.go index 1a017ba9..e3a1eecd 100644 --- a/internal/wire/new_connection_id_frame.go +++ b/internal/wire/new_connection_id_frame.go @@ -60,18 +60,18 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC return frame, nil } -func (f *NewConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x18) - quicvarint.Write(b, f.SequenceNumber) - quicvarint.Write(b, f.RetirePriorTo) +func (f *NewConnectionIDFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x18) + b = quicvarint.Append(b, f.SequenceNumber) + b = quicvarint.Append(b, f.RetirePriorTo) connIDLen := f.ConnectionID.Len() if connIDLen > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d", connIDLen) + return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) } - b.WriteByte(uint8(connIDLen)) - b.Write(f.ConnectionID.Bytes()) - b.Write(f.StatelessResetToken[:]) - return nil + b = append(b, uint8(connIDLen)) + b = append(b, f.ConnectionID.Bytes()...) + b = append(b, f.StatelessResetToken[:]...) + return b, nil } // Length of a written frame diff --git a/internal/wire/new_connection_id_frame_test.go b/internal/wire/new_connection_id_frame_test.go index 75fe85c5..ba7334e2 100644 --- a/internal/wire/new_connection_id_frame_test.go +++ b/internal/wire/new_connection_id_frame_test.go @@ -77,15 +77,15 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, StatelessResetToken: token, } - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + b, err := frame.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x18} expected = append(expected, encodeVarInt(0x1337)...) expected = append(expected, encodeVarInt(0x42)...) expected = append(expected, 6) expected = append(expected, []byte{1, 2, 3, 4, 5, 6}...) expected = append(expected, token[:]...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("has the correct length", func() { @@ -96,9 +96,9 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, StatelessResetToken: token, } - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) + b, err := frame.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(int(frame.Length(protocol.Version1)))) }) }) }) diff --git a/internal/wire/new_token_frame.go b/internal/wire/new_token_frame.go index 3d5d5c3a..90df7906 100644 --- a/internal/wire/new_token_frame.go +++ b/internal/wire/new_token_frame.go @@ -35,11 +35,11 @@ func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFra return &NewTokenFrame{Token: token}, nil } -func (f *NewTokenFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x7) - quicvarint.Write(b, uint64(len(f.Token))) - b.Write(f.Token) - return nil +func (f *NewTokenFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x7) + b = quicvarint.Append(b, uint64(len(f.Token))) + b = append(b, f.Token...) + return b, nil } // Length of a written frame diff --git a/internal/wire/new_token_frame_test.go b/internal/wire/new_token_frame_test.go index 47eb6242..c2c88c1c 100644 --- a/internal/wire/new_token_frame_test.go +++ b/internal/wire/new_token_frame_test.go @@ -50,12 +50,12 @@ var _ = Describe("NEW_TOKEN frame", func() { It("writes a sample frame", func() { token := "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat." f := &NewTokenFrame{Token: []byte(token)} - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + b, err := f.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x7} expected = append(expected, encodeVarInt(uint64(len(token)))...) expected = append(expected, token...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("has the correct min length", func() { diff --git a/internal/wire/path_challenge_frame.go b/internal/wire/path_challenge_frame.go index 5ec82177..1260abe7 100644 --- a/internal/wire/path_challenge_frame.go +++ b/internal/wire/path_challenge_frame.go @@ -26,10 +26,10 @@ func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathCh return frame, nil } -func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x1a) - b.Write(f.Data[:]) - return nil +func (f *PathChallengeFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x1a) + b = append(b, f.Data[:]...) + return b, nil } // Length of a written frame diff --git a/internal/wire/path_challenge_frame_test.go b/internal/wire/path_challenge_frame_test.go index e7d8a970..4f0ba5dc 100644 --- a/internal/wire/path_challenge_frame_test.go +++ b/internal/wire/path_challenge_frame_test.go @@ -33,14 +33,13 @@ var _ = Describe("PATH_CHALLENGE frame", func() { Context("when writing", func() { It("writes a sample frame", func() { - b := &bytes.Buffer{} frame := PathChallengeFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} - err := frame.Write(b, protocol.VersionWhatever) + b, err := frame.Write(nil, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x1a, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) + Expect(b).To(Equal([]byte{0x1a, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) }) - It("has the correct min length", func() { + It("has the correct length", func() { frame := PathChallengeFrame{} Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) }) diff --git a/internal/wire/path_response_frame.go b/internal/wire/path_response_frame.go index 262819f8..fb32d3c2 100644 --- a/internal/wire/path_response_frame.go +++ b/internal/wire/path_response_frame.go @@ -26,10 +26,10 @@ func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathRes return frame, nil } -func (f *PathResponseFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x1b) - b.Write(f.Data[:]) - return nil +func (f *PathResponseFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x1b) + b = append(b, f.Data[:]...) + return b, nil } // Length of a written frame diff --git a/internal/wire/path_response_frame_test.go b/internal/wire/path_response_frame_test.go index 11e9d67a..17ef3b4d 100644 --- a/internal/wire/path_response_frame_test.go +++ b/internal/wire/path_response_frame_test.go @@ -32,14 +32,13 @@ var _ = Describe("PATH_RESPONSE frame", func() { Context("when writing", func() { It("writes a sample frame", func() { - b := &bytes.Buffer{} frame := PathResponseFrame{Data: [8]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}} - err := frame.Write(b, protocol.VersionWhatever) + b, err := frame.Write(nil, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x1b, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) + Expect(b).To(Equal([]byte{0x1b, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) }) - It("has the correct min length", func() { + It("has the correct length", func() { frame := PathResponseFrame{} Expect(frame.Length(protocol.VersionWhatever)).To(Equal(protocol.ByteCount(9))) }) diff --git a/internal/wire/ping_frame.go b/internal/wire/ping_frame.go index dc029e45..2de051f6 100644 --- a/internal/wire/ping_frame.go +++ b/internal/wire/ping_frame.go @@ -16,12 +16,11 @@ func parsePingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PingFrame, erro return &PingFrame{}, nil } -func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x1) - return nil +func (f *PingFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + return append(b, 0x1), nil } // Length of a written frame -func (f *PingFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *PingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { return 1 } diff --git a/internal/wire/ping_frame_test.go b/internal/wire/ping_frame_test.go index d43efe34..5ea955c6 100644 --- a/internal/wire/ping_frame_test.go +++ b/internal/wire/ping_frame_test.go @@ -8,7 +8,7 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("PingFrame", func() { +var _ = Describe("PING frame", func() { Context("when parsing", func() { It("accepts sample frame", func() { b := bytes.NewReader([]byte{0x1}) @@ -25,13 +25,13 @@ var _ = Describe("PingFrame", func() { Context("when writing", func() { It("writes a sample frame", func() { - b := &bytes.Buffer{} frame := PingFrame{} - frame.Write(b, protocol.VersionWhatever) - Expect(b.Bytes()).To(Equal([]byte{0x1})) + b, err := frame.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte{0x1})) }) - It("has the correct min length", func() { + It("has the correct length", func() { frame := PingFrame{} Expect(frame.Length(0)).To(Equal(protocol.ByteCount(1))) }) diff --git a/internal/wire/reset_stream_frame.go b/internal/wire/reset_stream_frame.go index 69bbc2b9..57c62062 100644 --- a/internal/wire/reset_stream_frame.go +++ b/internal/wire/reset_stream_frame.go @@ -44,12 +44,12 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr }, nil } -func (f *ResetStreamFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x4) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.ErrorCode)) - quicvarint.Write(b, uint64(f.FinalSize)) - return nil +func (f *ResetStreamFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x4) + b = quicvarint.Append(b, uint64(f.StreamID)) + b = quicvarint.Append(b, uint64(f.ErrorCode)) + b = quicvarint.Append(b, uint64(f.FinalSize)) + return b, nil } // Length of a written frame diff --git a/internal/wire/reset_stream_frame_test.go b/internal/wire/reset_stream_frame_test.go index e241a8e6..1f5ea762 100644 --- a/internal/wire/reset_stream_frame_test.go +++ b/internal/wire/reset_stream_frame_test.go @@ -47,17 +47,16 @@ var _ = Describe("RESET_STREAM frame", func() { FinalSize: 0x11223344decafbad, ErrorCode: 0xcafe, } - b := &bytes.Buffer{} - err := frame.Write(b, protocol.Version1) + b, err := frame.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x4} expected = append(expected, encodeVarInt(0x1337)...) expected = append(expected, encodeVarInt(0xcafe)...) expected = append(expected, encodeVarInt(0x11223344decafbad)...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) - It("has the correct min length", func() { + It("has the correct length", func() { rst := ResetStreamFrame{ StreamID: 0x1337, FinalSize: 0x1234567, diff --git a/internal/wire/retire_connection_id_frame.go b/internal/wire/retire_connection_id_frame.go index 0f7e58c8..55b3d49d 100644 --- a/internal/wire/retire_connection_id_frame.go +++ b/internal/wire/retire_connection_id_frame.go @@ -24,10 +24,10 @@ func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*R return &RetireConnectionIDFrame{SequenceNumber: seq}, nil } -func (f *RetireConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x19) - quicvarint.Write(b, f.SequenceNumber) - return nil +func (f *RetireConnectionIDFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x19) + b = quicvarint.Append(b, f.SequenceNumber) + return b, nil } // Length of a written frame diff --git a/internal/wire/retire_connection_id_frame_test.go b/internal/wire/retire_connection_id_frame_test.go index 0338b6cc..9700a031 100644 --- a/internal/wire/retire_connection_id_frame_test.go +++ b/internal/wire/retire_connection_id_frame_test.go @@ -36,18 +36,18 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { Context("when writing", func() { It("writes a sample frame", func() { frame := &RetireConnectionIDFrame{SequenceNumber: 0x1337} - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + b, err := frame.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x19} expected = append(expected, encodeVarInt(0x1337)...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("has the correct length", func() { frame := &RetireConnectionIDFrame{SequenceNumber: 0xdecafbad} - b := &bytes.Buffer{} - Expect(frame.Write(b, protocol.Version1)).To(Succeed()) - Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) + b, err := frame.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(HaveLen(int(frame.Length(protocol.Version1)))) }) }) }) diff --git a/internal/wire/stop_sending_frame.go b/internal/wire/stop_sending_frame.go index fb1160c1..3e52519f 100644 --- a/internal/wire/stop_sending_frame.go +++ b/internal/wire/stop_sending_frame.go @@ -40,9 +40,9 @@ func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) } -func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x5) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.ErrorCode)) - return nil +func (f *StopSendingFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x5) + b = quicvarint.Append(b, uint64(f.StreamID)) + b = quicvarint.Append(b, uint64(f.ErrorCode)) + return b, nil } diff --git a/internal/wire/stop_sending_frame_test.go b/internal/wire/stop_sending_frame_test.go index 7b6793c2..59f06c00 100644 --- a/internal/wire/stop_sending_frame_test.go +++ b/internal/wire/stop_sending_frame_test.go @@ -44,12 +44,12 @@ var _ = Describe("STOP_SENDING frame", func() { StreamID: 0xdeadbeefcafe, ErrorCode: 0xdecafbad, } - buf := &bytes.Buffer{} - Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) + b, err := frame.Write(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x5} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) expected = append(expected, encodeVarInt(0xdecafbad)...) - Expect(buf.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("has the correct min length", func() { diff --git a/internal/wire/stream_data_blocked_frame.go b/internal/wire/stream_data_blocked_frame.go index dc6d631a..a75a5c2a 100644 --- a/internal/wire/stream_data_blocked_frame.go +++ b/internal/wire/stream_data_blocked_frame.go @@ -33,11 +33,11 @@ func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*St }, nil } -func (f *StreamDataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x15) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.MaximumStreamData)) - return nil +func (f *StreamDataBlockedFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x15) + b = quicvarint.Append(b, uint64(f.StreamID)) + b = quicvarint.Append(b, uint64(f.MaximumStreamData)) + return b, nil } // Length of a written frame diff --git a/internal/wire/stream_data_blocked_frame_test.go b/internal/wire/stream_data_blocked_frame_test.go index b5ec7cbd..7d54d52e 100644 --- a/internal/wire/stream_data_blocked_frame_test.go +++ b/internal/wire/stream_data_blocked_frame_test.go @@ -47,17 +47,16 @@ var _ = Describe("STREAM_DATA_BLOCKED frame", func() { }) It("writes a sample frame", func() { - b := &bytes.Buffer{} f := &StreamDataBlockedFrame{ StreamID: 0xdecafbad, MaximumStreamData: 0x1337, } - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x15} expected = append(expected, encodeVarInt(uint64(f.StreamID))...) expected = append(expected, encodeVarInt(uint64(f.MaximumStreamData))...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) }) }) diff --git a/internal/wire/stream_frame.go b/internal/wire/stream_frame.go index 66340d16..b1b457f2 100644 --- a/internal/wire/stream_frame.go +++ b/internal/wire/stream_frame.go @@ -26,9 +26,9 @@ func parseStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, return nil, err } - hasOffset := typeByte&0x4 > 0 - fin := typeByte&0x1 > 0 - hasDataLen := typeByte&0x2 > 0 + hasOffset := typeByte&0b100 > 0 + fin := typeByte&0b1 > 0 + hasDataLen := typeByte&0b10 > 0 streamID, err := quicvarint.Read(r) if err != nil { @@ -84,32 +84,32 @@ func parseStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, } // Write writes a STREAM frame -func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +func (f *StreamFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { if len(f.Data) == 0 && !f.Fin { - return errors.New("StreamFrame: attempting to write empty frame without FIN") + return nil, errors.New("StreamFrame: attempting to write empty frame without FIN") } typeByte := byte(0x8) if f.Fin { - typeByte ^= 0x1 + typeByte ^= 0b1 } hasOffset := f.Offset != 0 if f.DataLenPresent { - typeByte ^= 0x2 + typeByte ^= 0b10 } if hasOffset { - typeByte ^= 0x4 + typeByte ^= 0b100 } - b.WriteByte(typeByte) - quicvarint.Write(b, uint64(f.StreamID)) + b = append(b, typeByte) + b = quicvarint.Append(b, uint64(f.StreamID)) if hasOffset { - quicvarint.Write(b, uint64(f.Offset)) + b = quicvarint.Append(b, uint64(f.Offset)) } if f.DataLenPresent { - quicvarint.Write(b, uint64(f.DataLen())) + b = quicvarint.Append(b, uint64(f.DataLen())) } - b.Write(f.Data) - return nil + b = append(b, f.Data...) + return b, nil } // Length returns the total length of the STREAM frame diff --git a/internal/wire/stream_frame_test.go b/internal/wire/stream_frame_test.go index 92f863a5..3bb3c717 100644 --- a/internal/wire/stream_frame_test.go +++ b/internal/wire/stream_frame_test.go @@ -145,13 +145,12 @@ var _ = Describe("STREAM frame", func() { StreamID: 0x1337, Data: []byte("foobar"), } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a frame with offset", func() { @@ -160,14 +159,13 @@ var _ = Describe("STREAM frame", func() { Offset: 0x123456, Data: []byte("foobar"), } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8 ^ 0x4} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, encodeVarInt(0x123456)...) // offset expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a frame with FIN bit", func() { @@ -176,13 +174,12 @@ var _ = Describe("STREAM frame", func() { Offset: 0x123456, Fin: true, } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8 ^ 0x4 ^ 0x1} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, encodeVarInt(0x123456)...) // offset - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a frame with data length", func() { @@ -191,14 +188,13 @@ var _ = Describe("STREAM frame", func() { Data: []byte("foobar"), DataLenPresent: true, } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8 ^ 0x2} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, encodeVarInt(6)...) // data length expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a frame with data length and offset", func() { @@ -208,15 +204,14 @@ var _ = Describe("STREAM frame", func() { DataLenPresent: true, Offset: 0x123456, } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8 ^ 0x4 ^ 0x2} expected = append(expected, encodeVarInt(0x1337)...) // stream ID expected = append(expected, encodeVarInt(0x123456)...) // offset expected = append(expected, encodeVarInt(6)...) // data length expected = append(expected, []byte("foobar")...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("refuses to write an empty frame without FIN", func() { @@ -224,8 +219,7 @@ var _ = Describe("STREAM frame", func() { StreamID: 0x42, Offset: 0x1337, } - b := &bytes.Buffer{} - err := f.Write(b, protocol.Version1) + _, err := f.Write(nil, protocol.Version1) Expect(err).To(MatchError("StreamFrame: attempting to write empty frame without FIN")) }) }) @@ -268,23 +262,21 @@ var _ = Describe("STREAM frame", func() { StreamID: 0x1337, Offset: 0xdeadbeef, } - b := &bytes.Buffer{} for i := 1; i < 3000; i++ { - b.Reset() f.Data = nil maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeNumerically(">", i)) + Expect(len(b)).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(Equal(i)) + Expect(len(b)).To(Equal(i)) } }) @@ -295,31 +287,29 @@ var _ = Describe("STREAM frame", func() { Offset: 0xdeadbeef, DataLenPresent: true, } - b := &bytes.Buffer{} var frameOneByteTooSmallCounter int for i := 1; i < 3000; i++ { - b.Reset() f.Data = nil maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(b.Len()).To(BeNumerically(">", i)) + Expect(len(b)).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - err := f.Write(b, protocol.Version1) + b, err := f.Write(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes // In that case, it's impossible to create a STREAM frame of the desired size - if b.Len() == i-1 { + if len(b) == i-1 { frameOneByteTooSmallCounter++ continue } - Expect(b.Len()).To(Equal(i)) + Expect(len(b)).To(Equal(i)) } Expect(frameOneByteTooSmallCounter).To(Equal(1)) }) diff --git a/internal/wire/streams_blocked_frame.go b/internal/wire/streams_blocked_frame.go index f4066071..69a825c0 100644 --- a/internal/wire/streams_blocked_frame.go +++ b/internal/wire/streams_blocked_frame.go @@ -38,15 +38,15 @@ func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*Strea return f, nil } -func (f *StreamsBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { +func (f *StreamsBlockedFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: - b.WriteByte(0x16) + b = append(b, 0x16) case protocol.StreamTypeUni: - b.WriteByte(0x17) + b = append(b, 0x17) } - quicvarint.Write(b, uint64(f.StreamLimit)) - return nil + b = quicvarint.Append(b, uint64(f.StreamLimit)) + return b, nil } // Length of a written frame diff --git a/internal/wire/streams_blocked_frame_test.go b/internal/wire/streams_blocked_frame_test.go index 3247e2c8..3e09dcd7 100644 --- a/internal/wire/streams_blocked_frame_test.go +++ b/internal/wire/streams_blocked_frame_test.go @@ -55,9 +55,9 @@ var _ = Describe("STREAMS_BLOCKED frame", func() { Type: streamType, StreamLimit: protocol.MaxStreamCount, } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - frame, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + b, err := f.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + frame, err := parseStreamsBlockedFrame(bytes.NewReader(b), protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) @@ -67,9 +67,9 @@ var _ = Describe("STREAMS_BLOCKED frame", func() { Type: streamType, StreamLimit: protocol.MaxStreamCount + 1, } - b := &bytes.Buffer{} - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) - _, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + b, err := f.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + _, err = parseStreamsBlockedFrame(bytes.NewReader(b), protocol.VersionWhatever) Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) }) } @@ -77,27 +77,27 @@ var _ = Describe("STREAMS_BLOCKED frame", func() { Context("writing", func() { It("writes a frame for bidirectional streams", func() { - b := &bytes.Buffer{} f := StreamsBlockedFrame{ Type: protocol.StreamTypeBidi, StreamLimit: 0xdeadbeefcafe, } - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + b, err := f.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x16} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("writes a frame for unidirectional streams", func() { - b := &bytes.Buffer{} f := StreamsBlockedFrame{ Type: protocol.StreamTypeUni, StreamLimit: 0xdeadbeefcafe, } - Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + b, err := f.Write(nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) expected := []byte{0x17} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) - Expect(b.Bytes()).To(Equal(expected)) + Expect(b).To(Equal(expected)) }) It("has the correct min length", func() { diff --git a/packet_packer.go b/packet_packer.go index aec485f0..80ade8be 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -830,35 +830,38 @@ func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedH return nil, err } payloadOffset := buf.Len() + raw := buffer.Data[:payloadOffset] if payload.ack != nil { - if err := payload.ack.Write(buf, p.version); err != nil { + var err error + raw, err = payload.ack.Write(raw, p.version) + if err != nil { return nil, err } } if paddingLen > 0 { - buf.Write(make([]byte, paddingLen)) + raw = append(raw, make([]byte, paddingLen)...) } for _, frame := range payload.frames { - if err := frame.Write(buf, p.version); err != nil { + var err error + raw, err = frame.Write(raw, p.version) + if err != nil { return nil, err } } - if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { + if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != payload.length { return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) } if !isMTUProbePacket { - if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize { + if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > p.maxPacketSize { return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } } - raw := buffer.Data // encrypt the packet - raw = raw[:buf.Len()] _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset]) - raw = raw[0 : buf.Len()+sealer.Overhead()] + raw = raw[0 : len(raw)+sealer.Overhead()] // apply header protection pnOffset := payloadOffset - int(header.PacketNumberLen) sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[hdrOffset], raw[pnOffset:payloadOffset]) diff --git a/packet_packer_test.go b/packet_packer_test.go index 6fb7d030..4a6ce672 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -499,10 +499,10 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - b := &bytes.Buffer{} - f.Write(b, packer.version) + b, err := f.Write(nil, packer.version) + Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(Equal([]ackhandler.Frame{{Frame: f}})) - Expect(p.buffer.Data).To(ContainSubstring(b.String())) + Expect(p.buffer.Data).To(ContainSubstring(string(b))) }) It("stores the encryption level a packet was sealed with", func() { diff --git a/server.go b/server.go index ac68681c..b5bc2bdb 100644 --- a/server.go +++ b/server.go @@ -639,13 +639,14 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han } payloadOffset := buf.Len() - if err := ccf.Write(buf, hdr.Version); err != nil { + raw := buf.Bytes() + raw, err := ccf.Write(raw, hdr.Version) + if err != nil { return err } - raw := buf.Bytes() _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset]) - raw = raw[0 : buf.Len()+sealer.Overhead()] + raw = raw[0 : len(raw)+sealer.Overhead()] pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) sealer.EncryptHeader( @@ -659,7 +660,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.config.Tracer != nil { s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf}) } - _, err := s.conn.WritePacket(raw, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(raw, remoteAddr, info.OOB()) return err } diff --git a/streams_map_incoming_test.go b/streams_map_incoming_test.go index 3001ad20..cf642327 100644 --- a/streams_map_incoming_test.go +++ b/streams_map_incoming_test.go @@ -43,9 +43,9 @@ var _ = Describe("Streams Map (incoming)", func() { // check that the frame can be serialized and deserialized checkFrameSerialization := func(f wire.Frame) { - b := &bytes.Buffer{} - ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed()) - frame, err := wire.NewFrameParser(false, protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) + b, err := f.Write(nil, protocol.VersionTLS) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + frame, err := wire.NewFrameParser(false, protocol.VersionTLS).ParseNext(bytes.NewReader(b), protocol.Encryption1RTT) ExpectWithOffset(1, err).ToNot(HaveOccurred()) Expect(f).To(Equal(frame)) }