From 3f9d8feab2de6a552af39bb00ae3c1c9123a22b1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 26 Jan 2023 12:58:06 -0800 Subject: [PATCH] return StreamErrors for all kinds of stream cancelations (#3681) * fix: return typed errors when after cancelling actions This is errors.Is'ed by libp2p and the fmt.Errorf messages didn't passed that test: https://github.com/libp2p/go-libp2p/blob/82315917f76ca1e9e1c3aca3ad957080f0ac870c/p2p/transport/quic/stream.go#L23 * replace StreamErrorAction with a local / remote flag Co-authored-by: Jorropo --- errors.go | 7 +++++- integrationtests/self/cancelation_test.go | 20 ++++++++++++---- receive_stream.go | 5 ++-- receive_stream_test.go | 25 +++++++++++++++----- send_stream.go | 11 ++++----- send_stream_test.go | 28 +++++++++++++++++------ 6 files changed, 69 insertions(+), 27 deletions(-) diff --git a/errors.go b/errors.go index 89149fd0..c9fb0a07 100644 --- a/errors.go +++ b/errors.go @@ -46,6 +46,7 @@ const ( type StreamError struct { StreamID StreamID ErrorCode StreamErrorCode + Remote bool } func (e *StreamError) Is(target error) bool { @@ -54,5 +55,9 @@ func (e *StreamError) Is(target error) bool { } func (e *StreamError) Error() string { - return fmt.Sprintf("stream %d canceled with error code %d", e.StreamID, e.ErrorCode) + pers := "local" + if e.Remote { + pers = "remote" + } + return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode) } diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index 48a64413..7b98aba4 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -16,7 +16,7 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Stream Cancelations", func() { +var _ = Describe("Stream Cancellations", func() { const numStreams = 80 Context("canceling the read side", func() { @@ -45,9 +45,10 @@ var _ = Describe("Stream Cancelations", func() { str, err := conn.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) if _, err := str.Write(data); err != nil { - Expect(err).To(MatchError(&quic.StreamError{ + Expect(err).To(Equal(&quic.StreamError{ StreamID: str.StreamID(), ErrorCode: quic.StreamErrorCode(str.StreamID()), + Remote: true, })) atomic.AddInt32(&canceledCounter, 1) return @@ -90,7 +91,14 @@ var _ = Describe("Stream Cancelations", func() { // cancel around 2/3 of the streams if rand.Int31()%3 != 0 { atomic.AddInt32(&canceledCounter, 1) - str.CancelRead(quic.StreamErrorCode(str.StreamID())) + resetErr := quic.StreamErrorCode(str.StreamID()) + str.CancelRead(resetErr) + _, err := str.Read([]byte{0}) + Expect(err).To(Equal(&quic.StreamError{ + StreamID: str.StreamID(), + ErrorCode: resetErr, + Remote: false, + })) return } data, err := io.ReadAll(str) @@ -189,7 +197,11 @@ var _ = Describe("Stream Cancelations", func() { b := make([]byte, 32) if _, err := str.Read(b); err != nil { atomic.AddInt32(&counter, 1) - Expect(err.Error()).To(ContainSubstring("canceled with error code 1234")) + Expect(err).To(Equal(&quic.StreamError{ + StreamID: str.StreamID(), + ErrorCode: 1234, + Remote: false, + })) return } }() diff --git a/receive_stream.go b/receive_stream.go index e0bb3667..5d220e22 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -44,7 +44,7 @@ type receiveStream struct { closedForShutdown bool // set when CloseForShutdown() is called finRead bool // set once we read a frame with a Fin canceledRead bool // set when CancelRead() is called - resetRemotely bool // set when HandleResetStreamFrame() is called + resetRemotely bool // set when handleResetStreamFrame() is called readChan chan struct{} readOnce chan struct{} // cap: 1, to protect against concurrent use of Read @@ -215,7 +215,7 @@ func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* c return false } s.canceledRead = true - s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) + s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false} s.signalRead() s.sender.queueControlFrame(&wire.StopSendingFrame{ StreamID: s.streamID, @@ -287,6 +287,7 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) s.resetRemotelyErr = &StreamError{ StreamID: s.streamID, ErrorCode: frame.ErrorCode, + Remote: true, } s.signalRead() return newlyRcvdFinalOffset, nil diff --git a/receive_stream_test.go b/receive_stream_test.go index 18a05bc9..3987eea6 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -484,7 +484,7 @@ var _ = Describe("Receive Stream", func() { }) }) - Context("stream cancelations", func() { + Context("stream cancellations", func() { Context("canceling read", func() { It("unblocks Read", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) @@ -492,7 +492,11 @@ var _ = Describe("Receive Stream", func() { go func() { defer GinkgoRecover() _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) + Expect(err).To(Equal(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + Remote: false, + })) close(done) }() Consistently(done).ShouldNot(BeClosed()) @@ -504,7 +508,11 @@ var _ = Describe("Receive Stream", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) str.CancelRead(1234) _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) + Expect(err).To(Equal(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + Remote: false, + })) }) It("does nothing when CancelRead is called twice", func() { @@ -512,7 +520,11 @@ var _ = Describe("Receive Stream", func() { str.CancelRead(1234) str.CancelRead(1234) _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) + Expect(err).To(Equal(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + Remote: false, + })) }) It("queues a STOP_SENDING frame", func() { @@ -609,9 +621,10 @@ var _ = Describe("Receive Stream", func() { go func() { defer GinkgoRecover() _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError(&StreamError{ + Expect(err).To(Equal(&StreamError{ StreamID: streamID, ErrorCode: 1234, + Remote: true, })) close(done) }() @@ -668,7 +681,7 @@ var _ = Describe("Receive Stream", func() { Expect(str.handleResetStreamFrame(rst)).To(Succeed()) }) - It("doesn't do anyting when it was closed for shutdown", func() { + It("doesn't do anything when it was closed for shutdown", func() { str.closeForShutdown(nil) err := str.handleResetStreamFrame(rst) Expect(err).ToNot(HaveOccurred()) diff --git a/send_stream.go b/send_stream.go index 146c9c85..6b1d7b17 100644 --- a/send_stream.go +++ b/send_stream.go @@ -416,11 +416,11 @@ func (s *sendStream) Close() error { } func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { - s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) + s.cancelWriteImpl(errorCode, false) } // must be called after locking the mutex -func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) { +func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) { s.mutex.Lock() if s.canceledWrite { s.mutex.Unlock() @@ -428,7 +428,7 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr er } s.ctxCancel() s.canceledWrite = true - s.cancelWriteErr = writeErr + s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote} s.numOutstandingFrames = 0 s.retransmissionQueue = nil newlyCompleted := s.isNewlyCompleted() @@ -457,10 +457,7 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { } func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - s.cancelWriteImpl(frame.ErrorCode, &StreamError{ - StreamID: s.streamID, - ErrorCode: frame.ErrorCode, - }) + s.cancelWriteImpl(frame.ErrorCode, true) } func (s *sendStream) Context() context.Context { diff --git a/send_stream_test.go b/send_stream_test.go index 4079761e..0fcdf6c8 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -726,7 +726,11 @@ var _ = Describe("Send Stream", func() { defer GinkgoRecover() var err error n, err = strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) + Expect(err).To(Equal(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + Remote: false, + })) close(writeReturned) }() waitForWrite() @@ -770,7 +774,11 @@ var _ = Describe("Send Stream", func() { go func() { defer GinkgoRecover() _, err := strWithTimeout.Write(getData(5000)) - Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) + Expect(err).To(Equal(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + Remote: false, + })) close(writeReturned) }() waitForWrite() @@ -818,7 +826,11 @@ var _ = Describe("Send Stream", func() { mockSender.EXPECT().onStreamCompleted(gomock.Any()) str.CancelWrite(1234) _, err := strWithTimeout.Write([]byte("foobar")) - Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) + Expect(err).To(MatchError(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + Remote: false, + })) }) It("only cancels once", func() { @@ -862,9 +874,10 @@ var _ = Describe("Send Stream", func() { go func() { defer GinkgoRecover() _, err := str.Write(getData(5000)) - Expect(err).To(MatchError(&StreamError{ + Expect(err).To(Equal(&StreamError{ StreamID: streamID, - ErrorCode: 1234, + ErrorCode: 123, + Remote: true, })) close(done) }() @@ -884,9 +897,10 @@ var _ = Describe("Send Stream", func() { ErrorCode: 123, }) _, err := str.Write([]byte("foobar")) - Expect(err).To(MatchError(&StreamError{ + Expect(err).To(Equal(&StreamError{ StreamID: streamID, - ErrorCode: 1234, + ErrorCode: 123, + Remote: true, })) }) })