mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
unblock Accept in streamsMap when an error is registered
This commit is contained in:
parent
7c005ec7ad
commit
2fe9da6d27
2 changed files with 34 additions and 6 deletions
|
@ -24,7 +24,8 @@ type streamsMap struct {
|
||||||
|
|
||||||
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
|
||||||
highestStreamOpenedByPeer protocol.StreamID
|
highestStreamOpenedByPeer protocol.StreamID
|
||||||
nextStreamCond sync.Cond
|
nextStreamOrErrCond sync.Cond
|
||||||
|
closeErr error
|
||||||
nextStreamToAccept protocol.StreamID
|
nextStreamToAccept protocol.StreamID
|
||||||
|
|
||||||
newStream newStreamLambda
|
newStream newStreamLambda
|
||||||
|
@ -50,7 +51,7 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connect
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
connectionParameters: connectionParameters,
|
connectionParameters: connectionParameters,
|
||||||
}
|
}
|
||||||
sm.nextStreamCond.L = &sm.mutex
|
sm.nextStreamOrErrCond.L = &sm.mutex
|
||||||
|
|
||||||
if pers == protocol.PerspectiveClient {
|
if pers == protocol.PerspectiveClient {
|
||||||
sm.nextStream = 1
|
sm.nextStream = 1
|
||||||
|
@ -100,7 +101,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
|
||||||
sid -= 2
|
sid -= 2
|
||||||
}
|
}
|
||||||
|
|
||||||
m.nextStreamCond.Broadcast()
|
m.nextStreamOrErrCond.Broadcast()
|
||||||
return m.streams[id], nil
|
return m.streams[id], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,20 +166,22 @@ func (m *streamsMap) OpenStream() (*stream, error) {
|
||||||
|
|
||||||
// AcceptStream returns the next stream opened by the peer
|
// AcceptStream returns the next stream opened by the peer
|
||||||
// it blocks until a new stream is opened
|
// it blocks until a new stream is opened
|
||||||
// TODO: implement error conditions
|
|
||||||
func (m *streamsMap) AcceptStream() (utils.Stream, error) {
|
func (m *streamsMap) AcceptStream() (utils.Stream, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
var str utils.Stream
|
var str utils.Stream
|
||||||
for {
|
for {
|
||||||
var ok bool
|
var ok bool
|
||||||
|
if m.closeErr != nil {
|
||||||
|
return nil, m.closeErr
|
||||||
|
}
|
||||||
str, ok = m.streams[m.nextStreamToAccept]
|
str, ok = m.streams[m.nextStreamToAccept]
|
||||||
if ok {
|
if ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
m.nextStreamCond.Wait()
|
m.nextStreamOrErrCond.Wait()
|
||||||
}
|
}
|
||||||
m.nextStreamToAccept += 2
|
m.nextStreamToAccept += 2
|
||||||
m.mutex.Unlock()
|
|
||||||
return str, nil
|
return str, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,3 +291,10 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
|
||||||
delete(m.streams, id)
|
delete(m.streams, id)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *streamsMap) CloseWithError(err error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.closeErr = err
|
||||||
|
m.nextStreamOrErrCond.Broadcast()
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
|
@ -318,6 +318,24 @@ var _ = Describe("Streams Map", func() {
|
||||||
}()
|
}()
|
||||||
Consistently(func() bool { return accepted }).Should(BeFalse())
|
Consistently(func() bool { return accepted }).Should(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("stops waiting when an error is registered", func() {
|
||||||
|
testErr := errors.New("testErr")
|
||||||
|
var acceptErr error
|
||||||
|
go func() {
|
||||||
|
_, acceptErr = m.AcceptStream()
|
||||||
|
}()
|
||||||
|
Consistently(func() error { return acceptErr }).ShouldNot(HaveOccurred())
|
||||||
|
m.CloseWithError(testErr)
|
||||||
|
Eventually(func() error { return acceptErr }).Should(MatchError(testErr))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("immediately returns when Accept is called after an error was registered", func() {
|
||||||
|
testErr := errors.New("testErr")
|
||||||
|
m.CloseWithError(testErr)
|
||||||
|
_, err := m.AcceptStream()
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue