diff --git a/send_stream.go b/send_stream.go index c674e236..97503d23 100644 --- a/send_stream.go +++ b/send_stream.go @@ -43,7 +43,8 @@ type sendStream struct { closedForShutdown bool // set when CloseForShutdown() is called finishedWriting bool // set once Close() is called canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received - finSent bool // set when a STREAM_FRAME with FIN bit has b + finSent bool // set when a STREAM_FRAME with FIN bit has been sent + completed bool // set when this stream has been reported to the streamSender as completed dataForWriting []byte @@ -255,22 +256,28 @@ func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, boo } func (s *sendStream) frameAcked() { - var streamCompleted bool s.mutex.Lock() s.numOutstandingFrames-- if s.numOutstandingFrames < 0 { panic("numOutStandingFrames negative") } - if (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 { - streamCompleted = true - } + newlyCompleted := s.isNewlyCompleted() s.mutex.Unlock() - if streamCompleted { + if newlyCompleted { s.sender.onStreamCompleted(s.streamID) } } +func (s *sendStream) isNewlyCompleted() bool { + completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 + if completed && !s.completed { + s.completed = true + return true + } + return false +} + func (s *sendStream) queueRetransmission(f wire.Frame) { sf := f.(*wire.StreamFrame) sf.DataLenPresent = true @@ -299,27 +306,32 @@ func (s *sendStream) Close() error { return nil } -// TODO: also complete the stream when this is called after all outstanding data has been acknowledged func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) { - s.mutex.Lock() s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) - s.mutex.Unlock() + } // must be called after locking the mutex func (s *sendStream) cancelWriteImpl(errorCode protocol.ApplicationErrorCode, writeErr error) { - if s.canceledWrite || s.finishedWriting { + s.mutex.Lock() + if s.canceledWrite { + s.mutex.Unlock() return } s.canceledWrite = true s.cancelWriteErr = writeErr + newlyCompleted := s.isNewlyCompleted() + s.mutex.Unlock() + s.signalWrite() s.sender.queueControlFrame(&wire.ResetStreamFrame{ StreamID: s.streamID, ByteOffset: s.writeOffset, ErrorCode: errorCode, }) - // TODO(#991): cancel retransmissions for this stream + if newlyCompleted { + s.sender.onStreamCompleted(s.streamID) + } s.ctxCancel() } @@ -334,11 +346,7 @@ func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) { } } -// TODO: also complete the stream when the frame is received after all outstanding data has been acknowledged func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - s.mutex.Lock() - defer s.mutex.Unlock() - writeErr := streamCanceledError{ errorCode: frame.ErrorCode, error: fmt.Errorf("Stream %d was reset with error code %d", s.streamID, frame.ErrorCode), diff --git a/send_stream_test.go b/send_stream_test.go index db3c27fe..dadec815 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -527,11 +527,14 @@ var _ = Describe("Send Stream", func() { Context("stream cancellations", func() { Context("canceling writing", func() { It("queues a RESET_STREAM frame", func() { - mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{ - StreamID: streamID, - ByteOffset: 1234, - ErrorCode: 9876, - }) + gomock.InOrder( + mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{ + StreamID: streamID, + ByteOffset: 1234, + ErrorCode: 9876, + }), + mockSender.EXPECT().onStreamCompleted(streamID), + ) str.writeOffset = 1234 str.CancelWrite(9876) }) @@ -583,6 +586,7 @@ var _ = Describe("Send Stream", func() { It("cancels the context", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) Expect(str.Context().Done()).ToNot(BeClosed()) str.CancelWrite(1234) Expect(str.Context().Done()).To(BeClosed()) @@ -590,6 +594,7 @@ var _ = Describe("Send Stream", func() { It("doesn't allow further calls to Write", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) str.CancelWrite(1234) _, err := strWithTimeout.Write([]byte("foobar")) Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) @@ -597,12 +602,17 @@ var _ = Describe("Send Stream", func() { It("only cancels once", func() { mockSender.EXPECT().queueControlFrame(&wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 1234}) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) str.CancelWrite(1234) str.CancelWrite(4321) }) - It("doesn't do anything when the stream was already closed", func() { + It("queues a RESET_STREAM frame, even if the stream was already closed", func() { mockSender.EXPECT().onHasStreamData(streamID) + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f).To(BeAssignableToTypeOf(&wire.ResetStreamFrame{})) + }) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) Expect(str.Close()).To(Succeed()) // don't EXPECT any calls to queueControlFrame str.CancelWrite(123) @@ -615,6 +625,8 @@ var _ = Describe("Send Stream", func() { StreamID: streamID, ErrorCode: 101, }) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) + str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 101, @@ -624,6 +636,7 @@ var _ = Describe("Send Stream", func() { It("unblocks Write", func() { mockSender.EXPECT().onHasStreamData(streamID) mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -644,6 +657,7 @@ var _ = Describe("Send Stream", func() { It("doesn't allow further calls to Write", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) + mockSender.EXPECT().onStreamCompleted(gomock.Any()) str.handleStopSendingFrame(&wire.StopSendingFrame{ StreamID: streamID, ErrorCode: 123,