preserve error from SendStream during cancellation and closing (#4882)

* preserve error from cancelled SendStream during cancellation and closing

* clarify that canceling a send stream after closing is valid

* preserve close and reset error when send stream is closed for shutdown
This commit is contained in:
Marten Seemann 2025-01-21 01:42:47 -08:00 committed by GitHub
parent fba6ea90a0
commit 5d4835e422
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 77 additions and 53 deletions

View file

@ -37,10 +37,10 @@ type sendStream struct {
writeOffset protocol.ByteCount
cancelWriteErr *StreamError
closeForShutdownErr error
queuedResetStreamFrame bool
// finalError is the error that is returned by Write.
// It can either be a cancellation error or the shutdown error.
finalError error
queuedResetStreamFrame *wire.ResetStreamFrame
finishedWriting bool // set once Close() is called
finSent bool // set when a STREAM_FRAME with FIN bit has been sent
@ -48,6 +48,8 @@ type sendStream struct {
// This can happen because the application called CancelWrite,
// or because Write returned the error (for remote cancellations).
cancellationFlagged bool
cancelled bool // both local and remote cancellations
closedForShutdown bool // set by closeForShutdown
completed bool // set when this stream has been reported to the streamSender as completed
dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
@ -105,16 +107,15 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
s.mutex.Lock()
defer s.mutex.Unlock()
if s.finalError != nil {
if s.cancelled {
s.cancellationFlagged = true
}
return s.isNewlyCompleted(), 0, s.finalError
}
if s.finishedWriting {
return false, 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if s.cancelWriteErr != nil {
s.cancellationFlagged = true
return s.isNewlyCompleted(), 0, s.cancelWriteErr
}
if s.closeForShutdownErr != nil {
return false, 0, s.closeForShutdownErr
}
if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
return false, 0, errDeadline
}
@ -168,7 +169,7 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
}
deadlineTimer.Reset(deadline)
}
if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
if s.dataForWriting == nil || s.finalError != nil {
break
}
}
@ -197,11 +198,11 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
if bytesWritten == len(p) {
return false, bytesWritten, nil
}
if s.closeForShutdownErr != nil {
return false, bytesWritten, s.closeForShutdownErr
} else if s.cancelWriteErr != nil {
s.cancellationFlagged = true
return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr
if s.finalError != nil {
if s.cancelled {
s.cancellationFlagged = true
}
return s.isNewlyCompleted(), bytesWritten, s.finalError
}
return false, bytesWritten, nil
}
@ -234,7 +235,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
}
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) {
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
if s.finalError != nil {
return nil, nil, false
}
@ -374,7 +375,7 @@ func (s *sendStream) isNewlyCompleted() bool {
return false
}
// We need to keep the stream around until all frames have been sent and acknowledged.
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame {
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame != nil {
return false
}
// The stream is completed if we sent the FIN.
@ -387,7 +388,7 @@ func (s *sendStream) isNewlyCompleted() bool {
// 2. we received a STOP_SENDING, and
// * the application consumed the error via Write, or
// * the application called Close
if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) {
if s.cancelled && (s.cancellationFlagged || s.finishedWriting) {
s.completed = true
return true
}
@ -396,13 +397,13 @@ func (s *sendStream) isNewlyCompleted() bool {
func (s *sendStream) Close() error {
s.mutex.Lock()
if s.closeForShutdownErr != nil || s.finishedWriting {
if s.closedForShutdown || s.finishedWriting {
s.mutex.Unlock()
return nil
}
s.finishedWriting = true
cancelWriteErr := s.cancelWriteErr
if cancelWriteErr != nil {
cancelled := s.cancelled
if cancelled {
s.cancellationFlagged = true
}
completed := s.isNewlyCompleted()
@ -411,7 +412,7 @@ func (s *sendStream) Close() error {
if completed {
s.sender.onStreamCompleted(s.streamID)
}
if cancelWriteErr != nil {
if cancelled {
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex
@ -421,18 +422,21 @@ func (s *sendStream) Close() error {
}
func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
s.cancelWriteImpl(errorCode, false)
s.cancelWrite(errorCode, false)
}
func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
// cancelWrite cancels the stream
// It is possible to cancel a stream after it has been closed, both locally and remotely.
// This is useful to prevent the retransmission of outstanding stream data.
func (s *sendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) {
s.mutex.Lock()
if s.closeForShutdownErr != nil {
if s.closedForShutdown {
s.mutex.Unlock()
return
}
if !remote {
s.cancellationFlagged = true
if s.cancelWriteErr != nil {
if s.cancelled {
completed := s.isNewlyCompleted()
s.mutex.Unlock()
// The user has called CancelWrite. If the previous cancellation was
@ -444,15 +448,20 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
return
}
}
if s.cancelWriteErr != nil {
if s.cancelled {
s.mutex.Unlock()
return
}
s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
s.ctxCancel(s.cancelWriteErr)
s.cancelled = true
s.finalError = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
s.ctxCancel(s.finalError)
s.numOutstandingFrames = 0
s.retransmissionQueue = nil
s.queuedResetStreamFrame = true
s.queuedResetStreamFrame = &wire.ResetStreamFrame{
StreamID: s.streamID,
FinalSize: s.writeOffset,
ErrorCode: errorCode,
}
s.mutex.Unlock()
s.signalWrite()
@ -473,26 +482,23 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
}
func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
s.cancelWriteImpl(frame.ErrorCode, true)
s.cancelWrite(frame.ErrorCode, true)
}
func (s *sendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.queuedResetStreamFrame {
if s.queuedResetStreamFrame == nil {
return ackhandler.Frame{}, false, false
}
s.queuedResetStreamFrame = false
s.numOutstandingFrames++
return ackhandler.Frame{
Frame: &wire.ResetStreamFrame{
StreamID: s.streamID,
FinalSize: s.writeOffset,
ErrorCode: s.cancelWriteErr.ErrorCode,
},
f := ackhandler.Frame{
Frame: s.queuedResetStreamFrame,
Handler: (*sendStreamResetStreamHandler)(s),
}, true, false
}
s.queuedResetStreamFrame = nil
return f, true, false
}
func (s *sendStream) Context() context.Context {
@ -512,7 +518,10 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock()
s.closeForShutdownErr = err
s.closedForShutdown = true
if s.finalError == nil && !s.finishedWriting {
s.finalError = err
}
s.mutex.Unlock()
s.signalWrite()
}
@ -533,7 +542,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
sf := f.(*wire.StreamFrame)
sf.PutBack()
s.mutex.Lock()
if s.cancelWriteErr != nil {
if s.cancelled {
s.mutex.Unlock()
return
}
@ -552,7 +561,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
sf := f.(*wire.StreamFrame)
s.mutex.Lock()
if s.cancelWriteErr != nil {
if s.cancelled {
s.mutex.Unlock()
return
}
@ -585,9 +594,9 @@ func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) {
}
}
func (s *sendStreamResetStreamHandler) OnLost(wire.Frame) {
func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) {
s.mutex.Lock()
s.queuedResetStreamFrame = true
s.queuedResetStreamFrame = f.(*wire.ResetStreamFrame)
s.numOutstandingFrames--
s.mutex.Unlock()
s.sender.onHasStreamControlFrame(s.streamID, (*sendStream)(s))

View file

@ -432,6 +432,11 @@ func TestSendStreamClose(t *testing.T) {
require.Nil(t, frame.Frame)
require.False(t, hasMore)
require.True(t, mockCtrl.Satisfied())
// shutting down has no effect
str.closeForShutdown(errors.New("goodbye"))
_, err = strWithTimeout.Write([]byte("foobar"))
require.ErrorContains(t, err, "write on closed stream 1234")
}
func TestSendStreamImmediateClose(t *testing.T) {
@ -630,21 +635,28 @@ func TestSendStreamCancellation(t *testing.T) {
require.ErrorContains(t, str.Close(), "close called for canceled stream")
frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
require.Nil(t, frame.Frame)
_, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar"))
_, err = strWithTimeout.Write([]byte("foobar"))
require.Error(t, err)
// TODO(#4808):error code and remote flag are unchanged
// require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false})
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false})
// shutting down has no effect
str.closeForShutdown(errors.New("goodbyte"))
_, err = strWithTimeout.Write([]byte("foobar"))
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false})
}
// It is possible to cancel a stream after it has been closed.
// This is useful if the applications wants to prevent the retransmission of outstanding stream data.
func TestSendStreamCancellationAfterClose(t *testing.T) {
const streamID protocol.StreamID = 1234
mockCtrl := gomock.NewController(t)
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
mockSender := NewMockStreamSender(mockCtrl)
str := newSendStream(context.Background(), streamID, mockSender, mockFC)
strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second}
mockSender.EXPECT().onHasStreamData(streamID, str).Times(2)
_, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar"))
_, err := strWithTimeout.Write([]byte("foobar"))
require.NoError(t, err)
require.NoError(t, str.Close())
@ -659,6 +671,10 @@ func TestSendStreamCancellationAfterClose(t *testing.T) {
require.True(t, ok)
require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, cf.Frame)
require.False(t, hasMore)
_, err = strWithTimeout.Write([]byte("foobar"))
require.Error(t, err)
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: false})
}
func TestSendStreamCancellationStreamRetransmission(t *testing.T) {
@ -809,8 +825,7 @@ func TestSendStreamStopSending(t *testing.T) {
require.Nil(t, frame.Frame)
_, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar"))
require.Error(t, err)
// TODO(#4808):error code and remote flag are unchanged
// require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true})
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true})
}
// This test is inherently racy, as it tests a concurrent call to Write() and CancelRead().