diff --git a/crypto_stream.go b/crypto_stream.go index bc75a18a..3886231d 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -12,10 +12,9 @@ type cryptoStreamI interface { io.Reader io.Writer AddStreamFrame(*wire.StreamFrame) error - HasDataForWriting() bool - GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool) - GetWriteOffset() protocol.ByteCount + PopStreamFrame(protocol.ByteCount) *wire.StreamFrame Cancel(error) + HasDataForWriting() bool SetReadOffset(protocol.ByteCount) // methods needed for flow control GetWindowUpdate() protocol.ByteCount @@ -27,6 +26,8 @@ type cryptoStream struct { *stream } +var _ cryptoStreamI = &cryptoStream{} + func newCryptoStream(onData func(), flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI { str := newStream(version.CryptoStreamID(), onData, nil, flowController, version) return &cryptoStream{str} @@ -39,3 +40,10 @@ func (s *cryptoStream) SetReadOffset(offset protocol.ByteCount) { s.readOffset = offset s.frameQueue.readPosition = offset } + +func (s *cryptoStream) HasDataForWriting() bool { + s.mutex.Lock() + hasData := s.dataForWriting != nil + s.mutex.Unlock() + return hasData +} diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 010db203..4b8f8da3 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -10,11 +10,23 @@ import ( var _ = Describe("Stream", func() { var str *cryptoStream - str = newCryptoStream(nil, nil, protocol.VersionWhatever).(*cryptoStream) + str = newCryptoStream(func() {}, nil, protocol.VersionWhatever).(*cryptoStream) It("sets the read offset", func() { str.SetReadOffset(0x42) Expect(str.readOffset).To(Equal(protocol.ByteCount(0x42))) Expect(str.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42))) }) + + It("says if it has data for writing", func() { + Expect(str.HasDataForWriting()).To(BeFalse()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + Eventually(str.HasDataForWriting).Should(BeTrue()) + }) }) diff --git a/internal/mocks/stream.go b/internal/mocks/stream.go index 1e6b4465..ab1829eb 100644 --- a/internal/mocks/stream.go +++ b/internal/mocks/stream.go @@ -94,19 +94,6 @@ func (_mr *MockStreamIMockRecorder) Finished() *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Finished", reflect.TypeOf((*MockStreamI)(nil).Finished)) } -// GetDataForWriting mocks base method -func (_m *MockStreamI) GetDataForWriting(_param0 protocol.ByteCount) ([]byte, bool) { - ret := _m.ctrl.Call(_m, "GetDataForWriting", _param0) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// GetDataForWriting indicates an expected call of GetDataForWriting -func (_mr *MockStreamIMockRecorder) GetDataForWriting(arg0 interface{}) *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetDataForWriting", reflect.TypeOf((*MockStreamI)(nil).GetDataForWriting), arg0) -} - // GetWindowUpdate mocks base method func (_m *MockStreamI) GetWindowUpdate() protocol.ByteCount { ret := _m.ctrl.Call(_m, "GetWindowUpdate") @@ -119,30 +106,6 @@ func (_mr *MockStreamIMockRecorder) GetWindowUpdate() *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).GetWindowUpdate)) } -// GetWriteOffset mocks base method -func (_m *MockStreamI) GetWriteOffset() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetWriteOffset") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetWriteOffset indicates an expected call of GetWriteOffset -func (_mr *MockStreamIMockRecorder) GetWriteOffset() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWriteOffset", reflect.TypeOf((*MockStreamI)(nil).GetWriteOffset)) -} - -// HasDataForWriting mocks base method -func (_m *MockStreamI) HasDataForWriting() bool { - ret := _m.ctrl.Call(_m, "HasDataForWriting") - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasDataForWriting indicates an expected call of HasDataForWriting -func (_mr *MockStreamIMockRecorder) HasDataForWriting() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "HasDataForWriting", reflect.TypeOf((*MockStreamI)(nil).HasDataForWriting)) -} - // IsFlowControlBlocked mocks base method func (_m *MockStreamI) IsFlowControlBlocked() bool { ret := _m.ctrl.Call(_m, "IsFlowControlBlocked") @@ -155,6 +118,18 @@ func (_mr *MockStreamIMockRecorder) IsFlowControlBlocked() *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "IsFlowControlBlocked", reflect.TypeOf((*MockStreamI)(nil).IsFlowControlBlocked)) } +// PopStreamFrame mocks base method +func (_m *MockStreamI) PopStreamFrame(_param0 protocol.ByteCount) *wire.StreamFrame { + ret := _m.ctrl.Call(_m, "PopStreamFrame", _param0) + ret0, _ := ret[0].(*wire.StreamFrame) + return ret0 +} + +// PopStreamFrame indicates an expected call of PopStreamFrame +func (_mr *MockStreamIMockRecorder) PopStreamFrame(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "PopStreamFrame", reflect.TypeOf((*MockStreamI)(nil).PopStreamFrame), arg0) +} + // Read mocks base method func (_m *MockStreamI) Read(_param0 []byte) (int, error) { ret := _m.ctrl.Call(_m, "Read", _param0) diff --git a/packet_packer.go b/packet_packer.go index badbc81f..484a08fd 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -176,7 +176,9 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { return nil, err } maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength - frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} + sf := p.streamFramer.PopCryptoStreamFrame(maxLen) + sf.DataLenPresent = false + frames := []wire.Frame{sf} raw, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { return nil, err diff --git a/session_test.go b/session_test.go index e02b1761..ea2394a2 100644 --- a/session_test.go +++ b/session_test.go @@ -1036,7 +1036,7 @@ var _ = Describe("Session", func() { close(done) }() Eventually(sess.sendingScheduled).Should(Receive()) - s.(*stream).GetDataForWriting(1000) // unblock + s.(*stream).PopStreamFrame(1000) // unblock }) It("sets the timer to the ack timer", func() { @@ -1441,8 +1441,8 @@ var _ = Describe("Session", func() { s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) Expect(s.Close()).To(Succeed()) - _, sentFin := s.(*stream).GetDataForWriting(1000) // trigger "sending" of the FIN bit - Expect(sentFin).To(BeTrue()) + f := s.(*stream).PopStreamFrame(1000) // trigger "sending" of the FIN bit + Expect(f.FinBit).To(BeTrue()) s.(*stream).CloseRemote(0) _, err = s.Read([]byte("a")) Expect(err).To(MatchError(io.EOF)) diff --git a/stream.go b/stream.go index d1b95f03..cb24f6d7 100644 --- a/stream.go +++ b/stream.go @@ -19,9 +19,7 @@ type streamI interface { AddStreamFrame(*wire.StreamFrame) error RegisterRemoteError(error, protocol.ByteCount) error - HasDataForWriting() bool - GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool) - GetWriteOffset() protocol.ByteCount + PopStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFrame Finished() bool Cancel(error) // methods needed for flow control @@ -256,29 +254,37 @@ func (s *stream) GetWriteOffset() protocol.ByteCount { return s.writeOffset } -// HasDataForWriting says if there's stream available to be dequeued for writing -func (s *stream) HasDataForWriting() bool { - s.mutex.Lock() - hasData := s.err == nil && // nothing should be sent if an error occurred - (len(s.dataForWriting) > 0 || // there is data queued for sending - s.finishedWriting.Get() && !s.finSent.Get()) // if there is no data, but writing finished and the FIN hasn't been sent yet - s.mutex.Unlock() - return hasData -} - -func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { - data, shouldSendFin := s.getDataForWritingImpl(maxBytes) - if shouldSendFin { - s.finSent.Set(true) - } - return data, shouldSendFin -} - -func (s *stream) getDataForWritingImpl(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { +// PopStreamFrame returns the next STREAM frame that is supposed to be sent on this stream +// maxBytes is the maximum length this frame (including frame header) will have. +func (s *stream) PopStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFrame { s.mutex.Lock() defer s.mutex.Unlock() - if s.err != nil || s.dataForWriting == nil { + if s.err != nil { + return nil + } + + frame := &wire.StreamFrame{ + StreamID: s.streamID, + Offset: s.writeOffset, + DataLenPresent: true, + } + frameLen := frame.MinLength(s.version) + if frameLen >= maxBytes { // a STREAM frame must have at least one byte of data + return nil + } + frame.Data, frame.FinBit = s.getDataForWriting(maxBytes - frameLen) + if len(frame.Data) == 0 && !frame.FinBit { + return nil + } + if frame.FinBit { + s.finSent.Set(true) + } + return frame +} + +func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { + if s.dataForWriting == nil { return nil, s.finishedWriting.Get() && !s.finSent.Get() } diff --git a/stream_framer.go b/stream_framer.go index ded71ec6..9244bc3e 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -59,15 +59,7 @@ func (f *streamFramer) HasCryptoStreamFrame() bool { // TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { - if !f.HasCryptoStreamFrame() { - return nil - } - frame := &wire.StreamFrame{ - StreamID: f.version.CryptoStreamID(), - Offset: f.cryptoStream.GetWriteOffset(), - } - frame.Data, frame.FinBit = f.cryptoStream.GetDataForWriting(maxLen - frame.MinLength(f.version)) - return frame + return f.cryptoStream.PopStreamFrame(maxLen) } func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { @@ -96,7 +88,6 @@ func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.Byte } func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame) { - frame := &wire.StreamFrame{DataLenPresent: true} var currentLen protocol.ByteCount fn := func(s streamI) (bool, error) { @@ -104,23 +95,12 @@ func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) (res return true, nil } - frame.StreamID = s.StreamID() - // not perfect, but thread-safe since writeOffset is only written when getting data - frame.Offset = s.GetWriteOffset() - - frameHeaderBytes := frame.MinLength(f.version) - if currentLen+frameHeaderBytes > maxTotalLen { - return false, nil // theoretically, we could find another stream that fits, but this is quite unlikely, so we stop here - } maxLen := maxTotalLen - currentLen if maxLen < protocol.MinStreamFrameSize { // don't try to add new STREAM frames, if only little space is left in the packet return false, nil } - - if s.HasDataForWriting() { - frame.Data, frame.FinBit = s.GetDataForWriting(maxLen - frameHeaderBytes) - } - if len(frame.Data) == 0 && !frame.FinBit { + frame := s.PopStreamFrame(maxLen) + if frame == nil { return true, nil } @@ -133,7 +113,7 @@ func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) (res } res = append(res, frame) - currentLen += frameHeaderBytes + frame.DataLen() + currentLen += frame.MinLength(f.version) + frame.DataLen() if currentLen == maxTotalLen { return false, nil diff --git a/stream_framer_test.go b/stream_framer_test.go index 4403ae66..cb5ba7c3 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -26,6 +26,10 @@ var _ = Describe("Stream Framer", func() { connFC *mocks.MockConnectionFlowController ) + setNoData := func(str *mocks.MockStreamI) { + str.EXPECT().PopStreamFrame(gomock.Any()).AnyTimes() + } + BeforeEach(func() { retransmittedFrame1 = &wire.StreamFrame{ StreamID: 5, @@ -49,12 +53,6 @@ var _ = Describe("Stream Framer", func() { framer = newStreamFramer(nil, streamsMap, connFC, versionGQUICFrames) }) - setNoData := func(str *mocks.MockStreamI) { - str.EXPECT().HasDataForWriting().Return(false).AnyTimes() - str.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, false).AnyTimes() - str.EXPECT().GetWriteOffset().AnyTimes() - } - It("says if it has retransmissions", func() { Expect(framer.HasFramesForRetransmission()).To(BeFalse()) framer.AddFrameForRetransmission(retransmittedFrame1) @@ -70,18 +68,6 @@ var _ = Describe("Stream Framer", func() { Expect(fs[0].DataLenPresent).To(BeTrue()) }) - It("sets the DataLenPresent for dequeued normal frames", func() { - connFC.EXPECT().IsBlocked() - setNoData(stream2) - stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().IsFlowControlBlocked() - fs := framer.PopStreamFrames(protocol.MaxByteCount) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].DataLenPresent).To(BeTrue()) - }) - Context("Popping", func() { BeforeEach(func() { // nothing is blocked here @@ -102,15 +88,12 @@ var _ = Describe("Stream Framer", func() { framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame2) fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(2)) - Expect(fs[0]).To(Equal(retransmittedFrame1)) - Expect(fs[1]).To(Equal(retransmittedFrame2)) + Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, retransmittedFrame2})) + // make sure the frames are actually removed, and not returned a second time Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) }) - It("doesn't pop frames for retransmission, if the remaining space in the packet is too small, and the frame would be split", func() { - setNoData(stream1) - setNoData(stream2) + It("doesn't pop frames for retransmission, if the size would be smaller than the minimum STREAM frame size", func() { framer.AddFrameForRetransmission(&wire.StreamFrame{ StreamID: id1, Data: bytes.Repeat([]byte{'a'}, int(protocol.MinStreamFrameSize)), @@ -128,109 +111,91 @@ var _ = Describe("Stream Framer", func() { }) It("pops frames for retransmission, if the remaining size is the miniumum STREAM frame size", func() { - setNoData(stream1) - setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) + fs := framer.PopStreamFrames(protocol.MinStreamFrameSize) + Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1})) }) It("returns normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() setNoData(stream2) + f := &wire.StreamFrame{ + StreamID: id1, + Data: []byte("foobar"), + Offset: 42, + } + stream1.EXPECT().PopStreamFrame(gomock.Any()).Return(f) fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) - Expect(fs[0].Data).To(Equal([]byte("foobar"))) - Expect(fs[0].FinBit).To(BeFalse()) + Expect(fs).To(Equal([]*wire.StreamFrame{f})) }) It("returns multiple normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - stream2.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobaz"), false) - stream2.EXPECT().HasDataForWriting().Return(true) - stream2.EXPECT().GetWriteOffset() + f1 := &wire.StreamFrame{Data: []byte("foobar")} + f2 := &wire.StreamFrame{Data: []byte("foobaz")} + stream1.EXPECT().PopStreamFrame(gomock.Any()).Return(f1) + stream2.EXPECT().PopStreamFrame(gomock.Any()).Return(f2) fs := framer.PopStreamFrames(1000) Expect(fs).To(HaveLen(2)) - // Swap if we dequeued in other order - if fs[0].StreamID != stream1.StreamID() { - fs[0], fs[1] = fs[1], fs[0] - } - Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) - Expect(fs[0].Data).To(Equal([]byte("foobar"))) - Expect(fs[1].StreamID).To(Equal(stream2.StreamID())) - Expect(fs[1].Data).To(Equal([]byte("foobaz"))) + Expect(fs).To(ContainElement(f1)) + Expect(fs).To(ContainElement(f2)) }) It("returns retransmission frames before normal frames", func() { - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() setNoData(stream2) + f1 := &wire.StreamFrame{Data: []byte("foobar")} + stream1.EXPECT().PopStreamFrame(gomock.Any()).Return(f1) framer.AddFrameForRetransmission(retransmittedFrame1) fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(2)) - Expect(fs[0]).To(Equal(retransmittedFrame1)) - Expect(fs[1].StreamID).To(Equal(stream1.StreamID())) + Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, f1})) }) It("does not pop empty frames", func() { - stream1.EXPECT().HasDataForWriting().Return(false) - stream1.EXPECT().GetWriteOffset() + setNoData(stream1) setNoData(stream2) fs := framer.PopStreamFrames(500) Expect(fs).To(BeEmpty()) }) It("pops frames that have the minimum size", func() { - streamFrameHeaderLen := protocol.ByteCount(4) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().GetDataForWriting(protocol.MinStreamFrameSize-streamFrameHeaderLen).Return(bytes.Repeat([]byte{'f'}, int(protocol.MinStreamFrameSize-streamFrameHeaderLen)), false) - setNoData(stream2) - fs := framer.PopStreamFrames(protocol.MinStreamFrameSize) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].DataLen()).To(Equal(protocol.MinStreamFrameSize - streamFrameHeaderLen)) + stream1.EXPECT().PopStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}) + framer.PopStreamFrames(protocol.MinStreamFrameSize) }) It("does not pop frames smaller than the mimimum size", func() { - setNoData(stream1) - setNoData(stream2) - fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) - Expect(fs).To(BeEmpty()) + // don't expect a call to PopStreamFrame() + framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) }) It("uses the round-robin scheduling", func() { - streamFrameHeaderLen := protocol.ByteCount(4) - stream1.EXPECT().GetDataForWriting(1000-streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(1000-streamFrameHeaderLen)), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - stream2.EXPECT().GetDataForWriting(protocol.ByteCount(1000-streamFrameHeaderLen)).Return(bytes.Repeat([]byte("e"), int(1000-streamFrameHeaderLen)), false) - stream2.EXPECT().HasDataForWriting().Return(true) - stream2.EXPECT().GetWriteOffset() - fs := framer.PopStreamFrames(1000) + stream1.EXPECT().PopStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ + StreamID: id1, + Data: []byte("foobar"), + }) + stream1.EXPECT().PopStreamFrame(gomock.Any()).MaxTimes(1) + stream2.EXPECT().PopStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ + StreamID: id2, + Data: []byte("foobaz"), + }) + stream2.EXPECT().PopStreamFrame(gomock.Any()).MaxTimes(1) + fs := framer.PopStreamFrames(protocol.MinStreamFrameSize) Expect(fs).To(HaveLen(1)) // it doesn't matter here if this data is from stream1 or from stream2... firstStreamID := fs[0].StreamID - fs = framer.PopStreamFrames(1000) + fs = framer.PopStreamFrames(protocol.MinStreamFrameSize) Expect(fs).To(HaveLen(1)) // ... but the data popped this time has to be from the other stream Expect(fs[0].StreamID).ToNot(Equal(firstStreamID)) }) It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() { - streamFrameHeaderLen := protocol.ByteCount(4) // pop a frame such that the remaining size is one byte less than the minimum STREAM frame size - stream1.EXPECT().GetDataForWriting(1000-streamFrameHeaderLen).Return(bytes.Repeat([]byte("f"), int(1000-streamFrameHeaderLen-protocol.MinStreamFrameSize+1)), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() + f := &wire.StreamFrame{ + StreamID: id1, + Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)), + } + stream1.EXPECT().PopStreamFrame(protocol.ByteCount(500)).Return(f) setNoData(stream2) - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) + fs := framer.PopStreamFrames(500) + Expect(fs).To(Equal([]*wire.StreamFrame{f})) }) Context("splitting of frames", func() { @@ -268,8 +233,6 @@ var _ = Describe("Stream Framer", func() { }) It("splits a frame", func() { - setNoData(stream1) - setNoData(stream2) frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, 600)} framer.AddFrameForRetransmission(frame) fs := framer.PopStreamFrames(500) @@ -281,9 +244,9 @@ var _ = Describe("Stream Framer", func() { }) It("only removes a frame from the framer after returning all split parts", func() { - frameHeaderLen := protocol.ByteCount(4) setNoData(stream1) setNoData(stream2) + frameHeaderLen := protocol.ByteCount(4) frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, int(501-frameHeaderLen))} framer.AddFrameForRetransmission(frame) fs := framer.PopStreamFrames(500) @@ -295,37 +258,6 @@ var _ = Describe("Stream Framer", func() { Expect(framer.retransmissionQueue).To(BeEmpty()) }) }) - - Context("sending FINs", func() { - It("sends FINs when streams are closed", func() { - offset := protocol.ByteCount(42) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return(nil, true) - stream1.EXPECT().GetWriteOffset().Return(offset) - setNoData(stream2) - - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) - Expect(fs[0].Offset).To(Equal(offset)) - Expect(fs[0].FinBit).To(BeTrue()) - Expect(fs[0].Data).To(BeEmpty()) - }) - - It("bundles FINs with data", func() { - offset := protocol.ByteCount(42) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), true) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset().Return(offset) - setNoData(stream2) - - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(1)) - Expect(fs[0].StreamID).To(Equal(stream1.StreamID())) - Expect(fs[0].Data).To(Equal([]byte("foobar"))) - Expect(fs[0].FinBit).To(BeTrue()) - }) - }) }) Context("BLOCKED frames", func() { @@ -334,12 +266,13 @@ var _ = Describe("Stream Framer", func() { }) It("queues and pops BLOCKED frames for individually blocked streams", func() { - connFC.EXPECT().IsBlocked() - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foobar"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().IsFlowControlBlocked().Return(true) setNoData(stream2) + connFC.EXPECT().IsBlocked() + stream1.EXPECT().PopStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ + StreamID: id1, + Data: []byte("foobar"), + }) + stream1.EXPECT().IsFlowControlBlocked().Return(true) frames := framer.PopStreamFrames(1000) Expect(frames).To(HaveLen(1)) f := framer.PopBlockedFrame() @@ -349,27 +282,30 @@ var _ = Describe("Stream Framer", func() { Expect(framer.PopBlockedFrame()).To(BeNil()) }) - It("does not queue a stream-level BLOCKED frame after sending the FinBit frame", func() { - connFC.EXPECT().IsBlocked() - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), true) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() + It("doesn't queue a stream-level BLOCKED frame after sending the FIN bit frame", func() { setNoData(stream2) + f := &wire.StreamFrame{ + StreamID: id1, + Data: []byte("foobar"), + FinBit: true, + } + connFC.EXPECT().IsBlocked() + stream1.EXPECT().PopStreamFrame(gomock.Any()).Return(f) + // no call to IsFlowControlBlocked() frames := framer.PopStreamFrames(1000) - Expect(frames).To(HaveLen(1)) - Expect(frames[0].FinBit).To(BeTrue()) - Expect(frames[0].DataLen()).To(Equal(protocol.ByteCount(3))) + Expect(frames).To(Equal([]*wire.StreamFrame{f})) blockedFrame := framer.PopBlockedFrame() Expect(blockedFrame).To(BeNil()) }) It("queues and pops BLOCKED frames for connection blocked streams", func() { - connFC.EXPECT().IsBlocked().Return(true) - stream1.EXPECT().GetDataForWriting(gomock.Any()).Return([]byte("foo"), false) - stream1.EXPECT().HasDataForWriting().Return(true) - stream1.EXPECT().GetWriteOffset() - stream1.EXPECT().IsFlowControlBlocked().Return(false) setNoData(stream2) + connFC.EXPECT().IsBlocked().Return(true) + stream1.EXPECT().PopStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ + StreamID: id1, + Data: []byte("foo"), + }) + stream1.EXPECT().IsFlowControlBlocked().Return(false) framer.PopStreamFrames(1000) f := framer.PopBlockedFrame() Expect(f).To(BeAssignableToTypeOf(&wire.BlockedFrame{})) diff --git a/stream_test.go b/stream_test.go index 72e990a3..e5e1cfe5 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,6 +1,7 @@ package quic import ( + "bytes" "errors" "io" "runtime" @@ -600,6 +601,7 @@ var _ = Describe("Stream", func() { }) It("returns how much was written when recieving a remote error", func() { + frameHeaderSize := protocol.ByteCount(4) mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(10), true) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) @@ -612,7 +614,10 @@ var _ = Describe("Stream", func() { close(done) }() - Eventually(func() []byte { data, _ := str.GetDataForWriting(4); return data }).ShouldNot(BeEmpty()) + var frame *wire.StreamFrame + Eventually(func() *wire.StreamFrame { frame = str.PopStreamFrame(4 + frameHeaderSize); return frame }).ShouldNot(BeNil()) + Expect(frame).ToNot(BeNil()) + Expect(frame.DataLen()).To(BeEquivalentTo(4)) str.RegisterRemoteError(testErr, 10) Eventually(done).Should(BeClosed()) }) @@ -635,8 +640,8 @@ var _ = Describe("Stream", func() { It("doesn't call onReset if it already sent a FIN", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(0), true) str.Close() - _, sentFin := str.GetDataForWriting(1000) - Expect(sentFin).To(BeTrue()) + f := str.PopStreamFrame(100) + Expect(f.FinBit).To(BeTrue()) str.RegisterRemoteError(testErr, 0) Expect(resetCalled).To(BeFalse()) }) @@ -672,7 +677,7 @@ var _ = Describe("Stream", func() { }() Consistently(done).ShouldNot(BeClosed()) str.Reset(testErr) - Expect(str.GetDataForWriting(6)).To(BeNil()) + Expect(str.PopStreamFrame(1000)).To(BeNil()) Eventually(done).Should(BeClosed()) }) @@ -681,7 +686,7 @@ var _ = Describe("Stream", func() { n, err := strWithTimeout.Write([]byte("foobar")) Expect(n).To(BeZero()) Expect(err).To(MatchError(testErr)) - Expect(str.GetDataForWriting(6)).To(BeNil()) + Expect(str.PopStreamFrame(1000)).To(BeNil()) }) It("stops reading", func() { @@ -721,8 +726,8 @@ var _ = Describe("Stream", func() { It("doesn't call onReset if it already sent a FIN", func() { str.Close() - _, sentFin := str.GetDataForWriting(1000) - Expect(sentFin).To(BeTrue()) + f := str.PopStreamFrame(1000) + Expect(f.FinBit).To(BeTrue()) str.Reset(testErr) Expect(resetCalled).To(BeFalse()) }) @@ -771,18 +776,20 @@ var _ = Describe("Stream", func() { }).Should(Equal([]byte("foobar"))) Consistently(done).ShouldNot(BeClosed()) Expect(onDataCalled).To(BeTrue()) - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(1000) - Expect(data).To(Equal([]byte("foobar"))) - Expect(sendFin).To(BeFalse()) + f := str.PopStreamFrame(1000) + Expect(f.Data).To(Equal([]byte("foobar"))) + Expect(f.FinBit).To(BeFalse()) + Expect(f.Offset).To(BeZero()) + Expect(f.DataLenPresent).To(BeTrue()) Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) Expect(str.dataForWriting).To(BeNil()) Eventually(done).Should(BeClosed()) }) It("writes and gets data in two turns", func() { + frameHeaderLen := protocol.ByteCount(4) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any() /* protocol.ByteCount(3)*/).Times(2) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -797,29 +804,29 @@ var _ = Describe("Stream", func() { return str.dataForWriting }).Should(Equal([]byte("foobar"))) Consistently(done).ShouldNot(BeClosed()) - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(3) - Expect(data).To(Equal([]byte("foo"))) - Expect(sendFin).To(BeFalse()) - Expect(str.writeOffset).To(Equal(protocol.ByteCount(3))) - Expect(str.dataForWriting).ToNot(BeNil()) - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin = str.GetDataForWriting(3) - Expect(data).To(Equal([]byte("bar"))) - Expect(sendFin).To(BeFalse()) - Expect(str.writeOffset).To(Equal(protocol.ByteCount(6))) - Expect(str.dataForWriting).To(BeNil()) - Expect(str.HasDataForWriting()).To(BeFalse()) + f := str.PopStreamFrame(3 + frameHeaderLen) + Expect(f.Data).To(Equal([]byte("foo"))) + Expect(f.FinBit).To(BeFalse()) + Expect(f.Offset).To(BeZero()) + Expect(f.DataLenPresent).To(BeTrue()) + f = str.PopStreamFrame(100) + Expect(f.Data).To(Equal([]byte("bar"))) + Expect(f.FinBit).To(BeFalse()) + Expect(f.Offset).To(Equal(protocol.ByteCount(3))) + Expect(f.DataLenPresent).To(BeTrue()) + Expect(str.PopStreamFrame(1000)).To(BeNil()) Eventually(done).Should(BeClosed()) }) - It("getDataForWriting returns nil if no data is available", func() { - Expect(str.GetDataForWriting(1000)).To(BeNil()) + It("PopStreamFrame returns nil if no data is available", func() { + Expect(str.PopStreamFrame(1000)).To(BeNil()) }) It("copies the slice while writing", func() { - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) + frameHeaderSize := protocol.ByteCount(4) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) s := []byte("foo") go func() { defer GinkgoRecover() @@ -827,9 +834,13 @@ var _ = Describe("Stream", func() { Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(3)) }() - Eventually(func() bool { return str.HasDataForWriting() }).Should(BeTrue()) - s[0] = 'v' - Expect(str.GetDataForWriting(3)).To(Equal([]byte("foo"))) + var frame *wire.StreamFrame + Eventually(func() *wire.StreamFrame { frame = str.PopStreamFrame(frameHeaderSize + 1); return frame }).ShouldNot(BeNil()) + Expect(frame.Data).To(Equal([]byte("f"))) + s[1] = 'e' + f := str.PopStreamFrame(100) + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(Equal([]byte("oo"))) }) It("returns when given a nil input", func() { @@ -913,50 +924,40 @@ var _ = Describe("Stream", func() { It("allows FIN", func() { str.Close() - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(1000) - Expect(data).To(BeEmpty()) - Expect(sendFin).To(BeTrue()) + f := str.PopStreamFrame(1000) + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(BeEmpty()) + Expect(f.FinBit).To(BeTrue()) }) - It("does not allow FIN when there's still data", func() { + It("doesn't allow FIN when there's still data", func() { + frameHeaderLen := protocol.ByteCount(4) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) str.dataForWriting = []byte("foobar") str.Close() - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(3) - Expect(data).To(Equal([]byte("foo"))) - Expect(sendFin).To(BeFalse()) - data, sendFin = str.GetDataForWriting(3) - Expect(data).To(Equal([]byte("bar"))) - Expect(sendFin).To(BeTrue()) + f := str.PopStreamFrame(3 + frameHeaderLen) + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(Equal([]byte("foo"))) + Expect(f.FinBit).To(BeFalse()) + f = str.PopStreamFrame(100) + Expect(f.Data).To(Equal([]byte("bar"))) + Expect(f.FinBit).To(BeTrue()) }) - It("does not allow FIN when the stream is not closed", func() { - Expect(str.HasDataForWriting()).To(BeFalse()) - _, sendFin := str.GetDataForWriting(3) - Expect(sendFin).To(BeFalse()) - }) - - It("does not allow FIN after an error", func() { + It("doesn't allow FIN after an error", func() { str.Cancel(errors.New("test")) - Expect(str.HasDataForWriting()).To(BeFalse()) - data, sendFin := str.GetDataForWriting(1000) - Expect(data).To(BeEmpty()) - Expect(sendFin).To(BeFalse()) + f := str.PopStreamFrame(1000) + Expect(f).To(BeNil()) }) - It("does not allow FIN twice", func() { + It("doesn't allow FIN twice", func() { str.Close() - Expect(str.HasDataForWriting()).To(BeTrue()) - data, sendFin := str.GetDataForWriting(1000) - Expect(data).To(BeEmpty()) - Expect(sendFin).To(BeTrue()) - Expect(str.HasDataForWriting()).To(BeFalse()) - data, sendFin = str.GetDataForWriting(1000) - Expect(data).To(BeEmpty()) - Expect(sendFin).To(BeFalse()) + f := str.PopStreamFrame(1000) + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(BeEmpty()) + Expect(f.FinBit).To(BeTrue()) + Expect(str.PopStreamFrame(1000)).To(BeNil()) }) }) @@ -971,18 +972,19 @@ var _ = Describe("Stream", func() { }) It("doesn't get data for writing if an error occurred", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := strWithTimeout.Write([]byte("foobar")) + _, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 500)) Expect(err).To(MatchError(testErr)) + close(done) }() - Eventually(func() []byte { return str.dataForWriting }).ShouldNot(BeNil()) - Expect(str.HasDataForWriting()).To(BeTrue()) + Eventually(func() *wire.StreamFrame { return str.PopStreamFrame(50) }).ShouldNot(BeNil()) // get a STREAM frame containing some data, but not all str.Cancel(testErr) - data, sendFin := str.GetDataForWriting(6) - Expect(data).To(BeNil()) - Expect(sendFin).To(BeFalse()) - Expect(str.HasDataForWriting()).To(BeFalse()) + Expect(str.PopStreamFrame(1000)).To(BeNil()) + Eventually(done).Should(BeClosed()) }) }) }) @@ -1017,8 +1019,8 @@ var _ = Describe("Stream", func() { It("is not finished if it is only closed for writing", func() { str.Close() - _, sentFin := str.GetDataForWriting(1000) - Expect(sentFin).To(BeTrue()) + f := str.PopStreamFrame(1000) + Expect(f.FinBit).To(BeTrue()) Expect(str.Finished()).To(BeFalse()) }) @@ -1060,8 +1062,8 @@ var _ = Describe("Stream", func() { It("is finished after finishing writing and receiving a RST", func() { mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(13), true) str.Close() - _, sentFin := str.GetDataForWriting(1000) - Expect(sentFin).To(BeTrue()) + f := str.PopStreamFrame(1000) + Expect(f.FinBit).To(BeTrue()) str.RegisterRemoteError(testErr, 13) Expect(str.Finished()).To(BeTrue()) })