introduce a quic.StreamError type and use it for stream cancelations

This commit is contained in:
Marten Seemann 2021-04-26 16:26:05 +07:00
parent 93cfef57ca
commit 90727cb41a
26 changed files with 128 additions and 114 deletions

View file

@ -1,6 +1,8 @@
package quic package quic
import ( import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qerr"
) )
@ -14,6 +16,7 @@ type (
type ( type (
TransportErrorCode = qerr.TransportErrorCode TransportErrorCode = qerr.TransportErrorCode
ApplicationErrorCode = qerr.ApplicationErrorCode ApplicationErrorCode = qerr.ApplicationErrorCode
StreamErrorCode = qerr.StreamErrorCode
) )
const ( const (
@ -35,3 +38,19 @@ const (
AEADLimitReached = qerr.AEADLimitReached AEADLimitReached = qerr.AEADLimitReached
NoViablePathError = qerr.NoViablePathError NoViablePathError = qerr.NoViablePathError
) )
// A StreamError is used for Stream.CancelRead and Stream.CancelWrite.
// It is also returned from Stream.Read and Stream.Write if the peer canceled reading or writing.
type StreamError struct {
StreamID StreamID
ErrorCode StreamErrorCode
}
func (e *StreamError) Is(target error) bool {
_, ok := target.(*StreamError)
return ok
}
func (e *StreamError) Error() string {
return fmt.Sprintf("stream %d canceled with error code %d", e.StreamID, e.ErrorCode)
}

View file

@ -124,17 +124,17 @@ func getFrames() []wire.Frame {
&wire.PingFrame{}, &wire.PingFrame{},
&wire.ResetStreamFrame{ &wire.ResetStreamFrame{
StreamID: protocol.StreamID(getRandomNumber()), StreamID: protocol.StreamID(getRandomNumber()),
ErrorCode: quic.ApplicationErrorCode(getRandomNumber()), ErrorCode: quic.StreamErrorCode(getRandomNumber()),
FinalSize: protocol.ByteCount(getRandomNumber()), FinalSize: protocol.ByteCount(getRandomNumber()),
}, },
&wire.ResetStreamFrame{ // at maximum offset &wire.ResetStreamFrame{ // at maximum offset
StreamID: protocol.StreamID(getRandomNumber()), StreamID: protocol.StreamID(getRandomNumber()),
ErrorCode: quic.ApplicationErrorCode(getRandomNumber()), ErrorCode: quic.StreamErrorCode(getRandomNumber()),
FinalSize: protocol.MaxByteCount, FinalSize: protocol.MaxByteCount,
}, },
&wire.StopSendingFrame{ &wire.StopSendingFrame{
StreamID: protocol.StreamID(getRandomNumber()), StreamID: protocol.StreamID(getRandomNumber()),
ErrorCode: quic.ApplicationErrorCode(getRandomNumber()), ErrorCode: quic.StreamErrorCode(getRandomNumber()),
}, },
&wire.CryptoFrame{ &wire.CryptoFrame{
Data: getRandomData(100), Data: getRandomData(100),

View file

@ -93,6 +93,6 @@ func (r *body) requestDone() {
func (r *body) Close() error { func (r *body) Close() error {
r.requestDone() r.requestDone()
// If the EOF was read, CancelRead() is a no-op. // If the EOF was read, CancelRead() is a no-op.
r.str.CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)) r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
return nil return nil
} }

View file

@ -173,12 +173,12 @@ var _ = Describe("Body", func() {
}) })
It("closes responses", func() { It("closes responses", func() {
str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)) str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled))
Expect(rb.Close()).To(Succeed()) Expect(rb.Close()).To(Succeed())
}) })
It("allows multiple calls to Close", func() { It("allows multiple calls to Close", func() {
str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)).MaxTimes(2) str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).MaxTimes(2)
Expect(rb.Close()).To(Succeed()) Expect(rb.Close()).To(Succeed())
Expect(reqDone).To(BeClosed()) Expect(reqDone).To(BeClosed())
Expect(rb.Close()).To(Succeed()) Expect(rb.Close()).To(Succeed())

View file

@ -165,7 +165,7 @@ func (c *client) handleUnidirectionalStreams() {
c.session.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") c.session.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
return return
default: default:
str.CancelRead(quic.ApplicationErrorCode(errorStreamCreationError)) str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return return
} }
f, err := parseNextFrame(str) f, err := parseNextFrame(str)
@ -243,8 +243,8 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
go func() { go func() {
select { select {
case <-req.Context().Done(): case <-req.Context().Done():
str.CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)) str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
str.CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)) str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
case <-reqDone: case <-reqDone:
} }
}() }()
@ -253,7 +253,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
if rerr.err != nil { // if any error occurred if rerr.err != nil { // if any error occurred
close(reqDone) close(reqDone)
if rerr.streamErr != 0 { // if it was a stream error if rerr.streamErr != 0 { // if it was a stream error
str.CancelWrite(quic.ApplicationErrorCode(rerr.streamErr)) str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
} }
if rerr.connErr != 0 { // if it was a connection error if rerr.connErr != 0 { // if it was a connection error
var reason string var reason string

View file

@ -267,7 +267,7 @@ var _ = Describe("Client", func() {
str := mockquic.NewMockStream(mockCtrl) str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
done := make(chan struct{}) done := make(chan struct{})
str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorStreamCreationError)).Do(func(code quic.ApplicationErrorCode) { str.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)).Do(func(code quic.StreamErrorCode) {
close(done) close(done)
}) })
@ -546,7 +546,7 @@ var _ = Describe("Client", func() {
request.Body.(*mockBody).readErr = errors.New("testErr") request.Body.(*mockBody).readErr = errors.New("testErr")
done := make(chan struct{}) done := make(chan struct{})
gomock.InOrder( gomock.InOrder(
str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)).Do(func(quic.ApplicationErrorCode) { str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) {
close(done) close(done)
}), }),
str.EXPECT().CancelWrite(gomock.Any()), str.EXPECT().CancelWrite(gomock.Any()),
@ -596,7 +596,7 @@ var _ = Describe("Client", func() {
It("cancels the stream when the HEADERS frame is too large", func() { It("cancels the stream when the HEADERS frame is too large", func() {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
(&headersFrame{Length: 1338}).Write(buf) (&headersFrame{Length: 1338}).Write(buf)
str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorFrameError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError))
closed := make(chan struct{}) closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
@ -635,8 +635,8 @@ var _ = Describe("Client", func() {
done := make(chan struct{}) done := make(chan struct{})
canceled := make(chan struct{}) canceled := make(chan struct{})
gomock.InOrder( gomock.InOrder(
str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)).Do(func(quic.ApplicationErrorCode) { close(canceled) }), str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)).Do(func(quic.ApplicationErrorCode) { close(done) }), str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
) )
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
@ -663,8 +663,8 @@ var _ = Describe("Client", func() {
done := make(chan struct{}) done := make(chan struct{})
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)) str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorRequestCanceled)).Do(func(quic.ApplicationErrorCode) { close(done) }) str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
_, err := client.RoundTrip(req) _, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
cancel() cancel()

