diff --git a/send_stream.go b/send_stream.go index cb7a828d..885ed7a9 100644 --- a/send_stream.go +++ b/send_stream.go @@ -346,6 +346,10 @@ func (s *sendStream) frameAcked(f wire.Frame) { f.(*wire.StreamFrame).PutBack() s.mutex.Lock() + if s.canceledWrite { + s.mutex.Unlock() + return + } s.numOutstandingFrames-- if s.numOutstandingFrames < 0 { panic("numOutStandingFrames negative") @@ -371,6 +375,10 @@ func (s *sendStream) queueRetransmission(f wire.Frame) { sf := f.(*wire.StreamFrame) sf.DataLenPresent = true s.mutex.Lock() + if s.canceledWrite { + s.mutex.Unlock() + return + } s.retransmissionQueue = append(s.retransmissionQueue, sf) s.numOutstandingFrames-- if s.numOutstandingFrames < 0 { @@ -413,6 +421,8 @@ func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, wr s.ctxCancel() s.canceledWrite = true s.cancelWriteErr = writeErr + s.numOutstandingFrames = 0 + s.retransmissionQueue = nil newlyCompleted := s.isNewlyCompleted() s.mutex.Unlock() diff --git a/send_stream_test.go b/send_stream_test.go index eaddb943..ec7bed00 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -738,6 +738,7 @@ var _ = Describe("Send Stream", func() { waitForWrite() frame, _ := str.popStreamFrame(50) Expect(frame).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) str.CancelWrite(1234) Eventually(writeReturned).Should(BeClosed()) Expect(n).To(BeEquivalentTo(frame.Frame.(*wire.StreamFrame).DataLen())) @@ -758,6 +759,7 @@ var _ = Describe("Send Stream", func() { frame, hasMoreData := str.popStreamFrame(50) Expect(hasMoreData).To(BeTrue()) Expect(frame).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) str.CancelWrite(1234) frame, hasMoreData = str.popStreamFrame(10) Expect(frame).To(BeNil()) @@ -781,6 +783,7 @@ var _ = Describe("Send Stream", func() { frame, hasMoreData := str.popStreamFrame(50) Expect(hasMoreData).To(BeTrue()) Expect(frame).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) str.CancelWrite(1234) frame, hasMoreData = str.popStreamFrame(10) Expect(hasMoreData).To(BeFalse()) @@ -788,6 +791,26 @@ var _ = Describe("Send Stream", func() { Eventually(writeReturned).Should(BeClosed()) }) + It("ignores acknowledgements for STREAM frames after it was cancelled", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onHasStreamData(streamID) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(gomock.Any()) + writeReturned := make(chan struct{}) + go func() { + defer GinkgoRecover() + strWithTimeout.Write(getData(100)) + close(writeReturned) + }() + waitForWrite() + frame, hasMoreData := str.popStreamFrame(50) + Expect(hasMoreData).To(BeTrue()) + Expect(frame).ToNot(BeNil()) + mockSender.EXPECT().onStreamCompleted(streamID) + str.CancelWrite(1234) + frame.OnAcked(frame.Frame) + }) + It("cancels the context", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) mockSender.EXPECT().onStreamCompleted(gomock.Any()) @@ -956,20 +979,28 @@ var _ = Describe("Send Stream", func() { Expect(newFrame.Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) }) - It("doesn't get a retransmission after a stream was canceled", func() { - str.numOutstandingFrames = 1 - f := &wire.StreamFrame{ - Data: []byte("foobar"), - Offset: 0x42, - DataLenPresent: false, - } + It("doesn't queue retransmissions for a stream that was canceled", func() { mockSender.EXPECT().onHasStreamData(streamID) - str.queueRetransmission(f) - mockSender.EXPECT().queueControlFrame(gomock.Any()) - str.CancelWrite(0) - frame, hasMoreData := str.popStreamFrame(protocol.MaxByteCount) - Expect(hasMoreData).To(BeFalse()) - Expect(frame).To(BeNil()) + mockFC.EXPECT().SendWindowSize().Return(protocol.MaxByteCount) + mockFC.EXPECT().AddBytesSent(protocol.ByteCount(6)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + waitForWrite() + f, _ := str.popStreamFrame(100) + Expect(f).ToNot(BeNil()) + gomock.InOrder( + mockSender.EXPECT().queueControlFrame(gomock.Any()), + mockSender.EXPECT().onStreamCompleted(streamID), + ) + str.CancelWrite(9876) + // don't EXPECT any calls to onHasStreamData + f.OnLost(f.Frame) + Expect(str.retransmissionQueue).To(BeEmpty()) }) })