From 3c62c48fcaea5dacf908b6ee33ae00a8945e52f7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 11 Aug 2020 10:19:31 +0700 Subject: [PATCH] never send a value larger than 2^60 in MAX_STREAMS frames --- streams_map_incoming_bidi.go | 15 +++++---- streams_map_incoming_generic.go | 15 +++++---- streams_map_incoming_generic_test.go | 49 +++++++++++++++++++++++++--- streams_map_incoming_uni.go | 15 +++++---- 4 files changed, 71 insertions(+), 23 deletions(-) diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index 2f1bc3fa..0825b6d4 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -159,12 +159,15 @@ func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error { delete(m.streams, num) // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream if m.maxNumStreams > uint64(len(m.streams)) { - numNewStreams := m.maxNumStreams - uint64(len(m.streams)) - m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1 - m.queueMaxStreamID(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: m.maxStream, - }) + maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 + // Never send a value larger than protocol.MaxStreamCount. + if maxStream <= protocol.MaxStreamCount { + m.maxStream = maxStream + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: m.maxStream, + }) + } } return nil } diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index 2078f1e8..f9ae12b7 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -157,12 +157,15 @@ func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error { delete(m.streams, num) // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream if m.maxNumStreams > uint64(len(m.streams)) { - numNewStreams := m.maxNumStreams - uint64(len(m.streams)) - m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1 - m.queueMaxStreamID(&wire.MaxStreamsFrame{ - Type: streamTypeGeneric, - MaxStreamNum: m.maxStream, - }) + maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 + // Never send a value larger than protocol.MaxStreamCount. + if maxStream <= protocol.MaxStreamCount { + m.maxStream = maxStream + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: streamTypeGeneric, + MaxStreamNum: m.maxStream, + }) + } } return nil } diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 2434db75..0bf4a3d1 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -1,6 +1,7 @@ package quic import ( + "bytes" "context" "errors" @@ -25,17 +26,25 @@ func (s *mockGenericStream) closeForShutdown(err error) { } var _ = Describe("Streams Map (incoming)", func() { - const ( - maxNumStreams uint64 = 5 - ) - var ( m *incomingItemsMap newItemCounter int mockSender *MockStreamSender + maxNumStreams uint64 ) - BeforeEach(func() { + // check that the frame can be serialized and deserialized + checkFrameSerialization := func(f wire.Frame) { + b := &bytes.Buffer{} + ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed()) + frame, err := wire.NewFrameParser(protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT) + ExpectWithOffset(1, err).ToNot(HaveOccurred()) + Expect(f).To(Equal(frame)) + } + + BeforeEach(func() { maxNumStreams = 5 }) + + JustBeforeEach(func() { newItemCounter = 0 mockSender = NewMockStreamSender(mockCtrl) m = newIncomingItemsMap( @@ -211,11 +220,41 @@ var _ = Describe("Streams Map (incoming)", func() { } mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) + checkFrameSerialization(f) }) Expect(m.DeleteStream(3)).To(Succeed()) mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) + checkFrameSerialization(f) }) Expect(m.DeleteStream(4)).To(Succeed()) }) + + Context("using high stream limits", func() { + BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 }) + + It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() { + // open a bunch of streams + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + // accept all streams + for i := 0; i < 5; i++ { + _, err := m.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + } + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1)) + checkFrameSerialization(f) + }) + Expect(m.DeleteStream(4)).To(Succeed()) + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount)) + checkFrameSerialization(f) + }) + Expect(m.DeleteStream(3)).To(Succeed()) + // at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent + Expect(m.DeleteStream(2)).To(Succeed()) + Expect(m.DeleteStream(1)).To(Succeed()) + }) + }) }) diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index 69b9f8c4..70467a76 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -159,12 +159,15 @@ func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { delete(m.streams, num) // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream if m.maxNumStreams > uint64(len(m.streams)) { - numNewStreams := m.maxNumStreams - uint64(len(m.streams)) - m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1 - m.queueMaxStreamID(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeUni, - MaxStreamNum: m.maxStream, - }) + maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 + // Never send a value larger than protocol.MaxStreamCount. + if maxStream <= protocol.MaxStreamCount { + m.maxStream = maxStream + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreamNum: m.maxStream, + }) + } } return nil }