View file

@ -79,7 +79,7 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bo
if rerr == io.EOF { if rerr == io.EOF {
break break
} }
str.CancelWrite(quic.ApplicationErrorCode(errorRequestCanceled)) str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
w.logger.Errorf("Error writing request: %s", rerr) w.logger.Errorf("Error writing request: %s", rerr)
return return
} }

View file

@ -263,7 +263,7 @@ func (s *Server) handleConn(sess quic.EarlySession) {
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
s.logger.Debugf("Handling request failed: %s", err) s.logger.Debugf("Handling request failed: %s", err)
if rerr.streamErr != 0 { if rerr.streamErr != 0 {
str.CancelWrite(quic.ApplicationErrorCode(rerr.streamErr)) str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
} }
if rerr.connErr != 0 { if rerr.connErr != 0 {
var reason string var reason string
@ -304,7 +304,7 @@ func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) {
sess.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "") sess.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
return return
default: default:
str.CancelRead(quic.ApplicationErrorCode(errorStreamCreationError)) str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return return
} }
f, err := parseNextFrame(str) f, err := parseNextFrame(str)
@ -410,7 +410,7 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac
r.WriteHeader(200) r.WriteHeader(200)
} }
// If the EOF was read by the handler, CancelRead() is a no-op. // If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.ApplicationErrorCode(errorNoError)) str.CancelRead(quic.StreamErrorCode(errorNoError))
} }
return requestError{} return requestError{}
} }

View file

@ -257,7 +257,7 @@ var _ = Describe("Server", func() {
str := mockquic.NewMockStream(mockCtrl) str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
done := make(chan struct{}) done := make(chan struct{})
str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorStreamCreationError)).Do(func(code quic.ApplicationErrorCode) { str.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)).Do(func(code quic.StreamErrorCode) {
close(done) close(done)
}) })
@ -408,7 +408,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{}) done := make(chan struct{})
str.EXPECT().Context().Return(reqContext) str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorNoError)) str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError))
str.EXPECT().Close().Do(func() { close(done) }) str.EXPECT().Close().Do(func() { close(done) })
s.handleConn(sess) s.handleConn(sess)
@ -431,7 +431,7 @@ var _ = Describe("Server", func() {
setRequest(append(requestData, buf.Bytes()...)) setRequest(append(requestData, buf.Bytes()...))
done := make(chan struct{}) done := make(chan struct{})
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorFrameError)).Do(func(quic.ApplicationErrorCode) { close(done) }) str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(sess) s.handleConn(sess)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -446,7 +446,7 @@ var _ = Describe("Server", func() {
testErr := errors.New("stream reset") testErr := errors.New("stream reset")
done := make(chan struct{}) done := make(chan struct{})
str.EXPECT().Read(gomock.Any()).Return(0, testErr) str.EXPECT().Read(gomock.Any()).Return(0, testErr)
str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorRequestIncomplete)).Do(func(quic.ApplicationErrorCode) { close(done) }) str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(sess) s.handleConn(sess)
Consistently(handlerCalled).ShouldNot(BeClosed()) Consistently(handlerCalled).ShouldNot(BeClosed())
@ -491,7 +491,7 @@ var _ = Describe("Server", func() {
return len(p), nil return len(p), nil
}).AnyTimes() }).AnyTimes()
done := make(chan struct{}) done := make(chan struct{})
str.EXPECT().CancelWrite(quic.ApplicationErrorCode(errorFrameError)).Do(func(quic.ApplicationErrorCode) { close(done) }) str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(sess) s.handleConn(sess)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -513,7 +513,7 @@ var _ = Describe("Server", func() {
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil return len(p), nil
}).AnyTimes() }).AnyTimes()
str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorNoError)) str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError))
serr := s.handleRequest(sess, str, qpackDecoder, nil) serr := s.handleRequest(sess, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred()) Expect(serr.err).ToNot(HaveOccurred())
@ -536,7 +536,7 @@ var _ = Describe("Server", func() {
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return len(p), nil return len(p), nil
}).AnyTimes() }).AnyTimes()
str.EXPECT().CancelRead(quic.ApplicationErrorCode(errorNoError)) str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError))
serr := s.handleRequest(sess, str, qpackDecoder, nil) serr := s.handleRequest(sess, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred()) Expect(serr.err).ToNot(HaveOccurred())

View file

@ -45,7 +45,10 @@ var _ = Describe("Stream Cancelations", func() {
str, err := sess.OpenUniStreamSync(context.Background()) str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
if _, err := str.Write(PRData); err != nil { if _, err := str.Write(PRData); err != nil {
Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) Expect(err).To(MatchError(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()),
}))
atomic.AddInt32(&canceledCounter, 1) atomic.AddInt32(&canceledCounter, 1)
return return
} }
@ -87,7 +90,7 @@ var _ = Describe("Stream Cancelations", func() {
// cancel around 2/3 of the streams // cancel around 2/3 of the streams
if rand.Int31()%3 != 0 { if rand.Int31()%3 != 0 {
atomic.AddInt32(&canceledCounter, 1) atomic.AddInt32(&canceledCounter, 1)
str.CancelRead(quic.ApplicationErrorCode(str.StreamID())) str.CancelRead(quic.StreamErrorCode(str.StreamID()))
return return
} }
data, err := ioutil.ReadAll(str) data, err := ioutil.ReadAll(str)
@ -133,7 +136,7 @@ var _ = Describe("Stream Cancelations", func() {
length := int(rand.Int31n(int32(len(PRData) - 1))) length := int(rand.Int31n(int32(len(PRData) - 1)))
data, err := ioutil.ReadAll(io.LimitReader(str, int64(length))) data, err := ioutil.ReadAll(io.LimitReader(str, int64(length)))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str.CancelRead(quic.ApplicationErrorCode(str.StreamID())) str.CancelRead(quic.StreamErrorCode(str.StreamID()))
Expect(data).To(Equal(PRData[:length])) Expect(data).To(Equal(PRData[:length]))
atomic.AddInt32(&canceledCounter, 1) atomic.AddInt32(&canceledCounter, 1)
return return
@ -179,7 +182,10 @@ var _ = Describe("Stream Cancelations", func() {
data, err := ioutil.ReadAll(str) data, err := ioutil.ReadAll(str)
if err != nil { if err != nil {
atomic.AddInt32(&counter, 1) atomic.AddInt32(&counter, 1)
Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) Expect(err).To(MatchError(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()),
}))
return return
} }
Expect(data).To(Equal(PRData)) Expect(data).To(Equal(PRData))
@ -212,7 +218,7 @@ var _ = Describe("Stream Cancelations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// cancel about 2/3 of the streams // cancel about 2/3 of the streams
if rand.Int31()%3 != 0 { if rand.Int31()%3 != 0 {
str.CancelWrite(quic.ApplicationErrorCode(str.StreamID())) str.CancelWrite(quic.StreamErrorCode(str.StreamID()))
atomic.AddInt32(&canceledCounter, 1) atomic.AddInt32(&canceledCounter, 1)
return return
} }
@ -246,7 +252,7 @@ var _ = Describe("Stream Cancelations", func() {
length := int(rand.Int31n(int32(len(PRData) - 1))) length := int(rand.Int31n(int32(len(PRData) - 1)))
_, err = str.Write(PRData[:length]) _, err = str.Write(PRData[:length])
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str.CancelWrite(quic.ApplicationErrorCode(str.StreamID())) str.CancelWrite(quic.StreamErrorCode(str.StreamID()))
atomic.AddInt32(&canceledCounter, 1) atomic.AddInt32(&canceledCounter, 1)
return return
} }
@ -282,11 +288,14 @@ var _ = Describe("Stream Cancelations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// cancel about half of the streams // cancel about half of the streams
if rand.Int31()%2 == 0 { if rand.Int31()%2 == 0 {
str.CancelWrite(quic.ApplicationErrorCode(str.StreamID())) str.CancelWrite(quic.StreamErrorCode(str.StreamID()))
return return
} }
if _, err = str.Write(PRData); err != nil { if _, err = str.Write(PRData); err != nil {
Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) Expect(err).To(MatchError(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()),
}))
return return
} }
if err := str.Close(); err != nil { if err := str.Close(); err != nil {
@ -317,12 +326,15 @@ var _ = Describe("Stream Cancelations", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// cancel around half of the streams // cancel around half of the streams
if rand.Int31()%2 == 0 { if rand.Int31()%2 == 0 {
str.CancelRead(quic.ApplicationErrorCode(str.StreamID())) str.CancelRead(quic.StreamErrorCode(str.StreamID()))
return return
} }
data, err := ioutil.ReadAll(str) data, err := ioutil.ReadAll(str)
if err != nil { if err != nil {
Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) Expect(err).To(MatchError(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()),
}))
return return
} }
atomic.AddInt32(&counter, 1) atomic.AddInt32(&counter, 1)
@ -364,11 +376,14 @@ var _ = Describe("Stream Cancelations", func() {
length = int(rand.Int31n(int32(len(PRData) - 1))) length = int(rand.Int31n(int32(len(PRData) - 1)))
} }
if _, err = str.Write(PRData[:length]); err != nil { if _, err = str.Write(PRData[:length]); err != nil {
Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) Expect(err).To(MatchError(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()),
}))
return return
} }
if length < len(PRData) { if length < len(PRData) {
str.CancelWrite(quic.ApplicationErrorCode(str.StreamID())) str.CancelWrite(quic.StreamErrorCode(str.StreamID()))
} else if err := str.Close(); err != nil { } else if err := str.Close(); err != nil {
Expect(err).To(MatchError(fmt.Sprintf("close called for canceled stream %d", str.StreamID()))) Expect(err).To(MatchError(fmt.Sprintf("close called for canceled stream %d", str.StreamID())))
return return
@ -405,12 +420,15 @@ var _ = Describe("Stream Cancelations", func() {
} }
data, err := ioutil.ReadAll(r) data, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
Expect(err).To(MatchError(fmt.Sprintf("stream %d was reset with error code %d", str.StreamID(), str.StreamID()))) Expect(err).To(MatchError(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()),
}))
return return
} }
Expect(data).To(Equal(PRData[:length])) Expect(data).To(Equal(PRData[:length]))
if length < len(PRData) { if length < len(PRData) {
str.CancelRead(quic.ApplicationErrorCode(str.StreamID())) str.CancelRead(quic.StreamErrorCode(str.StreamID()))
return return
} }

View file

@ -6,6 +6,7 @@ import (
"compress/gzip" "compress/gzip"
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -24,11 +25,6 @@ import (
"github.com/onsi/gomega/gbytes" "github.com/onsi/gomega/gbytes"
) )
type streamCancelError interface {
Canceled() bool
ErrorCode() quic.ApplicationErrorCode
}
var _ = Describe("HTTP tests", func() { var _ = Describe("HTTP tests", func() {
var ( var (
mux *http.ServeMux mux *http.ServeMux
@ -260,10 +256,9 @@ var _ = Describe("HTTP tests", func() {
for { for {
if _, err := w.Write([]byte("foobar")); err != nil { if _, err := w.Write([]byte("foobar")); err != nil {
Expect(r.Context().Done()).To(BeClosed()) Expect(r.Context().Done()).To(BeClosed())
serr, ok := err.(streamCancelError) var strErr *quic.StreamError
Expect(ok).To(BeTrue()) Expect(errors.As(err, &strErr)).To(BeTrue())
Expect(serr.Canceled()).To(BeTrue()) Expect(strErr.ErrorCode).To(Equal(quic.StreamErrorCode(0x10c)))
Expect(serr.ErrorCode()).To(BeEquivalentTo(0x10c))
return return
} }
} }

View file

@ -104,7 +104,7 @@ type ReceiveStream interface {
// It will ask the peer to stop transmitting stream data. // It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail. // Read will unblock immediately, and future Read calls will fail.
// When called multiple times or after reading the io.EOF it is a no-op. // When called multiple times or after reading the io.EOF it is a no-op.
CancelRead(ApplicationErrorCode) CancelRead(StreamErrorCode)
// SetReadDeadline sets the deadline for future Read calls and // SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call. // any currently-blocked Read call.
// A zero value for t means Read will not time out. // A zero value for t means Read will not time out.
@ -133,7 +133,7 @@ type SendStream interface {
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail. // Write will unblock immediately, and future calls to Write will fail.
// When called multiple times or after closing the stream it is a no-op. // When called multiple times or after closing the stream it is a no-op.
CancelWrite(ApplicationErrorCode) CancelWrite(StreamErrorCode)
// The context is canceled as soon as the write-side of the stream is closed. // The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() or CancelWrite() is called, or when the peer // This happens when Close() or CancelWrite() is called, or when the peer
// cancels the read-side of their stream. // cancels the read-side of their stream.
@ -147,13 +147,6 @@ type SendStream interface {
SetWriteDeadline(t time.Time) error SetWriteDeadline(t time.Time) error
} }
// StreamError is returned by Read and Write when the peer cancels the stream.
type StreamError interface {
error
Canceled() bool
ErrorCode() ApplicationErrorCode
}
// A Session is a QUIC connection between two peers. // A Session is a QUIC connection between two peers.
type Session interface { type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available. // AcceptStream returns the next stream opened by the peer, blocking until one is available.

View file

@ -38,7 +38,7 @@ func (m *MockStream) EXPECT() *MockStreamMockRecorder {
} }
// CancelRead mocks base method. // CancelRead mocks base method.
func (m *MockStream) CancelRead(arg0 qerr.ApplicationErrorCode) { func (m *MockStream) CancelRead(arg0 qerr.StreamErrorCode) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "CancelRead", arg0) m.ctrl.Call(m, "CancelRead", arg0)
} }
@ -50,7 +50,7 @@ func (mr *MockStreamMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
} }
// CancelWrite mocks base method. // CancelWrite mocks base method.
func (m *MockStream) CancelWrite(arg0 qerr.ApplicationErrorCode) { func (m *MockStream) CancelWrite(arg0 qerr.StreamErrorCode) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "CancelWrite", arg0) m.ctrl.Call(m, "CancelWrite", arg0)
} }

View file

@ -52,6 +52,9 @@ func (e *TransportError) Error() string {
// An ApplicationErrorCode is an application-defined error code. // An ApplicationErrorCode is an application-defined error code.
type ApplicationErrorCode uint64 type ApplicationErrorCode uint64
// A StreamErrorCode is an error code used to cancel streams.
type StreamErrorCode uint64
type ApplicationError struct { type ApplicationError struct {
Remote bool Remote bool
ErrorCode ApplicationErrorCode ErrorCode ApplicationErrorCode

View file

@ -11,7 +11,7 @@ import (
// A ResetStreamFrame is a RESET_STREAM frame in QUIC // A ResetStreamFrame is a RESET_STREAM frame in QUIC
type ResetStreamFrame struct { type ResetStreamFrame struct {
StreamID protocol.StreamID StreamID protocol.StreamID
ErrorCode qerr.ApplicationErrorCode ErrorCode qerr.StreamErrorCode
FinalSize protocol.ByteCount FinalSize protocol.ByteCount
} }
@ -39,7 +39,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr
return &ResetStreamFrame{ return &ResetStreamFrame{
StreamID: streamID, StreamID: streamID,
ErrorCode: qerr.ApplicationErrorCode(errorCode), ErrorCode: qerr.StreamErrorCode(errorCode),
FinalSize: byteOffset, FinalSize: byteOffset,
}, nil }, nil
} }

View file

@ -23,7 +23,7 @@ var _ = Describe("RESET_STREAM frame", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef)))
Expect(frame.FinalSize).To(Equal(protocol.ByteCount(0x987654321))) Expect(frame.FinalSize).To(Equal(protocol.ByteCount(0x987654321)))
Expect(frame.ErrorCode).To(Equal(qerr.ApplicationErrorCode(0x1337))) Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337)))
}) })
It("errors on EOFs", func() { It("errors on EOFs", func() {

View file

@ -11,7 +11,7 @@ import (
// A StopSendingFrame is a STOP_SENDING frame // A StopSendingFrame is a STOP_SENDING frame
type StopSendingFrame struct { type StopSendingFrame struct {
StreamID protocol.StreamID StreamID protocol.StreamID
ErrorCode qerr.ApplicationErrorCode ErrorCode qerr.StreamErrorCode
} }
// parseStopSendingFrame parses a STOP_SENDING frame // parseStopSendingFrame parses a STOP_SENDING frame
@ -31,7 +31,7 @@ func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSend
return &StopSendingFrame{ return &StopSendingFrame{
StreamID: protocol.StreamID(streamID), StreamID: protocol.StreamID(streamID),
ErrorCode: qerr.ApplicationErrorCode(errorCode), ErrorCode: qerr.StreamErrorCode(errorCode),
}, nil }, nil
} }

View file

@ -21,7 +21,7 @@ var _ = Describe("STOP_SENDING frame", func() {
frame, err := parseStopSendingFrame(b, versionIETFFrames) frame, err := parseStopSendingFrame(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(frame.ErrorCode).To(Equal(qerr.ApplicationErrorCode(0x1337))) Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337)))
Expect(b.Len()).To(BeZero()) Expect(b.Len()).To(BeZero())
}) })

View file

@ -37,7 +37,7 @@ func (m *MockReceiveStreamI) EXPECT() *MockReceiveStreamIMockRecorder {
} }
// CancelRead mocks base method. // CancelRead mocks base method.
func (m *MockReceiveStreamI) CancelRead(arg0 ApplicationErrorCode) { func (m *MockReceiveStreamI) CancelRead(arg0 StreamErrorCode) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "CancelRead", arg0) m.ctrl.Call(m, "CancelRead", arg0)
} }

View file

@ -39,7 +39,7 @@ func (m *MockSendStreamI) EXPECT() *MockSendStreamIMockRecorder {
} }
// CancelWrite mocks base method. // CancelWrite mocks base method.
func (m *MockSendStreamI) CancelWrite(arg0 ApplicationErrorCode) { func (m *MockSendStreamI) CancelWrite(arg0 StreamErrorCode) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "CancelWrite", arg0) m.ctrl.Call(m, "CancelWrite", arg0)
} }

View file

@ -39,7 +39,7 @@ func (m *MockStreamI) EXPECT() *MockStreamIMockRecorder {
} }
// CancelRead mocks base method. // CancelRead mocks base method.
func (m *MockStreamI) CancelRead(arg0 ApplicationErrorCode) { func (m *MockStreamI) CancelRead(arg0 StreamErrorCode) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "CancelRead", arg0) m.ctrl.Call(m, "CancelRead", arg0)
} }
@ -51,7 +51,7 @@ func (mr *MockStreamIMockRecorder) CancelRead(arg0 interface{}) *gomock.Call {
} }
// CancelWrite mocks base method. // CancelWrite mocks base method.
func (m *MockStreamI) CancelWrite(arg0 ApplicationErrorCode) { func (m *MockStreamI) CancelWrite(arg0 StreamErrorCode) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "CancelWrite", arg0) m.ctrl.Call(m, "CancelWrite", arg0)
} }

