never send a value larger than 2^60 in MAX_STREAMS frames

This commit is contained in:
Marten Seemann 2020-08-11 10:19:31 +07:00
parent d7c2169c55
commit 3c62c48fca
4 changed files with 71 additions and 23 deletions

View file

@ -159,13 +159,16 @@ func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error {
delete(m.streams, num) delete(m.streams, num)
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
if m.maxNumStreams > uint64(len(m.streams)) { if m.maxNumStreams > uint64(len(m.streams)) {
numNewStreams := m.maxNumStreams - uint64(len(m.streams)) maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1
m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1 // Never send a value larger than protocol.MaxStreamCount.
if maxStream <= protocol.MaxStreamCount {
m.maxStream = maxStream
m.queueMaxStreamID(&wire.MaxStreamsFrame{ m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi, Type: protocol.StreamTypeBidi,
MaxStreamNum: m.maxStream, MaxStreamNum: m.maxStream,
}) })
} }
}
return nil return nil
} }

View file

@ -157,13 +157,16 @@ func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error {
delete(m.streams, num) delete(m.streams, num)
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
if m.maxNumStreams > uint64(len(m.streams)) { if m.maxNumStreams > uint64(len(m.streams)) {
numNewStreams := m.maxNumStreams - uint64(len(m.streams)) maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1
m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1 // Never send a value larger than protocol.MaxStreamCount.
if maxStream <= protocol.MaxStreamCount {
m.maxStream = maxStream
m.queueMaxStreamID(&wire.MaxStreamsFrame{ m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: streamTypeGeneric, Type: streamTypeGeneric,
MaxStreamNum: m.maxStream, MaxStreamNum: m.maxStream,
}) })
} }
}
return nil return nil
} }

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"bytes"
"context" "context"
"errors" "errors"
@ -25,17 +26,25 @@ func (s *mockGenericStream) closeForShutdown(err error) {
} }
var _ = Describe("Streams Map (incoming)", func() { var _ = Describe("Streams Map (incoming)", func() {
const (
maxNumStreams uint64 = 5
)
var ( var (
m *incomingItemsMap m *incomingItemsMap
newItemCounter int newItemCounter int
mockSender *MockStreamSender mockSender *MockStreamSender
maxNumStreams uint64
) )
BeforeEach(func() { // check that the frame can be serialized and deserialized
checkFrameSerialization := func(f wire.Frame) {
b := &bytes.Buffer{}
ExpectWithOffset(1, f.Write(b, protocol.VersionTLS)).To(Succeed())
frame, err := wire.NewFrameParser(protocol.VersionTLS).ParseNext(bytes.NewReader(b.Bytes()), protocol.Encryption1RTT)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
Expect(f).To(Equal(frame))
}
BeforeEach(func() { maxNumStreams = 5 })
JustBeforeEach(func() {
newItemCounter = 0 newItemCounter = 0
mockSender = NewMockStreamSender(mockCtrl) mockSender = NewMockStreamSender(mockCtrl)
m = newIncomingItemsMap( m = newIncomingItemsMap(
@ -211,11 +220,41 @@ var _ = Describe("Streams Map (incoming)", func() {
} }
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1))) Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 1)))
checkFrameSerialization(f)
}) })
Expect(m.DeleteStream(3)).To(Succeed()) Expect(m.DeleteStream(3)).To(Succeed())
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2))) Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.StreamNum(maxNumStreams + 2)))
checkFrameSerialization(f)
}) })
Expect(m.DeleteStream(4)).To(Succeed()) Expect(m.DeleteStream(4)).To(Succeed())
}) })
Context("using high stream limits", func() {
BeforeEach(func() { maxNumStreams = uint64(protocol.MaxStreamCount) - 2 })
It("doesn't send MAX_STREAMS frames if they would overflow 2^60 (the maximum stream count)", func() {
// open a bunch of streams
_, err := m.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
// accept all streams
for i := 0; i < 5; i++ {
_, err := m.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred())
}
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount - 1))
checkFrameSerialization(f)
})
Expect(m.DeleteStream(4)).To(Succeed())
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.MaxStreamsFrame).MaxStreamNum).To(Equal(protocol.MaxStreamCount))
checkFrameSerialization(f)
})
Expect(m.DeleteStream(3)).To(Succeed())
// at this point, we can't increase the stream limit any further, so no more MAX_STREAMS frames will be sent
Expect(m.DeleteStream(2)).To(Succeed())
Expect(m.DeleteStream(1)).To(Succeed())
})
})
}) })

View file

@ -159,13 +159,16 @@ func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error {
delete(m.streams, num) delete(m.streams, num)
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
if m.maxNumStreams > uint64(len(m.streams)) { if m.maxNumStreams > uint64(len(m.streams)) {
numNewStreams := m.maxNumStreams - uint64(len(m.streams)) maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1
m.maxStream = m.nextStreamToOpen + protocol.StreamNum(numNewStreams) - 1 // Never send a value larger than protocol.MaxStreamCount.
if maxStream <= protocol.MaxStreamCount {
m.maxStream = maxStream
m.queueMaxStreamID(&wire.MaxStreamsFrame{ m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni, Type: protocol.StreamTypeUni,
MaxStreamNum: m.maxStream, MaxStreamNum: m.maxStream,
}) })
} }
}
return nil return nil
} }