mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 05:37:36 +03:00
add an OpenStreamSync method to the streamsMap
This commit is contained in:
parent
af4b36b002
commit
d1562e5f87
2 changed files with 81 additions and 7 deletions
|
@ -25,8 +25,10 @@ type streamsMap struct {
|
|||
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||
highestStreamOpenedByPeer protocol.StreamID
|
||||
nextStreamOrErrCond sync.Cond
|
||||
closeErr error
|
||||
nextStreamToAccept protocol.StreamID
|
||||
openStreamCond sync.Cond
|
||||
|
||||
closeErr error
|
||||
nextStreamToAccept protocol.StreamID
|
||||
|
||||
newStream newStreamLambda
|
||||
|
||||
|
@ -52,6 +54,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect
|
|||
connectionParameters: connectionParameters,
|
||||
}
|
||||
sm.nextStreamOrErrCond.L = &sm.mutex
|
||||
sm.openStreamCond.L = &sm.mutex
|
||||
|
||||
if pers == protocol.PerspectiveClient {
|
||||
sm.nextStream = 1
|
||||
|
@ -139,11 +142,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
// OpenStream opens the next available stream
|
||||
func (m *streamsMap) OpenStream() (*stream, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
func (m *streamsMap) openStreamImpl() (*stream, error) {
|
||||
id := m.nextStream
|
||||
if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() {
|
||||
return nil, qerr.TooManyOpenStreams
|
||||
|
@ -165,6 +164,30 @@ func (m *streamsMap) OpenStream() (*stream, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
// OpenStream opens the next available stream
|
||||
func (m *streamsMap) OpenStream() (*stream, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.openStreamImpl()
|
||||
}
|
||||
|
||||
func (m *streamsMap) OpenStreamSync() (*stream, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for {
|
||||
str, err := m.openStreamImpl()
|
||||
if err == nil {
|
||||
return str, err
|
||||
}
|
||||
if err != nil && err != qerr.TooManyOpenStreams {
|
||||
return nil, err
|
||||
}
|
||||
m.openStreamCond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// AcceptStream returns the next stream opened by the peer
|
||||
// it blocks until a new stream is opened
|
||||
func (m *streamsMap) AcceptStream() (utils.Stream, error) {
|
||||
|
@ -290,6 +313,7 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
|
|||
}
|
||||
|
||||
delete(m.streams, id)
|
||||
m.openStreamCond.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -175,6 +175,13 @@ var _ = Describe("Streams Map", func() {
|
|||
Expect(m.numOutgoingStreams).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
It("errors if the stream can't be created", func() {
|
||||
testErr := errors.New("test error")
|
||||
m.newStream = func(protocol.StreamID) (*stream, error) { return nil, testErr }
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
Context("counting streams", func() {
|
||||
var maxNumStreams int
|
||||
|
||||
|
@ -210,6 +217,49 @@ var _ = Describe("Streams Map", func() {
|
|||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("opening streams synchronously", func() {
|
||||
var maxNumStreams int
|
||||
|
||||
BeforeEach(func() {
|
||||
maxNumStreams = int(cpm.GetMaxOutgoingStreams())
|
||||
})
|
||||
|
||||
openMaxNumStreams := func() {
|
||||
for i := 1; i <= maxNumStreams; i++ {
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
_, err := m.OpenStream()
|
||||
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||
}
|
||||
|
||||
It("waits until another stream is closed", func() {
|
||||
openMaxNumStreams()
|
||||
var returned bool
|
||||
var str *stream
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
str, err = m.OpenStreamSync()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
returned = true
|
||||
}()
|
||||
|
||||
Consistently(func() bool { return returned }).Should(BeFalse())
|
||||
err := m.RemoveStream(6)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(func() bool { return returned }).Should(BeTrue())
|
||||
Expect(str.StreamID()).To(Equal(protocol.StreamID(2*maxNumStreams + 2)))
|
||||
})
|
||||
|
||||
It("errors if the stream can't be created", func() {
|
||||
testErr := errors.New("test error")
|
||||
m.newStream = func(protocol.StreamID) (*stream, error) { return nil, testErr }
|
||||
_, err := m.OpenStreamSync()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("accepting streams", func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue