unblock Accept in streamsMap when an error is registered

This commit is contained in:
Marten Seemann 2017-02-11 12:03:05 +07:00
parent 7c005ec7ad
commit 2fe9da6d27
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
2 changed files with 34 additions and 6 deletions

View file

@ -24,7 +24,8 @@ type streamsMap struct {
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
nextStreamCond sync.Cond
nextStreamOrErrCond sync.Cond
closeErr error
nextStreamToAccept protocol.StreamID
newStream newStreamLambda
@ -50,7 +51,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect
newStream: newStream,
connectionParameters: connectionParameters,
}
sm.nextStreamCond.L = &sm.mutex
sm.nextStreamOrErrCond.L = &sm.mutex
if pers == protocol.PerspectiveClient {
sm.nextStream = 1
@ -100,7 +101,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
sid -= 2
}
m.nextStreamCond.Broadcast()
m.nextStreamOrErrCond.Broadcast()
return m.streams[id], nil
}
@ -165,20 +166,22 @@ func (m *streamsMap) OpenStream() (*stream, error) {
// AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened
// TODO: implement error conditions
func (m *streamsMap) AcceptStream() (utils.Stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var str utils.Stream
for {
var ok bool
if m.closeErr != nil {
return nil, m.closeErr
}
str, ok = m.streams[m.nextStreamToAccept]
if ok {
break
}
m.nextStreamCond.Wait()
m.nextStreamOrErrCond.Wait()
}
m.nextStreamToAccept += 2
m.mutex.Unlock()
return str, nil
}
@ -288,3 +291,10 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
delete(m.streams, id)
return nil
}
func (m *streamsMap) CloseWithError(err error) {
m.mutex.Lock()
m.closeErr = err
m.nextStreamOrErrCond.Broadcast()
m.mutex.Unlock()
}

View file

@ -318,6 +318,24 @@ var _ = Describe("Streams Map", func() {
}()
Consistently(func() bool { return accepted }).Should(BeFalse())
})
It("stops waiting when an error is registered", func() {
testErr := errors.New("testErr")
var acceptErr error
go func() {
_, acceptErr = m.AcceptStream()
}()
Consistently(func() error { return acceptErr }).ShouldNot(HaveOccurred())
m.CloseWithError(testErr)
Eventually(func() error { return acceptErr }).Should(MatchError(testErr))
})
It("immediately returns when Accept is called after an error was registered", func() {
testErr := errors.New("testErr")
m.CloseWithError(testErr)
_, err := m.AcceptStream()
Expect(err).To(MatchError(testErr))
})
})
})