diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index c098f8c5..9aa0890b 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -2,7 +2,6 @@ package protocol import ( "fmt" - "math" ) // A PacketNumber in QUIC @@ -57,13 +56,13 @@ func (t PacketType) String() string { type ConnectionID uint64 // A StreamID in QUIC -type StreamID uint32 +type StreamID uint64 // A ByteCount in QUIC type ByteCount uint64 // MaxByteCount is the maximum value of a ByteCount -const MaxByteCount = ByteCount(math.MaxUint64) +const MaxByteCount = ByteCount(1<<62 - 1) // MaxReceivePacketSize maximum packet size of any QUIC packet, based on // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index 9ca26550..33c5cf38 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -56,6 +56,11 @@ var _ = Describe("Version", func() { Expect(VersionTLS.UsesIETFFrameFormat()).To(BeTrue()) }) + It("tells if a version uses the IETF frame types", func() { + Expect(Version39.UsesIETFFrameFormat()).To(BeFalse()) + Expect(VersionTLS.UsesIETFFrameFormat()).To(BeTrue()) + }) + It("says if a stream contributes to connection-level flowcontrol, for gQUIC", func() { Expect(Version39.StreamContributesToConnectionFlowControl(1)).To(BeFalse()) Expect(Version39.StreamContributesToConnectionFlowControl(2)).To(BeTrue()) diff --git a/internal/wire/stream_frame.go b/internal/wire/stream_frame.go index 75be8880..fc38acd0 100644 --- a/internal/wire/stream_frame.go +++ b/internal/wire/stream_frame.go @@ -19,13 +19,12 @@ type StreamFrame struct { Data []byte } -var ( - errInvalidStreamIDLen = errors.New("StreamFrame: Invalid StreamID length") - errInvalidOffsetLen = errors.New("StreamFrame: Invalid offset length") -) - -// ParseStreamFrame reads a stream frame. The type byte must not have been read yet. +// ParseStreamFrame reads a STREAM frame func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) { + if !version.UsesIETFFrameFormat() { + return parseLegacyStreamFrame(r, version) + } + frame := &StreamFrame{} typeByte, err := r.ReadByte() @@ -33,44 +32,39 @@ func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF return nil, err } - frame.FinBit = typeByte&0x40 > 0 - frame.DataLenPresent = typeByte&0x20 > 0 - offsetLen := typeByte & 0x1c >> 2 - if offsetLen != 0 { - offsetLen++ - } - streamIDLen := typeByte&0x3 + 1 + frame.FinBit = typeByte&0x1 > 0 + frame.DataLenPresent = typeByte&0x2 > 0 + hasOffset := typeByte&0x4 > 0 - sid, err := utils.GetByteOrder(version).ReadUintN(r, streamIDLen) + streamID, err := utils.ReadVarInt(r) if err != nil { return nil, err } - frame.StreamID = protocol.StreamID(sid) - - offset, err := utils.GetByteOrder(version).ReadUintN(r, offsetLen) - if err != nil { - return nil, err - } - frame.Offset = protocol.ByteCount(offset) - - var dataLen uint16 - if frame.DataLenPresent { - dataLen, err = utils.GetByteOrder(version).ReadUint16(r) + frame.StreamID = protocol.StreamID(streamID) + if hasOffset { + offset, err := utils.ReadVarInt(r) if err != nil { return nil, err } + frame.Offset = protocol.ByteCount(offset) } - // shortcut to prevent the unneccessary allocation of dataLen bytes - // if the dataLen is larger than the remaining length of the packet - // reading the packet contents would result in EOF when attempting to READ - if int(dataLen) > r.Len() { - return nil, io.EOF - } - - if !frame.DataLenPresent { + var dataLen uint64 + if frame.DataLenPresent { + var err error + dataLen, err = utils.ReadVarInt(r) + if err != nil { + return nil, err + } + // shortcut to prevent the unneccessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the packet contents would result in EOF when attempting to READ + if dataLen > uint64(r.Len()) { + return nil, io.EOF + } + } else { // The rest of the packet is data - dataLen = uint16(r.Len()) + dataLen = uint64(r.Len()) } if dataLen != 0 { frame.Data = make([]byte, dataLen) @@ -79,8 +73,7 @@ func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF return nil, err } } - - if frame.Offset+frame.DataLen() < frame.Offset { + if frame.Offset+frame.DataLen() > protocol.MaxByteCount { return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset") } if !frame.FinBit && frame.DataLen() == 0 { @@ -89,118 +82,51 @@ func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF return frame, nil } -// WriteStreamFrame writes a stream frame. +// Write writes a STREAM frame func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if !version.UsesIETFFrameFormat() { + return f.writeLegacy(b, version) + } + if len(f.Data) == 0 && !f.FinBit { return errors.New("StreamFrame: attempting to write empty frame without FIN") } - typeByte := uint8(0x80) // sets the leftmost bit to 1 + typeByte := byte(0x10) if f.FinBit { - typeByte ^= 0x40 + typeByte ^= 0x1 } + hasOffset := f.Offset != 0 if f.DataLenPresent { - typeByte ^= 0x20 + typeByte ^= 0x2 } - - offsetLength := f.getOffsetLength() - if offsetLength > 0 { - typeByte ^= (uint8(offsetLength) - 1) << 2 + if hasOffset { + typeByte ^= 0x4 } - - streamIDLen := f.calculateStreamIDLength() - typeByte ^= streamIDLen - 1 - b.WriteByte(typeByte) - - switch streamIDLen { - case 1: - b.WriteByte(uint8(f.StreamID)) - case 2: - utils.GetByteOrder(version).WriteUint16(b, uint16(f.StreamID)) - case 3: - utils.GetByteOrder(version).WriteUint24(b, uint32(f.StreamID)) - case 4: - utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) - default: - return errInvalidStreamIDLen + utils.WriteVarInt(b, uint64(f.StreamID)) + if hasOffset { + utils.WriteVarInt(b, uint64(f.Offset)) } - - switch offsetLength { - case 0: - case 2: - utils.GetByteOrder(version).WriteUint16(b, uint16(f.Offset)) - case 3: - utils.GetByteOrder(version).WriteUint24(b, uint32(f.Offset)) - case 4: - utils.GetByteOrder(version).WriteUint32(b, uint32(f.Offset)) - case 5: - utils.GetByteOrder(version).WriteUint40(b, uint64(f.Offset)) - case 6: - utils.GetByteOrder(version).WriteUint48(b, uint64(f.Offset)) - case 7: - utils.GetByteOrder(version).WriteUint56(b, uint64(f.Offset)) - case 8: - utils.GetByteOrder(version).WriteUint64(b, uint64(f.Offset)) - default: - return errInvalidOffsetLen - } - if f.DataLenPresent { - utils.GetByteOrder(version).WriteUint16(b, uint16(len(f.Data))) + utils.WriteVarInt(b, uint64(f.DataLen())) } - b.Write(f.Data) return nil } -func (f *StreamFrame) calculateStreamIDLength() uint8 { - if f.StreamID < (1 << 8) { - return 1 - } else if f.StreamID < (1 << 16) { - return 2 - } else if f.StreamID < (1 << 24) { - return 3 - } - return 4 -} - -func (f *StreamFrame) getOffsetLength() protocol.ByteCount { - if f.Offset == 0 { - return 0 - } - if f.Offset < (1 << 16) { - return 2 - } - if f.Offset < (1 << 24) { - return 3 - } - if f.Offset < (1 << 32) { - return 4 - } - if f.Offset < (1 << 40) { - return 5 - } - if f.Offset < (1 << 48) { - return 6 - } - if f.Offset < (1 << 56) { - return 7 - } - return 8 -} - // MinLength returns the length of the header of a StreamFrame -// the total length of the StreamFrame is frame.MinLength() + frame.DataLen() -func (f *StreamFrame) MinLength(protocol.VersionNumber) (protocol.ByteCount, error) { - length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength() +// the total length of the frame is frame.MinLength() + frame.DataLen() +func (f *StreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) { + if !version.UsesIETFFrameFormat() { + return f.minLengthLegacy(version) + } + length := 1 + utils.VarIntLen(uint64(f.StreamID)) + if f.Offset != 0 { + length += utils.VarIntLen(uint64(f.Offset)) + } if f.DataLenPresent { - length += 2 + length += utils.VarIntLen(uint64(f.DataLen())) } return length, nil } - -// DataLen gives the length of data in bytes -func (f *StreamFrame) DataLen() protocol.ByteCount { - return protocol.ByteCount(len(f.Data)) -} diff --git a/internal/wire/stream_frame_legacy.go b/internal/wire/stream_frame_legacy.go new file mode 100644 index 00000000..e41d9978 --- /dev/null +++ b/internal/wire/stream_frame_legacy.go @@ -0,0 +1,197 @@ +package wire + +import ( + "bytes" + "errors" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +var ( + errInvalidStreamIDLen = errors.New("StreamFrame: Invalid StreamID length") + errInvalidOffsetLen = errors.New("StreamFrame: Invalid offset length") +) + +// parseLegacyStreamFrame reads a stream frame. The type byte must not have been read yet. +func parseLegacyStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) { + frame := &StreamFrame{} + + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + frame.FinBit = typeByte&0x40 > 0 + frame.DataLenPresent = typeByte&0x20 > 0 + offsetLen := typeByte & 0x1c >> 2 + if offsetLen != 0 { + offsetLen++ + } + streamIDLen := typeByte&0x3 + 1 + + sid, err := utils.GetByteOrder(version).ReadUintN(r, streamIDLen) + if err != nil { + return nil, err + } + frame.StreamID = protocol.StreamID(sid) + + offset, err := utils.GetByteOrder(version).ReadUintN(r, offsetLen) + if err != nil { + return nil, err + } + frame.Offset = protocol.ByteCount(offset) + + var dataLen uint16 + if frame.DataLenPresent { + dataLen, err = utils.GetByteOrder(version).ReadUint16(r) + if err != nil { + return nil, err + } + } + + // shortcut to prevent the unneccessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the packet contents would result in EOF when attempting to READ + if int(dataLen) > r.Len() { + return nil, io.EOF + } + + if !frame.DataLenPresent { + // The rest of the packet is data + dataLen = uint16(r.Len()) + } + if dataLen != 0 { + frame.Data = make([]byte, dataLen) + if _, err := io.ReadFull(r, frame.Data); err != nil { + // this should never happen, since we already checked the dataLen earlier + return nil, err + } + } + + // MaxByteCount is the highest value that can be encoded with the IETF QUIC variable integer encoding (2^62-1). + // Note that this value is smaller than the maximum value that could be encoded in the gQUIC STREAM frame (2^64-1). + if frame.Offset+frame.DataLen() > protocol.MaxByteCount { + return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset") + } + if !frame.FinBit && frame.DataLen() == 0 { + return nil, qerr.EmptyStreamFrameNoFin + } + return frame, nil +} + +// writeLegacy writes a stream frame. +func (f *StreamFrame) writeLegacy(b *bytes.Buffer, version protocol.VersionNumber) error { + if len(f.Data) == 0 && !f.FinBit { + return errors.New("StreamFrame: attempting to write empty frame without FIN") + } + + typeByte := uint8(0x80) // sets the leftmost bit to 1 + if f.FinBit { + typeByte ^= 0x40 + } + if f.DataLenPresent { + typeByte ^= 0x20 + } + + offsetLength := f.getOffsetLength() + if offsetLength > 0 { + typeByte ^= (uint8(offsetLength) - 1) << 2 + } + + streamIDLen := f.calculateStreamIDLength() + typeByte ^= streamIDLen - 1 + + b.WriteByte(typeByte) + + switch streamIDLen { + case 1: + b.WriteByte(uint8(f.StreamID)) + case 2: + utils.GetByteOrder(version).WriteUint16(b, uint16(f.StreamID)) + case 3: + utils.GetByteOrder(version).WriteUint24(b, uint32(f.StreamID)) + case 4: + utils.GetByteOrder(version).WriteUint32(b, uint32(f.StreamID)) + default: + return errInvalidStreamIDLen + } + + switch offsetLength { + case 0: + case 2: + utils.GetByteOrder(version).WriteUint16(b, uint16(f.Offset)) + case 3: + utils.GetByteOrder(version).WriteUint24(b, uint32(f.Offset)) + case 4: + utils.GetByteOrder(version).WriteUint32(b, uint32(f.Offset)) + case 5: + utils.GetByteOrder(version).WriteUint40(b, uint64(f.Offset)) + case 6: + utils.GetByteOrder(version).WriteUint48(b, uint64(f.Offset)) + case 7: + utils.GetByteOrder(version).WriteUint56(b, uint64(f.Offset)) + case 8: + utils.GetByteOrder(version).WriteUint64(b, uint64(f.Offset)) + default: + return errInvalidOffsetLen + } + + if f.DataLenPresent { + utils.GetByteOrder(version).WriteUint16(b, uint16(len(f.Data))) + } + + b.Write(f.Data) + return nil +} + +func (f *StreamFrame) calculateStreamIDLength() uint8 { + if f.StreamID < (1 << 8) { + return 1 + } else if f.StreamID < (1 << 16) { + return 2 + } else if f.StreamID < (1 << 24) { + return 3 + } + return 4 +} + +func (f *StreamFrame) getOffsetLength() protocol.ByteCount { + if f.Offset == 0 { + return 0 + } + if f.Offset < (1 << 16) { + return 2 + } + if f.Offset < (1 << 24) { + return 3 + } + if f.Offset < (1 << 32) { + return 4 + } + if f.Offset < (1 << 40) { + return 5 + } + if f.Offset < (1 << 48) { + return 6 + } + if f.Offset < (1 << 56) { + return 7 + } + return 8 +} + +func (f *StreamFrame) minLengthLegacy(protocol.VersionNumber) (protocol.ByteCount, error) { + length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength() + if f.DataLenPresent { + length += 2 + } + return length, nil +} + +// DataLen gives the length of data in bytes +func (f *StreamFrame) DataLen() protocol.ByteCount { + return protocol.ByteCount(len(f.Data)) +} diff --git a/internal/wire/stream_frame_legacy_test.go b/internal/wire/stream_frame_legacy_test.go new file mode 100644 index 00000000..b7b8d256 --- /dev/null +++ b/internal/wire/stream_frame_legacy_test.go @@ -0,0 +1,483 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/qerr" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STREAM frame (for gQUIC)", func() { + Context("when parsing", func() { + It("accepts a sample frame", func() { + // a STREAM frame, plus 3 additional bytes, not belonging to this frame + b := bytes.NewReader([]byte{0x80 ^ 0x20, + 0x1, // stream id + 0x0, 0x6, // data length + 'f', 'o', 'o', 'b', 'a', 'r', + 'f', 'o', 'o', // additional bytes + }) + frame, err := ParseStreamFrame(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.FinBit).To(BeFalse()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(1))) + Expect(frame.Offset).To(BeZero()) + Expect(frame.DataLenPresent).To(BeTrue()) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(b.Len()).To(Equal(3)) + }) + + It("accepts frames with offsets", func() { + b := bytes.NewReader([]byte{0x80 ^ 0x20 /* 2 byte offset */ ^ 0x4, + 0x1, // stream id + 0x0, 0x42, // offset + 0x0, 0x6, // data length + 'f', 'o', 'o', 'b', 'a', 'r', + }) + frame, err := ParseStreamFrame(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.FinBit).To(BeFalse()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(1))) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0x42))) + Expect(frame.DataLenPresent).To(BeTrue()) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x80 ^ 0x20 ^ 0x4, + 0x1, // stream id + 0x0, 0x2a, // offset + 0x0, 0x6, // data length, + 'f', 'o', 'o', 'b', 'a', 'r', + } + _, err := ParseStreamFrame(bytes.NewReader(data), versionBigEndian) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseStreamFrame(bytes.NewReader(data[0:i]), versionBigEndian) + Expect(err).To(HaveOccurred()) + } + }) + + It("accepts frame without data length", func() { + b := bytes.NewReader([]byte{0x80, + 0x1, // stream id + 'f', 'o', 'o', 'b', 'a', 'r', + }) + frame, err := ParseStreamFrame(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.FinBit).To(BeFalse()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(1))) + Expect(frame.Offset).To(BeZero()) + Expect(frame.DataLenPresent).To(BeFalse()) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts an empty frame with FinBit set, with data length set", func() { + // the STREAM frame, plus 3 additional bytes, not belonging to this frame + b := bytes.NewReader([]byte{0x80 ^ 0x40 ^ 0x20, + 0x1, // stream id + 0, 0, // data length + 'f', 'o', 'o', // additional bytes + }) + frame, err := ParseStreamFrame(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.FinBit).To(BeTrue()) + Expect(frame.DataLenPresent).To(BeTrue()) + Expect(frame.Data).To(BeEmpty()) + Expect(b.Len()).To(Equal(3)) + }) + + It("accepts an empty frame with the FinBit set", func() { + b := bytes.NewReader([]byte{0x80 ^ 0x40, + 0x1, // stream id + 'f', 'o', 'o', 'b', 'a', 'r', + }) + frame, err := ParseStreamFrame(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.FinBit).To(BeTrue()) + Expect(frame.DataLenPresent).To(BeFalse()) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on empty stream frames that don't have the FinBit set", func() { + b := bytes.NewReader([]byte{0x80 ^ 0x20, + 0x1, // stream id + 0, 0, // data length + }) + _, err := ParseStreamFrame(b, versionBigEndian) + Expect(err).To(MatchError(qerr.EmptyStreamFrameNoFin)) + }) + + It("rejects frames to too large dataLen", func() { + b := bytes.NewReader([]byte{0xa0, 0x1, 0xff, 0xff}) + _, err := ParseStreamFrame(b, versionBigEndian) + Expect(err).To(MatchError(io.EOF)) + }) + + It("rejects frames that overflow the offset", func() { + // Offset + len(Data) overflows MaxByteCount + f := &StreamFrame{ + StreamID: 1, + Offset: protocol.MaxByteCount, + Data: []byte{'f'}, + } + b := &bytes.Buffer{} + err := f.Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + _, err = ParseStreamFrame(bytes.NewReader(b.Bytes()), versionBigEndian) + Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset"))) + }) + }) + + Context("when writing", func() { + Context("in big endian", func() { + It("writes sample frame", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + DataLenPresent: true, + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x80 ^ 0x20, + 0x1, // stream id + 0x0, 0x6, // data length + 'f', 'o', 'o', 'b', 'a', 'r', + })) + }) + }) + + It("sets the FinBit", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + FinBit: true, + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x40).To(Equal(byte(0x40))) + }) + + It("errors when length is zero and FIN is not set", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + }).Write(b, versionBigEndian) + Expect(err).To(MatchError("StreamFrame: attempting to write empty frame without FIN")) + }) + + It("has proper min length for a short StreamID and a short offset", func() { + b := &bytes.Buffer{} + f := &StreamFrame{ + StreamID: 1, + Data: []byte{}, + Offset: 0, + FinBit: true, + } + err := f.Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(f.MinLength(0)).To(Equal(protocol.ByteCount(b.Len()))) + }) + + It("has proper min length for a long StreamID and a big offset", func() { + b := &bytes.Buffer{} + f := &StreamFrame{ + StreamID: 0xdecafbad, + Data: []byte{}, + Offset: 0xdeadbeefcafe, + FinBit: true, + } + err := f.Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(f.MinLength(versionBigEndian)).To(Equal(protocol.ByteCount(b.Len()))) + }) + + Context("data length field", func() { + It("writes the data length", func() { + dataLen := 0x1337 + b := &bytes.Buffer{} + f := &StreamFrame{ + StreamID: 1, + Data: bytes.Repeat([]byte{'f'}, dataLen), + DataLenPresent: true, + Offset: 0, + } + err := f.Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + minLength, _ := f.MinLength(0) + Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0x20))) + Expect(b.Bytes()[minLength-2 : minLength]).To(Equal([]byte{0x13, 0x37})) + }) + }) + + It("omits the data length field", func() { + dataLen := 0x1337 + b := &bytes.Buffer{} + f := &StreamFrame{ + StreamID: 1, + Data: bytes.Repeat([]byte{'f'}, dataLen), + DataLenPresent: false, + Offset: 0, + } + err := f.Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0))) + Expect(b.Bytes()[1 : b.Len()-dataLen]).ToNot(ContainSubstring(string([]byte{0x37, 0x13}))) + minLength, _ := f.MinLength(versionBigEndian) + f.DataLenPresent = true + minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian) + Expect(minLength).To(Equal(minLengthWithoutDataLen - 2)) + }) + + It("calculates the correct min-length", func() { + f := &StreamFrame{ + StreamID: 0xcafe, + Data: []byte("foobar"), + DataLenPresent: false, + Offset: 0xdeadbeef, + } + minLengthWithoutDataLen, _ := f.MinLength(versionBigEndian) + f.DataLenPresent = true + Expect(f.MinLength(versionBigEndian)).To(Equal(minLengthWithoutDataLen + 2)) + }) + + Context("offset lengths", func() { + It("does not write an offset if the offset is 0", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + Offset: 0, + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x0))) + }) + + It("writes a 2-byte offset if the offset is larger than 0", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + Offset: 0x1337, + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x1 << 2))) + Expect(b.Bytes()[2:4]).To(Equal([]byte{0x13, 0x37})) + }) + + It("writes a 3-byte offset if the offset", func() { + b := &bytes.Buffer{} + (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + Offset: 0x13cafe, + }).Write(b, versionBigEndian) + Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x2 << 2))) + Expect(b.Bytes()[2:5]).To(Equal([]byte{0x13, 0xca, 0xfe})) + }) + + It("writes a 4-byte offset if the offset", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + Offset: 0xdeadbeef, + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x3 << 2))) + Expect(b.Bytes()[2:6]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + }) + + It("writes a 5-byte offset if the offset", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + Offset: 0x13deadbeef, + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x4 << 2))) + Expect(b.Bytes()[2:7]).To(Equal([]byte{0x13, 0xde, 0xad, 0xbe, 0xef})) + }) + + It("writes a 6-byte offset if the offset", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + Offset: 0xdeadbeefcafe, + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x5 << 2))) + Expect(b.Bytes()[2:8]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe})) + }) + + It("writes a 7-byte offset if the offset", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + Offset: 0x13deadbeefcafe, + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x6 << 2))) + Expect(b.Bytes()[2:9]).To(Equal([]byte{0x13, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe})) + }) + + It("writes a 8-byte offset if the offset", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + Offset: 0x1337deadbeefcafe, + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x7 << 2))) + Expect(b.Bytes()[2:10]).To(Equal([]byte{0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe})) + }) + }) + + Context("lengths of StreamIDs", func() { + It("writes a 1 byte StreamID", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 13, + Data: []byte("foobar"), + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x0))) + Expect(b.Bytes()[1]).To(Equal(uint8(13))) + }) + + It("writes a 2 byte StreamID", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 0xcafe, + Data: []byte("foobar"), + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x1))) + Expect(b.Bytes()[1:3]).To(Equal([]byte{0xca, 0xfe})) + }) + + It("writes a 3 byte StreamID", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 0x13beef, + Data: []byte("foobar"), + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x2))) + Expect(b.Bytes()[1:4]).To(Equal([]byte{0x13, 0xbe, 0xef})) + }) + + It("writes a 4 byte StreamID", func() { + b := &bytes.Buffer{} + err := (&StreamFrame{ + StreamID: 0xdecafbad, + Data: []byte("foobar"), + }).Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x3))) + Expect(b.Bytes()[1:5]).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) + }) + + It("writes a multiple byte StreamID, after the Stream length was already determined by MinLenght()", func() { + b := &bytes.Buffer{} + frame := &StreamFrame{ + StreamID: 0xdecafbad, + Data: []byte("foobar"), + } + frame.MinLength(0) + err := frame.Write(b, versionBigEndian) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x3))) + Expect(b.Bytes()[1:5]).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) + }) + }) + }) + + Context("shortening of StreamIDs", func() { + It("determines the length of a 1 byte StreamID", func() { + f := &StreamFrame{StreamID: 0xFF} + Expect(f.calculateStreamIDLength()).To(Equal(uint8(1))) + }) + + It("determines the length of a 2 byte StreamID", func() { + f := &StreamFrame{StreamID: 0xFFFF} + Expect(f.calculateStreamIDLength()).To(Equal(uint8(2))) + }) + + It("determines the length of a 3 byte StreamID", func() { + f := &StreamFrame{StreamID: 0xFFFFFF} + Expect(f.calculateStreamIDLength()).To(Equal(uint8(3))) + }) + + It("determines the length of a 4 byte StreamID", func() { + f := &StreamFrame{StreamID: 0xFFFFFFFF} + Expect(f.calculateStreamIDLength()).To(Equal(uint8(4))) + }) + }) + + Context("shortening of Offsets", func() { + It("determines length 0 of offset 0", func() { + f := &StreamFrame{Offset: 0} + Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(0))) + }) + + It("determines the length of a 2 byte offset", func() { + f := &StreamFrame{Offset: 0xFFFF} + Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(2))) + }) + + It("determines the length of a 2 byte offset, even if it would fit into 1 byte", func() { + f := &StreamFrame{Offset: 0x1} + Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(2))) + }) + + It("determines the length of a 3 byte offset", func() { + f := &StreamFrame{Offset: 0xFFFFFF} + Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(3))) + }) + + It("determines the length of a 4 byte offset", func() { + f := &StreamFrame{Offset: 0xFFFFFFFF} + Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(4))) + }) + + It("determines the length of a 5 byte offset", func() { + f := &StreamFrame{Offset: 0xFFFFFFFFFF} + Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(5))) + }) + + It("determines the length of a 6 byte offset", func() { + f := &StreamFrame{Offset: 0xFFFFFFFFFFFF} + Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(6))) + }) + + It("determines the length of a 7 byte offset", func() { + f := &StreamFrame{Offset: 0xFFFFFFFFFFFFFF} + Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(7))) + }) + + It("determines the length of an 8 byte offset", func() { + f := &StreamFrame{Offset: 0xFFFFFFFFFFFFFFFF} + Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(8))) + }) + }) + + Context("DataLen", func() { + It("determines the length of the data", func() { + frame := StreamFrame{ + Data: []byte("foobar"), + } + Expect(frame.DataLen()).To(Equal(protocol.ByteCount(6))) + }) + }) +}) diff --git a/internal/wire/stream_frame_test.go b/internal/wire/stream_frame_test.go index 1a455d6d..f116de3e 100644 --- a/internal/wire/stream_frame_test.go +++ b/internal/wire/stream_frame_test.go @@ -2,490 +2,210 @@ package wire import ( "bytes" - "io" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("StreamFrame", func() { +var _ = Describe("STREAM frame (for IETF QUIC)", func() { Context("when parsing", func() { - Context("in big endian", func() { - It("accepts a sample frame", func() { - // a STREAM frame, plus 3 additional bytes, not belonging to this frame - b := bytes.NewReader([]byte{0x80 ^ 0x20, - 0x1, // stream id - 0x0, 0x6, // data length - 'f', 'o', 'o', 'b', 'a', 'r', - 'f', 'o', 'o', // additional bytes - }) - frame, err := ParseStreamFrame(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.FinBit).To(BeFalse()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(1))) - Expect(frame.Offset).To(BeZero()) - Expect(frame.DataLenPresent).To(BeTrue()) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(b.Len()).To(Equal(3)) - }) - - It("accepts frames with offsets", func() { - b := bytes.NewReader([]byte{0x80 ^ 0x20 /* 2 byte offset */ ^ 0x4, - 0x1, // stream id - 0x0, 0x42, // offset - 0x0, 0x6, // data length - 'f', 'o', 'o', 'b', 'a', 'r', - }) - frame, err := ParseStreamFrame(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.FinBit).To(BeFalse()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(1))) - Expect(frame.Offset).To(Equal(protocol.ByteCount(0x42))) - Expect(frame.DataLenPresent).To(BeTrue()) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x80 ^ 0x20 ^ 0x4, - 0x1, // stream id - 0x0, 0x2a, // offset - 0x0, 0x6, // data length, - 'f', 'o', 'o', 'b', 'a', 'r', - } - _, err := ParseStreamFrame(bytes.NewReader(data), versionBigEndian) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := ParseStreamFrame(bytes.NewReader(data[0:i]), versionBigEndian) - Expect(err).To(HaveOccurred()) - } - }) - }) - - It("accepts frame without data length", func() { - b := bytes.NewReader([]byte{0x80, - 0x1, // stream id - 'f', 'o', 'o', 'b', 'a', 'r', - }) - frame, err := ParseStreamFrame(b, protocol.VersionWhatever) + It("parses a frame with OFF bit", func() { + data := []byte{0x10 ^ 0x4} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := ParseStreamFrame(r, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal([]byte("foobar"))) + Expect(frame.FinBit).To(BeFalse()) + Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) + Expect(r.Len()).To(BeZero()) + }) + + It("respects the LEN when parsing the frame", func() { + data := []byte{0x10 ^ 0x2} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(4)...) // data length + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := ParseStreamFrame(r, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) + Expect(frame.Data).To(Equal([]byte("foob"))) Expect(frame.FinBit).To(BeFalse()) - Expect(frame.StreamID).To(Equal(protocol.StreamID(1))) Expect(frame.Offset).To(BeZero()) - Expect(frame.DataLenPresent).To(BeFalse()) - Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(b.Len()).To(BeZero()) + Expect(r.Len()).To(Equal(2)) }) - It("accepts an empty frame with FinBit set, with data length set", func() { - // the STREAM frame, plus 3 additional bytes, not belonging to this frame - b := bytes.NewReader([]byte{0x80 ^ 0x40 ^ 0x20, - 0x1, // stream id - 0, 0, // data length - 'f', 'o', 'o', // additional bytes - }) - frame, err := ParseStreamFrame(b, protocol.VersionWhatever) + It("parses a frame with FIN bit", func() { + data := []byte{0x10 ^ 0x1} + data = append(data, encodeVarInt(9)...) // stream ID + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + frame, err := ParseStreamFrame(r, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - Expect(frame.FinBit).To(BeTrue()) - Expect(frame.DataLenPresent).To(BeTrue()) - Expect(frame.Data).To(BeEmpty()) - Expect(b.Len()).To(Equal(3)) - }) - - It("accepts an empty frame with the FinBit set", func() { - b := bytes.NewReader([]byte{0x80 ^ 0x40, - 0x1, // stream id - 'f', 'o', 'o', 'b', 'a', 'r', - }) - frame, err := ParseStreamFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(frame.FinBit).To(BeTrue()) - Expect(frame.DataLenPresent).To(BeFalse()) + Expect(frame.StreamID).To(Equal(protocol.StreamID(9))) Expect(frame.Data).To(Equal([]byte("foobar"))) - Expect(b.Len()).To(BeZero()) + Expect(frame.FinBit).To(BeTrue()) + Expect(frame.Offset).To(BeZero()) + Expect(r.Len()).To(BeZero()) }) - It("errors on empty stream frames that don't have the FinBit set", func() { - b := bytes.NewReader([]byte{0x80 ^ 0x20, - 0x1, // stream id - 0, 0, // data length - }) - _, err := ParseStreamFrame(b, protocol.VersionWhatever) + It("rejects empty frames than don't have the FIN bit set", func() { + data := []byte{0x10} + data = append(data, encodeVarInt(0x1337)...) // stream ID + r := bytes.NewReader(data) + _, err := ParseStreamFrame(r, versionIETFFrames) Expect(err).To(MatchError(qerr.EmptyStreamFrameNoFin)) }) - It("rejects frames to too large dataLen", func() { - b := bytes.NewReader([]byte{0xa0, 0x1, 0xff, 0xff}) - _, err := ParseStreamFrame(b, protocol.VersionWhatever) - Expect(err).To(MatchError(io.EOF)) + It("rejects frames that overflow the maximum offset", func() { + data := []byte{0x10 ^ 0x4} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset + data = append(data, []byte("foobar")...) + r := bytes.NewReader(data) + _, err := ParseStreamFrame(r, versionIETFFrames) + Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset"))) }) - It("rejects frames that overflow the offset", func() { - // Offset + len(Data) overflows MaxByteCount - f := &StreamFrame{ - StreamID: 1, - Offset: protocol.MaxByteCount, - Data: []byte{'f'}, + It("errors on EOFs", func() { + data := []byte{0x10 ^ 0x4 ^ 0x2} + data = append(data, encodeVarInt(0x12345)...) // stream ID + data = append(data, encodeVarInt(0xdecafbad)...) // offset + data = append(data, encodeVarInt(6)...) // data length + data = append(data, []byte("foobar")...) + _, err := ParseStreamFrame(bytes.NewReader(data), versionIETFFrames) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := ParseStreamFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + Expect(err).To(HaveOccurred()) } - b := &bytes.Buffer{} - err := f.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - _, err = ParseStreamFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) - Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset"))) }) }) Context("when writing", func() { - Context("in big endian", func() { - It("writes sample frame", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - DataLenPresent: true, - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x80 ^ 0x20, - 0x1, // stream id - 0x0, 0x6, // data length - 'f', 'o', 'o', 'b', 'a', 'r', - })) - }) - }) - - It("sets the FinBit", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, + It("writes a frame without offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, Data: []byte("foobar"), - FinBit: true, - }).Write(b, protocol.VersionWhatever) + } + b := &bytes.Buffer{} + err := f.Write(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x40).To(Equal(byte(0x40))) + expected := []byte{0x10} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) }) - It("errors when length is zero and FIN is not set", func() { + It("writes a frame with offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x123456, + Data: []byte("foobar"), + } b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - }).Write(b, protocol.VersionWhatever) + err := f.Write(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x10 ^ 0x4} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(0x123456)...) // offset + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with FIN bit", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x123456, + FinBit: true, + } + b := &bytes.Buffer{} + err := f.Write(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x10 ^ 0x4 ^ 0x1} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(0x123456)...) // offset + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with data length", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + DataLenPresent: true, + } + b := &bytes.Buffer{} + err := f.Write(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x10 ^ 0x2} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(6)...) // data length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("writes a frame with data length and offset", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Data: []byte("foobar"), + DataLenPresent: true, + Offset: 0x123456, + } + b := &bytes.Buffer{} + err := f.Write(b, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + expected := []byte{0x10 ^ 0x4 ^ 0x2} + expected = append(expected, encodeVarInt(0x1337)...) // stream ID + expected = append(expected, encodeVarInt(0x123456)...) // offset + expected = append(expected, encodeVarInt(6)...) // data length + expected = append(expected, []byte("foobar")...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("refuses to write an empty frame without FIN", func() { + f := &StreamFrame{ + StreamID: 0x42, + Offset: 0x1337, + } + b := &bytes.Buffer{} + err := f.Write(b, versionIETFFrames) Expect(err).To(MatchError("StreamFrame: attempting to write empty frame without FIN")) }) + }) - It("has proper min length for a short StreamID and a short offset", func() { - b := &bytes.Buffer{} + Context("length", func() { + It("has the right length for a frame without offset and data length", func() { f := &StreamFrame{ - StreamID: 1, - Data: []byte{}, - Offset: 0, - FinBit: true, + StreamID: 0x1337, + Data: []byte("foobar"), } - err := f.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.MinLength(0)).To(Equal(protocol.ByteCount(b.Len()))) + Expect(f.MinLength(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0x1337))) }) - It("has proper min length for a long StreamID and a big offset", func() { - b := &bytes.Buffer{} + It("has the right length for a frame with offset", func() { f := &StreamFrame{ - StreamID: 0xdecafbad, - Data: []byte{}, - Offset: 0xdeadbeefcafe, - FinBit: true, + StreamID: 0x1337, + Offset: 0x42, + Data: []byte("foobar"), } - err := f.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.MinLength(0)).To(Equal(protocol.ByteCount(b.Len()))) + Expect(f.MinLength(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0x1337) + utils.VarIntLen(0x42))) }) - Context("data length field", func() { - Context("in big endian", func() { - It("writes the data length", func() { - dataLen := 0x1337 - b := &bytes.Buffer{} - f := &StreamFrame{ - StreamID: 1, - Data: bytes.Repeat([]byte{'f'}, dataLen), - DataLenPresent: true, - Offset: 0, - } - err := f.Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - minLength, _ := f.MinLength(0) - Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0x20))) - Expect(b.Bytes()[minLength-2 : minLength]).To(Equal([]byte{0x13, 0x37})) - }) - }) - - It("omits the data length field", func() { - dataLen := 0x1337 - b := &bytes.Buffer{} - f := &StreamFrame{ - StreamID: 1, - Data: bytes.Repeat([]byte{'f'}, dataLen), - DataLenPresent: false, - Offset: 0, - } - err := f.Write(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x20).To(Equal(uint8(0))) - Expect(b.Bytes()[1 : b.Len()-dataLen]).ToNot(ContainSubstring(string([]byte{0x37, 0x13}))) - minLength, _ := f.MinLength(0) - f.DataLenPresent = true - minLengthWithoutDataLen, _ := f.MinLength(0) - Expect(minLength).To(Equal(minLengthWithoutDataLen - 2)) - }) - - It("calculates the correcct min-length", func() { - f := &StreamFrame{ - StreamID: 0xCAFE, - Data: []byte("foobar"), - DataLenPresent: false, - Offset: 0xDEADBEEF, - } - minLengthWithoutDataLen, _ := f.MinLength(0) - f.DataLenPresent = true - Expect(f.MinLength(0)).To(Equal(minLengthWithoutDataLen + 2)) - }) - }) - - Context("offset lengths", func() { - Context("in big endian", func() { - It("does not write an offset if the offset is 0", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - Offset: 0, - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x0))) - }) - - It("writes a 2-byte offset if the offset is larger than 0", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - Offset: 0x1337, - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x1 << 2))) - Expect(b.Bytes()[2:4]).To(Equal([]byte{0x13, 0x37})) - }) - - It("writes a 3-byte offset if the offset", func() { - b := &bytes.Buffer{} - (&StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - Offset: 0x13cafe, - }).Write(b, versionBigEndian) - Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x2 << 2))) - Expect(b.Bytes()[2:5]).To(Equal([]byte{0x13, 0xca, 0xfe})) - }) - - It("writes a 4-byte offset if the offset", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - Offset: 0xdeadbeef, - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x3 << 2))) - Expect(b.Bytes()[2:6]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - }) - - It("writes a 5-byte offset if the offset", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - Offset: 0x13deadbeef, - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x4 << 2))) - Expect(b.Bytes()[2:7]).To(Equal([]byte{0x13, 0xde, 0xad, 0xbe, 0xef})) - }) - - It("writes a 6-byte offset if the offset", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - Offset: 0xdeadbeefcafe, - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x5 << 2))) - Expect(b.Bytes()[2:8]).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe})) - }) - - It("writes a 7-byte offset if the offset", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - Offset: 0x13deadbeefcafe, - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x6 << 2))) - Expect(b.Bytes()[2:9]).To(Equal([]byte{0x13, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe})) - }) - - It("writes a 8-byte offset if the offset", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - Offset: 0x1337deadbeefcafe, - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x1c).To(Equal(uint8(0x7 << 2))) - Expect(b.Bytes()[2:10]).To(Equal([]byte{0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe})) - }) - }) - }) - - Context("lengths of StreamIDs", func() { - Context("in big endian", func() { - It("writes a 1 byte StreamID", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 13, - Data: []byte("foobar"), - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x0))) - Expect(b.Bytes()[1]).To(Equal(uint8(13))) - }) - - It("writes a 2 byte StreamID", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 0xcafe, - Data: []byte("foobar"), - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x1))) - Expect(b.Bytes()[1:3]).To(Equal([]byte{0xca, 0xfe})) - }) - - It("writes a 3 byte StreamID", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 0x13beef, - Data: []byte("foobar"), - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x2))) - Expect(b.Bytes()[1:4]).To(Equal([]byte{0x13, 0xbe, 0xef})) - }) - - It("writes a 4 byte StreamID", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 0xdecafbad, - Data: []byte("foobar"), - }).Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x3))) - Expect(b.Bytes()[1:5]).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) - }) - - It("writes a multiple byte StreamID, after the Stream length was already determined by MinLenght()", func() { - b := &bytes.Buffer{} - frame := &StreamFrame{ - StreamID: 0xdecafbad, - Data: []byte("foobar"), - } - frame.MinLength(0) - err := frame.Write(b, versionBigEndian) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x3))) - Expect(b.Bytes()[1:5]).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) - }) - }) - }) - }) - - Context("shortening of StreamIDs", func() { - It("determines the length of a 1 byte StreamID", func() { - f := &StreamFrame{StreamID: 0xFF} - Expect(f.calculateStreamIDLength()).To(Equal(uint8(1))) - }) - - It("determines the length of a 2 byte StreamID", func() { - f := &StreamFrame{StreamID: 0xFFFF} - Expect(f.calculateStreamIDLength()).To(Equal(uint8(2))) - }) - - It("determines the length of a 3 byte StreamID", func() { - f := &StreamFrame{StreamID: 0xFFFFFF} - Expect(f.calculateStreamIDLength()).To(Equal(uint8(3))) - }) - - It("determines the length of a 4 byte StreamID", func() { - f := &StreamFrame{StreamID: 0xFFFFFFFF} - Expect(f.calculateStreamIDLength()).To(Equal(uint8(4))) - }) - }) - - Context("shortening of Offsets", func() { - It("determines length 0 of offset 0", func() { - f := &StreamFrame{Offset: 0} - Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(0))) - }) - - It("determines the length of a 2 byte offset", func() { - f := &StreamFrame{Offset: 0xFFFF} - Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(2))) - }) - - It("determines the length of a 2 byte offset, even if it would fit into 1 byte", func() { - f := &StreamFrame{Offset: 0x1} - Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(2))) - }) - - It("determines the length of a 3 byte offset", func() { - f := &StreamFrame{Offset: 0xFFFFFF} - Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(3))) - }) - - It("determines the length of a 4 byte offset", func() { - f := &StreamFrame{Offset: 0xFFFFFFFF} - Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(4))) - }) - - It("determines the length of a 5 byte offset", func() { - f := &StreamFrame{Offset: 0xFFFFFFFFFF} - Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(5))) - }) - - It("determines the length of a 6 byte offset", func() { - f := &StreamFrame{Offset: 0xFFFFFFFFFFFF} - Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(6))) - }) - - It("determines the length of a 7 byte offset", func() { - f := &StreamFrame{Offset: 0xFFFFFFFFFFFFFF} - Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(7))) - }) - - It("determines the length of an 8 byte offset", func() { - f := &StreamFrame{Offset: 0xFFFFFFFFFFFFFFFF} - Expect(f.getOffsetLength()).To(Equal(protocol.ByteCount(8))) - }) - }) - - Context("DataLen", func() { - It("determines the length of the data", func() { - frame := StreamFrame{ - Data: []byte("foobar"), + It("has the right length for a frame with data length", func() { + f := &StreamFrame{ + StreamID: 0x1337, + Offset: 0x1234567, + DataLenPresent: true, + Data: []byte("foobar"), } - Expect(frame.DataLen()).To(Equal(protocol.ByteCount(6))) + Expect(f.MinLength(versionIETFFrames)).To(Equal(1 + utils.VarIntLen(0x1337) + utils.VarIntLen(0x1234567) + utils.VarIntLen(6))) }) }) }) diff --git a/internal/wire/wire_suite_test.go b/internal/wire/wire_suite_test.go index b0145420..491a93ba 100644 --- a/internal/wire/wire_suite_test.go +++ b/internal/wire/wire_suite_test.go @@ -1,6 +1,8 @@ package wire import ( + "bytes" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" @@ -21,6 +23,12 @@ const ( versionIETFFrames = protocol.VersionTLS ) +func encodeVarInt(i uint64) []byte { + b := &bytes.Buffer{} + utils.WriteVarInt(b, i) + return b.Bytes() +} + var _ = BeforeSuite(func() { Expect(utils.GetByteOrder(versionBigEndian)).To(Equal(utils.BigEndian)) Expect(utils.GetByteOrder(versionIETFFrames)).To(Equal(utils.BigEndian)) diff --git a/packet_packer.go b/packet_packer.go index 603f692f..ed75db2f 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -236,10 +236,16 @@ func (p *packetPacker) composeNextPacket( return payloadFrames, nil } - // temporarily increase the maxFrameSize by 2 bytes + // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field // this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set - // however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size - maxFrameSize += 2 + // however, for the last StreamFrame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size + // for gQUIC STREAM frames, DataLen is always 2 bytes + // for IETF draft style STREAM frames, the length is encoded to either 1 or 2 bytes + if p.version.UsesIETFFrameFormat() { + maxFrameSize++ + } else { + maxFrameSize += 2 + } fs := p.streamFramer.PopStreamFrames(maxFrameSize - payloadLength) if len(fs) != 0 { diff --git a/packet_packer_test.go b/packet_packer_test.go index d2f8ab17..28669b26 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -60,9 +60,10 @@ var _ = Describe("Packet packer", func() { ) BeforeEach(func() { - cryptoStream = &stream{flowController: flowcontrol.NewStreamFlowController(1, false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)} - streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, protocol.VersionWhatever) - streamFramer = newStreamFramer(cryptoStream, streamsMap, nil) + version := versionGQUICFrames + cryptoStream = &stream{streamID: version.CryptoStreamID(), flowController: flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)} + streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) + streamFramer = newStreamFramer(cryptoStream, streamsMap, nil, versionGQUICFrames) packer = &packetPacker{ cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, @@ -73,8 +74,8 @@ var _ = Describe("Packet packer", func() { } publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen - packer.version = protocol.VersionWhatever packer.hasSentPacket = true + packer.version = version }) It("returns nil when no packet is queued", func() { @@ -93,7 +94,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b := &bytes.Buffer{} - f.Write(b, 0) + f.Write(b, packer.version) Expect(p.frames).To(HaveLen(1)) Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) }) @@ -327,7 +328,7 @@ var _ = Describe("Packet packer", func() { It("packs a lot of control frames into 2 packets if they don't fit into one", func() { blockedFrame := &wire.BlockedFrame{} - minLength, _ := blockedFrame.MinLength(0) + minLength, _ := blockedFrame.MinLength(packer.version) maxFramesPerPacket := int(maxFrameSize) / int(minLength) var controlFrames []wire.Frame for i := 0; i < maxFramesPerPacket+10; i++ { @@ -360,14 +361,14 @@ var _ = Describe("Packet packer", func() { Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2))) }) - Context("Stream Frame handling", func() { - It("does not splits a stream frame with maximum size", func() { + Context("STREAM Frame handling", func() { + It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() { f := &wire.StreamFrame{ Offset: 1, StreamID: 5, DataLenPresent: false, } - minLength, _ := f.MinLength(0) + minLength, _ := f.MinLength(packer.version) maxStreamFrameDataLen := maxFrameSize - minLength f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)) streamFramer.AddFrameForRetransmission(f) @@ -380,7 +381,30 @@ var _ = Describe("Packet packer", func() { Expect(payloadFrames).To(BeEmpty()) }) - It("correctly handles a stream frame with one byte less than maximum size", func() { + It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() { + packer.version = versionIETFFrames + streamFramer.version = versionIETFFrames + f := &wire.StreamFrame{ + Offset: 1, + StreamID: 5, + DataLenPresent: true, + } + minLength, _ := f.MinLength(packer.version) + // for IETF draft style STREAM frames, we don't know the size of the DataLen, because it is a variable length integer + // in the general case, we therefore use a STREAM frame that is 1 byte smaller than the maximum size + maxStreamFrameDataLen := maxFrameSize - minLength - 1 + f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)) + streamFramer.AddFrameForRetransmission(f) + payloadFrames, err := packer.composeNextPacket(maxFrameSize, true) + Expect(err).ToNot(HaveOccurred()) + Expect(payloadFrames).To(HaveLen(1)) + Expect(payloadFrames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + payloadFrames, err = packer.composeNextPacket(maxFrameSize, true) + Expect(err).ToNot(HaveOccurred()) + Expect(payloadFrames).To(BeEmpty()) + }) + + It("correctly handles a STREAM frame with one byte less than maximum size", func() { maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) - 1 f1 := &wire.StreamFrame{ StreamID: 5, @@ -405,7 +429,7 @@ var _ = Describe("Packet packer", func() { Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) }) - It("packs multiple small stream frames into single packet", func() { + It("packs multiple small STREAM frames into single packet", func() { f1 := &wire.StreamFrame{ StreamID: 5, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, @@ -437,12 +461,12 @@ var _ = Describe("Packet packer", func() { Expect(p.raw).To(ContainSubstring(string(f3.Data))) }) - It("splits one stream frame larger than maximum size", func() { + It("splits one STREAM frame larger than maximum size", func() { f := &wire.StreamFrame{ StreamID: 7, Offset: 1, } - minLength, _ := f.MinLength(0) + minLength, _ := f.MinLength(packer.version) maxStreamFrameDataLen := maxFrameSize - minLength f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200) streamFramer.AddFrameForRetransmission(f) @@ -461,7 +485,7 @@ var _ = Describe("Packet packer", func() { Expect(payloadFrames).To(BeEmpty()) }) - It("packs 2 stream frames that are too big for one packet correctly", func() { + It("packs 2 STREAM frames that are too big for one packet correctly", func() { maxStreamFrameDataLen := maxFrameSize - (1 + 1 + 2) f1 := &wire.StreamFrame{ StreamID: 5, @@ -496,12 +520,12 @@ var _ = Describe("Packet packer", func() { Expect(p).To(BeNil()) }) - It("packs a packet that has the maximum packet size when given a large enough stream frame", func() { + It("packs a packet that has the maximum packet size when given a large enough STREAM frame", func() { f := &wire.StreamFrame{ StreamID: 5, Offset: 1, } - minLength, _ := f.MinLength(0) + minLength, _ := f.MinLength(packer.version) f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header streamFramer.AddFrameForRetransmission(f) p, err := packer.PackPacket() @@ -510,12 +534,12 @@ var _ = Describe("Packet packer", func() { Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) }) - It("splits a stream frame larger than the maximum size", func() { + It("splits a STREAM frame larger than the maximum size", func() { f := &wire.StreamFrame{ StreamID: 5, Offset: 1, } - minLength, _ := f.MinLength(0) + minLength, _ := f.MinLength(packer.version) f.Data = bytes.Repeat([]byte{'f'}, int(maxFrameSize-minLength+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header streamFramer.AddFrameForRetransmission(f) diff --git a/packet_unpacker.go b/packet_unpacker.go index 5b8d90ae..3e186c14 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -77,8 +77,7 @@ func (u *packetUnpacker) parseFrame(r *bytes.Reader, typeByte byte, hdr *wire.He func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wire.Header) (wire.Frame, error) { var frame wire.Frame var err error - // TODO: implement the IETF STREAM frame - if typeByte&0x80 == 0x80 { + if typeByte&0xf8 == 0x10 { frame, err = wire.ParseStreamFrame(r, u.version) if err != nil { err = qerr.Error(qerr.InvalidStreamData, err.Error()) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index c1832adb..b7794a0b 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -377,11 +377,55 @@ var _ = Describe("Packet unpacker", func() { 0x04: qerr.InvalidWindowUpdateData, 0x05: qerr.InvalidWindowUpdateData, 0x09: qerr.InvalidBlockedData, + 0x10: qerr.InvalidStreamData, } { setData([]byte{b}) _, err := unpacker.Unpack(hdrBin, hdr, data) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(e)) } }) + + Context("unpacking STREAM frames", func() { + It("unpacks unencrypted STREAM frames on the crypto stream", func() { + unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionUnencrypted + f := &wire.StreamFrame{ + StreamID: versionIETFFrames.CryptoStreamID(), + Data: []byte("foobar"), + } + err := f.Write(buf, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]wire.Frame{f})) + }) + + It("unpacks encrypted STREAM frames on the crypto stream", func() { + unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionSecure + f := &wire.StreamFrame{ + StreamID: versionIETFFrames.CryptoStreamID(), + Data: []byte("foobar"), + } + err := f.Write(buf, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + setData(buf.Bytes()) + packet, err := unpacker.Unpack(hdrBin, hdr, data) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.frames).To(Equal([]wire.Frame{f})) + }) + + It("does not unpack unencrypted STREAM frames on higher streams", func() { + unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionUnencrypted + f := &wire.StreamFrame{ + StreamID: 3, + Data: []byte("foobar"), + } + err := f.Write(buf, versionIETFFrames) + Expect(err).ToNot(HaveOccurred()) + setData(buf.Bytes()) + _, err = unpacker.Unpack(hdrBin, hdr, data) + Expect(err).To(MatchError(qerr.Error(qerr.UnencryptedStreamData, "received unencrypted stream data on stream 3"))) + }) + }) }) }) diff --git a/session.go b/session.go index 439edaa9..7cc190bd 100644 --- a/session.go +++ b/session.go @@ -320,7 +320,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) - s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController, s.version) s.packer = newPacketPacker(s.connectionID, initialPacketNumber, diff --git a/stream_framer.go b/stream_framer.go index 8928e490..e16a01e9 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -9,6 +9,7 @@ import ( type streamFramer struct { streamsMap *streamsMap cryptoStream streamI + version protocol.VersionNumber connFlowController flowcontrol.ConnectionFlowController @@ -20,11 +21,13 @@ func newStreamFramer( cryptoStream streamI, streamsMap *streamsMap, cfc flowcontrol.ConnectionFlowController, + v protocol.VersionNumber, ) *streamFramer { return &streamFramer{ streamsMap: streamsMap, cryptoStream: cryptoStream, connFlowController: cfc, + version: v, } } @@ -63,7 +66,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str StreamID: f.cryptoStream.StreamID(), Offset: f.cryptoStream.GetWriteOffset(), } - frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error + frameHeaderBytes, _ := frame.MinLength(f.version) // can never error frame.Data = f.cryptoStream.GetDataForWriting(maxLen - frameHeaderBytes) return frame } @@ -73,7 +76,7 @@ func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount frame := f.retransmissionQueue[0] frame.DataLenPresent = true - frameHeaderLen, _ := frame.MinLength(protocol.VersionWhatever) // can never error + frameHeaderLen, _ := frame.MinLength(f.version) // can never error if currentLen+frameHeaderLen >= maxLen { break } @@ -106,7 +109,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] frame.StreamID = s.StreamID() frame.Offset = s.GetWriteOffset() // not perfect, but thread-safe since writeOffset is only written when getting data - frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error + frameHeaderBytes, _ := frame.MinLength(f.version) // can never error if currentLen+frameHeaderBytes > maxBytes { return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here } diff --git a/stream_framer_test.go b/stream_framer_test.go index e2d5e6a2..84897b9b 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -41,12 +41,12 @@ var _ = Describe("Stream Framer", func() { stream2 = mocks.NewMockStreamI(mockCtrl) stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() - streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, protocol.VersionWhatever) + streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) streamsMap.putStream(stream1) streamsMap.putStream(stream2) connFC = mocks.NewMockConnectionFlowController(mockCtrl) - framer = newStreamFramer(nil, streamsMap, connFC) + framer = newStreamFramer(nil, streamsMap, connFC, versionGQUICFrames) }) setNoData := func(str *mocks.MockStreamI) { @@ -227,7 +227,7 @@ var _ = Describe("Stream Framer", func() { origlen := retransmittedFrame2.DataLen() fs := framer.PopStreamFrames(6) Expect(fs).To(HaveLen(1)) - minLength, _ := fs[0].MinLength(0) + minLength, _ := fs[0].MinLength(framer.version) Expect(minLength + fs[0].DataLen()).To(Equal(protocol.ByteCount(6))) Expect(framer.retransmissionQueue[0].Data).To(HaveLen(int(origlen - fs[0].DataLen()))) Expect(framer.retransmissionQueue[0].Offset).To(Equal(fs[0].DataLen()))