From 128eb13d0c1fc37ed15111e6ee59a2c9417912f3 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 5 Dec 2017 16:32:22 +0700 Subject: [PATCH] fix counting of incoming and outgoing streams for the client --- streams_map.go | 16 +--- streams_map_test.go | 193 +++++++++++++++++++++++++++++++------------- 2 files changed, 141 insertions(+), 68 deletions(-) diff --git a/streams_map.go b/streams_map.go index 141d1cdb..8b71903a 100644 --- a/streams_map.go +++ b/streams_map.go @@ -135,12 +135,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (streamI, error) { return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) } - if m.perspective == protocol.PerspectiveServer { - m.numIncomingStreams++ - } else { - m.numOutgoingStreams++ - } - + m.numIncomingStreams++ if id > m.highestStreamOpenedByPeer { m.highestStreamOpenedByPeer = id } @@ -156,12 +151,7 @@ func (m *streamsMap) openStreamImpl() (streamI, error) { return nil, qerr.TooManyOpenStreams } - if m.perspective == protocol.PerspectiveServer { - m.numOutgoingStreams++ - } else { - m.numIncomingStreams++ - } - + m.numOutgoingStreams++ m.nextStream += 2 s := m.newStream(id) m.putStream(s) @@ -235,7 +225,7 @@ func (m *streamsMap) DeleteClosedStreams() error { } numDeletedStreams++ m.openStreams[i] = 0 - if streamID%2 == 0 { + if m.streamInitiatedBy(streamID) == m.perspective { m.numOutgoingStreams-- } else { m.numIncomingStreams-- diff --git a/streams_map_test.go b/streams_map_test.go index 11953811..2cde482e 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -418,7 +418,7 @@ var _ = Describe("Streams Map", func() { m.UpdateMaxStreamLimit(100) }) - Context("client-side streams", func() { + Context("server-side streams", func() { It("rejects streams with odd IDs", func() { _, err := m.GetOrOpenStream(5) Expect(err).To(MatchError("InvalidStreamID: peer attempted to open stream 5")) @@ -435,8 +435,8 @@ var _ = Describe("Streams Map", func() { s, err := m.GetOrOpenStream(2) Expect(err).NotTo(HaveOccurred()) Expect(s.StreamID()).To(Equal(protocol.StreamID(2))) - Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) - Expect(m.numIncomingStreams).To(BeZero()) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) }) It("opens skipped streams", func() { @@ -445,6 +445,8 @@ var _ = Describe("Streams Map", func() { Expect(m.streams).To(HaveKey(protocol.StreamID(2))) Expect(m.streams).To(HaveKey(protocol.StreamID(4))) Expect(m.streams).To(HaveKey(protocol.StreamID(6))) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) }) It("doesn't reopen an already closed stream", func() { @@ -459,7 +461,7 @@ var _ = Describe("Streams Map", func() { }) }) - Context("server-side streams", func() { + Context("client-side streams", func() { It("starts with stream 1, if the crypto stream is stream 0", func() { setNewStreamsMap(protocol.PerspectiveClient, versionCryptoStream0) m.UpdateMaxStreamLimit(100) @@ -467,8 +469,8 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) Expect(s.StreamID()).To(BeEquivalentTo(1)) - Expect(m.numOutgoingStreams).To(BeZero()) - Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + Expect(m.numIncomingStreams).To(BeZero()) }) It("starts with stream 3, if the crypto stream is stream 1", func() { @@ -476,8 +478,8 @@ var _ = Describe("Streams Map", func() { Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) Expect(s.StreamID()).To(BeEquivalentTo(3)) - Expect(m.numOutgoingStreams).To(BeZero()) - Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + Expect(m.numIncomingStreams).To(BeZero()) }) It("opens multiple streams", func() { @@ -524,63 +526,144 @@ var _ = Describe("Streams Map", func() { closeStream := func(id protocol.StreamID) { str := m.streams[id] - Expect(str).ToNot(BeNil()) + ExpectWithOffset(1, str).ToNot(BeNil()) finishedStreams[id].Return(true) } Context("deleting streams", func() { - BeforeEach(func() { - for i := 1; i <= 5; i++ { - err := m.putStream(newStream(protocol.StreamID(i))) + Context("as a server", func() { + BeforeEach(func() { + m.UpdateMaxStreamLimit(100) + for i := 1; i <= 5; i++ { + if i%2 == 1 { + _, err := m.openRemoteStream(protocol.StreamID(i)) + Expect(err).ToNot(HaveOccurred()) + } else { + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + } + } + Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5})) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) // 2 and 4 + Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) // 1, 3 and 5 + }) + + It("does not delete streams with Close()", func() { + str, err := m.GetOrOpenStream(55) Expect(err).ToNot(HaveOccurred()) - } - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4, 5})) + str.(*mocks.MockStreamI).EXPECT().Close() + str.Close() + err = m.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + str, err = m.GetOrOpenStream(55) + Expect(err).ToNot(HaveOccurred()) + Expect(str).ToNot(BeNil()) + }) + + It("removes the first stream", func() { + closeStream(1) + err := m.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + Expect(m.openStreams).To(HaveLen(4)) + Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5})) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) + Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) + }) + + It("removes a stream in the middle", func() { + closeStream(3) + err := m.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + Expect(m.streams).To(HaveLen(4)) + Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5})) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) + Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) + }) + + It("removes a client-initiated stream", func() { + closeStream(2) + err := m.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + Expect(m.streams).To(HaveLen(4)) + Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 3, 4, 5})) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(1)) + Expect(m.numIncomingStreams).To(BeEquivalentTo(3)) + }) + + It("removes a stream at the end", func() { + closeStream(5) + err := m.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + Expect(m.openStreams).To(HaveLen(4)) + Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4})) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) + Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) + }) + + It("removes all streams", func() { + for i := 1; i <= 5; i++ { + closeStream(protocol.StreamID(i)) + } + err := m.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + Expect(m.streams).To(BeEmpty()) + Expect(m.openStreams).To(BeEmpty()) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(BeZero()) + }) }) - It("does not delete streams with Close()", func() { - str, err := m.GetOrOpenStream(55) - Expect(err).ToNot(HaveOccurred()) - str.(*mocks.MockStreamI).EXPECT().Close() - str.Close() - err = m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - str, err = m.GetOrOpenStream(55) - Expect(err).ToNot(HaveOccurred()) - Expect(str).ToNot(BeNil()) - }) + Context("as a client", func() { + BeforeEach(func() { + setNewStreamsMap(protocol.PerspectiveClient, versionCryptoStream1) + m.UpdateMaxStreamLimit(100) + for i := 1; i <= 5; i++ { + if i%2 == 0 { + _, err := m.openRemoteStream(protocol.StreamID(i)) + Expect(err).ToNot(HaveOccurred()) + } else { + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + } + } + Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 2, 5, 4, 7})) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) // 3, 5 and 7 + Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) // 2 and 4 + }) - It("removes the first stream", func() { - closeStream(1) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.openStreams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 3, 4, 5})) - }) + It("removes a stream that we initiated", func() { + closeStream(3) + err := m.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + Expect(m.streams).To(HaveLen(4)) + Expect(m.openStreams).To(Equal([]protocol.StreamID{2, 5, 4, 7})) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(2)) + Expect(m.numIncomingStreams).To(BeEquivalentTo(2)) + }) - It("removes a stream in the middle", func() { - closeStream(3) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 4, 5})) - }) + It("removes a stream that the server initiated", func() { + closeStream(2) + err := m.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + Expect(m.openStreams).To(HaveLen(4)) + Expect(m.openStreams).To(Equal([]protocol.StreamID{3, 5, 4, 7})) + Expect(m.numOutgoingStreams).To(BeEquivalentTo(3)) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + }) - It("removes a stream at the end", func() { - closeStream(5) - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.openStreams).To(HaveLen(4)) - Expect(m.openStreams).To(Equal([]protocol.StreamID{1, 2, 3, 4})) - }) - - It("removes all streams", func() { - for i := 1; i <= 5; i++ { - closeStream(protocol.StreamID(i)) - } - err := m.DeleteClosedStreams() - Expect(err).ToNot(HaveOccurred()) - Expect(m.streams).To(BeEmpty()) - Expect(m.openStreams).To(BeEmpty()) + It("removes all streams", func() { + closeStream(3) + closeStream(2) + closeStream(5) + closeStream(4) + closeStream(7) + err := m.DeleteClosedStreams() + Expect(err).ToNot(HaveOccurred()) + Expect(m.streams).To(BeEmpty()) + Expect(m.openStreams).To(BeEmpty()) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(BeZero()) + }) }) })