don't create stream 1 in the streamsMap on GetOrOpenStream (for gQUIC)

This commit is contained in:
Marten Seemann 2017-12-25 17:19:26 +07:00
parent 091b6b2602
commit 38281575a2
2 changed files with 34 additions and 28 deletions

View file

@ -20,7 +20,7 @@ type streamsMap struct {
openStreams []protocol.StreamID openStreams []protocol.StreamID
roundRobinIndex int roundRobinIndex int
nextStream protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
highestStreamOpenedByPeer protocol.StreamID highestStreamOpenedByPeer protocol.StreamID
nextStreamOrErrCond sync.Cond nextStreamOrErrCond sync.Cond
openStreamOrErrCond sync.Cond openStreamOrErrCond sync.Cond
@ -58,18 +58,22 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver pro
sm.nextStreamOrErrCond.L = &sm.mutex sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex
nextOddStream := protocol.StreamID(1) nextClientInitiatedStream := protocol.StreamID(1)
if ver.CryptoStreamID() == protocol.StreamID(1) { nextServerInitiatedStream := protocol.StreamID(2)
nextOddStream = 3 if !ver.UsesTLS() {
nextServerInitiatedStream = 2
nextClientInitiatedStream = 3
if pers == protocol.PerspectiveServer {
sm.highestStreamOpenedByPeer = 1
}
} }
if pers == protocol.PerspectiveClient { if pers == protocol.PerspectiveServer {
sm.nextStream = nextOddStream sm.nextStreamToOpen = nextServerInitiatedStream
sm.nextStreamToAccept = 2 sm.nextStreamToAccept = nextClientInitiatedStream
} else { } else {
sm.nextStream = 2 sm.nextStreamToOpen = nextClientInitiatedStream
sm.nextStreamToAccept = nextOddStream sm.nextStreamToAccept = nextServerInitiatedStream
} }
return &sm return &sm
} }
@ -81,6 +85,13 @@ func (m *streamsMap) streamInitiatedBy(id protocol.StreamID) protocol.Perspectiv
return protocol.PerspectiveClient return protocol.PerspectiveClient
} }
func (m *streamsMap) nextStreamID(id protocol.StreamID) protocol.StreamID {
if m.perspective == protocol.PerspectiveServer && id == 0 {
return 1
}
return id + 2
}
// GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
// Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
@ -101,7 +112,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
} }
if m.perspective == m.streamInitiatedBy(id) { if m.perspective == m.streamInitiatedBy(id) {
if id <= m.nextStream { // this is a stream opened by us. Must have been closed already if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already
return nil, nil return nil, nil
} }
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
@ -110,14 +121,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) {
return nil, nil return nil, nil
} }
// sid is the next stream that will be opened for sid := m.nextStreamID(m.highestStreamOpenedByPeer); sid <= id; sid = m.nextStreamID(sid) {
sid := m.highestStreamOpenedByPeer + 2
// if there is no stream opened yet, and this is the server, stream 1 should be openend
if sid == 2 && m.perspective == protocol.PerspectiveServer {
sid = 1
}
for ; sid <= id; sid += 2 {
if _, err := m.openRemoteStream(sid); err != nil { if _, err := m.openRemoteStream(sid); err != nil {
return nil, err return nil, err
} }
@ -146,15 +150,14 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) {
} }
func (m *streamsMap) openStreamImpl() (streamI, error) { func (m *streamsMap) openStreamImpl() (streamI, error) {
id := m.nextStream
if m.numOutgoingStreams >= m.maxOutgoingStreams { if m.numOutgoingStreams >= m.maxOutgoingStreams {
return nil, qerr.TooManyOpenStreams return nil, qerr.TooManyOpenStreams
} }
m.numOutgoingStreams++ m.numOutgoingStreams++
m.nextStream += 2 s := m.newStream(m.nextStreamToOpen)
s := m.newStream(id)
m.putStream(s) m.putStream(s)
m.nextStreamToOpen = m.nextStreamID(m.nextStreamToOpen)
return s, nil return s, nil
} }

View file

@ -54,9 +54,11 @@ var _ = Describe("Streams Map", func() {
Context("client-side streams", func() { Context("client-side streams", func() {
It("gets new streams", func() { It("gets new streams", func() {
s, err := m.GetOrOpenStream(1) s, err := m.GetOrOpenStream(3)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(1))) Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(Equal(protocol.StreamID(3)))
Expect(m.streams).To(HaveLen(1))
Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
Expect(m.numOutgoingStreams).To(BeZero()) Expect(m.numOutgoingStreams).To(BeZero())
}) })
@ -93,11 +95,11 @@ var _ = Describe("Streams Map", func() {
}) })
It("opens skipped streams", func() { It("opens skipped streams", func() {
_, err := m.GetOrOpenStream(5) _, err := m.GetOrOpenStream(7)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(m.streams).To(HaveKey(protocol.StreamID(1)))
Expect(m.streams).To(HaveKey(protocol.StreamID(3))) Expect(m.streams).To(HaveKey(protocol.StreamID(3)))
Expect(m.streams).To(HaveKey(protocol.StreamID(5))) Expect(m.streams).To(HaveKey(protocol.StreamID(5)))
Expect(m.streams).To(HaveKey(protocol.StreamID(7)))
}) })
It("doesn't reopen an already closed stream", func() { It("doesn't reopen an already closed stream", func() {
@ -121,7 +123,7 @@ var _ = Describe("Streams Map", func() {
}) })
It("errors when too many streams are opened implicitely", func() { It("errors when too many streams are opened implicitely", func() {
_, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 1)) _, err := m.GetOrOpenStream(protocol.StreamID(m.maxIncomingStreams*2 + 3))
Expect(err).To(MatchError(qerr.TooManyOpenStreams)) Expect(err).To(MatchError(qerr.TooManyOpenStreams))
}) })
@ -423,7 +425,7 @@ var _ = Describe("Streams Map", func() {
Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5")) Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5"))
}) })
It("rejects streams with odds IDs, which are lower thatn the highest server-side stream", func() { It("rejects streams with odds IDs, which are lower than the highest server-side stream", func() {
_, err := m.GetOrOpenStream(6) _, err := m.GetOrOpenStream(6)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
_, err = m.GetOrOpenStream(5) _, err = m.GetOrOpenStream(5)
@ -434,6 +436,7 @@ var _ = Describe("Streams Map", func() {
s, err := m.GetOrOpenStream(2) s, err := m.GetOrOpenStream(2)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) Expect(s.StreamID()).To(Equal(protocol.StreamID(2)))
Expect(m.streams).To(HaveLen(1))
Expect(m.numOutgoingStreams).To(BeZero()) Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
}) })