From a2e48e204bc413f337b77459a255ab81e2cb0069 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 6 Jan 2019 17:18:43 +0700 Subject: [PATCH] return a net.Error when opening streams net.Error.Temporary() will be true if no stream can be opened when the peer's stream limit is reached. --- interface.go | 14 +++++++----- streams_map.go | 12 ++++++++++ streams_map_outgoing_bidi.go | 14 ++++++++---- streams_map_outgoing_generic.go | 14 ++++++++---- streams_map_outgoing_generic_test.go | 20 +++++++++++----- streams_map_outgoing_uni.go | 14 ++++++++---- streams_map_test.go | 34 ++++++++++++++++++++-------- 7 files changed, 85 insertions(+), 37 deletions(-) diff --git a/interface.go b/interface.go index 5ba909cb..3f93d2bc 100644 --- a/interface.go +++ b/interface.go @@ -120,20 +120,22 @@ type Session interface { // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. AcceptUniStream() (ReceiveStream, error) // OpenStream opens a new bidirectional QUIC stream. - // It returns a special error when the peer's concurrent stream limit is reached. // There is no signaling to the peer about new streams: // The peer can only accept the stream after data has been sent on the stream. - // TODO(#1152): Enable testing for the special error + // If the error is non-nil, it satisfies the net.Error interface. + // When reaching the peer's stream limit, err.Temporary() will be true. OpenStream() (Stream, error) // OpenStreamSync opens a new bidirectional QUIC stream. - // It blocks until the peer's concurrent stream limit allows a new stream to be opened. + // It blocks until a new stream can be opened. + // If the error is non-nil, it satisfies the net.Error interface. OpenStreamSync() (Stream, error) // OpenUniStream opens a new outgoing unidirectional QUIC stream. - // It returns a special error when the peer's concurrent stream limit is reached. - // TODO(#1152): Enable testing for the special error + // If the error is non-nil, it satisfies the net.Error interface. + // When reaching the peer's stream limit, Temporary() will be true. OpenUniStream() (SendStream, error) // OpenUniStreamSync opens a new outgoing unidirectional QUIC stream. - // It blocks until the peer's concurrent stream limit allows a new stream to be opened. + // It blocks until a new stream can be opened. + // If the error is non-nil, it satisfies the net.Error interface. OpenUniStreamSync() (SendStream, error) // LocalAddr returns the local address. LocalAddr() net.Addr diff --git a/streams_map.go b/streams_map.go index 4be37bf4..b7195375 100644 --- a/streams_map.go +++ b/streams_map.go @@ -1,7 +1,9 @@ package quic import ( + "errors" "fmt" + "net" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/handshake" @@ -9,6 +11,16 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) +type streamOpenErr struct{ error } + +var _ net.Error = &streamOpenErr{} + +func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams } +func (streamOpenErr) Timeout() bool { return false } + +// errTooManyOpenStreams is used internally by the outgoing streams maps. +var errTooManyOpenStreams = errors.New("too many open streams") + type streamsMap struct { perspective protocol.Perspective diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go index 6405e438..480d8d77 100644 --- a/streams_map_outgoing_bidi.go +++ b/streams_map_outgoing_bidi.go @@ -49,7 +49,11 @@ func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.openStreamImpl() + str, err := m.openStreamImpl() + if err != nil { + return nil, streamOpenErr{err} + } + return str, nil } func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { @@ -59,10 +63,10 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { for { str, err := m.openStreamImpl() if err == nil { - return str, err + return str, nil } - if err != nil && err != qerr.TooManyOpenStreams { - return nil, err + if err != nil && err != errTooManyOpenStreams { + return nil, streamOpenErr{err} } m.cond.Wait() } @@ -87,7 +91,7 @@ func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) { } m.blockedSent = true } - return nil, qerr.TooManyOpenStreams + return nil, errTooManyOpenStreams } s := m.newStream(m.nextStream) m.streams[m.nextStream] = s diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go index 23bd1917..e9cff98b 100644 --- a/streams_map_outgoing_generic.go +++ b/streams_map_outgoing_generic.go @@ -47,7 +47,11 @@ func (m *outgoingItemsMap) OpenStream() (item, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.openStreamImpl() + str, err := m.openStreamImpl() + if err != nil { + return nil, streamOpenErr{err} + } + return str, nil } func (m *outgoingItemsMap) OpenStreamSync() (item, error) { @@ -57,10 +61,10 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) { for { str, err := m.openStreamImpl() if err == nil { - return str, err + return str, nil } - if err != nil && err != qerr.TooManyOpenStreams { - return nil, err + if err != nil && err != errTooManyOpenStreams { + return nil, streamOpenErr{err} } m.cond.Wait() } @@ -85,7 +89,7 @@ func (m *outgoingItemsMap) openStreamImpl() (item, error) { } m.blockedSent = true } - return nil, qerr.TooManyOpenStreams + return nil, errTooManyOpenStreams } s := m.newStream(m.nextStream) m.streams[m.nextStream] = s diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_generic_test.go index 59a1aa82..ce70c17a 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_generic_test.go @@ -2,6 +2,7 @@ package quic import ( "errors" + "net" "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -46,7 +47,12 @@ var _ = Describe("Streams Map (outgoing)", func() { testErr := errors.New("close") m.CloseWithError(testErr) _, err := m.OpenStream() - Expect(err).To(MatchError(testErr)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeFalse()) + Expect(nerr.Temporary()).To(BeFalse()) }) It("gets streams", func() { @@ -104,7 +110,7 @@ var _ = Describe("Streams Map (outgoing)", func() { It("errors when no stream can be opened immediately", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) _, err := m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + expectTooManyStreamsError(err) }) It("blocks until a stream can be opened synchronously", func() { @@ -149,7 +155,8 @@ var _ = Describe("Streams Map (outgoing)", func() { go func() { defer GinkgoRecover() _, err := m.OpenStreamSync() - Expect(err).To(MatchError(testErr)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) close(done) }() @@ -180,7 +187,8 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(f.(*wire.StreamsBlockedFrame).StreamLimit).To(BeEquivalentTo(6)) }) _, err := m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(errTooManyOpenStreams.Error())) }) It("only sends one STREAM_ID_BLOCKED frame for one stream ID", func() { @@ -192,9 +200,9 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(err).ToNot(HaveOccurred()) // try to open a stream twice, but expect only one STREAM_ID_BLOCKED to be sent _, err = m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + expectTooManyStreamsError(err) _, err = m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + expectTooManyStreamsError(err) }) }) }) diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing_uni.go index 838c9aa9..98456bfd 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing_uni.go @@ -49,7 +49,11 @@ func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.openStreamImpl() + str, err := m.openStreamImpl() + if err != nil { + return nil, streamOpenErr{err} + } + return str, nil } func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { @@ -59,10 +63,10 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { for { str, err := m.openStreamImpl() if err == nil { - return str, err + return str, nil } - if err != nil && err != qerr.TooManyOpenStreams { - return nil, err + if err != nil && err != errTooManyOpenStreams { + return nil, streamOpenErr{err} } m.cond.Wait() } @@ -87,7 +91,7 @@ func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) { } m.blockedSent = true } - return nil, qerr.TooManyOpenStreams + return nil, errTooManyOpenStreams } s := m.newStream(m.nextStream) m.streams[m.nextStream] = s diff --git a/streams_map_test.go b/streams_map_test.go index b29781cd..bec5de22 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math" + "net" "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/flowcontrol" @@ -24,6 +25,15 @@ type streamMapping struct { firstOutgoingUniStream protocol.StreamID } +func expectTooManyStreamsError(err error) { + ExpectWithOffset(1, err).To(HaveOccurred()) + ExpectWithOffset(1, err.Error()).To(Equal(errTooManyOpenStreams.Error())) + nerr, ok := err.(net.Error) + ExpectWithOffset(1, ok).To(BeTrue()) + ExpectWithOffset(1, nerr.Temporary()).To(BeTrue()) + ExpectWithOffset(1, nerr.Timeout()).To(BeFalse()) +} + var _ = Describe("Streams Map", func() { newFlowController := func(protocol.StreamID) flowcontrol.StreamFlowController { return mocks.NewMockStreamFlowController(mockCtrl) @@ -270,7 +280,7 @@ var _ = Describe("Streams Map", func() { It("processes the parameter for outgoing streams, as a server", func() { m.perspective = protocol.PerspectiveServer _, err := m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + expectTooManyStreamsError(err) m.UpdateLimits(&handshake.TransportParameters{ MaxBidiStreams: 5, MaxUniStreams: 5, @@ -282,7 +292,7 @@ var _ = Describe("Streams Map", func() { It("processes the parameter for outgoing streams, as a client", func() { m.perspective = protocol.PerspectiveClient _, err := m.OpenUniStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + expectTooManyStreamsError(err) m.UpdateLimits(&handshake.TransportParameters{ MaxBidiStreams: 5, MaxUniStreams: 5, @@ -299,7 +309,7 @@ var _ = Describe("Streams Map", func() { It("processes IDs for outgoing bidirectional streams", func() { _, err := m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + expectTooManyStreamsError(err) Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeBidi, MaxStreams: 1, @@ -308,12 +318,12 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) _, err = m.OpenStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + expectTooManyStreamsError(err) }) It("processes IDs for outgoing unidirectional streams", func() { _, err := m.OpenUniStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + expectTooManyStreamsError(err) Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ Type: protocol.StreamTypeUni, MaxStreams: 1, @@ -322,7 +332,7 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) _, err = m.OpenUniStream() - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + expectTooManyStreamsError(err) }) }) @@ -352,13 +362,17 @@ var _ = Describe("Streams Map", func() { testErr := errors.New("test error") m.CloseWithError(testErr) _, err := m.OpenStream() - Expect(err).To(MatchError(testErr)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) _, err = m.OpenUniStream() - Expect(err).To(MatchError(testErr)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) _, err = m.AcceptStream() - Expect(err).To(MatchError(testErr)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) _, err = m.AcceptUniStream() - Expect(err).To(MatchError(testErr)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(testErr.Error())) }) }) }