diff --git a/frames/stream_frame.go b/frames/stream_frame.go index 647bdaf9..dc337691 100644 --- a/frames/stream_frame.go +++ b/frames/stream_frame.go @@ -14,7 +14,6 @@ import ( type StreamFrame struct { FinBit bool StreamID protocol.StreamID - streamIDLen protocol.ByteCount Offset protocol.ByteCount Data []byte DataLenPresent bool @@ -40,9 +39,9 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { if offsetLen != 0 { offsetLen++ } - frame.streamIDLen = protocol.ByteCount(typeByte&0x03 + 1) + streamIDLen := typeByte&0x03 + 1 - sid, err := utils.ReadUintN(r, uint8(frame.streamIDLen)) + sid, err := utils.ReadUintN(r, streamIDLen) if err != nil { return nil, err } @@ -104,14 +103,12 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err typeByte ^= (uint8(offsetLength) - 1) << 2 } - if f.streamIDLen == 0 { - f.calculateStreamIDLength() - } - typeByte ^= uint8(f.streamIDLen) - 1 + streamIDLen := f.calculateStreamIDLength() + typeByte ^= streamIDLen - 1 b.WriteByte(typeByte) - switch f.streamIDLen { + switch streamIDLen { case 1: b.WriteByte(uint8(f.StreamID)) case 2: @@ -153,16 +150,15 @@ func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) err return nil } -func (f *StreamFrame) calculateStreamIDLength() { +func (f *StreamFrame) calculateStreamIDLength() uint8 { if f.StreamID < (1 << 8) { - f.streamIDLen = 1 + return 1 } else if f.StreamID < (1 << 16) { - f.streamIDLen = 2 + return 2 } else if f.StreamID < (1 << 24) { - f.streamIDLen = 3 - } else { - f.streamIDLen = 4 + return 3 } + return 4 } func (f *StreamFrame) getOffsetLength() protocol.ByteCount { @@ -192,11 +188,7 @@ func (f *StreamFrame) getOffsetLength() protocol.ByteCount { // MinLength of a written frame func (f *StreamFrame) MinLength(protocol.VersionNumber) (protocol.ByteCount, error) { - if f.streamIDLen == 0 { - f.calculateStreamIDLength() - } - - length := protocol.ByteCount(1) + f.streamIDLen + f.getOffsetLength() + length := protocol.ByteCount(1) + protocol.ByteCount(f.calculateStreamIDLength()) + f.getOffsetLength() if f.DataLenPresent { length += 2 } diff --git a/frames/stream_frame_test.go b/frames/stream_frame_test.go index f013151d..35ac5cd0 100644 --- a/frames/stream_frame_test.go +++ b/frames/stream_frame_test.go @@ -248,16 +248,6 @@ var _ = Describe("StreamFrame", func() { }) Context("lengths of StreamIDs", func() { - It("returns an error for a non-valid StreamID length", func() { - b := &bytes.Buffer{} - err := (&StreamFrame{ - StreamID: 1, - streamIDLen: 13, - Data: []byte("foobar"), - }).Write(b, 0) - Expect(err).To(MatchError(errInvalidStreamIDLen)) - }) - It("writes a 1 byte StreamID", func() { b := &bytes.Buffer{} err := (&StreamFrame{ @@ -320,26 +310,22 @@ var _ = Describe("StreamFrame", func() { Context("shortening of StreamIDs", func() { It("determines the length of a 1 byte StreamID", func() { f := &StreamFrame{StreamID: 0xFF} - f.calculateStreamIDLength() - Expect(f.streamIDLen).To(Equal(protocol.ByteCount(1))) + Expect(f.calculateStreamIDLength()).To(Equal(uint8(1))) }) It("determines the length of a 2 byte StreamID", func() { f := &StreamFrame{StreamID: 0xFFFF} - f.calculateStreamIDLength() - Expect(f.streamIDLen).To(Equal(protocol.ByteCount(2))) + Expect(f.calculateStreamIDLength()).To(Equal(uint8(2))) }) It("determines the length of a 3 byte StreamID", func() { f := &StreamFrame{StreamID: 0xFFFFFF} - f.calculateStreamIDLength() - Expect(f.streamIDLen).To(Equal(protocol.ByteCount(3))) + Expect(f.calculateStreamIDLength()).To(Equal(uint8(3))) }) It("determines the length of a 4 byte StreamID", func() { f := &StreamFrame{StreamID: 0xFFFFFFFF} - f.calculateStreamIDLength() - Expect(f.streamIDLen).To(Equal(protocol.ByteCount(4))) + Expect(f.calculateStreamIDLength()).To(Equal(uint8(4))) }) }) diff --git a/integrationtests/chrome_test.go b/integrationtests/chrome_test.go index f4f018a9..55fae526 100644 --- a/integrationtests/chrome_test.go +++ b/integrationtests/chrome_test.go @@ -45,7 +45,7 @@ var _ = Describe("Chrome tests", func() { close(done) }, 5) - PIt("loads a large number of files", func(done Done) { + It("loads a large number of files", func(done Done) { err := wd.Get("https://quic.clemente.io/tiles") Expect(err).NotTo(HaveOccurred()) Eventually(func() error {