mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
preserve error from SendStream during cancellation and closing (#4882)
* preserve error from cancelled SendStream during cancellation and closing * clarify that canceling a send stream after closing is valid * preserve close and reset error when send stream is closed for shutdown
This commit is contained in:
parent
fba6ea90a0
commit
5d4835e422
2 changed files with 77 additions and 53 deletions
103
send_stream.go
103
send_stream.go
|
@ -37,10 +37,10 @@ type sendStream struct {
|
|||
|
||||
writeOffset protocol.ByteCount
|
||||
|
||||
cancelWriteErr *StreamError
|
||||
closeForShutdownErr error
|
||||
|
||||
queuedResetStreamFrame bool
|
||||
// finalError is the error that is returned by Write.
|
||||
// It can either be a cancellation error or the shutdown error.
|
||||
finalError error
|
||||
queuedResetStreamFrame *wire.ResetStreamFrame
|
||||
|
||||
finishedWriting bool // set once Close() is called
|
||||
finSent bool // set when a STREAM_FRAME with FIN bit has been sent
|
||||
|
@ -48,6 +48,8 @@ type sendStream struct {
|
|||
// This can happen because the application called CancelWrite,
|
||||
// or because Write returned the error (for remote cancellations).
|
||||
cancellationFlagged bool
|
||||
cancelled bool // both local and remote cancellations
|
||||
closedForShutdown bool // set by closeForShutdown
|
||||
completed bool // set when this stream has been reported to the streamSender as completed
|
||||
|
||||
dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
|
||||
|
@ -105,16 +107,15 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
|
|||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.finalError != nil {
|
||||
if s.cancelled {
|
||||
s.cancellationFlagged = true
|
||||
}
|
||||
return s.isNewlyCompleted(), 0, s.finalError
|
||||
}
|
||||
if s.finishedWriting {
|
||||
return false, 0, fmt.Errorf("write on closed stream %d", s.streamID)
|
||||
}
|
||||
if s.cancelWriteErr != nil {
|
||||
s.cancellationFlagged = true
|
||||
return s.isNewlyCompleted(), 0, s.cancelWriteErr
|
||||
}
|
||||
if s.closeForShutdownErr != nil {
|
||||
return false, 0, s.closeForShutdownErr
|
||||
}
|
||||
if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
|
||||
return false, 0, errDeadline
|
||||
}
|
||||
|
@ -168,7 +169,7 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
|
|||
}
|
||||
deadlineTimer.Reset(deadline)
|
||||
}
|
||||
if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
|
||||
if s.dataForWriting == nil || s.finalError != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -197,11 +198,11 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
|
|||
if bytesWritten == len(p) {
|
||||
return false, bytesWritten, nil
|
||||
}
|
||||
if s.closeForShutdownErr != nil {
|
||||
return false, bytesWritten, s.closeForShutdownErr
|
||||
} else if s.cancelWriteErr != nil {
|
||||
s.cancellationFlagged = true
|
||||
return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr
|
||||
if s.finalError != nil {
|
||||
if s.cancelled {
|
||||
s.cancellationFlagged = true
|
||||
}
|
||||
return s.isNewlyCompleted(), bytesWritten, s.finalError
|
||||
}
|
||||
return false, bytesWritten, nil
|
||||
}
|
||||
|
@ -234,7 +235,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
|
|||
}
|
||||
|
||||
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) {
|
||||
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
|
||||
if s.finalError != nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
|
@ -374,7 +375,7 @@ func (s *sendStream) isNewlyCompleted() bool {
|
|||
return false
|
||||
}
|
||||
// We need to keep the stream around until all frames have been sent and acknowledged.
|
||||
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame {
|
||||
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame != nil {
|
||||
return false
|
||||
}
|
||||
// The stream is completed if we sent the FIN.
|
||||
|
@ -387,7 +388,7 @@ func (s *sendStream) isNewlyCompleted() bool {
|
|||
// 2. we received a STOP_SENDING, and
|
||||
// * the application consumed the error via Write, or
|
||||
// * the application called Close
|
||||
if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) {
|
||||
if s.cancelled && (s.cancellationFlagged || s.finishedWriting) {
|
||||
s.completed = true
|
||||
return true
|
||||
}
|
||||
|
@ -396,13 +397,13 @@ func (s *sendStream) isNewlyCompleted() bool {
|
|||
|
||||
func (s *sendStream) Close() error {
|
||||
s.mutex.Lock()
|
||||
if s.closeForShutdownErr != nil || s.finishedWriting {
|
||||
if s.closedForShutdown || s.finishedWriting {
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.finishedWriting = true
|
||||
cancelWriteErr := s.cancelWriteErr
|
||||
if cancelWriteErr != nil {
|
||||
cancelled := s.cancelled
|
||||
if cancelled {
|
||||
s.cancellationFlagged = true
|
||||
}
|
||||
completed := s.isNewlyCompleted()
|
||||
|
@ -411,7 +412,7 @@ func (s *sendStream) Close() error {
|
|||
if completed {
|
||||
s.sender.onStreamCompleted(s.streamID)
|
||||
}
|
||||
if cancelWriteErr != nil {
|
||||
if cancelled {
|
||||
return fmt.Errorf("close called for canceled stream %d", s.streamID)
|
||||
}
|
||||
s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex
|
||||
|
@ -421,18 +422,21 @@ func (s *sendStream) Close() error {
|
|||
}
|
||||
|
||||
func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
|
||||
s.cancelWriteImpl(errorCode, false)
|
||||
s.cancelWrite(errorCode, false)
|
||||
}
|
||||
|
||||
func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
|
||||
// cancelWrite cancels the stream
|
||||
// It is possible to cancel a stream after it has been closed, both locally and remotely.
|
||||
// This is useful to prevent the retransmission of outstanding stream data.
|
||||
func (s *sendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) {
|
||||
s.mutex.Lock()
|
||||
if s.closeForShutdownErr != nil {
|
||||
if s.closedForShutdown {
|
||||
s.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
if !remote {
|
||||
s.cancellationFlagged = true
|
||||
if s.cancelWriteErr != nil {
|
||||
if s.cancelled {
|
||||
completed := s.isNewlyCompleted()
|
||||
s.mutex.Unlock()
|
||||
// The user has called CancelWrite. If the previous cancellation was
|
||||
|
@ -444,15 +448,20 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
|
|||
return
|
||||
}
|
||||
}
|
||||
if s.cancelWriteErr != nil {
|
||||
if s.cancelled {
|
||||
s.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
|
||||
s.ctxCancel(s.cancelWriteErr)
|
||||
s.cancelled = true
|
||||
s.finalError = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
|
||||
s.ctxCancel(s.finalError)
|
||||
s.numOutstandingFrames = 0
|
||||
s.retransmissionQueue = nil
|
||||
s.queuedResetStreamFrame = true
|
||||
s.queuedResetStreamFrame = &wire.ResetStreamFrame{
|
||||
StreamID: s.streamID,
|
||||
FinalSize: s.writeOffset,
|
||||
ErrorCode: errorCode,
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
s.signalWrite()
|
||||
|
@ -473,26 +482,23 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
|
|||
}
|
||||
|
||||
func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
|
||||
s.cancelWriteImpl(frame.ErrorCode, true)
|
||||
s.cancelWrite(frame.ErrorCode, true)
|
||||
}
|
||||
|
||||
func (s *sendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.queuedResetStreamFrame {
|
||||
if s.queuedResetStreamFrame == nil {
|
||||
return ackhandler.Frame{}, false, false
|
||||
}
|
||||
s.queuedResetStreamFrame = false
|
||||
s.numOutstandingFrames++
|
||||
return ackhandler.Frame{
|
||||
Frame: &wire.ResetStreamFrame{
|
||||
StreamID: s.streamID,
|
||||
FinalSize: s.writeOffset,
|
||||
ErrorCode: s.cancelWriteErr.ErrorCode,
|
||||
},
|
||||
f := ackhandler.Frame{
|
||||
Frame: s.queuedResetStreamFrame,
|
||||
Handler: (*sendStreamResetStreamHandler)(s),
|
||||
}, true, false
|
||||
}
|
||||
s.queuedResetStreamFrame = nil
|
||||
return f, true, false
|
||||
}
|
||||
|
||||
func (s *sendStream) Context() context.Context {
|
||||
|
@ -512,7 +518,10 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
|
|||
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
|
||||
func (s *sendStream) closeForShutdown(err error) {
|
||||
s.mutex.Lock()
|
||||
s.closeForShutdownErr = err
|
||||
s.closedForShutdown = true
|
||||
if s.finalError == nil && !s.finishedWriting {
|
||||
s.finalError = err
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
s.signalWrite()
|
||||
}
|
||||
|
@ -533,7 +542,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
|
|||
sf := f.(*wire.StreamFrame)
|
||||
sf.PutBack()
|
||||
s.mutex.Lock()
|
||||
if s.cancelWriteErr != nil {
|
||||
if s.cancelled {
|
||||
s.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
@ -552,7 +561,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
|
|||
func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
|
||||
sf := f.(*wire.StreamFrame)
|
||||
s.mutex.Lock()
|
||||
if s.cancelWriteErr != nil {
|
||||
if s.cancelled {
|
||||
s.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
@ -585,9 +594,9 @@ func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *sendStreamResetStreamHandler) OnLost(wire.Frame) {
|
||||
func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) {
|
||||
s.mutex.Lock()
|
||||
s.queuedResetStreamFrame = true
|
||||
s.queuedResetStreamFrame = f.(*wire.ResetStreamFrame)
|
||||
s.numOutstandingFrames--
|
||||
s.mutex.Unlock()
|
||||
s.sender.onHasStreamControlFrame(s.streamID, (*sendStream)(s))
|
||||
|
|
|
@ -432,6 +432,11 @@ func TestSendStreamClose(t *testing.T) {
|
|||
require.Nil(t, frame.Frame)
|
||||
require.False(t, hasMore)
|
||||
require.True(t, mockCtrl.Satisfied())
|
||||
|
||||
// shutting down has no effect
|
||||
str.closeForShutdown(errors.New("goodbye"))
|
||||
_, err = strWithTimeout.Write([]byte("foobar"))
|
||||
require.ErrorContains(t, err, "write on closed stream 1234")
|
||||
}
|
||||
|
||||
func TestSendStreamImmediateClose(t *testing.T) {
|
||||
|
@ -630,21 +635,28 @@ func TestSendStreamCancellation(t *testing.T) {
|
|||
require.ErrorContains(t, str.Close(), "close called for canceled stream")
|
||||
frame, _, _ = str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
|
||||
require.Nil(t, frame.Frame)
|
||||
_, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar"))
|
||||
_, err = strWithTimeout.Write([]byte("foobar"))
|
||||
require.Error(t, err)
|
||||
// TODO(#4808):error code and remote flag are unchanged
|
||||
// require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false})
|
||||
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false})
|
||||
|
||||
// shutting down has no effect
|
||||
str.closeForShutdown(errors.New("goodbyte"))
|
||||
_, err = strWithTimeout.Write([]byte("foobar"))
|
||||
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1234, Remote: false})
|
||||
}
|
||||
|
||||
// It is possible to cancel a stream after it has been closed.
|
||||
// This is useful if the applications wants to prevent the retransmission of outstanding stream data.
|
||||
func TestSendStreamCancellationAfterClose(t *testing.T) {
|
||||
const streamID protocol.StreamID = 1234
|
||||
mockCtrl := gomock.NewController(t)
|
||||
mockFC := mocks.NewMockStreamFlowController(mockCtrl)
|
||||
mockSender := NewMockStreamSender(mockCtrl)
|
||||
str := newSendStream(context.Background(), streamID, mockSender, mockFC)
|
||||
strWithTimeout := &writerWithTimeout{Writer: str, Timeout: time.Second}
|
||||
|
||||
mockSender.EXPECT().onHasStreamData(streamID, str).Times(2)
|
||||
_, err := (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar"))
|
||||
_, err := strWithTimeout.Write([]byte("foobar"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, str.Close())
|
||||
|
||||
|
@ -659,6 +671,10 @@ func TestSendStreamCancellationAfterClose(t *testing.T) {
|
|||
require.True(t, ok)
|
||||
require.Equal(t, &wire.ResetStreamFrame{StreamID: streamID, FinalSize: 0, ErrorCode: 1337}, cf.Frame)
|
||||
require.False(t, hasMore)
|
||||
|
||||
_, err = strWithTimeout.Write([]byte("foobar"))
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: false})
|
||||
}
|
||||
|
||||
func TestSendStreamCancellationStreamRetransmission(t *testing.T) {
|
||||
|
@ -809,8 +825,7 @@ func TestSendStreamStopSending(t *testing.T) {
|
|||
require.Nil(t, frame.Frame)
|
||||
_, err = (&writerWithTimeout{Writer: str, Timeout: time.Second}).Write([]byte("foobar"))
|
||||
require.Error(t, err)
|
||||
// TODO(#4808):error code and remote flag are unchanged
|
||||
// require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true})
|
||||
require.ErrorIs(t, err, &StreamError{StreamID: streamID, ErrorCode: 1337, Remote: true})
|
||||
}
|
||||
|
||||
// This test is inherently racy, as it tests a concurrent call to Write() and CancelRead().
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue