From 3b4feedb2c017de450c112c4639ac27f5261d654 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 9 Dec 2016 23:14:15 +0700 Subject: [PATCH] add support for server-side opened streams fixes #296 --- streams_map.go | 54 ++++++++++++--- streams_map_test.go | 165 ++++++++++++++++++++++++++++++-------------- 2 files changed, 158 insertions(+), 61 deletions(-) diff --git a/streams_map.go b/streams_map.go index 78953784..b4606772 100644 --- a/streams_map.go +++ b/streams_map.go @@ -21,11 +21,14 @@ type streamsMap struct { highestStreamOpenedByClient protocol.StreamID streamsOpenedAfterLastGarbageCollect int - newStream newStreamLambda - maxOpenOutgoingStreams uint32 - maxIncomingStreams uint32 + newStream newStreamLambda - roundRobinIndex int + maxOutgoingStreams uint32 + numOutgoingStreams uint32 + maxIncomingStreams uint32 + numIncomingStreams uint32 + + roundRobinIndex uint32 } type streamLambda func(*stream) (bool, error) @@ -62,7 +65,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if ok { return s, nil } - if uint32(len(m.openStreams)) == m.connectionParameters.GetMaxIncomingStreams() { + if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() { return nil, qerr.TooManyOpenStreams } if id%2 == 0 { @@ -76,11 +79,13 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { if err != nil { return nil, err } + m.numIncomingStreams++ if id > m.highestStreamOpenedByClient { m.highestStreamOpenedByClient = id } + // maybe trigger garbage collection of streams map m.streamsOpenedAfterLastGarbageCollect++ if m.streamsOpenedAfterLastGarbageCollect%protocol.MaxNewStreamIDDelta == 0 { m.garbageCollectClosedStreams() @@ -92,7 +97,28 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { // OpenStream opens a stream from the server's side func (m *streamsMap) OpenStream(id protocol.StreamID) (*stream, error) { - panic("OpenStream: not implemented") + if id%2 == 1 { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) + } + + m.mutex.Lock() + defer m.mutex.Unlock() + _, ok := m.streams[id] + if ok { + return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is already open", id)) + } + if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() { + return nil, qerr.TooManyOpenStreams + } + + s, err := m.newStream(id) + if err != nil { + return nil, err + } + m.numOutgoingStreams++ + + m.putStream(s) + return s, nil } func (m *streamsMap) Iterate(fn streamLambda) error { @@ -118,7 +144,7 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { m.mutex.Lock() defer m.mutex.Unlock() - numStreams := len(m.openStreams) + numStreams := uint32(len(m.openStreams)) startIndex := m.roundRobinIndex for _, i := range []protocol.StreamID{1, 3} { @@ -131,7 +157,7 @@ func (m *streamsMap) RoundRobinIterate(fn streamLambda) error { } } - for i := 0; i < numStreams; i++ { + for i := uint32(0); i < numStreams; i++ { streamID := m.openStreams[(i+startIndex)%numStreams] if streamID == 1 || streamID == 3 { @@ -181,13 +207,18 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error { } m.streams[id] = nil + if id%2 == 0 { + m.numOutgoingStreams-- + } else { + m.numIncomingStreams-- + } for i, s := range m.openStreams { if s == id { // delete the streamID from the openStreams slice m.openStreams = m.openStreams[:i+copy(m.openStreams[i:], m.openStreams[i+1:])] // adjust round-robin index, if necessary - if i < m.roundRobinIndex { + if uint32(i) < m.roundRobinIndex { m.roundRobinIndex-- } break @@ -204,7 +235,10 @@ func (m *streamsMap) garbageCollectClosedStreams() { if str != nil { continue } - if id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient { + + // server-side streams can be gargage collected immediately + // client-side streams need to be kept as nils in the streams map for a bit longer, in order to prevent a client from reopening closed streams + if id%2 == 0 || id+protocol.MaxNewStreamIDDelta <= m.highestStreamOpenedByClient { delete(m.streams, id) } } diff --git a/streams_map_test.go b/streams_map_test.go index aaac8064..300e7c92 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -14,6 +14,7 @@ import ( type mockConnectionParametersManager struct { maxIncomingStreams uint32 + maxOutgoingStreams uint32 idleTime time.Duration } @@ -35,7 +36,7 @@ func (m *mockConnectionParametersManager) GetReceiveStreamFlowControlWindow() pr func (m *mockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { return math.MaxUint64 } -func (m *mockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { panic("not implemented") } +func (m *mockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { return m.maxOutgoingStreams } func (m *mockConnectionParametersManager) GetMaxIncomingStreams() uint32 { return m.maxIncomingStreams } func (m *mockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { return m.idleTime @@ -53,6 +54,7 @@ var _ = Describe("Streams Map", func() { BeforeEach(func() { cpm = &mockConnectionParametersManager{ maxIncomingStreams: 75, + maxOutgoingStreams: 60, } m = newStreamsMap(nil, cpm) }) @@ -68,57 +70,118 @@ var _ = Describe("Streams Map", func() { s, err := m.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) + Expect(m.numIncomingStreams).To(Equal(uint32(1))) + Expect(m.numOutgoingStreams).To(BeZero()) }) - It("rejects streams with even IDs", func() { - _, err := m.GetOrOpenStream(6) - Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side")) - }) - - It("gets existing streams", func() { - s, err := m.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - s, err = m.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) - }) - - It("returns nil for closed streams", func() { - s, err := m.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - err = m.RemoveStream(5) - Expect(err).NotTo(HaveOccurred()) - s, err = m.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - Expect(s).To(BeNil()) - }) - - It("panics on OpenStream", func() { - Expect(func() { m.OpenStream(0) }).To(Panic()) - }) - - Context("counting streams", func() { - var maxNumStreams int - - BeforeEach(func() { - maxNumStreams = int(cpm.GetMaxIncomingStreams()) + Context("client-side streams", func() { + It("rejects streams with even IDs", func() { + _, err := m.GetOrOpenStream(6) + Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 6 from client-side")) }) - It("errors when too many streams are opened", func() { - for i := 0; i < maxNumStreams; i++ { - _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) - Expect(err).NotTo(HaveOccurred()) - } - _, err := m.GetOrOpenStream(protocol.StreamID(2*maxNumStreams + 2)) - Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + It("gets existing streams", func() { + s, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + s, err = m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(5))) + Expect(m.numIncomingStreams).To(Equal(uint32(1))) }) - It("does not error when many streams are opened and closed", func() { - for i := 2; i < 10*maxNumStreams; i++ { - _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) - Expect(err).NotTo(HaveOccurred()) - m.RemoveStream(protocol.StreamID(i*2 + 1)) - } + It("returns nil for closed streams", func() { + s, err := m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + err = m.RemoveStream(5) + Expect(err).NotTo(HaveOccurred()) + s, err = m.GetOrOpenStream(5) + Expect(err).NotTo(HaveOccurred()) + Expect(s).To(BeNil()) + Expect(m.numIncomingStreams).To(BeZero()) + }) + + Context("counting streams", func() { + var maxNumStreams int + + BeforeEach(func() { + maxNumStreams = int(cpm.GetMaxIncomingStreams()) + }) + + It("errors when too many streams are opened", func() { + for i := 0; i < maxNumStreams; i++ { + _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.GetOrOpenStream(protocol.StreamID(2*maxNumStreams + 2)) + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("does not error when many streams are opened and closed", func() { + for i := 2; i < 10*maxNumStreams; i++ { + _, err := m.GetOrOpenStream(protocol.StreamID(i*2 + 1)) + Expect(err).NotTo(HaveOccurred()) + m.RemoveStream(protocol.StreamID(i*2 + 1)) + } + }) + }) + }) + + Context("server-side streams", func() { + It("rejects streams with odd IDs", func() { + _, err := m.OpenStream(5) + Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 5 from server-side")) + }) + + It("opens a new stream", func() { + s, err := m.OpenStream(6) + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(6))) + Expect(m.numIncomingStreams).To(BeZero()) + Expect(m.numOutgoingStreams).To(Equal(uint32(1))) + }) + + It("returns an error for already openend streams", func() { + _, err := m.OpenStream(4) + Expect(err).ToNot(HaveOccurred()) + _, err = m.OpenStream(4) + Expect(err).To(MatchError("InvalidStreamID: attempted to open stream 4, which is already open")) + }) + + Context("counting streams", func() { + var maxNumStreams int + + BeforeEach(func() { + maxNumStreams = int(cpm.GetMaxOutgoingStreams()) + }) + + It("errors when too many streams are opened", func() { + for i := 1; i <= maxNumStreams; i++ { + _, err := m.OpenStream(protocol.StreamID(2 * i)) + Expect(err).NotTo(HaveOccurred()) + } + _, err := m.OpenStream(protocol.StreamID(2*maxNumStreams + 10)) + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + + It("does not error when many streams are opened and closed", func() { + for i := 2; i < 10*maxNumStreams; i++ { + _, err := m.OpenStream(protocol.StreamID(2*i + 2)) + Expect(err).NotTo(HaveOccurred()) + m.RemoveStream(protocol.StreamID(2 * i)) + } + }) + + It("allows many server- and client-side streams at the same time", func() { + for i := 1; i < int(cpm.GetMaxOutgoingStreams()); i++ { + _, err := m.OpenStream(protocol.StreamID(2 * i)) + Expect(err).ToNot(HaveOccurred()) + } + for i := 0; i < int(cpm.GetMaxIncomingStreams()); i++ { + _, err := m.GetOrOpenStream(protocol.StreamID(2*i + 1)) + Expect(err).ToNot(HaveOccurred()) + } + }) }) }) @@ -343,7 +406,7 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) Expect(numIterations).To(Equal(5)) Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{7, 8, 4, 5, 6})) - Expect(m.roundRobinIndex).To(Equal(3)) + Expect(m.roundRobinIndex).To(Equal(uint32(3))) }) It("picks up at the index+1 where it last stopped", func() { @@ -359,7 +422,7 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) Expect(numIterations).To(Equal(2)) Expect(lambdaCalledForStream).To(Equal([]protocol.StreamID{4, 5})) - Expect(m.roundRobinIndex).To(Equal(2)) + Expect(m.roundRobinIndex).To(Equal(uint32(2))) numIterations = 0 lambdaCalledForStream = lambdaCalledForStream[:0] fn2 := func(str *stream) (bool, error) { @@ -379,19 +442,19 @@ var _ = Describe("Streams Map", func() { It("adjust the RoundRobinIndex when deleting an element in front", func() { m.roundRobinIndex = 3 // stream 7 m.RemoveStream(5) - Expect(m.roundRobinIndex).To(Equal(2)) + Expect(m.roundRobinIndex).To(Equal(uint32(2))) }) It("doesn't adjust the RoundRobinIndex when deleting an element at the back", func() { m.roundRobinIndex = 1 // stream 5 m.RemoveStream(7) - Expect(m.roundRobinIndex).To(Equal(1)) + Expect(m.roundRobinIndex).To(Equal(uint32(1))) }) It("doesn't adjust the RoundRobinIndex when deleting the element it is pointing to", func() { m.roundRobinIndex = 3 // stream 7 m.RemoveStream(7) - Expect(m.roundRobinIndex).To(Equal(3)) + Expect(m.roundRobinIndex).To(Equal(uint32(3))) }) Context("Prioritizing crypto- and header streams", func() {