add an OpenStreamSync method to the streamsMap

This commit is contained in:
Marten Seemann 2017-02-20 14:45:44 +07:00
parent af4b36b002
commit d1562e5f87
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
2 changed files with 81 additions and 7 deletions

View file

@ -25,8 +25,10 @@ type streamsMap struct {
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond
closeErr error
nextStreamToAccept protocol.StreamID
openStreamCond sync.Cond
closeErr error
nextStreamToAccept protocol.StreamID
newStream newStreamLambda
@ -52,6 +54,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect
connectionParameters: connectionParameters,
}
sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamCond.L = &sm.mutex
if pers == protocol.PerspectiveClient {
sm.nextStream = 1
@ -139,11 +142,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
return s, nil
}
// OpenStream opens the next available stream
func (m *streamsMap) OpenStream() (*stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
func (m *streamsMap) openStreamImpl() (*stream, error) {
id := m.nextStream
if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() {
return nil, qerr.TooManyOpenStreams
@ -165,6 +164,30 @@ func (m *streamsMap) OpenStream() (*stream, error) {
return s, nil
}
// OpenStream opens the next available stream
func (m *streamsMap) OpenStream() (*stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.openStreamImpl()
}
func (m *streamsMap) OpenStreamSync() (*stream, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
for {
str, err := m.openStreamImpl()
if err == nil {
return str, err
}
if err != nil && err != qerr.TooManyOpenStreams {
return nil, err
}
m.openStreamCond.Wait()
}
}
// AcceptStream returns the next stream opened by the peer
// it blocks until a new stream is opened
func (m *streamsMap) AcceptStream() (utils.Stream, error) {
@ -290,6 +313,7 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
}
delete(m.streams, id)
m.openStreamCond.Signal()
return nil
}

View file

@ -175,6 +175,13 @@ var _ = Describe("Streams Map", func() {
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
})
It("errors if the stream can't be created", func() {
testErr := errors.New("test error")
m.newStream = func(protocol.StreamID) (*stream, error) { return nil, testErr }
_, err := m.OpenStream()
Expect(err).To(MatchError(testErr))
})
Context("counting streams", func() {
var maxNumStreams int
@ -210,6 +217,49 @@ var _ = Describe("Streams Map", func() {
}
})
})
Context("opening streams synchronously", func() {
var maxNumStreams int
BeforeEach(func() {
maxNumStreams = int(cpm.GetMaxOutgoingStreams())
})
openMaxNumStreams := func() {
for i := 1; i <= maxNumStreams; i++ {
_, err := m.OpenStream()
Expect(err).NotTo(HaveOccurred())
}
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
}
It("waits until another stream is closed", func() {
openMaxNumStreams()
var returned bool
var str *stream
go func() {
defer GinkgoRecover()
var err error
str, err = m.OpenStreamSync()
Expect(err).ToNot(HaveOccurred())
returned = true
}()
Consistently(func() bool { return returned }).Should(BeFalse())
err := m.RemoveStream(6)
Expect(err).ToNot(HaveOccurred())
Eventually(func() bool { return returned }).Should(BeTrue())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxNumStreams + 2)))
})
It("errors if the stream can't be created", func() {
testErr := errors.New("test error")
m.newStream = func(protocol.StreamID) (*stream, error) { return nil, testErr }
_, err := m.OpenStreamSync()
Expect(err).To(MatchError(testErr))
})
})
})
Context("accepting streams", func() {