add an OpenStreamSync method to the streamsMap

This commit is contained in:
Marten Seemann 2017-02-20 14:45:44 +07:00
parent af4b36b002
commit d1562e5f87
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
2 changed files with 81 additions and 7 deletions

View file

@ -25,8 +25,10 @@ type streamsMap struct {
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond nextStreamOrErrCond sync.Cond
closeErr error openStreamCond sync.Cond
nextStreamToAccept protocol.StreamID
closeErr error
nextStreamToAccept protocol.StreamID
newStream newStreamLambda newStream newStreamLambda
@ -52,6 +54,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect
connectionParameters: connectionParameters, connectionParameters: connectionParameters,
} }
sm.nextStreamOrErrCond.L = &sm.mutex sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamCond.L = &sm.mutex
if pers == protocol.PerspectiveClient { if pers == protocol.PerspectiveClient {
sm.nextStream = 1 sm.nextStream = 1
@ -139,11 +142,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
return s, nil return s, nil
} }
// OpenStream opens the next available stream func (m *streamsMap) openStreamImpl() (*stream, error) {
func (m *streamsMap) OpenStream() (*stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
id := m.nextStream id := m.nextStream
if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() { if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() {
return nil, qerr.TooManyOpenStreams return nil, qerr.TooManyOpenStreams
@ -165,6 +164,30 @@ func (m *streamsMap) OpenStream() (*stream, error) {
return s, nil return s, nil
} }
// OpenStream opens the next available stream
func (m *streamsMap) OpenStream() (*stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.openStreamImpl()
}
func (m *streamsMap) OpenStreamSync() (*stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for {
str, err := m.openStreamImpl()
if err == nil {
return str, err
}
if err != nil && err != qerr.TooManyOpenStreams {
return nil, err
}
m.openStreamCond.Wait()
}
}
// AcceptStream returns the next stream opened by the peer // AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened // it blocks until a new stream is opened
func (m *streamsMap) AcceptStream() (utils.Stream, error) { func (m *streamsMap) AcceptStream() (utils.Stream, error) {
@ -290,6 +313,7 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
} }
delete(m.streams, id) delete(m.streams, id)
m.openStreamCond.Signal()
return nil return nil
} }

View file

@ -175,6 +175,13 @@ var _ = Describe("Streams Map", func() {
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
}) })
It("errors if the stream can't be created", func() {
testErr := errors.New("test error")
m.newStream = func(protocol.StreamID) (*stream, error) { return nil, testErr }
_, err := m.OpenStream()
Expect(err).To(MatchError(testErr))
})
Context("counting streams", func() { Context("counting streams", func() {
var maxNumStreams int var maxNumStreams int
@ -210,6 +217,49 @@ var _ = Describe("Streams Map", func() {
} }
}) })
}) })
Context("opening streams synchronously", func() {
var maxNumStreams int
BeforeEach(func() {
maxNumStreams = int(cpm.GetMaxOutgoingStreams())
})
openMaxNumStreams := func() {
for i := 1; i <= maxNumStreams; i++ {
_, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred())
}
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
}
It("waits until another stream is closed", func() {
openMaxNumStreams()
var returned bool
var str *stream
go func() {
defer GinkgoRecover()
var err error
str, err = m.OpenStreamSync()
Expect(err).ToNot(HaveOccurred())
returned = true
}()
Consistently(func() bool { return returned }).Should(BeFalse())
err := m.RemoveStream(6)
Expect(err).ToNot(HaveOccurred())
Eventually(func() bool { return returned }).Should(BeTrue())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxNumStreams + 2)))
})
It("errors if the stream can't be created", func() {
testErr := errors.New("test error")
m.newStream = func(protocol.StreamID) (*stream, error) { return nil, testErr }
_, err := m.OpenStreamSync()
Expect(err).To(MatchError(testErr))
})
})
}) })
Context("accepting streams", func() { Context("accepting streams", func() {