View file

@ -39,7 +39,7 @@ type receiveStream struct {
closeForShutdownErr error closeForShutdownErr error
cancelReadErr error cancelReadErr error
resetRemotelyErr StreamError resetRemotelyErr *StreamError
closedForShutdown bool // set when CloseForShutdown() is called closedForShutdown bool // set when CloseForShutdown() is called
finRead bool // set once we read a frame with a Fin finRead bool // set once we read a frame with a Fin
@ -197,7 +197,7 @@ func (s *receiveStream) dequeueNextFrame() {
s.readPosInFrame = 0 s.readPosInFrame = 0
} }
func (s *receiveStream) CancelRead(errorCode qerr.ApplicationErrorCode) { func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
s.mutex.Lock() s.mutex.Lock()
completed := s.cancelReadImpl(errorCode) completed := s.cancelReadImpl(errorCode)
s.mutex.Unlock() s.mutex.Unlock()
@ -208,7 +208,7 @@ func (s *receiveStream) CancelRead(errorCode qerr.ApplicationErrorCode) {
} }
} }
func (s *receiveStream) cancelReadImpl(errorCode qerr.ApplicationErrorCode) bool /* completed */ { func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ {
if s.finRead || s.canceledRead || s.resetRemotely { if s.finRead || s.canceledRead || s.resetRemotely {
return false return false
} }
@ -282,9 +282,9 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
return false, nil return false, nil
} }
s.resetRemotely = true s.resetRemotely = true
s.resetRemotelyErr = streamCanceledError{ s.resetRemotelyErr = &StreamError{
errorCode: frame.ErrorCode, StreamID: s.streamID,
error: fmt.Errorf("stream %d was reset with error code %d", s.streamID, frame.ErrorCode), ErrorCode: frame.ErrorCode,
} }
s.signalRead() s.signalRead()
return newlyRcvdFinalOffset, nil return newlyRcvdFinalOffset, nil

View file

