diff --git a/frames/stream_frame.go b/frames/stream_frame.go index 8f5d9323..0be80904 100644 --- a/frames/stream_frame.go +++ b/frames/stream_frame.go @@ -10,10 +10,11 @@ import ( // A StreamFrame of QUIC type StreamFrame struct { - FinBit bool - StreamID protocol.StreamID - Offset protocol.ByteCount - Data []byte + FinBit bool + StreamID protocol.StreamID + streamIDLen protocol.ByteCount + Offset protocol.ByteCount + Data []byte } // ParseStreamFrame reads a stream frame. The type byte must not have been read yet. @@ -31,9 +32,9 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { if offsetLen != 0 { offsetLen++ } - streamIDLen := typeByte&0x03 + 1 + frame.streamIDLen = protocol.ByteCount(typeByte&0x03 + 1) - sid, err := utils.ReadUintN(r, streamIDLen) + sid, err := utils.ReadUintN(r, uint8(frame.streamIDLen)) if err != nil { return nil, err } @@ -71,17 +72,43 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { // WriteStreamFrame writes a stream frame. func (f *StreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen, version protocol.VersionNumber) error { - typeByte := uint8(0x80) + typeByte := uint8(0x80) // sets the leftmost bit to 1 if f.FinBit { typeByte ^= 0x40 } - typeByte ^= 0x20 + typeByte ^= 0x20 // dataLenPresent if f.Offset != 0 { typeByte ^= 0x1c // TODO: Send shorter offset if possible } - typeByte ^= 0x03 // TODO: Send shorter stream ID if possible + + if f.streamIDLen == 0 { + f.calculateStreamIDLength() + } + + switch f.streamIDLen { + case 1: + typeByte ^= 0x0 + case 2: + typeByte ^= 0x01 + case 3: + typeByte ^= 0x02 + case 4: + typeByte ^= 0x03 + } + b.WriteByte(typeByte) - utils.WriteUint32(b, uint32(f.StreamID)) + + switch f.streamIDLen { + case 1: + b.WriteByte(uint8(f.StreamID)) + case 2: + utils.WriteUint16(b, uint16(f.StreamID)) + case 3: + utils.WriteUint24(b, uint32(f.StreamID)) + case 4: + utils.WriteUint32(b, uint32(f.StreamID)) + } + if f.Offset != 0 { utils.WriteUint64(b, uint64(f.Offset)) } @@ -90,9 +117,25 @@ func (f *StreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, return nil } +func (f *StreamFrame) calculateStreamIDLength() { + if f.StreamID < (1 << 8) { + f.streamIDLen = 1 + } else if f.StreamID < (1 << 16) { + f.streamIDLen = 2 + } else if f.StreamID < (1 << 24) { + f.streamIDLen = 3 + } else { + f.streamIDLen = 4 + } +} + // MinLength of a written frame func (f *StreamFrame) MinLength() protocol.ByteCount { - return 1 + 4 + 8 + 2 + 1 + if f.streamIDLen == 0 { + f.calculateStreamIDLength() + } + + return 1 + f.streamIDLen + 8 + 2 + 1 } // MaybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(n), nil is returned and nothing is modified. diff --git a/frames/stream_frame_test.go b/frames/stream_frame_test.go index 1ed7790d..d6166989 100644 --- a/frames/stream_frame_test.go +++ b/frames/stream_frame_test.go @@ -38,7 +38,7 @@ var _ = Describe("StreamFrame", func() { StreamID: 1, Data: []byte("foobar"), }).Write(b, 1, protocol.PacketNumberLen6, 0) - Expect(b.Bytes()).To(Equal([]byte{0xa3, 0x1, 0, 0, 0, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'})) + Expect(b.Bytes()).To(Equal([]byte{0xa0, 0x1, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'})) }) It("writes offsets", func() { @@ -48,7 +48,7 @@ var _ = Describe("StreamFrame", func() { Offset: 16, Data: []byte("foobar"), }).Write(b, 1, protocol.PacketNumberLen6, 0) - Expect(b.Bytes()).To(Equal([]byte{0xbf, 0x1, 0, 0, 0, 0x10, 0, 0, 0, 0, 0, 0, 0, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'})) + Expect(b.Bytes()).To(Equal([]byte{0xbc, 0x1, 0x10, 0, 0, 0, 0, 0, 0, 0, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'})) }) It("has proper min length", func() { @@ -61,6 +61,74 @@ var _ = Describe("StreamFrame", func() { f.Write(b, 1, protocol.PacketNumberLen6, 0) Expect(f.MinLength()).To(Equal(protocol.ByteCount(b.Len()))) }) + + Context("lengths of StreamIDs", func() { + It("writes a 2 byte StreamID", func() { + b := &bytes.Buffer{} + (&StreamFrame{ + StreamID: 13, + Data: []byte("foobar"), + }).Write(b, 1, protocol.PacketNumberLen6, 0) + Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x0))) + Expect(b.Bytes()[1]).To(Equal(uint8(13))) + }) + + It("writes a 2 byte StreamID", func() { + b := &bytes.Buffer{} + (&StreamFrame{ + StreamID: 0xCAFE, + Data: []byte("foobar"), + }).Write(b, 1, protocol.PacketNumberLen6, 0) + Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x1))) + Expect(b.Bytes()[1:3]).To(Equal([]byte{0xFE, 0xCA})) + }) + + It("writes a 3 byte StreamID", func() { + b := &bytes.Buffer{} + (&StreamFrame{ + StreamID: 0x13BEEF, + Data: []byte("foobar"), + }).Write(b, 1, protocol.PacketNumberLen6, 0) + Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x2))) + Expect(b.Bytes()[1:4]).To(Equal([]byte{0xEF, 0xBE, 0x13})) + }) + + It("writes a 4 byte StreamID", func() { + b := &bytes.Buffer{} + (&StreamFrame{ + StreamID: 0xDECAFBAD, + Data: []byte("foobar"), + }).Write(b, 1, protocol.PacketNumberLen6, 0) + Expect(b.Bytes()[0] & 0x3).To(Equal(uint8(0x3))) + Expect(b.Bytes()[1:5]).To(Equal([]byte{0xAD, 0xFB, 0xCA, 0xDE})) + }) + }) + }) + + Context("shortening of StreamIDs", func() { + It("determines the length of a 1 byte StreamID", func() { + f := &StreamFrame{StreamID: 0xFF} + f.calculateStreamIDLength() + Expect(f.streamIDLen).To(Equal(protocol.ByteCount(1))) + }) + + It("determines the length of a 2 byte StreamID", func() { + f := &StreamFrame{StreamID: 0xFFFF} + f.calculateStreamIDLength() + Expect(f.streamIDLen).To(Equal(protocol.ByteCount(2))) + }) + + It("determines the length of a 1 byte StreamID", func() { + f := &StreamFrame{StreamID: 0xFFFFFF} + f.calculateStreamIDLength() + Expect(f.streamIDLen).To(Equal(protocol.ByteCount(3))) + }) + + It("determines the length of a 1 byte StreamID", func() { + f := &StreamFrame{StreamID: 0xFFFFFFFF} + f.calculateStreamIDLength() + Expect(f.streamIDLen).To(Equal(protocol.ByteCount(4))) + }) }) Context("splitting off earlier stream frames", func() { diff --git a/packet_packer_test.go b/packet_packer_test.go index 55f427e8..7956ad7f 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -188,11 +188,12 @@ var _ = Describe("Packet packer", func() { It("splits one stream frame larger than maximum size", func() { publicHeaderLength := protocol.ByteCount(5) - maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength - (1 + 4 + 8 + 2) f := frames.StreamFrame{ - Data: bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200), - Offset: 1, + StreamID: 7, + Offset: 1, } + maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength - f.MinLength() + 1 // + 1 since MinceLength is 1 bigger than the actual StreamFrame header + f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200) packer.AddStreamFrame(f) payloadFrames, err := packer.composeNextPacket(nil, []frames.Frame{}, publicHeaderLength, true) Expect(err).ToNot(HaveOccurred()) @@ -250,9 +251,11 @@ var _ = Describe("Packet packer", func() { It("splits a stream frame larger than the maximum size", func() { publicHeaderLength := protocol.ByteCount(13) f := frames.StreamFrame{ - Data: bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLength-(1+4+8+2)+1)), - Offset: 1, + StreamID: 5, + Offset: 1, } + f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLength-f.MinLength()+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header + packer.AddStreamFrame(f) payloadFrames, err := packer.composeNextPacket(nil, []frames.Frame{}, publicHeaderLength, true) Expect(err).ToNot(HaveOccurred()) diff --git a/session_test.go b/session_test.go index c9db8038..6b0c999b 100644 --- a/session_test.go +++ b/session_test.go @@ -420,12 +420,12 @@ var _ = Describe("Session", func() { }) It("should call OnSent", func() { - session.QueueStreamFrame(&frames.StreamFrame{}) + session.QueueStreamFrame(&frames.StreamFrame{StreamID: 5}) session.sendPacket() Expect(cong.nCalls).To(Equal(2)) // OnPacketSent + GetCongestionWindow - Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(30))) + Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(27))) Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1))) - Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(30))) + Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(27))) Expect(cong.argsOnPacketSent[4]).To(BeTrue()) }) diff --git a/utils/utils.go b/utils/utils.go index 829d28de..30b32307 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -127,6 +127,13 @@ func WriteUint32(b *bytes.Buffer, i uint32) { b.WriteByte(uint8((i >> 24) & 0xff)) } +// WriteUint24 writes 24 bit of a uint32 +func WriteUint24(b *bytes.Buffer, i uint32) { + b.WriteByte(uint8(i & 0xff)) + b.WriteByte(uint8((i >> 8) & 0xff)) + b.WriteByte(uint8((i >> 16) & 0xff)) +} + // WriteUint16 writes a uint16 func WriteUint16(b *bytes.Buffer, i uint16) { b.WriteByte(uint8(i & 0xff)) diff --git a/utils/utils_test.go b/utils/utils_test.go index 0f56599d..0d69342a 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -69,6 +69,21 @@ var _ = Describe("Utils", func() { }) }) + Context("WriteUint24", func() { + It("outputs 3 bytes", func() { + b := &bytes.Buffer{} + WriteUint24(b, uint32(1)) + Expect(b.Len()).To(Equal(3)) + }) + + It("outputs a little endian", func() { + num := uint32(0xEFAC3512) + b := &bytes.Buffer{} + WriteUint24(b, num) + Expect(b.Bytes()).To(Equal([]byte{0x12, 0x35, 0xAC})) + }) + }) + Context("WriteUint32", func() { It("outputs 4 bytes", func() { b := &bytes.Buffer{}