return StreamErrors for all kinds of stream cancelations (#3681)

* fix: return typed errors when after cancelling actions

This is errors.Is'ed by libp2p and the fmt.Errorf messages didn't passed that test:
82315917f7/p2p/transport/quic/stream.go (L23)

* replace StreamErrorAction with a local / remote flag

Co-authored-by: Jorropo <jorropo.pgm@gmail.com>
This commit is contained in:
Marten Seemann 2023-01-26 12:58:06 -08:00 committed by GitHub
parent dfbfe71309
commit 3f9d8feab2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 69 additions and 27 deletions

View file

@ -46,6 +46,7 @@ const (
type StreamError struct { type StreamError struct {
StreamID StreamID StreamID StreamID
ErrorCode StreamErrorCode ErrorCode StreamErrorCode
Remote bool
} }
func (e *StreamError) Is(target error) bool { func (e *StreamError) Is(target error) bool {
@ -54,5 +55,9 @@ func (e *StreamError) Is(target error) bool {
} }
func (e *StreamError) Error() string { func (e *StreamError) Error() string {
return fmt.Sprintf("stream %d canceled with error code %d", e.StreamID, e.ErrorCode) pers := "local"
if e.Remote {
pers = "remote"
}
return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode)
} }

View file

@ -16,7 +16,7 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("Stream Cancelations", func() { var _ = Describe("Stream Cancellations", func() {
const numStreams = 80 const numStreams = 80
Context("canceling the read side", func() { Context("canceling the read side", func() {
@ -45,9 +45,10 @@ var _ = Describe("Stream Cancelations", func() {
str, err := conn.OpenUniStreamSync(context.Background()) str, err := conn.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
if _, err := str.Write(data); err != nil { if _, err := str.Write(data); err != nil {
Expect(err).To(MatchError(&quic.StreamError{ Expect(err).To(Equal(&quic.StreamError{
StreamID: str.StreamID(), StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()), ErrorCode: quic.StreamErrorCode(str.StreamID()),
Remote: true,
})) }))
atomic.AddInt32(&canceledCounter, 1) atomic.AddInt32(&canceledCounter, 1)
return return
@ -90,7 +91,14 @@ var _ = Describe("Stream Cancelations", func() {
// cancel around 2/3 of the streams // cancel around 2/3 of the streams
if rand.Int31()%3 != 0 { if rand.Int31()%3 != 0 {
atomic.AddInt32(&canceledCounter, 1) atomic.AddInt32(&canceledCounter, 1)
str.CancelRead(quic.StreamErrorCode(str.StreamID())) resetErr := quic.StreamErrorCode(str.StreamID())
str.CancelRead(resetErr)
_, err := str.Read([]byte{0})
Expect(err).To(Equal(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: resetErr,
Remote: false,
}))
return return
} }
data, err := io.ReadAll(str) data, err := io.ReadAll(str)
@ -189,7 +197,11 @@ var _ = Describe("Stream Cancelations", func() {
b := make([]byte, 32) b := make([]byte, 32)
if _, err := str.Read(b); err != nil { if _, err := str.Read(b); err != nil {
atomic.AddInt32(&counter, 1) atomic.AddInt32(&counter, 1)
Expect(err.Error()).To(ContainSubstring("canceled with error code 1234")) Expect(err).To(Equal(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: 1234,
Remote: false,
}))
return return
} }
}() }()

View file

@ -44,7 +44,7 @@ type receiveStream struct {
closedForShutdown bool // set when CloseForShutdown() is called closedForShutdown bool // set when CloseForShutdown() is called
finRead bool // set once we read a frame with a Fin finRead bool // set once we read a frame with a Fin
canceledRead bool // set when CancelRead() is called canceledRead bool // set when CancelRead() is called
resetRemotely bool // set when HandleResetStreamFrame() is called resetRemotely bool // set when handleResetStreamFrame() is called
readChan chan struct{} readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
@ -215,7 +215,7 @@ func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* c
return false return false
} }
s.canceledRead = true s.canceledRead = true
s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead() s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{ s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID, StreamID: s.streamID,
@ -287,6 +287,7 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
s.resetRemotelyErr = &StreamError{ s.resetRemotelyErr = &StreamError{
StreamID: s.streamID, StreamID: s.streamID,
ErrorCode: frame.ErrorCode, ErrorCode: frame.ErrorCode,
Remote: true,
} }
s.signalRead() s.signalRead()
return newlyRcvdFinalOffset, nil return newlyRcvdFinalOffset, nil

View file

@ -484,7 +484,7 @@ var _ = Describe("Receive Stream", func() {
}) })
}) })
Context("stream cancelations", func() { Context("stream cancellations", func() {
Context("canceling read", func() { Context("canceling read", func() {
It("unblocks Read", func() { It("unblocks Read", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any()) mockSender.EXPECT().queueControlFrame(gomock.Any())
@ -492,7 +492,11 @@ var _ = Describe("Receive Stream", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := strWithTimeout.Read([]byte{0}) _, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) Expect(err).To(Equal(&StreamError{
StreamID: streamID,
ErrorCode: 1234,
Remote: false,
}))
close(done) close(done)
}() }()
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
@ -504,7 +508,11 @@ var _ = Describe("Receive Stream", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any()) mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234) str.CancelRead(1234)
_, err := strWithTimeout.Read([]byte{0}) _, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) Expect(err).To(Equal(&StreamError{
StreamID: streamID,
ErrorCode: 1234,
Remote: false,
}))
}) })
It("does nothing when CancelRead is called twice", func() { It("does nothing when CancelRead is called twice", func() {
@ -512,7 +520,11 @@ var _ = Describe("Receive Stream", func() {
str.CancelRead(1234) str.CancelRead(1234)
str.CancelRead(1234) str.CancelRead(1234)
_, err := strWithTimeout.Read([]byte{0}) _, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("Read on stream 1337 canceled with error code 1234")) Expect(err).To(Equal(&StreamError{
StreamID: streamID,
ErrorCode: 1234,
Remote: false,
}))
}) })
It("queues a STOP_SENDING frame", func() { It("queues a STOP_SENDING frame", func() {
@ -609,9 +621,10 @@ var _ = Describe("Receive Stream", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := strWithTimeout.Read([]byte{0}) _, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError(&StreamError{ Expect(err).To(Equal(&StreamError{
StreamID: streamID, StreamID: streamID,
ErrorCode: 1234, ErrorCode: 1234,
Remote: true,
})) }))
close(done) close(done)
}() }()
@ -668,7 +681,7 @@ var _ = Describe("Receive Stream", func() {
Expect(str.handleResetStreamFrame(rst)).To(Succeed()) Expect(str.handleResetStreamFrame(rst)).To(Succeed())
}) })
It("doesn't do anyting when it was closed for shutdown", func() { It("doesn't do anything when it was closed for shutdown", func() {
str.closeForShutdown(nil) str.closeForShutdown(nil)
err := str.handleResetStreamFrame(rst) err := str.handleResetStreamFrame(rst)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -416,11 +416,11 @@ func (s *sendStream) Close() error {
} }
func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) s.cancelWriteImpl(errorCode, false)
} }
// must be called after locking the mutex // must be called after locking the mutex
func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) { func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
s.mutex.Lock() s.mutex.Lock()
if s.canceledWrite { if s.canceledWrite {
s.mutex.Unlock() s.mutex.Unlock()
@ -428,7 +428,7 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr er
} }
s.ctxCancel() s.ctxCancel()
s.canceledWrite = true s.canceledWrite = true
s.cancelWriteErr = writeErr s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
s.numOutstandingFrames = 0 s.numOutstandingFrames = 0
s.retransmissionQueue = nil s.retransmissionQueue = nil
newlyCompleted := s.isNewlyCompleted() newlyCompleted := s.isNewlyCompleted()
@ -457,10 +457,7 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
} }
func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
s.cancelWriteImpl(frame.ErrorCode, &StreamError{ s.cancelWriteImpl(frame.ErrorCode, true)
StreamID: s.streamID,
ErrorCode: frame.ErrorCode,
})
} }
func (s *sendStream) Context() context.Context { func (s *sendStream) Context() context.Context {

View file

@ -726,7 +726,11 @@ var _ = Describe("Send Stream", func() {
defer GinkgoRecover() defer GinkgoRecover()
var err error var err error
n, err = strWithTimeout.Write(getData(5000)) n, err = strWithTimeout.Write(getData(5000))
Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) Expect(err).To(Equal(&StreamError{
StreamID: streamID,
ErrorCode: 1234,
Remote: false,
}))
close(writeReturned) close(writeReturned)
}() }()
waitForWrite() waitForWrite()
@ -770,7 +774,11 @@ var _ = Describe("Send Stream", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := strWithTimeout.Write(getData(5000)) _, err := strWithTimeout.Write(getData(5000))
Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) Expect(err).To(Equal(&StreamError{
StreamID: streamID,
ErrorCode: 1234,
Remote: false,
}))
close(writeReturned) close(writeReturned)
}() }()
waitForWrite() waitForWrite()
@ -818,7 +826,11 @@ var _ = Describe("Send Stream", func() {
mockSender.EXPECT().onStreamCompleted(gomock.Any()) mockSender.EXPECT().onStreamCompleted(gomock.Any())
str.CancelWrite(1234) str.CancelWrite(1234)
_, err := strWithTimeout.Write([]byte("foobar")) _, err := strWithTimeout.Write([]byte("foobar"))
Expect(err).To(MatchError("Write on stream 1337 canceled with error code 1234")) Expect(err).To(MatchError(&StreamError{
StreamID: streamID,
ErrorCode: 1234,
Remote: false,
}))
}) })
It("only cancels once", func() { It("only cancels once", func() {
@ -862,9 +874,10 @@ var _ = Describe("Send Stream", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := str.Write(getData(5000)) _, err := str.Write(getData(5000))
Expect(err).To(MatchError(&StreamError{ Expect(err).To(Equal(&StreamError{
StreamID: streamID, StreamID: streamID,
ErrorCode: 1234, ErrorCode: 123,
Remote: true,
})) }))
close(done) close(done)
}() }()
@ -884,9 +897,10 @@ var _ = Describe("Send Stream", func() {
ErrorCode: 123, ErrorCode: 123,
}) })
_, err := str.Write([]byte("foobar")) _, err := str.Write([]byte("foobar"))
Expect(err).To(MatchError(&StreamError{ Expect(err).To(Equal(&StreamError{
StreamID: streamID, StreamID: streamID,
ErrorCode: 1234, ErrorCode: 123,
Remote: true,
})) }))
}) })
}) })