diff --git a/send_stream.go b/send_stream.go index c3191ec6..6c71f0b8 100644 --- a/send_stream.go +++ b/send_stream.go @@ -46,7 +46,8 @@ type sendStream struct { finSent bool // set when a STREAM_FRAME with FIN bit has been sent completed bool // set when this stream has been reported to the streamSender as completed - dataForWriting []byte + dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out + nextFrame *wire.StreamFrame writeChan chan struct{} deadline time.Time @@ -108,20 +109,40 @@ func (s *sendStream) Write(p []byte) (int, error) { notifiedSender bool ) for { - bytesWritten = len(p) - len(s.dataForWriting) - deadline := s.deadline - if !deadline.IsZero() { - if !time.Now().Before(deadline) { - s.dataForWriting = nil - return bytesWritten, errDeadline + var copied bool + var deadline time.Time + // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame), + // which can the be popped the next time we assemble a packet. + // This allows us to return Write() when all data but x bytes have been sent out. + // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame, + // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN). + if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 { + f := wire.GetStreamFrame() + f.Offset = s.writeOffset + f.StreamID = s.streamID + f.DataLenPresent = true + f.Data = f.Data[:len(s.dataForWriting)] + copy(f.Data, s.dataForWriting) + s.nextFrame = f + s.dataForWriting = nil + bytesWritten = len(p) + copied = true + } else { + bytesWritten = len(p) - len(s.dataForWriting) + deadline = s.deadline + if !deadline.IsZero() { + if !time.Now().Before(deadline) { + s.dataForWriting = nil + return bytesWritten, errDeadline + } + if deadlineTimer == nil { + deadlineTimer = utils.NewTimer() + } + deadlineTimer.Reset(deadline) } - if deadlineTimer == nil { - deadlineTimer = utils.NewTimer() + if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { + break } - deadlineTimer.Reset(deadline) - } - if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { - break } s.mutex.Unlock() @@ -129,6 +150,10 @@ func (s *sendStream) Write(p []byte) (int, error) { s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex notifiedSender = true } + if copied { + s.mutex.Lock() + break + } if deadline.IsZero() { <-s.writeChan } else { @@ -149,6 +174,10 @@ func (s *sendStream) Write(p []byte) (int, error) { return bytesWritten, nil } +func (s *sendStream) canBufferStreamFrame() bool { + return s.nextFrame == nil && protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxReceivePacketSize +} + // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) { @@ -182,6 +211,63 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun } } + if len(s.dataForWriting) == 0 && s.nextFrame == nil { + if s.finishedWriting && !s.finSent { + s.finSent = true + return &wire.StreamFrame{ + StreamID: s.streamID, + Offset: s.writeOffset, + DataLenPresent: true, + FinBit: true, + }, false + } + return nil, false + } + + sendWindow := s.flowController.SendWindowSize() + if sendWindow == 0 { + if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked { + s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{ + StreamID: s.streamID, + DataLimit: offset, + }) + return nil, false + } + return nil, true + } + + f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow) + if dataLen := f.DataLen(); dataLen > 0 { + s.writeOffset += f.DataLen() + s.flowController.AddBytesSent(f.DataLen()) + } + f.FinBit = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent + if f.FinBit { + s.finSent = true + } + return f, hasMoreData +} + +func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) { + if s.nextFrame != nil { + nextFrame := s.nextFrame + s.nextFrame = nil + + maxDataLen := utils.MinByteCount(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version)) + if nextFrame.DataLen() > maxDataLen { + s.nextFrame = wire.GetStreamFrame() + s.nextFrame.StreamID = s.streamID + s.nextFrame.Offset = s.writeOffset + maxDataLen + s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen] + s.nextFrame.DataLenPresent = true + copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:]) + nextFrame.Data = nextFrame.Data[:maxDataLen] + } else { + s.signalWrite() + } + return nextFrame, s.nextFrame != nil || s.dataForWriting != nil + } + f := wire.GetStreamFrame() f.FinBit = false f.StreamID = s.streamID @@ -189,8 +275,7 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun f.DataLenPresent = true f.Data = f.Data[:0] - hasMoreData := s.popNewStreamFrame(f, maxBytes) - + hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow) if len(f.Data) == 0 && !f.FinBit { f.PutBack() return nil, hasMoreData @@ -198,33 +283,14 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun return f, hasMoreData } -func (s *sendStream) popNewStreamFrame(f *wire.StreamFrame, maxBytes protocol.ByteCount) bool { +func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool { maxDataLen := f.MaxDataLen(maxBytes, s.version) if maxDataLen == 0 { // a STREAM frame must have at least one byte of data - return s.dataForWriting != nil + return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting } - s.getDataForWriting(f, maxDataLen) - if len(f.Data) == 0 && !f.FinBit { - // this can happen if: - // - popStreamFrame is called but there's no data for writing - // - there's data for writing, but the stream is stream-level flow control blocked - // - there's data for writing, but the stream is connection-level flow control blocked - if s.dataForWriting == nil { - return false - } - if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked { - s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{ - StreamID: s.streamID, - DataLimit: offset, - }) - return false - } - return true - } - if f.FinBit { - s.finSent = true - } - return s.dataForWriting != nil + s.getDataForWriting(f, utils.MinByteCount(maxDataLen, sendWindow)) + + return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting } func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) { @@ -245,29 +311,19 @@ func (s *sendStream) hasData() bool { } func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) { - if s.dataForWriting == nil { - f.FinBit = s.finishedWriting && !s.finSent - return - } - - maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) - if maxBytes == 0 { - return - } - - if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { - f.Data = f.Data[:maxBytes] - copy(f.Data, s.dataForWriting) - s.dataForWriting = s.dataForWriting[maxBytes:] - } else { + if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes { f.Data = f.Data[:len(s.dataForWriting)] copy(f.Data, s.dataForWriting) s.dataForWriting = nil s.signalWrite() + return + } + f.Data = f.Data[:maxBytes] + copy(f.Data, s.dataForWriting) + s.dataForWriting = s.dataForWriting[maxBytes:] + if s.canBufferStreamFrame() { + s.signalWrite() } - s.writeOffset += f.DataLen() - s.flowController.AddBytesSent(f.DataLen()) - f.FinBit = s.finishedWriting && s.dataForWriting == nil && !s.finSent } func (s *sendStream) frameAcked(f wire.Frame) { @@ -357,7 +413,7 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { s.mutex.Lock() - hasStreamData := s.dataForWriting != nil + hasStreamData := s.dataForWriting != nil || s.nextFrame != nil s.mutex.Unlock() s.flowController.UpdateSendWindow(frame.ByteOffset) diff --git a/send_stream_test.go b/send_stream_test.go index 57b92e9d..409f65c6 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -37,13 +37,33 @@ var _ = Describe("Send Stream", func() { strWithTimeout = gbytes.TimeoutWriter(str, timeout) }) + expectedFrameHeaderLen := func(offset protocol.ByteCount) protocol.ByteCount { + return (&wire.StreamFrame{ + StreamID: streamID, + Offset: offset, + DataLenPresent: true, + }).Length(protocol.VersionWhatever) + } + waitForWrite := func() { - EventuallyWithOffset(0, func() []byte { + EventuallyWithOffset(0, func() bool { str.mutex.Lock() - data := str.dataForWriting + hasData := str.dataForWriting != nil || str.nextFrame != nil str.mutex.Unlock() - return data - }).ShouldNot(BeEmpty()) + return hasData + }).Should(BeTrue()) + } + + getDataAtOffset := func(offset, length protocol.ByteCount) []byte { + b := make([]byte, length) + for i := protocol.ByteCount(0); i < length; i++ { + b[i] = uint8(offset + i) + } + return b + } + + getData := func(length protocol.ByteCount) []byte { + return getDataAtOffset(0, length) } It("gets stream id", func() { @@ -52,19 +72,19 @@ var _ = Describe("Send Stream", func() { Context("writing", func() { It("writes and gets all data at once", func() { - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) - mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) done := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) n, err := strWithTimeout.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(6)) - close(done) }() waitForWrite() - frame, _ := str.popStreamFrame(1000) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + frame, _ := str.popStreamFrame(protocol.MaxByteCount) f := frame.Frame.(*wire.StreamFrame) Expect(f.Data).To(Equal([]byte("foobar"))) Expect(f.FinBit).To(BeFalse()) @@ -76,26 +96,25 @@ var _ = Describe("Send Stream", func() { }) It("writes and gets data in two turns", func() { - mockSender.EXPECT().onHasStreamData(streamID) - frameHeaderLen := protocol.ByteCount(4) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) - mockFC.EXPECT().AddBytesSent(gomock.Any() /* protocol.ByteCount(3)*/).Times(2) done := make(chan struct{}) go func() { defer GinkgoRecover() + mockSender.EXPECT().onHasStreamData(streamID) n, err := strWithTimeout.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(6)) close(done) }() waitForWrite() - frame, _ := str.popStreamFrame(3 + frameHeaderLen) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)).Times(2) + frame, _ := str.popStreamFrame(expectedFrameHeaderLen(0) + 3) f := frame.Frame.(*wire.StreamFrame) - Expect(f.Data).To(Equal([]byte("foo"))) - Expect(f.FinBit).To(BeFalse()) Expect(f.Offset).To(BeZero()) + Expect(f.FinBit).To(BeFalse()) + Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.DataLenPresent).To(BeTrue()) - frame, _ = str.popStreamFrame(100) + frame, _ = str.popStreamFrame(protocol.MaxByteCount) f = frame.Frame.(*wire.StreamFrame) Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.FinBit).To(BeFalse()) @@ -105,6 +124,85 @@ var _ = Describe("Send Stream", func() { Eventually(done).Should(BeClosed()) }) + It("writes and gets data in multiple turns, for large writes", func() { + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(5) + var totalBytesSent protocol.ByteCount + mockFC.EXPECT().AddBytesSent(gomock.Any()).Do(func(l protocol.ByteCount) { totalBytesSent += l }).Times(5) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + mockSender.EXPECT().onHasStreamData(streamID) + n, err := strWithTimeout.Write(getData(5000)) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(5000)) + close(done) + }() + waitForWrite() + for i := 0; i < 5; i++ { + frame, _ := str.popStreamFrame(1100) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Offset).To(BeNumerically("~", 1100*i, 10*i)) + Expect(f.FinBit).To(BeFalse()) + Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) + Expect(f.DataLenPresent).To(BeTrue()) + } + Expect(totalBytesSent).To(Equal(protocol.ByteCount(5000))) + Eventually(done).Should(BeClosed()) + }) + + It("unblocks Write as soon as a STREAM frame can be buffered", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := strWithTimeout.Write(getData(protocol.MaxReceivePacketSize + 3)) + Expect(err).ToNot(HaveOccurred()) + }() + waitForWrite() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) + frame, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0) + 2) + Expect(hasMoreData).To(BeTrue()) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.DataLen()).To(Equal(protocol.ByteCount(2))) + Consistently(done).ShouldNot(BeClosed()) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) + frame, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(1) + 1) + Expect(hasMoreData).To(BeTrue()) + f = frame.Frame.(*wire.StreamFrame) + Expect(f.DataLen()).To(Equal(protocol.ByteCount(1))) + Eventually(done).Should(BeClosed()) + }) + + It("only unblocks Write once a previously buffered STREAM frame has been fully dequeued", func() { + mockSender.EXPECT().onHasStreamData(streamID) + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := str.Write(getData(protocol.MaxReceivePacketSize)) + Expect(err).ToNot(HaveOccurred()) + }() + waitForWrite() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) + frame, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0) + 2) + Expect(hasMoreData).To(BeTrue()) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("fo"))) + Consistently(done).ShouldNot(BeClosed()) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(4)) + frame, hasMoreData = str.popStreamFrame(expectedFrameHeaderLen(2) + 4) + Expect(hasMoreData).To(BeTrue()) + f = frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal([]byte("obar"))) + Eventually(done).Should(BeClosed()) + }) + It("popStreamFrame returns nil if no data is available", func() { frame, hasMoreData := str.popStreamFrame(1000) Expect(frame).To(BeNil()) @@ -112,43 +210,45 @@ var _ = Describe("Send Stream", func() { }) It("says if it has more data for writing", func() { - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) - mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) done := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) n, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(100)) - close(done) }() waitForWrite() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) frame, hasMoreData := str.popStreamFrame(50) Expect(frame).ToNot(BeNil()) + Expect(frame.Frame.(*wire.StreamFrame).FinBit).To(BeFalse()) Expect(hasMoreData).To(BeTrue()) - frame, hasMoreData = str.popStreamFrame(1000) + frame, hasMoreData = str.popStreamFrame(protocol.MaxByteCount) Expect(frame).ToNot(BeNil()) + Expect(frame.Frame.(*wire.StreamFrame).FinBit).To(BeFalse()) Expect(hasMoreData).To(BeFalse()) - frame, _ = str.popStreamFrame(1000) + frame, _ = str.popStreamFrame(protocol.MaxByteCount) Expect(frame).To(BeNil()) Eventually(done).Should(BeClosed()) }) It("copies the slice while writing", func() { - mockSender.EXPECT().onHasStreamData(streamID) frameHeaderSize := protocol.ByteCount(4) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(1)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(2)) s := []byte("foo") done := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) n, err := strWithTimeout.Write(s) Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(3)) - close(done) }() waitForWrite() frame, _ := str.popStreamFrame(frameHeaderSize + 1) @@ -178,7 +278,7 @@ var _ = Describe("Send Stream", func() { It("cancels the context when Close is called", func() { mockSender.EXPECT().onHasStreamData(streamID) Expect(str.Context().Done()).ToNot(BeClosed()) - str.Close() + Expect(str.Close()).To(Succeed()) Expect(str.Context().Done()).To(BeClosed()) }) @@ -190,13 +290,13 @@ var _ = Describe("Send Stream", func() { StreamID: streamID, DataLimit: 12, }) - mockSender.EXPECT().onHasStreamData(streamID) done := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) _, err := str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - close(done) }() waitForWrite() f, hasMoreData := str.popStreamFrame(1000) @@ -208,22 +308,20 @@ var _ = Describe("Send Stream", func() { }) It("says that it doesn't have any more data, when it is flow control blocked", func() { - frameHeaderSize := protocol.ByteCount(4) - mockSender.EXPECT().onHasStreamData(streamID) - done := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) _, err := str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - close(done) }() waitForWrite() // first pop a STREAM frame of the maximum size allowed by flow control mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(3)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(3)) - f, hasMoreData := str.popStreamFrame(frameHeaderSize + 3) + f, hasMoreData := str.popStreamFrame(expectedFrameHeaderLen(0) + 3) Expect(f).ToNot(BeNil()) Expect(hasMoreData).To(BeTrue()) @@ -256,7 +354,7 @@ var _ = Describe("Send Stream", func() { mockSender.EXPECT().onHasStreamData(streamID) deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) str.SetWriteDeadline(deadline) - n, err := strWithTimeout.Write([]byte("foobar")) + n, err := strWithTimeout.Write(getData(5000)) Expect(err).To(MatchError(errDeadline)) Expect(n).To(BeZero()) Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) @@ -268,7 +366,7 @@ var _ = Describe("Send Stream", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := str.Write([]byte("foobar")) + _, err := str.Write(getData(5000)) Expect(err).To(MatchError(errDeadline)) close(done) }() @@ -278,8 +376,7 @@ var _ = Describe("Send Stream", func() { }) It("returns the number of bytes written, when the deadline expires", func() { - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() mockFC.EXPECT().AddBytesSent(gomock.Any()) deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) str.SetWriteDeadline(deadline) @@ -287,11 +384,12 @@ var _ = Describe("Send Stream", func() { writeReturned := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(writeReturned) + mockSender.EXPECT().onHasStreamData(streamID) var err error - n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + n, err = strWithTimeout.Write(getData(5000)) Expect(err).To(MatchError(errDeadline)) Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) - close(writeReturned) }() waitForWrite() frame, hasMoreData := str.popStreamFrame(50) @@ -302,17 +400,17 @@ var _ = Describe("Send Stream", func() { }) It("doesn't pop any data after the deadline expired", func() { - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes() + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() mockFC.EXPECT().AddBytesSent(gomock.Any()) deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) str.SetWriteDeadline(deadline) writeReturned := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + defer close(writeReturned) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := strWithTimeout.Write(getData(5000)) Expect(err).To(MatchError(errDeadline)) - close(writeReturned) }() waitForWrite() frame, hasMoreData := str.popStreamFrame(50) @@ -339,7 +437,7 @@ var _ = Describe("Send Stream", func() { close(done) }() runtime.Gosched() - n, err := strWithTimeout.Write([]byte("foobar")) + n, err := strWithTimeout.Write(getData(5000)) Expect(err).To(MatchError(errDeadline)) Expect(n).To(BeZero()) Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) @@ -361,7 +459,7 @@ var _ = Describe("Send Stream", func() { }() str.SetWriteDeadline(deadline1) runtime.Gosched() - _, err := strWithTimeout.Write([]byte("foobar")) + _, err := strWithTimeout.Write(getData(5000)) Expect(err).To(MatchError(errDeadline)) Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) Eventually(done).Should(BeClosed()) @@ -383,7 +481,7 @@ var _ = Describe("Send Stream", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := strWithTimeout.Write([]byte("foobar")) + _, err := strWithTimeout.Write(getData(5000)) Expect(err).To(MatchError("test done")) close(done) }() @@ -412,27 +510,55 @@ var _ = Describe("Send Stream", func() { f := frame.Frame.(*wire.StreamFrame) Expect(f.Data).To(BeEmpty()) Expect(f.FinBit).To(BeTrue()) + Expect(f.DataLenPresent).To(BeTrue()) Expect(hasMoreData).To(BeFalse()) }) It("doesn't send a FIN when there's still data", func() { - mockSender.EXPECT().onHasStreamData(streamID) - frameHeaderLen := protocol.ByteCount(4) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)).Times(2) - mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) - str.dataForWriting = []byte("foobar") + const frameHeaderLen protocol.ByteCount = 4 + mockSender.EXPECT().onHasStreamData(streamID).Times(2) + _, err := strWithTimeout.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).Times(2) + mockFC.EXPECT().AddBytesSent(gomock.Any()).Times(2) frame, _ := str.popStreamFrame(3 + frameHeaderLen) Expect(frame).ToNot(BeNil()) f := frame.Frame.(*wire.StreamFrame) Expect(f.Data).To(Equal([]byte("foo"))) Expect(f.FinBit).To(BeFalse()) - frame, _ = str.popStreamFrame(100) + frame, _ = str.popStreamFrame(protocol.MaxByteCount) f = frame.Frame.(*wire.StreamFrame) Expect(f.Data).To(Equal([]byte("bar"))) Expect(f.FinBit).To(BeTrue()) }) + It("doesn't send a FIN when there's still data, for long writes", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + mockSender.EXPECT().onHasStreamData(streamID) + _, err := strWithTimeout.Write(getData(5000)) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onHasStreamData(streamID) + Expect(str.Close()).To(Succeed()) + }() + waitForWrite() + for i := 1; i <= 5; i++ { + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + if i == 5 { + Eventually(done).Should(BeClosed()) + } + frame, _ := str.popStreamFrame(1100) + Expect(frame).ToNot(BeNil()) + f := frame.Frame.(*wire.StreamFrame) + Expect(f.Data).To(Equal(getDataAtOffset(f.Offset, f.DataLen()))) + Expect(f.FinBit).To(Equal(i == 5)) // the last frame should have the FIN bit set + } + }) + It("doesn't allow FIN after it is closed for shutdown", func() { str.closeForShutdown(errors.New("test")) f, hasMoreData := str.popStreamFrame(1000) @@ -470,13 +596,13 @@ var _ = Describe("Send Stream", func() { }) It("doesn't get data for writing if an error occurred", func() { - mockSender.EXPECT().onHasStreamData(streamID) - mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 500)) + _, err := strWithTimeout.Write(getData(5000)) Expect(err).To(MatchError(testErr)) close(done) }() @@ -510,7 +636,7 @@ var _ = Describe("Send Stream", func() { It("says when it has data for sending", func() { mockFC.EXPECT().UpdateSendWindow(gomock.Any()) - mockSender.EXPECT().onHasStreamData(streamID).Times(2) // once for Write, once for the MAX_STREAM_DATA frame + mockSender.EXPECT().onHasStreamData(streamID) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -519,6 +645,7 @@ var _ = Describe("Send Stream", func() { close(done) }() waitForWrite() + mockSender.EXPECT().onHasStreamData(streamID) str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ StreamID: streamID, ByteOffset: 42, @@ -545,8 +672,8 @@ var _ = Describe("Send Stream", func() { }) It("unblocks Write", func() { - mockSender.EXPECT().onHasStreamData(streamID) mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) writeReturned := make(chan struct{}) @@ -554,7 +681,7 @@ var _ = Describe("Send Stream", func() { go func() { defer GinkgoRecover() var err error - n, err = strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + n, err = strWithTimeout.Write(getData(5000)) Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) close(writeReturned) }() @@ -567,14 +694,36 @@ var _ = Describe("Send Stream", func() { }) It("doesn't pop STREAM frames after being canceled", func() { - mockSender.EXPECT().onHasStreamData(streamID) mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) mockFC.EXPECT().AddBytesSent(gomock.Any()) writeReturned := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := strWithTimeout.Write(bytes.Repeat([]byte{0}, 100)) + strWithTimeout.Write(getData(100)) + close(writeReturned) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(hasMoreData).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + str.CancelWrite(1234) + frame, hasMoreData = str.popStreamFrame(10) + Expect(frame).To(BeNil()) + Expect(hasMoreData).To(BeFalse()) + Eventually(writeReturned).Should(BeClosed()) + }) + + It("doesn't pop STREAM frames after being canceled, for large writes", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(getData(5000)) Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) close(writeReturned) }() @@ -645,7 +794,7 @@ var _ = Describe("Send Stream", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := str.Write([]byte("foobar")) + _, err := str.Write(getData(5000)) Expect(err).To(MatchError("stream 1337 was reset with error code 123")) Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) @@ -733,7 +882,7 @@ var _ = Describe("Send Stream", func() { }) It("queues lost STREAM frames", func() { - mockSender.EXPECT().onHasStreamData(streamID).Times(2) + mockSender.EXPECT().onHasStreamData(streamID) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(9999)) mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) done := make(chan struct{}) @@ -750,6 +899,7 @@ var _ = Describe("Send Stream", func() { Expect(frame.Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) // now lose the frame + mockSender.EXPECT().onHasStreamData(streamID) frame.OnLost(frame.Frame) newFrame, _ := str.popStreamFrame(protocol.MaxByteCount) Expect(newFrame).ToNot(BeNil()) @@ -775,16 +925,16 @@ var _ = Describe("Send Stream", func() { Context("determining when a stream is completed", func() { BeforeEach(func() { - mockSender.EXPECT().onHasStreamData(streamID).AnyTimes() mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount).AnyTimes() mockFC.EXPECT().AddBytesSent(gomock.Any()).AnyTimes() }) It("says when a stream is completed", func() { + mockSender.EXPECT().onHasStreamData(streamID) done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := strWithTimeout.Write(make([]byte, 1000)) + _, err := strWithTimeout.Write(make([]byte, 100)) Expect(err).ToNot(HaveOccurred()) close(done) }() @@ -793,7 +943,7 @@ var _ = Describe("Send Stream", func() { // get a bunch of small frames (max. 20 bytes) var frames []ackhandler.Frame for { - frame, hasMoreData := str.popStreamFrame(200) + frame, hasMoreData := str.popStreamFrame(20) if frame == nil { continue } @@ -811,6 +961,7 @@ var _ = Describe("Send Stream", func() { } // Now close the stream and acknowledge the FIN. + mockSender.EXPECT().onHasStreamData(streamID) Expect(str.Close()).To(Succeed()) frame, _ := str.popStreamFrame(protocol.MaxByteCount) Expect(frame).ToNot(BeNil()) @@ -818,12 +969,36 @@ var _ = Describe("Send Stream", func() { frame.OnAcked(frame.Frame) }) - It("doesn't say it's completed when there are frames waiting to be retransmitted", func() { + It("says when a stream is completed, if Close() is called before popping the frame", func() { + mockSender.EXPECT().onHasStreamData(streamID).Times(2) done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := strWithTimeout.Write(make([]byte, 10)) + _, err := strWithTimeout.Write(make([]byte, 100)) Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + Eventually(done).Should(BeClosed()) + Expect(str.Close()).To(Succeed()) + + frame, hasMoreData := str.popStreamFrame(protocol.MaxByteCount) + Expect(hasMoreData).To(BeFalse()) + Expect(frame).ToNot(BeNil()) + Expect(frame.Frame.(*wire.StreamFrame).FinBit).To(BeTrue()) + + mockSender.EXPECT().onStreamCompleted(streamID) + frame.OnAcked(frame.Frame) + }) + + It("doesn't say it's completed when there are frames waiting to be retransmitted", func() { + mockSender.EXPECT().onHasStreamData(streamID) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := strWithTimeout.Write(getData(100)) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().onHasStreamData(streamID) Expect(str.Close()).To(Succeed()) close(done) }() @@ -832,7 +1007,7 @@ var _ = Describe("Send Stream", func() { // get a bunch of small frames (max. 20 bytes) var frames []ackhandler.Frame for { - frame, _ := str.popStreamFrame(protocol.MaxByteCount) + frame, _ := str.popStreamFrame(20) if frame == nil { continue } @@ -847,6 +1022,7 @@ var _ = Describe("Send Stream", func() { for _, f := range frames[1:] { f.OnAcked(f.Frame) } + mockSender.EXPECT().onHasStreamData(streamID) frames[0].OnLost(frames[0].Frame) // get the retransmission and acknowledge it