diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go index 480d8d77..cd2940cc 100644 --- a/streams_map_outgoing_bidi.go +++ b/streams_map_outgoing_bidi.go @@ -49,6 +49,10 @@ func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() + if m.closeErr != nil { + return nil, m.closeErr + } + str, err := m.openStreamImpl() if err != nil { return nil, streamOpenErr{err} @@ -61,6 +65,9 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { defer m.mutex.Unlock() for { + if m.closeErr != nil { + return nil, m.closeErr + } str, err := m.openStreamImpl() if err == nil { return str, nil @@ -73,9 +80,6 @@ func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { } func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) { - if m.closeErr != nil { - return nil, m.closeErr - } if !m.maxStreamSet || m.nextStream > m.maxStream { if !m.blockedSent { if m.maxStreamSet { diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go index e9cff98b..8e2a4794 100644 --- a/streams_map_outgoing_generic.go +++ b/streams_map_outgoing_generic.go @@ -47,6 +47,10 @@ func (m *outgoingItemsMap) OpenStream() (item, error) { m.mutex.Lock() defer m.mutex.Unlock() + if m.closeErr != nil { + return nil, m.closeErr + } + str, err := m.openStreamImpl() if err != nil { return nil, streamOpenErr{err} @@ -59,6 +63,9 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) { defer m.mutex.Unlock() for { + if m.closeErr != nil { + return nil, m.closeErr + } str, err := m.openStreamImpl() if err == nil { return str, nil @@ -71,9 +78,6 @@ func (m *outgoingItemsMap) OpenStreamSync() (item, error) { } func (m *outgoingItemsMap) openStreamImpl() (item, error) { - if m.closeErr != nil { - return nil, m.closeErr - } if !m.maxStreamSet || m.nextStream > m.maxStream { if !m.blockedSent { if m.maxStreamSet { diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_generic_test.go index ce70c17a..0d288e2b 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_generic_test.go @@ -2,7 +2,6 @@ package quic import ( "errors" - "net" "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -47,12 +46,7 @@ var _ = Describe("Streams Map (outgoing)", func() { testErr := errors.New("close") m.CloseWithError(testErr) _, err := m.OpenStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(testErr.Error())) - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeFalse()) - Expect(nerr.Temporary()).To(BeFalse()) + Expect(err).To(MatchError(testErr)) }) It("gets streams", func() { @@ -155,8 +149,7 @@ var _ = Describe("Streams Map (outgoing)", func() { go func() { defer GinkgoRecover() _, err := m.OpenStreamSync() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(testErr.Error())) + Expect(err).To(MatchError(testErr)) close(done) }() diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing_uni.go index 98456bfd..c5b7f4ab 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing_uni.go @@ -49,6 +49,10 @@ func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { m.mutex.Lock() defer m.mutex.Unlock() + if m.closeErr != nil { + return nil, m.closeErr + } + str, err := m.openStreamImpl() if err != nil { return nil, streamOpenErr{err} @@ -61,6 +65,9 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { defer m.mutex.Unlock() for { + if m.closeErr != nil { + return nil, m.closeErr + } str, err := m.openStreamImpl() if err == nil { return str, nil @@ -73,9 +80,6 @@ func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { } func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) { - if m.closeErr != nil { - return nil, m.closeErr - } if !m.maxStreamSet || m.nextStream > m.maxStream { if !m.blockedSent { if m.maxStreamSet {