mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
move overflow check in MAX_STREAMS frame to frame parser
This commit is contained in:
parent
9f5df74bd1
commit
1e7c7842b9
4 changed files with 32 additions and 10 deletions
|
@ -2,6 +2,7 @@ package wire
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
@ -31,6 +32,9 @@ func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStream
|
|||
return nil, err
|
||||
}
|
||||
f.MaxStreamNum = protocol.StreamNum(streamID)
|
||||
if f.MaxStreamNum > protocol.MaxStreamCount {
|
||||
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package wire
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
@ -43,6 +44,33 @@ var _ = Describe("MAX_STREAMS frame", func() {
|
|||
Expect(err).To(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} {
|
||||
streamType := t
|
||||
|
||||
It("accepts a frame containing the maximum stream count", func() {
|
||||
f := &MaxStreamsFrame{
|
||||
Type: streamType,
|
||||
MaxStreamNum: protocol.MaxStreamCount,
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
|
||||
frame, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame).To(Equal(f))
|
||||
})
|
||||
|
||||
It("errors when receiving a too large stream count", func() {
|
||||
f := &MaxStreamsFrame{
|
||||
Type: streamType,
|
||||
MaxStreamNum: protocol.MaxStreamCount + 1,
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
|
||||
_, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever)
|
||||
Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1)))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
Context("writing", func() {
|
||||
|
|
|
@ -214,9 +214,6 @@ func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, err
|
|||
}
|
||||
|
||||
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error {
|
||||
if f.MaxStreamNum > protocol.MaxStreamCount {
|
||||
return qerr.StreamLimitError
|
||||
}
|
||||
switch f.Type {
|
||||
case protocol.StreamTypeUni:
|
||||
m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
|
||||
|
|
|
@ -381,13 +381,6 @@ var _ = Describe("Streams Map", func() {
|
|||
_, err = m.OpenUniStream()
|
||||
expectTooManyStreamsError(err)
|
||||
})
|
||||
|
||||
It("rejects MAX_STREAMS frames with too large values", func() {
|
||||
Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
|
||||
Type: protocol.StreamTypeBidi,
|
||||
MaxStreamNum: protocol.MaxStreamCount + 1,
|
||||
})).To(MatchError(qerr.StreamLimitError))
|
||||
})
|
||||
})
|
||||
|
||||
Context("sending MAX_STREAMS frames", func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue