diff --git a/errors.go b/errors.go index bd52a289..b2d2b73b 100644 --- a/errors.go +++ b/errors.go @@ -1,6 +1,8 @@ package quic import ( + "fmt" + "github.com/lucas-clemente/quic-go/internal/qerr" ) @@ -14,6 +16,7 @@ type ( type ( TransportErrorCode = qerr.TransportErrorCode ApplicationErrorCode = qerr.ApplicationErrorCode + StreamErrorCode = qerr.StreamErrorCode ) const ( @@ -35,3 +38,19 @@ const ( AEADLimitReached = qerr.AEADLimitReached NoViablePathError = qerr.NoViablePathError ) + +// A StreamError is used for Stream.CancelRead and Stream.CancelWrite. +// It is also returned from Stream.Read and Stream.Write if the peer canceled reading or writing. +type StreamError struct { + StreamID StreamID + ErrorCode StreamErrorCode +} + +func (e *StreamError) Is(target error) bool { + _, ok := target.(*StreamError) + return ok +} + +func (e *StreamError) Error() string { + return fmt.Sprintf("stream %d canceled with error code %d", e.StreamID, e.ErrorCode) +} diff --git a/fuzzing/frames/cmd/corpus.go b/fuzzing/frames/cmd/corpus.go index 308d320d..78ef754f 100644 --- a/fuzzing/frames/cmd/corpus.go +++ b/fuzzing/frames/cmd/corpus.go @@ -124,17 +124,17 @@ func getFrames() []wire.Frame { &wire.PingFrame{}, &wire.ResetStreamFrame{ StreamID: protocol.StreamID(getRandomNumber()), - ErrorCode: quic.ApplicationErrorCode(getRandomNumber()), + ErrorCode: quic.StreamErrorCode(getRandomNumber()), FinalSize: protocol.ByteCount(getRandomNumber()), }, &wire.ResetStreamFrame{ // at maximum offset StreamID: protocol.StreamID(getRandomNumber()), - ErrorCode: quic.ApplicationErrorCode(getRandomNumber()), + ErrorCode: quic.StreamErrorCode(getRandomNumber()), FinalSize: protocol.MaxByteCount, }, &wire.StopSendingFrame{ StreamID: protocol.StreamID(getRandomNumber()), - ErrorCode: quic.ApplicationErrorCode(getRandomNumber()), + ErrorCode: quic.StreamErrorCode(getRandomNumber()), }, &wire.CryptoFrame{ Data: getRandomData(100), diff --git a/http3/body.go b/http3/body.go index 4433fef8..851eaa1f 100644 --- a/http3/body.go +++ b/http3/body.go @@ -93,6 +93,6 @@ func (r *body) requestDone() { func (r *body) Close() error { r.requestDone() // If the EOF was read, CancelRead() is a no-op. - r.str.CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)) + r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) return nil } diff --git a/http3/body_test.go b/http3/body_test.go index e63c52be..d9d5c780 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -173,12 +173,12 @@ var _ = Describe("Body", func() { }) It("closes responses", func() { - str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)) Expect(rb.Close()).To(Succeed()) }) It("allows multiple calls to Close", func() { - str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)).MaxTimes(2) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2) Expect(rb.Close()).To(Succeed()) Expect(reqDone).To(BeClosed()) Expect(rb.Close()).To(Succeed()) diff --git a/http3/client.go b/http3/client.go index 2af59da9..8db038f8 100644 --- a/http3/client.go +++ b/http3/client.go @@ -165,7 +165,7 @@ func (c *client) handleUnidirectionalStreams() { c.session.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") return default: - str.CancelRead(quic.ApplicationErrorCode(errorStreamCreationError)) + str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } f, err := parseNextFrame(str) @@ -243,8 +243,8 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { go func() { select { case <-req.Context().Done(): - str.CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)) - str.CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)) + str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) case <-reqDone: } }() @@ -253,7 +253,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { if rerr.err != nil { // if any error occurred close(reqDone) if rerr.streamErr != 0 { // if it was a stream error - str.CancelWrite(quic.ApplicationErrorCode(rerr.streamErr)) + str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) } if rerr.connErr != 0 { // if it was a connection error var reason string diff --git a/http3/client_test.go b/http3/client_test.go index 28ca29bc..1282b4ed 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -267,7 +267,7 @@ var _ = Describe("Client", func() { str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() done := make(chan struct{}) - str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorStreamCreationError)).Do(func(code quic.ApplicationErrorCode) { + str.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)).Do(func(code quic.StreamErrorCode) { close(done) }) @@ -546,7 +546,7 @@ var _ = Describe("Client", func() { request.Body.(*mockBody).readErr = errors.New("testErr") done := make(chan struct{}) gomock.InOrder( - str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)).Do(func(quic.ApplicationErrorCode) { + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), str.EXPECT().CancelWrite(gomock.Any()), @@ -596,7 +596,7 @@ var _ = Describe("Client", func() { It("cancels the stream when the HEADERS frame is too large", func() { buf := &bytes.Buffer{} (&headersFrame{Length: 1338}).Write(buf) - str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorFrameError)) + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)) closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() @@ -635,8 +635,8 @@ var _ = Describe("Client", func() { done := make(chan struct{}) canceled := make(chan struct{}) gomock.InOrder( - str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)).Do(func(quic.ApplicationErrorCode) { close(canceled) }), - str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)).Do(func(quic.ApplicationErrorCode) { close(done) }), + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), ) str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { @@ -663,8 +663,8 @@ var _ = Describe("Client", func() { done := make(chan struct{}) str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() - str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)) - str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)).Do(func(quic.ApplicationErrorCode) { close(done) }) + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) _, err := client.RoundTrip(req) Expect(err).ToNot(HaveOccurred()) cancel() diff --git a/http3/request_writer.go b/http3/request_writer.go index 2382e948..8878c8f1 100644 --- a/http3/request_writer.go +++ b/http3/request_writer.go @@ -79,7 +79,7 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo if rerr == io.EOF { break } - str.CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)) + str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) w.logger.Errorf("Error writing request: %s", rerr) return } diff --git a/http3/server.go b/http3/server.go index 69ae74f3..c710be9d 100644 --- a/http3/server.go +++ b/http3/server.go @@ -263,7 +263,7 @@ func (s *Server) handleConn(sess quic.EarlySession) { if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { s.logger.Debugf("Handling request failed: %s", err) if rerr.streamErr != 0 { - str.CancelWrite(quic.ApplicationErrorCode(rerr.streamErr)) + str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) } if rerr.connErr != 0 { var reason string @@ -304,7 +304,7 @@ func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) { sess.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "") return default: - str.CancelRead(quic.ApplicationErrorCode(errorStreamCreationError)) + str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } f, err := parseNextFrame(str) @@ -410,7 +410,7 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac r.WriteHeader(200) } // If the EOF was read by the handler, CancelRead() is a no-op. - str.CancelRead(quic.ApplicationErrorCode(errorNoError)) + str.CancelRead(quic.StreamErrorCode(errorNoError)) } return requestError{} } diff --git a/http3/server_test.go b/http3/server_test.go index 431e806b..987aae1a 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -257,7 +257,7 @@ var _ = Describe("Server", func() { str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() done := make(chan struct{}) - str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorStreamCreationError)).Do(func(code quic.ApplicationErrorCode) { + str.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)).Do(func(code quic.StreamErrorCode) { close(done) }) @@ -408,7 +408,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() - str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorNoError)) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) str.EXPECT().Close().Do(func() { close(done) }) s.handleConn(sess) @@ -431,7 +431,7 @@ var _ = Describe("Server", func() { setRequest(append(requestData, buf.Bytes()...)) done := make(chan struct{}) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() - str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorFrameError)).Do(func(quic.ApplicationErrorCode) { close(done) }) + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) s.handleConn(sess) Eventually(done).Should(BeClosed()) @@ -446,7 +446,7 @@ var _ = Describe("Server", func() { testErr := errors.New("stream reset") done := make(chan struct{}) str.EXPECT().Read(gomock.Any()).Return(0, testErr) - str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorRequestIncomplete)).Do(func(quic.ApplicationErrorCode) { close(done) }) + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) s.handleConn(sess) Consistently(handlerCalled).ShouldNot(BeClosed()) @@ -491,7 +491,7 @@ var _ = Describe("Server", func() { return len(p), nil }).AnyTimes() done := make(chan struct{}) - str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorFrameError)).Do(func(quic.ApplicationErrorCode) { close(done) }) + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) s.handleConn(sess) Eventually(done).Should(BeClosed()) @@ -513,7 +513,7 @@ var _ = Describe("Server", func() { str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return len(p), nil }).AnyTimes() - str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorNoError)) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) serr := s.handleRequest(sess, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) @@ -536,7 +536,7 @@ var _ = Describe("Server", func() { str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { return len(p), nil }).AnyTimes() - str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorNoError)) + str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) serr := s.handleRequest(sess, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index aa97a149..fb966638 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -45,7 +45,10 @@ var _ = Describe("Stream Cancelations", func() { str, err := sess.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) if _, err := str.Write(PRData); err != nil { - Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) + Expect(err).To(MatchError(&quic.StreamError{ + StreamID: str.StreamID(), + ErrorCode: quic.StreamErrorCode(str.StreamID()), + })) atomic.AddInt32(&canceledCounter, 1) return } @@ -87,7 +90,7 @@ var _ = Describe("Stream Cancelations", func() { // cancel around 2/3 of the streams if rand.Int31()%3 != 0 { atomic.AddInt32(&canceledCounter, 1) - str.CancelRead(quic.ApplicationErrorCode(str.StreamID())) + str.CancelRead(quic.StreamErrorCode(str.StreamID())) return } data, err := ioutil.ReadAll(str) @@ -133,7 +136,7 @@ var _ = Describe("Stream Cancelations", func() { length := int(rand.Int31n(int32(len(PRData) - 1))) data, err := ioutil.ReadAll(io.LimitReader(str, int64(length))) Expect(err).ToNot(HaveOccurred()) - str.CancelRead(quic.ApplicationErrorCode(str.StreamID())) + str.CancelRead(quic.StreamErrorCode(str.StreamID())) Expect(data).To(Equal(PRData[:length])) atomic.AddInt32(&canceledCounter, 1) return @@ -179,7 +182,10 @@ var _ = Describe("Stream Cancelations", func() { data, err := ioutil.ReadAll(str) if err != nil { atomic.AddInt32(&counter, 1) - Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) + Expect(err).To(MatchError(&quic.StreamError{ + StreamID: str.StreamID(), + ErrorCode: quic.StreamErrorCode(str.StreamID()), + })) return } Expect(data).To(Equal(PRData)) @@ -212,7 +218,7 @@ var _ = Describe("Stream Cancelations", func() { Expect(err).ToNot(HaveOccurred()) // cancel about 2/3 of the streams if rand.Int31()%3 != 0 { - str.CancelWrite(quic.ApplicationErrorCode(str.StreamID())) + str.CancelWrite(quic.StreamErrorCode(str.StreamID())) atomic.AddInt32(&canceledCounter, 1) return } @@ -246,7 +252,7 @@ var _ = Describe("Stream Cancelations", func() { length := int(rand.Int31n(int32(len(PRData) - 1))) _, err = str.Write(PRData[:length]) Expect(err).ToNot(HaveOccurred()) - str.CancelWrite(quic.ApplicationErrorCode(str.StreamID())) + str.CancelWrite(quic.StreamErrorCode(str.StreamID())) atomic.AddInt32(&canceledCounter, 1) return } @@ -282,11 +288,14 @@ var _ = Describe("Stream Cancelations", func() { Expect(err).ToNot(HaveOccurred()) // cancel about half of the streams if rand.Int31()%2 == 0 { - str.CancelWrite(quic.ApplicationErrorCode(str.StreamID())) + str.CancelWrite(quic.StreamErrorCode(str.StreamID())) return } if _, err = str.Write(PRData); err != nil { - Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) + Expect(err).To(MatchError(&quic.StreamError{ + StreamID: str.StreamID(), + ErrorCode: quic.StreamErrorCode(str.StreamID()), + })) return } if err := str.Close(); err != nil { @@ -317,12 +326,15 @@ var _ = Describe("Stream Cancelations", func() { Expect(err).ToNot(HaveOccurred()) // cancel around half of the streams if rand.Int31()%2 == 0 { - str.CancelRead(quic.ApplicationErrorCode(str.StreamID())) + str.CancelRead(quic.StreamErrorCode(str.StreamID())) return } data, err := ioutil.ReadAll(str) if err != nil { - Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) + Expect(err).To(MatchError(&quic.StreamError{ + StreamID: str.StreamID(), + ErrorCode: quic.StreamErrorCode(str.StreamID()), + })) return } atomic.AddInt32(&counter, 1) @@ -364,11 +376,14 @@ var _ = Describe("Stream Cancelations", func() { length = int(rand.Int31n(int32(len(PRData) - 1))) } if _, err = str.Write(PRData[:length]); err != nil { - Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) + Expect(err).To(MatchError(&quic.StreamError{ + StreamID: str.StreamID(), + ErrorCode: quic.StreamErrorCode(str.StreamID()), + })) return } if length < len(PRData) { - str.CancelWrite(quic.ApplicationErrorCode(str.StreamID())) + str.CancelWrite(quic.StreamErrorCode(str.StreamID())) } else if err := str.Close(); err != nil { Expect(err).To(MatchError(fmt.Sprintf("close called for canceled stream %d", str.StreamID()))) return @@ -405,12 +420,15 @@ var _ = Describe("Stream Cancelations", func() { } data, err := ioutil.ReadAll(r) if err != nil { - Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) + Expect(err).To(MatchError(&quic.StreamError{ + StreamID: str.StreamID(), + ErrorCode: quic.StreamErrorCode(str.StreamID()), + })) return } Expect(data).To(Equal(PRData[:length])) if length < len(PRData) { - str.CancelRead(quic.ApplicationErrorCode(str.StreamID())) + str.CancelRead(quic.StreamErrorCode(str.StreamID())) return } diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index a8817da3..b6339435 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -6,6 +6,7 @@ import ( "compress/gzip" "context" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -24,11 +25,6 @@ import ( "github.com/onsi/gomega/gbytes" ) -type streamCancelError interface { - Canceled() bool - ErrorCode() quic.ApplicationErrorCode -} - var _ = Describe("HTTP tests", func() { var ( mux *http.ServeMux @@ -260,10 +256,9 @@ var _ = Describe("HTTP tests", func() { for { if _, err := w.Write([]byte("foobar")); err != nil { Expect(r.Context().Done()).To(BeClosed()) - serr, ok := err.(streamCancelError) - Expect(ok).To(BeTrue()) - Expect(serr.Canceled()).To(BeTrue()) - Expect(serr.ErrorCode()).To(BeEquivalentTo(0x10c)) + var strErr *quic.StreamError + Expect(errors.As(err, &strErr)).To(BeTrue()) + Expect(strErr.ErrorCode).To(Equal(quic.StreamErrorCode(0x10c))) return } } diff --git a/interface.go b/interface.go index c57af432..56b90e04 100644 --- a/interface.go +++ b/interface.go @@ -104,7 +104,7 @@ type ReceiveStream interface { // It will ask the peer to stop transmitting stream data. // Read will unblock immediately, and future Read calls will fail. // When called multiple times or after reading the io.EOF it is a no-op. - CancelRead(ApplicationErrorCode) + CancelRead(StreamErrorCode) // SetReadDeadline sets the deadline for future Read calls and // any currently-blocked Read call. // A zero value for t means Read will not time out. @@ -133,7 +133,7 @@ type SendStream interface { // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. // Write will unblock immediately, and future calls to Write will fail. // When called multiple times or after closing the stream it is a no-op. - CancelWrite(ApplicationErrorCode) + CancelWrite(StreamErrorCode) // The context is canceled as soon as the write-side of the stream is closed. // This happens when Close() or CancelWrite() is called, or when the peer // cancels the read-side of their stream. @@ -147,13 +147,6 @@ type SendStream interface { SetWriteDeadline(t time.Time) error } -// StreamError is returned by Read and Write when the peer cancels the stream. -type StreamError interface { - error - Canceled() bool - ErrorCode() ApplicationErrorCode -} - // A Session is a QUIC connection between two peers. type Session interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. diff --git a/internal/mocks/quic/stream.go b/internal/mocks/quic/stream.go index 992baa0d..43f9c1d1 100644 --- a/internal/mocks/quic/stream.go +++ b/internal/mocks/quic/stream.go @@ -38,7 +38,7 @@ func (m *MockStream) EXPECT() *MockStreamMockRecorder { } // CancelRead mocks base method. -func (m *MockStream) CancelRead(arg0 qerr.ApplicationErrorCode) { +func (m *MockStream) CancelRead(arg0 qerr.StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelRead", arg0) } @@ -50,7 +50,7 @@ func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { } // CancelWrite mocks base method. -func (m *MockStream) CancelWrite(arg0 qerr.ApplicationErrorCode) { +func (m *MockStream) CancelWrite(arg0 qerr.StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelWrite", arg0) } diff --git a/internal/qerr/quic_error.go b/internal/qerr/quic_error.go index c4ad86bd..d9c19a6f 100644 --- a/internal/qerr/quic_error.go +++ b/internal/qerr/quic_error.go @@ -52,6 +52,9 @@ func (e *TransportError) Error() string { // An ApplicationErrorCode is an application-defined error code. type ApplicationErrorCode uint64 +// A StreamErrorCode is an error code used to cancel streams. +type StreamErrorCode uint64 + type ApplicationError struct { Remote bool ErrorCode ApplicationErrorCode diff --git a/internal/wire/reset_stream_frame.go b/internal/wire/reset_stream_frame.go index 8c2e66d8..69bbc2b9 100644 --- a/internal/wire/reset_stream_frame.go +++ b/internal/wire/reset_stream_frame.go @@ -11,7 +11,7 @@ import ( // A ResetStreamFrame is a RESET_STREAM frame in QUIC type ResetStreamFrame struct { StreamID protocol.StreamID - ErrorCode qerr.ApplicationErrorCode + ErrorCode qerr.StreamErrorCode FinalSize protocol.ByteCount } @@ -39,7 +39,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr return &ResetStreamFrame{ StreamID: streamID, - ErrorCode: qerr.ApplicationErrorCode(errorCode), + ErrorCode: qerr.StreamErrorCode(errorCode), FinalSize: byteOffset, }, nil } diff --git a/internal/wire/reset_stream_frame_test.go b/internal/wire/reset_stream_frame_test.go index 61cf6c97..a1f0259d 100644 --- a/internal/wire/reset_stream_frame_test.go +++ b/internal/wire/reset_stream_frame_test.go @@ -23,7 +23,7 @@ var _ = Describe("RESET_STREAM frame", func() { Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.FinalSize).To(Equal(protocol.ByteCount(0x987654321))) - Expect(frame.ErrorCode).To(Equal(qerr.ApplicationErrorCode(0x1337))) + Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) }) It("errors on EOFs", func() { diff --git a/internal/wire/stop_sending_frame.go b/internal/wire/stop_sending_frame.go index 283f3eae..fb1160c1 100644 --- a/internal/wire/stop_sending_frame.go +++ b/internal/wire/stop_sending_frame.go @@ -11,7 +11,7 @@ import ( // A StopSendingFrame is a STOP_SENDING frame type StopSendingFrame struct { StreamID protocol.StreamID - ErrorCode qerr.ApplicationErrorCode + ErrorCode qerr.StreamErrorCode } // parseStopSendingFrame parses a STOP_SENDING frame @@ -31,7 +31,7 @@ func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSend return &StopSendingFrame{ StreamID: protocol.StreamID(streamID), - ErrorCode: qerr.ApplicationErrorCode(errorCode), + ErrorCode: qerr.StreamErrorCode(errorCode), }, nil } diff --git a/internal/wire/stop_sending_frame_test.go b/internal/wire/stop_sending_frame_test.go index e78f0c0b..9a3dcda0 100644 --- a/internal/wire/stop_sending_frame_test.go +++ b/internal/wire/stop_sending_frame_test.go @@ -21,7 +21,7 @@ var _ = Describe("STOP_SENDING frame", func() { frame, err := parseStopSendingFrame(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) - Expect(frame.ErrorCode).To(Equal(qerr.ApplicationErrorCode(0x1337))) + Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) Expect(b.Len()).To(BeZero()) }) diff --git a/mock_receive_stream_internal_test.go b/mock_receive_stream_internal_test.go index 0bddae4f..a39b97df 100644 --- a/mock_receive_stream_internal_test.go +++ b/mock_receive_stream_internal_test.go @@ -37,7 +37,7 @@ func (m *MockReceiveStreamI) EXPECT() *MockReceiveStreamIMockRecorder { } // CancelRead mocks base method. -func (m *MockReceiveStreamI) CancelRead(arg0 ApplicationErrorCode) { +func (m *MockReceiveStreamI) CancelRead(arg0 StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelRead", arg0) } diff --git a/mock_send_stream_internal_test.go b/mock_send_stream_internal_test.go index 07de6e25..764bcbaa 100644 --- a/mock_send_stream_internal_test.go +++ b/mock_send_stream_internal_test.go @@ -39,7 +39,7 @@ func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder { } // CancelWrite mocks base method. -func (m *MockSendStreamI) CancelWrite(arg0 ApplicationErrorCode) { +func (m *MockSendStreamI) CancelWrite(arg0 StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelWrite", arg0) } diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index a89aef4b..ff8a2f5f 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -39,7 +39,7 @@ func (m *MockStreamI) EXPECT() *MockStreamIMockRecorder { } // CancelRead mocks base method. -func (m *MockStreamI) CancelRead(arg0 ApplicationErrorCode) { +func (m *MockStreamI) CancelRead(arg0 StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelRead", arg0) } @@ -51,7 +51,7 @@ func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call { } // CancelWrite mocks base method. -func (m *MockStreamI) CancelWrite(arg0 ApplicationErrorCode) { +func (m *MockStreamI) CancelWrite(arg0 StreamErrorCode) { m.ctrl.T.Helper() m.ctrl.Call(m, "CancelWrite", arg0) } diff --git a/receive_stream.go b/receive_stream.go index 7df588e5..cca3f709 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -39,7 +39,7 @@ type receiveStream struct { closeForShutdownErr error cancelReadErr error - resetRemotelyErr StreamError + resetRemotelyErr *StreamError closedForShutdown bool // set when CloseForShutdown() is called finRead bool // set once we read a frame with a Fin @@ -197,7 +197,7 @@ func (s *receiveStream) dequeueNextFrame() { s.readPosInFrame = 0 } -func (s *receiveStream) CancelRead(errorCode qerr.ApplicationErrorCode) { +func (s *receiveStream) CancelRead(errorCode StreamErrorCode) { s.mutex.Lock() completed := s.cancelReadImpl(errorCode) s.mutex.Unlock() @@ -208,7 +208,7 @@ func (s *receiveStream) CancelRead(errorCode qerr.ApplicationErrorCode) { } } -func (s *receiveStream) cancelReadImpl(errorCode qerr.ApplicationErrorCode) bool /* completed */ { +func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ { if s.finRead || s.canceledRead || s.resetRemotely { return false } @@ -282,9 +282,9 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) return false, nil } s.resetRemotely = true - s.resetRemotelyErr = streamCanceledError{ - errorCode: frame.ErrorCode, - error: fmt.Errorf("stream %d was reset with error code %d", s.streamID, frame.ErrorCode), + s.resetRemotelyErr = &StreamError{ + StreamID: s.streamID, + ErrorCode: frame.ErrorCode, } s.signalRead() return newlyRcvdFinalOffset, nil diff --git a/receive_stream_test.go b/receive_stream_test.go index 45e8a771..088be429 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -9,7 +9,6 @@ import ( "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -572,10 +571,10 @@ var _ = Describe("Receive Stream", func() { go func() { defer GinkgoRecover() _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError("stream 1337 was reset with error code 1234")) - Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) - Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) - Expect(err.(streamCanceledError).ErrorCode()).To(Equal(qerr.ApplicationErrorCode(1234))) + Expect(err).To(MatchError(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + })) close(done) }() Consistently(done).ShouldNot(BeClosed()) @@ -596,10 +595,10 @@ var _ = Describe("Receive Stream", func() { ) Expect(str.handleResetStreamFrame(rst)).To(Succeed()) _, err := strWithTimeout.Read([]byte{0}) - Expect(err).To(MatchError("stream 1337 was reset with error code 1234")) - Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) - Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) - Expect(err.(streamCanceledError).ErrorCode()).To(Equal(qerr.ApplicationErrorCode(1234))) + Expect(err).To(MatchError(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + })) }) It("errors when receiving a RESET_STREAM with an inconsistent offset", func() { diff --git a/send_stream.go b/send_stream.go index 75bbceba..946243ca 100644 --- a/send_stream.go +++ b/send_stream.go @@ -407,12 +407,12 @@ func (s *sendStream) Close() error { return nil } -func (s *sendStream) CancelWrite(errorCode qerr.ApplicationErrorCode) { +func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) } // must be called after locking the mutex -func (s *sendStream) cancelWriteImpl(errorCode qerr.ApplicationErrorCode, writeErr error) { +func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) { s.mutex.Lock() if s.canceledWrite { s.mutex.Unlock() @@ -449,11 +449,10 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { } func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - writeErr := streamCanceledError{ - errorCode: frame.ErrorCode, - error: fmt.Errorf("stream %d was reset with error code %d", s.streamID, frame.ErrorCode), - } - s.cancelWriteImpl(frame.ErrorCode, writeErr) + s.cancelWriteImpl(frame.ErrorCode, &StreamError{ + StreamID: s.streamID, + ErrorCode: frame.ErrorCode, + }) } func (s *sendStream) Context() context.Context { diff --git a/send_stream_test.go b/send_stream_test.go index 953d1dec..d2e07528 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -12,7 +12,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -863,10 +862,10 @@ var _ = Describe("Send Stream", func() { go func() { defer GinkgoRecover() _, err := str.Write(getData(5000)) - Expect(err).To(MatchError("stream 1337 was reset with error code 123")) - Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) - Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) - Expect(err.(streamCanceledError).ErrorCode()).To(Equal(qerr.ApplicationErrorCode(123))) + Expect(err).To(MatchError(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + })) close(done) }() waitForWrite() @@ -885,10 +884,10 @@ var _ = Describe("Send Stream", func() { ErrorCode: 123, }) _, err := str.Write([]byte("foobar")) - Expect(err).To(MatchError("stream 1337 was reset with error code 123")) - Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) - Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) - Expect(err.(streamCanceledError).ErrorCode()).To(Equal(qerr.ApplicationErrorCode(123))) + Expect(err).To(MatchError(&StreamError{ + StreamID: streamID, + ErrorCode: 1234, + })) }) }) }) diff --git a/stream.go b/stream.go index 112464e7..95bbcb35 100644 --- a/stream.go +++ b/stream.go @@ -9,7 +9,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -87,16 +86,6 @@ type stream struct { var _ Stream = &stream{} -type streamCanceledError struct { - error - errorCode qerr.ApplicationErrorCode -} - -func (streamCanceledError) Canceled() bool { return true } -func (e streamCanceledError) ErrorCode() qerr.ApplicationErrorCode { return e.errorCode } - -var _ StreamError = &streamCanceledError{} - // newStream creates a new Stream func newStream(streamID protocol.StreamID, sender streamSender,