diff --git a/internal/wire/max_streams_frame.go b/internal/wire/max_streams_frame.go index 63e506c4..8157e77c 100644 --- a/internal/wire/max_streams_frame.go +++ b/internal/wire/max_streams_frame.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -31,6 +32,9 @@ func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStream return nil, err } f.MaxStreamNum = protocol.StreamNum(streamID) + if f.MaxStreamNum > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) + } return f, nil } diff --git a/internal/wire/max_streams_frame_test.go b/internal/wire/max_streams_frame_test.go index 35eb6e90..3f75dfbf 100644 --- a/internal/wire/max_streams_frame_test.go +++ b/internal/wire/max_streams_frame_test.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -43,6 +44,33 @@ var _ = Describe("MAX_STREAMS frame", func() { Expect(err).To(HaveOccurred()) } }) + + for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { + streamType := t + + It("accepts a frame containing the maximum stream count", func() { + f := &MaxStreamsFrame{ + Type: streamType, + MaxStreamNum: protocol.MaxStreamCount, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + frame, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when receiving a too large stream count", func() { + f := &MaxStreamsFrame{ + Type: streamType, + MaxStreamNum: protocol.MaxStreamCount + 1, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + _, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) + }) + } }) Context("writing", func() { diff --git a/streams_map.go b/streams_map.go index 4dc8fc31..7b878a25 100644 --- a/streams_map.go +++ b/streams_map.go @@ -214,9 +214,6 @@ func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, err } func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error { - if f.MaxStreamNum > protocol.MaxStreamCount { - return qerr.StreamLimitError - } switch f.Type { case protocol.StreamTypeUni: m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) diff --git a/streams_map_test.go b/streams_map_test.go index 7ab5bf7d..0ae351b9 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -381,13 +381,6 @@ var _ = Describe("Streams Map", func() { _, err = m.OpenUniStream() expectTooManyStreamsError(err) }) - - It("rejects MAX_STREAMS frames with too large values", func() { - Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: protocol.MaxStreamCount + 1, - })).To(MatchError(qerr.StreamLimitError)) - }) }) Context("sending MAX_STREAMS frames", func() {