diff --git a/streams_map.go b/streams_map.go index e8308f69..fac38c3a 100644 --- a/streams_map.go +++ b/streams_map.go @@ -24,7 +24,8 @@ type streamsMap struct { nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID - nextStreamCond sync.Cond + nextStreamOrErrCond sync.Cond + closeErr error nextStreamToAccept protocol.StreamID newStream newStreamLambda @@ -50,7 +51,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect newStream: newStream, connectionParameters: connectionParameters, } - sm.nextStreamCond.L = &sm.mutex + sm.nextStreamOrErrCond.L = &sm.mutex if pers == protocol.PerspectiveClient { sm.nextStream = 1 @@ -100,7 +101,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { sid -= 2 } - m.nextStreamCond.Broadcast() + m.nextStreamOrErrCond.Broadcast() return m.streams[id], nil } @@ -165,20 +166,22 @@ func (m *streamsMap) OpenStream() (*stream, error) { // AcceptStream returns the next stream opened by the peer // it blocks until a new stream is opened -// TODO: implement error conditions func (m *streamsMap) AcceptStream() (utils.Stream, error) { m.mutex.Lock() + defer m.mutex.Unlock() var str utils.Stream for { var ok bool + if m.closeErr != nil { + return nil, m.closeErr + } str, ok = m.streams[m.nextStreamToAccept] if ok { break } - m.nextStreamCond.Wait() + m.nextStreamOrErrCond.Wait() } m.nextStreamToAccept += 2 - m.mutex.Unlock() return str, nil } @@ -288,3 +291,10 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error { delete(m.streams, id) return nil } + +func (m *streamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + m.nextStreamOrErrCond.Broadcast() + m.mutex.Unlock() +} diff --git a/streams_map_test.go b/streams_map_test.go index b8a00241..d2e4ddac 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -318,6 +318,24 @@ var _ = Describe("Streams Map", func() { }() Consistently(func() bool { return accepted }).Should(BeFalse()) }) + + It("stops waiting when an error is registered", func() { + testErr := errors.New("testErr") + var acceptErr error + go func() { + _, acceptErr = m.AcceptStream() + }() + Consistently(func() error { return acceptErr }).ShouldNot(HaveOccurred()) + m.CloseWithError(testErr) + Eventually(func() error { return acceptErr }).Should(MatchError(testErr)) + }) + + It("immediately returns when Accept is called after an error was registered", func() { + testErr := errors.New("testErr") + m.CloseWithError(testErr) + _, err := m.AcceptStream() + Expect(err).To(MatchError(testErr)) + }) }) })