diff --git a/crypto_stream.go b/crypto_stream.go index de9275b7..16294b6d 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -13,7 +13,7 @@ type cryptoStreamI interface { io.Reader io.Writer handleStreamFrame(*wire.StreamFrame) error - popStreamFrame(protocol.ByteCount) *wire.StreamFrame + popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool) closeForShutdown(error) hasDataForWriting() bool setReadOffset(protocol.ByteCount) diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index d87fb9af..f856e1ba 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -130,10 +130,11 @@ func (mr *MockCryptoStreamMockRecorder) hasDataForWriting() *gomock.Call { } // popStreamFrame mocks base method -func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { +func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { ret := m.ctrl.Call(m, "popStreamFrame", arg0) ret0, _ := ret[0].(*wire.StreamFrame) - return ret0 + ret1, _ := ret[1].(bool) + return ret0, ret1 } // popStreamFrame indicates an expected call of popStreamFrame diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 7d1f2a6e..0722acd0 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -238,10 +238,11 @@ func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.C } // popStreamFrame mocks base method -func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount) *wire.StreamFrame { +func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { ret := m.ctrl.Call(m, "popStreamFrame", arg0) ret0, _ := ret[0].(*wire.StreamFrame) - return ret0 + ret1, _ := ret[1].(bool) + return ret0, ret1 } // popStreamFrame indicates an expected call of popStreamFrame diff --git a/packet_packer_test.go b/packet_packer_test.go index 93b85493..8966733b 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -65,8 +65,7 @@ var _ = Describe("Packet packer", func() { mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() cryptoStream = newCryptoStream(mockSender, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version) - streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) - streamFramer = newStreamFramer(cryptoStream, streamsMap, versionGQUICFrames) + streamFramer = newStreamFramer(cryptoStream, nil, versionGQUICFrames) packer = &packetPacker{ cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, diff --git a/send_stream.go b/send_stream.go index 3584febb..392e3cef 100644 --- a/send_stream.go +++ b/send_stream.go @@ -122,12 +122,12 @@ func (s *sendStream) Write(p []byte) (int, error) { // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. -func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFrame { +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { s.mutex.Lock() defer s.mutex.Unlock() if s.closeForShutdownErr != nil { - return nil + return nil, false } frame := &wire.StreamFrame{ @@ -137,11 +137,11 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFra } frameLen := frame.MinLength(s.version) if frameLen >= maxBytes { // a STREAM frame must have at least one byte of data - return nil + return nil, s.dataForWriting != nil } frame.Data, frame.FinBit = s.getDataForWriting(maxBytes - frameLen) if len(frame.Data) == 0 && !frame.FinBit { - return nil + return nil, s.dataForWriting != nil } if frame.FinBit { s.finSent = true @@ -153,7 +153,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFra }) } } - return frame + return frame, s.dataForWriting != nil } func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) { diff --git a/send_stream_test.go b/send_stream_test.go index 51b0258f..b8bfbed5 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -57,7 +57,7 @@ var _ = Describe("Send Stream", func() { Consistently(done).ShouldNot(BeClosed()) var f *wire.StreamFrame Eventually(func() *wire.StreamFrame { - f = str.popStreamFrame(1000) + f, _ = str.popStreamFrame(1000) return f }).ShouldNot(BeNil()) Expect(f.Data).To(Equal([]byte("foobar"))) @@ -86,14 +86,14 @@ var _ = Describe("Send Stream", func() { Consistently(done).ShouldNot(BeClosed()) var f *wire.StreamFrame Eventually(func() *wire.StreamFrame { - f = str.popStreamFrame(3 + frameHeaderLen) + f, _ = str.popStreamFrame(3 + frameHeaderLen) return f }).ShouldNot(BeNil()) Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.FinBit).To(BeFalse()) Expect(f.Offset).To(BeZero()) Expect(f.DataLenPresent).To(BeTrue()) - f = str.popStreamFrame(100) + f, _ = str.popStreamFrame(100) Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.FinBit).To(BeFalse()) Expect(f.Offset).To(Equal(protocol.ByteCount(3))) @@ -103,7 +103,37 @@ var _ = Describe("Send Stream", func() { }) It("popStreamFrame returns nil if no data is available", func() { - Expect(str.popStreamFrame(1000)).To(BeNil()) + frame, hasMoreData := str.popStreamFrame(1000) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + }) + + It("says if it has more data for writing", func() { + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) + mockFC.EXPECT().IsNewlyBlocked().Times(2) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(100)) + close(done) + }() + var hasMoreData bool + Eventually(func() *wire.StreamFrame { + var frame *wire.StreamFrame + frame, hasMoreData = str.popStreamFrame(50) + return frame + }).ShouldNot(BeNil()) + Expect(hasMoreData).To(BeTrue()) + frame, hasMoreData := str.popStreamFrame(1000) + Expect(frame).ToNot(BeNil()) + Expect(hasMoreData).To(BeFalse()) + frame, _ = str.popStreamFrame(1000) + Expect(frame).To(BeNil()) + Eventually(done).Should(BeClosed()) }) It("copies the slice while writing", func() { @@ -121,10 +151,10 @@ var _ = Describe("Send Stream", func() { Expect(n).To(Equal(3)) }() var frame *wire.StreamFrame - Eventually(func() *wire.StreamFrame { frame = str.popStreamFrame(frameHeaderSize + 1); return frame }).ShouldNot(BeNil()) + Eventually(func() *wire.StreamFrame { frame, _ = str.popStreamFrame(frameHeaderSize + 1); return frame }).ShouldNot(BeNil()) Expect(frame.Data).To(Equal([]byte("f"))) s[1] = 'e' - f := str.popStreamFrame(100) + f, _ := str.popStreamFrame(100) Expect(f).ToNot(BeNil()) Expect(f.Data).To(Equal([]byte("oo"))) }) @@ -168,7 +198,7 @@ var _ = Describe("Send Stream", func() { }() var f *wire.StreamFrame Eventually(func() *wire.StreamFrame { - f = str.popStreamFrame(1000) + f, _ = str.popStreamFrame(1000) return f }).ShouldNot(BeNil()) Eventually(done).Should(BeClosed()) @@ -191,7 +221,7 @@ var _ = Describe("Send Stream", func() { Expect(str.Close()).To(Succeed()) var f *wire.StreamFrame Eventually(func() *wire.StreamFrame { - f = str.popStreamFrame(1000) + f, _ = str.popStreamFrame(1000) return f }).ShouldNot(BeNil()) Expect(f.FinBit).To(BeTrue()) @@ -236,7 +266,7 @@ var _ = Describe("Send Stream", func() { }() var frame *wire.StreamFrame Eventually(func() *wire.StreamFrame { - frame = str.popStreamFrame(50) + frame, _ = str.popStreamFrame(50) return frame }).ShouldNot(BeNil()) Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) @@ -259,11 +289,13 @@ var _ = Describe("Send Stream", func() { }() var frame *wire.StreamFrame Eventually(func() *wire.StreamFrame { - frame = str.popStreamFrame(50) + frame, _ = str.popStreamFrame(50) return frame }).ShouldNot(BeNil()) Eventually(writeReturned, scaleDuration(80*time.Millisecond)).Should(BeClosed()) - Expect(str.popStreamFrame(50)).To(BeNil()) + frame, hasMoreData := str.popStreamFrame(50) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) }) It("doesn't unblock if the deadline is changed before the first one expires", func() { @@ -315,10 +347,11 @@ var _ = Describe("Send Stream", func() { It("allows FIN", func() { mockSender.EXPECT().onHasStreamData(streamID) str.Close() - f := str.popStreamFrame(1000) + f, hasMoreData := str.popStreamFrame(1000) Expect(f).ToNot(BeNil()) Expect(f.Data).To(BeEmpty()) Expect(f.FinBit).To(BeTrue()) + Expect(hasMoreData).To(BeFalse()) }) It("doesn't send a FIN when there's still data", func() { @@ -329,29 +362,32 @@ var _ = Describe("Send Stream", func() { mockFC.EXPECT().IsNewlyBlocked() str.dataForWriting = []byte("foobar") Expect(str.Close()).To(Succeed()) - f := str.popStreamFrame(3 + frameHeaderLen) + f, _ := str.popStreamFrame(3 + frameHeaderLen) Expect(f).ToNot(BeNil()) Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.FinBit).To(BeFalse()) - f = str.popStreamFrame(100) + f, _ = str.popStreamFrame(100) Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.FinBit).To(BeTrue()) }) - It("doesn't allow FIN after an error", func() { + It("doesn't allow FIN after it is closed for shutdown", func() { str.closeForShutdown(errors.New("test")) - f := str.popStreamFrame(1000) + f, hasMoreData := str.popStreamFrame(1000) Expect(f).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) }) It("doesn't allow FIN twice", func() { mockSender.EXPECT().onHasStreamData(streamID) str.Close() - f := str.popStreamFrame(1000) + f, _ := str.popStreamFrame(1000) Expect(f).ToNot(BeNil()) Expect(f.Data).To(BeEmpty()) Expect(f.FinBit).To(BeTrue()) - Expect(str.popStreamFrame(1000)).To(BeNil()) + f, hasMoreData := str.popStreamFrame(1000) + Expect(f).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) }) }) @@ -377,7 +413,12 @@ var _ = Describe("Send Stream", func() { Expect(err).To(MatchError(testErr)) close(done) }() - Eventually(func() *wire.StreamFrame { return str.popStreamFrame(50) }).ShouldNot(BeNil()) // get a STREAM frame containing some data, but not all + Eventually(func() *wire.StreamFrame { + defer GinkgoRecover() + // get a STREAM frame containing some data, but not all + frame, _ := str.popStreamFrame(50) + return frame + }).ShouldNot(BeNil()) str.closeForShutdown(testErr) Expect(str.popStreamFrame(1000)).To(BeNil()) Eventually(done).Should(BeClosed()) @@ -421,7 +462,7 @@ var _ = Describe("Send Stream", func() { }() var frame *wire.StreamFrame Eventually(func() *wire.StreamFrame { - frame = str.popStreamFrame(50) + frame, _ = str.popStreamFrame(50) return frame }).ShouldNot(BeNil()) err := str.CancelWrite(1234) @@ -519,7 +560,7 @@ var _ = Describe("Send Stream", func() { It("is finished after Close()", func() { mockSender.EXPECT().onHasStreamData(streamID) str.Close() - f := str.popStreamFrame(1000) + f, _ := str.popStreamFrame(1000) Expect(f.FinBit).To(BeTrue()) Expect(str.finished()).To(BeTrue()) }) diff --git a/session.go b/session.go index 7d76790b..17e84303 100644 --- a/session.go +++ b/session.go @@ -961,7 +961,8 @@ func (s *session) onHasWindowUpdate(id protocol.StreamID) { s.scheduleSending() } -func (s *session) onHasStreamData(streamID protocol.StreamID) { +func (s *session) onHasStreamData(id protocol.StreamID) { + s.streamFramer.AddActiveStream(id) s.scheduleSending() } diff --git a/session_test.go b/session_test.go index 58bcb23f..817d91a1 100644 --- a/session_test.go +++ b/session_test.go @@ -1171,90 +1171,6 @@ var _ = Describe("Session", func() { Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero()) Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x13, 0x37})))) }) - - Context("bundling of small packets", func() { - It("bundles two small frames of different streams into one packet", func() { - s1, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - s2, err := sess.GetOrOpenStream(7) - Expect(err).NotTo(HaveOccurred()) - - done1 := make(chan struct{}) - done2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := s1.Write([]byte("foobar1")) - Expect(err).ToNot(HaveOccurred()) - close(done1) - }() - go func() { - defer GinkgoRecover() - s2.Write([]byte("foobar2")) - Expect(err).ToNot(HaveOccurred()) - close(done2) - }() - time.Sleep(100 * time.Millisecond) // make sure the both writes are active - - sess.scheduleSending() - go sess.run() - defer sess.Close(nil) - - Eventually(mconn.written).Should(HaveLen(1)) - packet := <-mconn.written - Expect(packet).To(ContainSubstring("foobar1")) - Expect(packet).To(ContainSubstring("foobar2")) - Eventually(done1).Should(BeClosed()) - Eventually(done2).Should(BeClosed()) - }) - - It("sends out two big frames in two packets", func() { - s1, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - s2, err := sess.GetOrOpenStream(7) - Expect(err).NotTo(HaveOccurred()) - go sess.run() - defer sess.Close(nil) - go func() { - defer GinkgoRecover() - s1.Write(bytes.Repeat([]byte{'e'}, 1000)) - }() - _, err = s2.Write(bytes.Repeat([]byte{'e'}, 1000)) - Expect(err).ToNot(HaveOccurred()) - Eventually(mconn.written).Should(HaveLen(2)) - }) - - It("sends out two small frames that are written to long after one another into two packets", func() { - s, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - go sess.run() - defer sess.Close(nil) - _, err = s.Write([]byte("foobar1")) - Expect(err).NotTo(HaveOccurred()) - Eventually(mconn.written).Should(HaveLen(1)) - _, err = s.Write([]byte("foobar2")) - Expect(err).NotTo(HaveOccurred()) - Eventually(mconn.written).Should(HaveLen(2)) - }) - - It("sends a queued ACK frame only once", func() { - packetNumber := protocol.PacketNumber(0x1337) - sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) - - s, err := sess.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - go sess.run() - defer sess.Close(nil) - _, err = s.Write([]byte("foobar1")) - Expect(err).NotTo(HaveOccurred()) - Eventually(mconn.written).Should(HaveLen(1)) - _, err = s.Write([]byte("foobar2")) - Expect(err).NotTo(HaveOccurred()) - - Eventually(mconn.written).Should(HaveLen(2)) - Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x13, 0x37})))) - Expect(mconn.written).ToNot(Receive(ContainSubstring(string([]byte{0x13, 0x37})))) - }) - }) }) It("closes when crypto stream errors", func() { @@ -1577,7 +1493,7 @@ var _ = Describe("Session", func() { s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) Expect(s.Close()).To(Succeed()) - f := s.(*stream).popStreamFrame(1000) // trigger "sending" of the FIN bit + f, _ := s.(*stream).popStreamFrame(1000) // trigger "sending" of the FIN bit Expect(f.FinBit).To(BeTrue()) s.(*stream).CloseRemote(0) _, err = s.Read([]byte("a")) diff --git a/stream.go b/stream.go index 69611ac2..80cbb7f1 100644 --- a/stream.go +++ b/stream.go @@ -27,7 +27,7 @@ type streamI interface { handleStreamFrame(*wire.StreamFrame) error handleRstStreamFrame(*wire.RstStreamFrame) error handleStopSendingFrame(*wire.StopSendingFrame) - popStreamFrame(maxBytes protocol.ByteCount) *wire.StreamFrame + popStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool) finished() bool closeForShutdown(error) // methods needed for flow control diff --git a/stream_framer.go b/stream_framer.go index 4cf0a54d..eeb09327 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -1,27 +1,34 @@ package quic import ( + "sync" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" ) type streamFramer struct { - streamsMap *streamsMap + streamGetter streamGetter cryptoStream cryptoStreamI version protocol.VersionNumber retransmissionQueue []*wire.StreamFrame + + streamQueueMutex sync.Mutex + activeStreams map[protocol.StreamID]struct{} + streamQueue []protocol.StreamID } func newStreamFramer( cryptoStream cryptoStreamI, - streamsMap *streamsMap, + streamGetter streamGetter, v protocol.VersionNumber, ) *streamFramer { return &streamFramer{ - streamsMap: streamsMap, - cryptoStream: cryptoStream, - version: v, + streamGetter: streamGetter, + cryptoStream: cryptoStream, + activeStreams: make(map[protocol.StreamID]struct{}), + version: v, } } @@ -29,6 +36,18 @@ func (f *streamFramer) AddFrameForRetransmission(frame *wire.StreamFrame) { f.retransmissionQueue = append(f.retransmissionQueue, frame) } +func (f *streamFramer) AddActiveStream(id protocol.StreamID) { + if id == f.version.CryptoStreamID() { // the crypto stream is handled separately + return + } + f.streamQueueMutex.Lock() + if _, ok := f.activeStreams[id]; !ok { + f.streamQueue = append(f.streamQueue, id) + f.activeStreams[id] = struct{}{} + } + f.streamQueueMutex.Unlock() +} + func (f *streamFramer) PopStreamFrames(maxLen protocol.ByteCount) []*wire.StreamFrame { fs, currentLen := f.maybePopFramesForRetransmission(maxLen) return append(fs, f.maybePopNormalFrames(maxLen-currentLen)...) @@ -38,13 +57,15 @@ func (f *streamFramer) HasFramesForRetransmission() bool { return len(f.retransmissionQueue) > 0 } +// TODO: don't need to ask the crypto stream here, just record this information in AddActiveStream func (f *streamFramer) HasCryptoStreamFrame() bool { return f.cryptoStream.hasDataForWriting() } // TODO(lclemente): This is somewhat duplicate with the normal path for generating frames. func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.StreamFrame { - return f.cryptoStream.popStreamFrame(maxLen) + frame, _ := f.cryptoStream.popStreamFrame(maxLen) + return frame } func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame, currentLen protocol.ByteCount) { @@ -72,32 +93,37 @@ func (f *streamFramer) maybePopFramesForRetransmission(maxTotalLen protocol.Byte return } -func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) (res []*wire.StreamFrame) { +func (f *streamFramer) maybePopNormalFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame { var currentLen protocol.ByteCount - - fn := func(s streamI) (bool, error) { - if s == nil { - return true, nil + var frames []*wire.StreamFrame + f.streamQueueMutex.Lock() + // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet + numActiveStreams := len(f.streamQueue) + for i := 0; i < numActiveStreams; i++ { + if maxTotalLen-currentLen < protocol.MinStreamFrameSize { + break } - - maxLen := maxTotalLen - currentLen - if maxLen < protocol.MinStreamFrameSize { // don't try to add new STREAM frames, if only little space is left in the packet - return false, nil + id := f.streamQueue[0] + f.streamQueue = f.streamQueue[1:] + str, err := f.streamGetter.GetOrOpenStream(id) + if err != nil { // can happen if the stream completed after it said it had data + delete(f.activeStreams, id) + continue } - frame := s.popStreamFrame(maxLen) - if frame == nil { - return true, nil + frame, hasMoreData := str.popStreamFrame(maxTotalLen - currentLen) + if hasMoreData { // put the stream back in the queue (at the end) + f.streamQueue = append(f.streamQueue, id) + } else { // no more data to send. Stream is not active any more + delete(f.activeStreams, id) } - res = append(res, frame) + if frame == nil { // can happen if the receiveStream was canceled after it said it had data + continue + } + frames = append(frames, frame) currentLen += frame.MinLength(f.version) + frame.DataLen() - if currentLen == maxTotalLen { - return false, nil - } - return true, nil } - - f.streamsMap.RoundRobinIterate(fn) - return + f.streamQueueMutex.Unlock() + return frames } // maybeSplitOffFrame removes the first n bytes and returns them as a separate frame. If n >= len(frame), nil is returned and nothing is modified. diff --git a/stream_framer_test.go b/stream_framer_test.go index 03f43d54..18e5b88a 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "errors" "github.com/golang/mock/gomock" @@ -20,15 +21,12 @@ var _ = Describe("Stream Framer", func() { var ( retransmittedFrame1, retransmittedFrame2 *wire.StreamFrame framer *streamFramer - streamsMap *streamsMap stream1, stream2 *MockStreamI + streamGetter *MockStreamGetter ) - setNoData := func(str *MockStreamI) { - str.EXPECT().popStreamFrame(gomock.Any()).AnyTimes() - } - BeforeEach(func() { + streamGetter = NewMockStreamGetter(mockCtrl) retransmittedFrame1 = &wire.StreamFrame{ StreamID: 5, Data: []byte{0x13, 0x37}, @@ -43,11 +41,7 @@ var _ = Describe("Stream Framer", func() { stream2 = NewMockStreamI(mockCtrl) stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() - streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) - streamsMap.putStream(stream1) - streamsMap.putStream(stream2) - - framer = newStreamFramer(nil, streamsMap, versionGQUICFrames) + framer = newStreamFramer(nil, streamGetter, versionGQUICFrames) }) It("says if it has retransmissions", func() { @@ -57,8 +51,6 @@ var _ = Describe("Stream Framer", func() { }) It("sets the DataLenPresent for dequeued retransmitted frames", func() { - setNoData(stream1) - setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) fs := framer.PopStreamFrames(protocol.MaxByteCount) Expect(fs).To(HaveLen(1)) @@ -67,14 +59,10 @@ var _ = Describe("Stream Framer", func() { Context("Popping", func() { It("returns nil when popping an empty framer", func() { - setNoData(stream1) - setNoData(stream2) Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) }) It("pops frames for retransmission", func() { - setNoData(stream1) - setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) framer.AddFrameForRetransmission(retransmittedFrame2) fs := framer.PopStreamFrames(1000) @@ -93,8 +81,6 @@ var _ = Describe("Stream Framer", func() { }) It("pops frames for retransmission, even if the remaining space in the packet is too small, if the frame doesn't need to be split", func() { - setNoData(stream1) - setNoData(stream2) framer.AddFrameForRetransmission(retransmittedFrame1) fs := framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1})) @@ -107,46 +93,127 @@ var _ = Describe("Stream Framer", func() { }) It("returns normal frames", func() { - setNoData(stream2) + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil) f := &wire.StreamFrame{ StreamID: id1, Data: []byte("foobar"), Offset: 42, } - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) + framer.AddActiveStream(id1) fs := framer.PopStreamFrames(1000) Expect(fs).To(Equal([]*wire.StreamFrame{f})) }) - It("returns multiple normal frames", func() { + It("skips a stream that was reported active, but was completed shortly after", func() { + streamGetter.EXPECT().GetOrOpenStream(id1).Return(nil, errors.New("stream was already deleted")) + streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil) + f := &wire.StreamFrame{ + StreamID: id2, + Data: []byte("foobar"), + } + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f})) + }) + + It("skips a stream that was reported active, but doesn't have any data", func() { + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil) + f := &wire.StreamFrame{ + StreamID: id2, + Data: []byte("foobar"), + } + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(nil, false) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f})) + }) + + It("pops from a stream multiple times, if it has enough data", func() { + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil).Times(2) + f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} + f2 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) + framer.AddActiveStream(id1) // only add it once + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f1})) + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) + // no further calls to popStreamFrame, after popStreamFrame said there's no more data + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(BeNil()) + }) + + It("re-queues a stream at the end, if it has enough data", func() { + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil).Times(2) + streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil) + f11 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} + f12 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobaz")} + f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f11, true) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f12, false) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) + framer.AddActiveStream(id1) // only add it once + framer.AddActiveStream(id2) + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f11})) // first a frame from stream 1 + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) // then a frame from stream 2 + Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f12})) // then another frame from stream 1 + }) + + It("only dequeues data from each stream once per packet", func() { + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil) + f1 := &wire.StreamFrame{StreamID: id1, Data: []byte("foobar")} + f2 := &wire.StreamFrame{StreamID: id2, Data: []byte("raboof")} + // both streams have more data, and will be re-queued + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, true) + framer.AddActiveStream(id1) + framer.AddActiveStream(id2) + Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f1, f2})) + }) + + It("returns multiple normal frames in the order they were reported active", func() { + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil) + streamGetter.EXPECT().GetOrOpenStream(id2).Return(stream2, nil) f1 := &wire.StreamFrame{Data: []byte("foobar")} f2 := &wire.StreamFrame{Data: []byte("foobaz")} - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1) - stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2) - fs := framer.PopStreamFrames(1000) - Expect(fs).To(HaveLen(2)) - Expect(fs).To(ContainElement(f1)) - Expect(fs).To(ContainElement(f2)) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false) + stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) + framer.AddActiveStream(id2) + framer.AddActiveStream(id1) + Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f2, f1})) + }) + + It("only asks a stream for data once, even if it was reported active multiple times", func() { + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil) + f := &wire.StreamFrame{Data: []byte("foobar")} + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function + framer.AddActiveStream(id1) + framer.AddActiveStream(id1) + Expect(framer.PopStreamFrames(1000)).To(HaveLen(1)) }) It("returns retransmission frames before normal frames", func() { - setNoData(stream2) + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil) + framer.AddActiveStream(id1) f1 := &wire.StreamFrame{Data: []byte("foobar")} - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1) + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, false) framer.AddFrameForRetransmission(retransmittedFrame1) fs := framer.PopStreamFrames(1000) Expect(fs).To(Equal([]*wire.StreamFrame{retransmittedFrame1, f1})) }) It("does not pop empty frames", func() { - setNoData(stream1) - setNoData(stream2) fs := framer.PopStreamFrames(500) Expect(fs).To(BeEmpty()) }) It("pops frames that have the minimum size", func() { - stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}) + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil) + stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}, false) + framer.AddActiveStream(id1) framer.PopStreamFrames(protocol.MinStreamFrameSize) }) @@ -155,35 +222,15 @@ var _ = Describe("Stream Framer", func() { framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) }) - It("uses the round-robin scheduling", func() { - stream1.EXPECT().popStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ - StreamID: id1, - Data: []byte("foobar"), - }) - stream1.EXPECT().popStreamFrame(gomock.Any()).MaxTimes(1) - stream2.EXPECT().popStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ - StreamID: id2, - Data: []byte("foobaz"), - }) - stream2.EXPECT().popStreamFrame(gomock.Any()).MaxTimes(1) - fs := framer.PopStreamFrames(protocol.MinStreamFrameSize) - Expect(fs).To(HaveLen(1)) - // it doesn't matter here if this data is from stream1 or from stream2... - firstStreamID := fs[0].StreamID - fs = framer.PopStreamFrames(protocol.MinStreamFrameSize) - Expect(fs).To(HaveLen(1)) - // ... but the data popped this time has to be from the other stream - Expect(fs[0].StreamID).ToNot(Equal(firstStreamID)) - }) - It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() { + streamGetter.EXPECT().GetOrOpenStream(id1).Return(stream1, nil) // pop a frame such that the remaining size is one byte less than the minimum STREAM frame size f := &wire.StreamFrame{ StreamID: id1, Data: bytes.Repeat([]byte("f"), int(500-protocol.MinStreamFrameSize)), } - stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f) - setNoData(stream2) + stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f, false) + framer.AddActiveStream(id1) fs := framer.PopStreamFrames(500) Expect(fs).To(Equal([]*wire.StreamFrame{f})) }) @@ -234,8 +281,6 @@ var _ = Describe("Stream Framer", func() { }) It("only removes a frame from the framer after returning all split parts", func() { - setNoData(stream1) - setNoData(stream2) frameHeaderLen := protocol.ByteCount(4) frame := &wire.StreamFrame{Data: bytes.Repeat([]byte{0}, int(501-frameHeaderLen))} framer.AddFrameForRetransmission(frame) diff --git a/stream_test.go b/stream_test.go index 9a7ad757..d691a5bd 100644 --- a/stream_test.go +++ b/stream_test.go @@ -140,7 +140,10 @@ var _ = Describe("Stream", func() { Expect(err).ToNot(HaveOccurred()) close(writeReturned) }() - Eventually(func() *wire.StreamFrame { return str.popStreamFrame(1000) }).ShouldNot(BeNil()) + Eventually(func() *wire.StreamFrame { + frame, _ := str.popStreamFrame(1000) + return frame + }).ShouldNot(BeNil()) Eventually(writeReturned).Should(BeClosed()) mockSender.EXPECT().queueControlFrame(&wire.RstStreamFrame{ StreamID: streamID, diff --git a/streams_map.go b/streams_map.go index 299c7af6..acf68b9d 100644 --- a/streams_map.go +++ b/streams_map.go @@ -255,30 +255,6 @@ func (m *streamsMap) DeleteClosedStreams() error { return nil } -// RoundRobinIterate executes the streamLambda for every open stream, until the streamLambda returns false -// It uses a round-robin-like scheduling to ensure that every stream is considered fairly -// It prioritizes the the header-stream (StreamID 3) -func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - numStreams := len(m.streams) - startIndex := m.roundRobinIndex - - for i := 0; i < numStreams; i++ { - streamID := m.openStreams[(i+startIndex)%numStreams] - cont, err := m.iterateFunc(streamID, fn) - if err != nil { - return err - } - m.roundRobinIndex = (m.roundRobinIndex + 1) % numStreams - if !cont { - break - } - } - return nil -} - // Range executes a callback for all streams, in pseudo-random order func (m *streamsMap) Range(cb func(s streamI)) { m.mutex.RLock() diff --git a/streams_map_test.go b/streams_map_test.go index 81d3f010..a5c04d0c 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -689,112 +689,5 @@ var _ = Describe("Streams Map", func() { Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) }) }) - - Context("RoundRobinIterate", func() { - // create 5 streams, ids 4 to 8 - var lambdaCalledForStream []protocol.StreamID - var numIterations int - - BeforeEach(func() { - lambdaCalledForStream = lambdaCalledForStream[:0] - numIterations = 0 - for i := 4; i <= 8; i++ { - err := m.putStream(newStream(protocol.StreamID(i))) - Expect(err).NotTo(HaveOccurred()) - } - }) - - It("executes the lambda exactly once for every stream", func() { - fn := func(str streamI) (bool, error) { - lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) - numIterations++ - return true, nil - } - err := m.RoundRobinIterate(fn) - Expect(err).ToNot(HaveOccurred()) - Expect(numIterations).To(Equal(5)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) - Expect(m.roundRobinIndex).To(BeZero()) - }) - - It("goes around once when starting in the middle", func() { - fn := func(str streamI) (bool, error) { - lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) - numIterations++ - return true, nil - } - m.roundRobinIndex = 3 // pointing to stream 7 - err := m.RoundRobinIterate(fn) - Expect(err).ToNot(HaveOccurred()) - Expect(numIterations).To(Equal(5)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6})) - Expect(m.roundRobinIndex).To(BeEquivalentTo(3)) - }) - - It("picks up at the index+1 where it last stopped", func() { - fn := func(str streamI) (bool, error) { - lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) - numIterations++ - if str.StreamID() == 5 { - return false, nil - } - return true, nil - } - err := m.RoundRobinIterate(fn) - Expect(err).ToNot(HaveOccurred()) - Expect(numIterations).To(Equal(2)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5})) - Expect(m.roundRobinIndex).To(BeEquivalentTo(2)) - numIterations = 0 - lambdaCalledForStream = lambdaCalledForStream[:0] - fn2 := func(str streamI) (bool, error) { - lambdaCalledForStream = append(lambdaCalledForStream, str.StreamID()) - numIterations++ - if str.StreamID() == 7 { - return false, nil - } - return true, nil - } - err = m.RoundRobinIterate(fn2) - Expect(err).ToNot(HaveOccurred()) - Expect(numIterations).To(Equal(2)) - Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{6, 7})) - }) - - Context("adjusting the RoundRobinIndex when deleting streams", func() { - /* - Index: 0 1 2 3 4 - StreamID: [ 4, 5, 6, 7, 8 ] - */ - - It("adjusts when deleting an element in front", func() { - m.roundRobinIndex = 3 // stream 7 - deleteStream(5) - Expect(m.roundRobinIndex).To(BeEquivalentTo(2)) - }) - - It("doesn't adjust when deleting an element at the back", func() { - m.roundRobinIndex = 1 // stream 5 - deleteStream(7) - Expect(m.roundRobinIndex).To(BeEquivalentTo(1)) - }) - - It("doesn't adjust when deleting the element it is pointing to", func() { - m.roundRobinIndex = 3 // stream 7 - deleteStream(7) - Expect(m.roundRobinIndex).To(BeEquivalentTo(3)) - }) - - It("adjusts when deleting multiple elements", func() { - m.roundRobinIndex = 3 // stream 7 - closeStream(5) - closeStream(6) - closeStream(8) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.roundRobinIndex).To(BeEquivalentTo(1)) - }) - }) - }) }) })