@ -9,7 +9,6 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -572,10 +571,10 @@ var _ = Describe("Receive Stream", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := strWithTimeout.Read([]byte{0}) _, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("stream 1337 was reset with error code 1234")) Expect(err).To(MatchError(&StreamError{
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) StreamID: streamID,
Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) ErrorCode: 1234,
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(qerr.ApplicationErrorCode(1234))) }))
close(done) close(done)
}() }()
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
@ -596,10 +595,10 @@ var _ = Describe("Receive Stream", func() {
) )
Expect(str.handleResetStreamFrame(rst)).To(Succeed()) Expect(str.handleResetStreamFrame(rst)).To(Succeed())
_, err := strWithTimeout.Read([]byte{0}) _, err := strWithTimeout.Read([]byte{0})
Expect(err).To(MatchError("stream 1337 was reset with error code 1234")) Expect(err).To(MatchError(&StreamError{
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) StreamID: streamID,
Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) ErrorCode: 1234,
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(qerr.ApplicationErrorCode(1234))) }))
}) })
It("errors when receiving a RESET_STREAM with an inconsistent offset", func() { It("errors when receiving a RESET_STREAM with an inconsistent offset", func() {

View file

@ -407,12 +407,12 @@ func (s *sendStream) Close() error {
return nil return nil
} }
func (s *sendStream) CancelWrite(errorCode qerr.ApplicationErrorCode) { func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode))
} }
// must be called after locking the mutex // must be called after locking the mutex
func (s *sendStream) cancelWriteImpl(errorCode qerr.ApplicationErrorCode, writeErr error) { func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) {
s.mutex.Lock() s.mutex.Lock()
if s.canceledWrite { if s.canceledWrite {
s.mutex.Unlock() s.mutex.Unlock()
@ -449,11 +449,10 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
} }
func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
writeErr := streamCanceledError{ s.cancelWriteImpl(frame.ErrorCode, &StreamError{
errorCode: frame.ErrorCode, StreamID: s.streamID,
error: fmt.Errorf("stream %d was reset with error code %d", s.streamID, frame.ErrorCode), ErrorCode: frame.ErrorCode,
} })
s.cancelWriteImpl(frame.ErrorCode, writeErr)
} }
func (s *sendStream) Context() context.Context { func (s *sendStream) Context() context.Context {

View file

@ -12,7 +12,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/ackhandler"
"github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -863,10 +862,10 @@ var _ = Describe("Send Stream", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := str.Write(getData(5000)) _, err := str.Write(getData(5000))
Expect(err).To(MatchError("stream 1337 was reset with error code 123")) Expect(err).To(MatchError(&StreamError{
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) StreamID: streamID,
Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) ErrorCode: 1234,
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(qerr.ApplicationErrorCode(123))) }))
close(done) close(done)
}() }()
waitForWrite() waitForWrite()
@ -885,10 +884,10 @@ var _ = Describe("Send Stream", func() {
ErrorCode: 123, ErrorCode: 123,
}) })
_, err := str.Write([]byte("foobar")) _, err := str.Write([]byte("foobar"))
Expect(err).To(MatchError("stream 1337 was reset with error code 123")) Expect(err).To(MatchError(&StreamError{
Expect(err).To(BeAssignableToTypeOf(streamCanceledError{})) StreamID: streamID,
Expect(err.(streamCanceledError).Canceled()).To(BeTrue()) ErrorCode: 1234,
Expect(err.(streamCanceledError).ErrorCode()).To(Equal(qerr.ApplicationErrorCode(123))) }))
}) })
}) })
}) })

View file

@ -9,7 +9,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/ackhandler"
"github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
) )
@ -87,16 +86,6 @@ type stream struct {
var _ Stream = &stream{} var _ Stream = &stream{}
type streamCanceledError struct {
error
errorCode qerr.ApplicationErrorCode
}
func (streamCanceledError) Canceled() bool { return true }
func (e streamCanceledError) ErrorCode() qerr.ApplicationErrorCode { return e.errorCode }
var _ StreamError = &streamCanceledError{}
// newStream creates a new Stream // newStream creates a new Stream
func newStream(streamID protocol.StreamID, func newStream(streamID protocol.StreamID,
sender streamSender, sender streamSender,