preserve error from SendStream during cancellation and closing (#4882)

* preserve error from cancelled SendStream during cancellation and closing

* clarify that canceling a send stream after closing is valid

* preserve close and reset error when send stream is closed for shutdown
This commit is contained in:
Marten Seemann 2025-01-21 01:42:47 -08:00 committed by GitHub
parent fba6ea90a0
commit 5d4835e422
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 77 additions and 53 deletions

View file

@ -37,10 +37,10 @@ type sendStream struct {
writeOffset protocol.ByteCount writeOffset protocol.ByteCount
cancelWriteErr *StreamError // finalError is the error that is returned by Write.
closeForShutdownErr error // It can either be a cancellation error or the shutdown error.
finalError error
queuedResetStreamFrame bool queuedResetStreamFrame *wire.ResetStreamFrame
finishedWriting bool // set once Close() is called finishedWriting bool // set once Close() is called
finSent bool // set when a STREAM_FRAME with FIN bit has been sent finSent bool // set when a STREAM_FRAME with FIN bit has been sent
@ -48,6 +48,8 @@ type sendStream struct {
// This can happen because the application called CancelWrite, // This can happen because the application called CancelWrite,
// or because Write returned the error (for remote cancellations). // or because Write returned the error (for remote cancellations).
cancellationFlagged bool cancellationFlagged bool
cancelled bool // both local and remote cancellations
closedForShutdown bool // set by closeForShutdown
completed bool // set when this stream has been reported to the streamSender as completed completed bool // set when this stream has been reported to the streamSender as completed
dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
@ -105,16 +107,15 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.finalError != nil {
if s.cancelled {
s.cancellationFlagged = true
}
return s.isNewlyCompleted(), 0, s.finalError
}
if s.finishedWriting { if s.finishedWriting {
return false, 0, fmt.Errorf("write on closed stream %d", s.streamID) return false, 0, fmt.Errorf("write on closed stream %d", s.streamID)
} }
if s.cancelWriteErr != nil {
s.cancellationFlagged = true
return s.isNewlyCompleted(), 0, s.cancelWriteErr
}
if s.closeForShutdownErr != nil {
return false, 0, s.closeForShutdownErr
}
if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
return false, 0, errDeadline return false, 0, errDeadline
} }
@ -168,7 +169,7 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
} }
deadlineTimer.Reset(deadline) deadlineTimer.Reset(deadline)
} }
if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil { if s.dataForWriting == nil || s.finalError != nil {
break break
} }
} }
@ -197,11 +198,11 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
if bytesWritten == len(p) { if bytesWritten == len(p) {
return false, bytesWritten, nil return false, bytesWritten, nil
} }
if s.closeForShutdownErr != nil { if s.finalError != nil {
return false, bytesWritten, s.closeForShutdownErr if s.cancelled {
} else if s.cancelWriteErr != nil {
s.cancellationFlagged = true s.cancellationFlagged = true
return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr }
return s.isNewlyCompleted(), bytesWritten, s.finalError
} }
return false, bytesWritten, nil return false, bytesWritten, nil
} }
@ -234,7 +235,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
} }
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) { func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) {
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil { if s.finalError != nil {
return nil, nil, false return nil, nil, false
} }
@ -374,7 +375,7 @@ func (s *sendStream) isNewlyCompleted() bool {
return false return false
} }
// We need to keep the stream around until all frames have been sent and acknowledged. // We need to keep the stream around until all frames have been sent and acknowledged.
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame { if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame != nil {
return false return false
} }
// The stream is completed if we sent the FIN. // The stream is completed if we sent the FIN.
@ -387,7 +388,7 @@ func (s *sendStream) isNewlyCompleted() bool {
// 2. we received a STOP_SENDING, and // 2. we received a STOP_SENDING, and
// * the application consumed the error via Write, or // * the application consumed the error via Write, or
// * the application called Close // * the application called Close
if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) { if s.cancelled && (s.cancellationFlagged || s.finishedWriting) {
s.completed = true s.completed = true
return true return true
} }
@ -396,13 +397,13 @@ func (s *sendStream) isNewlyCompleted() bool {
func (s *sendStream) Close() error { func (s *sendStream) Close() error {
s.mutex.Lock() s.mutex.Lock()
if s.closeForShutdownErr != nil || s.finishedWriting { if s.closedForShutdown || s.finishedWriting {
s.mutex.Unlock() s.mutex.Unlock()
return nil return nil
} }
s.finishedWriting = true s.finishedWriting = true
cancelWriteErr := s.cancelWriteErr cancelled := s.cancelled
if cancelWriteErr != nil { if cancelled {
s.cancellationFlagged = true s.cancellationFlagged = true
} }
completed := s.isNewlyCompleted() completed := s.isNewlyCompleted()
@ -411,7 +412,7 @@ func (s *sendStream) Close() error {
if completed { if completed {
s.sender.onStreamCompleted(s.streamID) s.sender.onStreamCompleted(s.streamID)
} }
if cancelWriteErr != nil { if cancelled {
return fmt.Errorf("close called for canceled stream %d", s.streamID) return fmt.Errorf("close called for canceled stream %d", s.streamID)
} }
s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex
@ -421,18 +422,21 @@ func (s *sendStream) Close() error {
} }
func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
s.cancelWriteImpl(errorCode, false) s.cancelWrite(errorCode, false)
} }
func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) { // cancelWrite cancels the stream
// It is possible to cancel a stream after it has been closed, both locally and remotely.
// This is useful to prevent the retransmission of outstanding stream data.
func (s *sendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) {
s.mutex.Lock() s.mutex.Lock()
if s.closeForShutdownErr != nil { if s.closedForShutdown {
s.mutex.Unlock() s.mutex.Unlock()
return return
} }
if !remote { if !remote {
s.cancellationFlagged = true s.cancellationFlagged = true
if s.cancelWriteErr != nil { if s.cancelled {
completed := s.isNewlyCompleted() completed := s.isNewlyCompleted()
s.mutex.Unlock() s.mutex.Unlock()
// The user has called CancelWrite. If the previous cancellation was // The user has called CancelWrite. If the previous cancellation was
@ -444,15 +448,20 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
return return
} }
} }
if s.cancelWriteErr != nil { if s.cancelled {
s.mutex.Unlock() s.mutex.Unlock()
return return
} }
s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote} s.cancelled = true
s.ctxCancel(s.cancelWriteErr) s.finalError = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
s.ctxCancel(s.finalError)
s.numOutstandingFrames = 0 s.numOutstandingFrames = 0
s.retransmissionQueue = nil s.retransmissionQueue = nil
s.queuedResetStreamFrame = true s.queuedResetStreamFrame = &wire.ResetStreamFrame{
StreamID: s.streamID,
FinalSize: s.writeOffset,
ErrorCode: errorCode,
}
s.mutex.Unlock() s.mutex.Unlock()
s.signalWrite() s.signalWrite()
@ -473,26 +482,23 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
} }
func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
s.cancelWriteImpl(frame.ErrorCode, true) s.cancelWrite(frame.ErrorCode, true)
} }
func (s *sendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) { func (s *sendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if !s.queuedResetStreamFrame { if s.queuedResetStreamFrame == nil {
return ackhandler.Frame{}, false, false return ackhandler.Frame{}, false, false
} }
s.queuedResetStreamFrame = false
s.numOutstandingFrames++ s.numOutstandingFrames++
return ackhandler.Frame{ f := ackhandler.Frame{
Frame: &wire.ResetStreamFrame{ Frame: s.queuedResetStreamFrame,
StreamID: s.streamID,
FinalSize: s.writeOffset,
ErrorCode: s.cancelWriteErr.ErrorCode,
},
Handler: (*sendStreamResetStreamHandler)(s), Handler: (*sendStreamResetStreamHandler)(s),
}, true, false }
s.queuedResetStreamFrame = nil
return f, true, false
} }
func (s *sendStream) Context() context.Context { func (s *sendStream) Context() context.Context {
@ -512,7 +518,10 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. // The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *sendStream) closeForShutdown(err error) { func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock() s.mutex.Lock()
s.closeForShutdownErr = err s.closedForShutdown = true
if s.finalError == nil && !s.finishedWriting {
s.finalError = err
}
s.mutex.Unlock() s.mutex.Unlock()
s.signalWrite() s.signalWrite()
} }
@ -533,7 +542,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
sf := f.(*wire.StreamFrame) sf := f.(*wire.StreamFrame)
sf.PutBack() sf.PutBack()
s.mutex.Lock() s.mutex.Lock()
if s.cancelWriteErr != nil { if s.cancelled {
s.mutex.Unlock() s.mutex.Unlock()
return return
} }
@ -552,7 +561,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
func (s *sendStreamAckHandler) OnLost(f wire.Frame) { func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
sf := f.(*wire.StreamFrame) sf := f.(*wire.StreamFrame)
s.mutex.Lock() s.mutex.Lock()
if s.cancelWriteErr != nil { if s.cancelled {
s.mutex.Unlock() s.mutex.Unlock()
return return
} }
@ -585,9 +594,9 @@ func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) {
} }
} }
func (s *sendStreamResetStreamHandler) OnLost(wire.Frame) { func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) {
s.mutex.Lock() s.mutex.Lock()
s.queuedResetStreamFrame = true s.queuedResetStreamFrame = f.(*wire.ResetStreamFrame)
s.numOutstandingFrames-- s.numOutstandingFrames--
s.mutex.Unlock() s.mutex.Unlock()
s.sender.onHasStreamControlFrame(s.streamID, (*sendStream)(s)) s.sender.onHasStreamControlFrame(s.streamID, (*sendStream)(s))

View file

@ -432,6 +432,11 @@ func TestSendStreamClose(t *testing.T) {
require.Nil(t, frame.Frame) require.Nil(t, frame.Frame)
require.False(t, hasMore) require.False(t, hasMore)
require.True(t, mockCtrl.Satisfied()) require.True(t, mockCtrl.Satisfied())
// shutting down has no effect
str.closeForShutdown(errors.New("goodbye"))
_, err = strWithTimeout.Write([]byte("foobar"))
require.ErrorContains(t, err, "write on closed stream 1234")
} }
func TestSendStreamImmediateClose(t *testing.T) { func TestSendStreamImmediateClose(t *testing.T) {
@ -630,21 +635,28 @@ func TestSendStreamCancellation(t *testing.T) {
require.ErrorContains(t, str.Close(), "close called for canceled stream") require.ErrorContains(t, str.Close(), "close called for canceled stream")
frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1) frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
require.Nil(t, frame.Frame) require.Nil(t, frame.Frame)
_, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) _, err = strWithTimeout.Write([]byte("foobar"))
require.Error(t, err) require.Error(t, err)
// TODO(#4808):error code and remote flag are unchanged require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false})
// require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false})
// shutting down has no effect
str.closeForShutdown(errors.New("goodbyte"))
_, err = strWithTimeout.Write([]byte("foobar"))
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false})
} }
// It is possible to cancel a stream after it has been closed.
// This is useful if the applications wants to prevent the retransmission of outstanding stream data.
func TestSendStreamCancellationAfterClose(t *testing.T) { func TestSendStreamCancellationAfterClose(t *testing.T) {
const streamID protocol.StreamID = 1234 const streamID protocol.StreamID = 1234
mockCtrl := gomock.NewController(t) mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl) mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl) mockSender := NewMockStreamSender(mockCtrl)
str := newSendStream(context.Background(), streamID, mockSender, mockFC) str := newSendStream(context.Background(), streamID, mockSender, mockFC)
strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second}
mockSender.EXPECT().onHasStreamData(streamID, str).Times(2) mockSender.EXPECT().onHasStreamData(streamID, str).Times(2)
_, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) _, err := strWithTimeout.Write([]byte("foobar"))
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, str.Close()) require.NoError(t, str.Close())
@ -659,6 +671,10 @@ func TestSendStreamCancellationAfterClose(t *testing.T) {
require.True(t, ok) require.True(t, ok)
require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, cf.Frame) require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, cf.Frame)
require.False(t, hasMore) require.False(t, hasMore)
_, err = strWithTimeout.Write([]byte("foobar"))
require.Error(t, err)
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: false})
} }
func TestSendStreamCancellationStreamRetransmission(t *testing.T) { func TestSendStreamCancellationStreamRetransmission(t *testing.T) {
@ -809,8 +825,7 @@ func TestSendStreamStopSending(t *testing.T) {
require.Nil(t, frame.Frame) require.Nil(t, frame.Frame)
_, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar")) _, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar"))
require.Error(t, err) require.Error(t, err)
// TODO(#4808):error code and remote flag are unchanged require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true})
// require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true})
} }
// This test is inherently racy, as it tests a concurrent call to Write() and CancelRead(). // This test is inherently racy, as it tests a concurrent call to Write() and CancelRead().