move overflow check in MAX_STREAMS frame to frame parser

This commit is contained in:
Marten Seemann 2019-11-08 11:44:56 +07:00
parent 9f5df74bd1
commit 1e7c7842b9
4 changed files with 32 additions and 10 deletions

View file

@ -2,6 +2,7 @@ package wire
import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -31,6 +32,9 @@ func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStream
return nil, err
}
f.MaxStreamNum = protocol.StreamNum(streamID)
if f.MaxStreamNum > protocol.MaxStreamCount {
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
}
return f, nil
}

View file

@ -2,6 +2,7 @@ package wire
import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -43,6 +44,33 @@ var _ = Describe("MAX_STREAMS frame", func() {
Expect(err).To(HaveOccurred())
}
})
for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} {
streamType := t
It("accepts a frame containing the maximum stream count", func() {
f := &MaxStreamsFrame{
Type: streamType,
MaxStreamNum: protocol.MaxStreamCount,
}
b := &bytes.Buffer{}
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
frame, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
})
It("errors when receiving a too large stream count", func() {
f := &MaxStreamsFrame{
Type: streamType,
MaxStreamNum: protocol.MaxStreamCount + 1,
}
b := &bytes.Buffer{}
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
_, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever)
Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1)))
})
}
})
Context("writing", func() {

View file

@ -214,9 +214,6 @@ func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, err
}
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error {
if f.MaxStreamNum > protocol.MaxStreamCount {
return qerr.StreamLimitError
}
switch f.Type {
case protocol.StreamTypeUni:
m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)

View file

@ -381,13 +381,6 @@ var _ = Describe("Streams Map", func() {
_, err = m.OpenUniStream()
expectTooManyStreamsError(err)
})
It("rejects MAX_STREAMS frames with too large values", func() {
Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreamNum: protocol.MaxStreamCount + 1,
})).To(MatchError(qerr.StreamLimitError))
})
})
Context("sending MAX_STREAMS frames", func() {