implement and use the MAX_STREAMS frame

This commit is contained in:
Marten Seemann 2018-11-09 10:58:13 +07:00
parent 0f931ca54e
commit 9518c90c0a
18 changed files with 248 additions and 167 deletions

View file

@ -29,6 +29,12 @@ func (s StreamID) Type() StreamType {
return StreamTypeBidi
}
// StreamNum returns how many streams in total are below this
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
func (s StreamID) StreamNum() uint64 {
return uint64(s/4) + 1
}
// MaxStreamID is the highest stream ID that a peer is allowed to open,
// when it is allowed to open numStreams.
func MaxStreamID(stype StreamType, numStreams uint64, pers Perspective) StreamID {

View file

@ -27,6 +27,17 @@ var _ = Describe("Stream ID", func() {
Expect(FirstStream(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3)))
})
It("tells the stream number", func() {
Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1))
Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1))
Expect(StreamID(2).StreamNum()).To(BeEquivalentTo(1))
Expect(StreamID(3).StreamNum()).To(BeEquivalentTo(1))
Expect(StreamID(8).StreamNum()).To(BeEquivalentTo(3))
Expect(StreamID(9).StreamNum()).To(BeEquivalentTo(3))
Expect(StreamID(10).StreamNum()).To(BeEquivalentTo(3))
Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3))
})
Context("maximum stream IDs", func() {
It("doesn't allow any", func() {
Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveClient)).To(Equal(StreamID(0)))

View file

@ -55,11 +55,6 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x6:
frame, err = parseMaxStreamIDFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0x7:
frame, err = parsePingFrame(r, v)
case 0x8:
@ -97,6 +92,11 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame
if err != nil {
err = qerr.Error(qerr.InvalidAckData, err.Error())
}
case 0x1c, 0x1d:
frame, err = parseMaxStreamsFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0x18:
frame, err = parseCryptoFrame(r, v)
if err != nil {

View file

@ -97,8 +97,11 @@ var _ = Describe("Frame parsing", func() {
Expect(frame).To(Equal(f))
})
It("unpacks MAX_STREAM_ID frames", func() {
f := &MaxStreamIDFrame{StreamID: 0x1337}
It("unpacks MAX_STREAMS frames", func() {
f := &MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: 0x1337,
}
buf := &bytes.Buffer{}
err := f.Write(buf, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())

View file

@ -1,37 +0,0 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A MaxStreamIDFrame is a MAX_STREAM_ID frame
type MaxStreamIDFrame struct {
StreamID protocol.StreamID
}
// parseMaxStreamIDFrame parses a MAX_STREAM_ID frame
func parseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) {
// read the Type byte
if _, err := r.ReadByte(); err != nil {
return nil, err
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil
}
func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x6)
utils.WriteVarInt(b, uint64(f.StreamID))
return nil
}
// Length of a written frame
func (f *MaxStreamIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.StreamID))
}

View file

@ -1,51 +0,0 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("MAX_STREAM_ID frame", func() {
Context("parsing", func() {
It("accepts sample frame", func() {
data := []byte{0x6}
data = append(data, encodeVarInt(0xdecafbad)...)
b := bytes.NewReader(data)
f, err := parseMaxStreamIDFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.StreamID).To(Equal(protocol.StreamID(0xdecafbad)))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x06}
data = append(data, encodeVarInt(0xdeadbeefcafe13)...)
_, err := parseMaxStreamIDFrame(bytes.NewReader(data), protocol.VersionWhatever)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := parseMaxStreamIDFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever)
Expect(err).To(HaveOccurred())
}
})
})
Context("writing", func() {
It("writes a sample frame", func() {
b := &bytes.Buffer{}
frame := MaxStreamIDFrame{StreamID: 0x12345678}
frame.Write(b, protocol.VersionWhatever)
expected := []byte{0x6}
expected = append(expected, encodeVarInt(0x12345678)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := MaxStreamIDFrame{StreamID: 0x1337}
Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337)))
})
})
})

View file

@ -0,0 +1,51 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A MaxStreamsFrame is a MAX_STREAMS frame
type MaxStreamsFrame struct {
Type protocol.StreamType
MaxStreams uint64
}
func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamsFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
f := &MaxStreamsFrame{}
switch typeByte {
case 0x1c:
f.Type = protocol.StreamTypeBidi
case 0x1d:
f.Type = protocol.StreamTypeUni
}
streamID, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
f.MaxStreams = streamID
return f, nil
}
func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
switch f.Type {
case protocol.StreamTypeBidi:
b.WriteByte(0x1c)
case protocol.StreamTypeUni:
b.WriteByte(0x1d)
}
utils.WriteVarInt(b, f.MaxStreams)
return nil
}
// Length of a written frame
func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(f.MaxStreams)
}

View file

@ -0,0 +1,78 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("MAX_STREAMS frame", func() {
Context("parsing", func() {
It("accepts a frame for a bidirectional stream", func() {
data := []byte{0x1c}
data = append(data, encodeVarInt(0xdecaf)...)
b := bytes.NewReader(data)
f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.Type).To(Equal(protocol.StreamTypeBidi))
Expect(f.MaxStreams).To(BeEquivalentTo(0xdecaf))
Expect(b.Len()).To(BeZero())
})
It("accepts a frame for a bidirectional stream", func() {
data := []byte{0x1d}
data = append(data, encodeVarInt(0xdecaf)...)
b := bytes.NewReader(data)
f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Expect(f.Type).To(Equal(protocol.StreamTypeUni))
Expect(f.MaxStreams).To(BeEquivalentTo(0xdecaf))
Expect(b.Len()).To(BeZero())
})
It("errors on EOFs", func() {
data := []byte{0x1d}
data = append(data, encodeVarInt(0xdeadbeefcafe13)...)
_, err := parseMaxStreamsFrame(bytes.NewReader(data), protocol.VersionWhatever)
Expect(err).NotTo(HaveOccurred())
for i := range data {
_, err := parseMaxStreamsFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever)
Expect(err).To(HaveOccurred())
}
})
})
Context("writing", func() {
It("for a bidirectional stream", func() {
f := &MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: 0xdeadbeef,
}
b := &bytes.Buffer{}
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
expected := []byte{0x1c}
expected = append(expected, encodeVarInt(0xdeadbeef)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("for a unidirectional stream", func() {
f := &MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: 0xdecafbad,
}
b := &bytes.Buffer{}
Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed())
expected := []byte{0x1d}
expected = append(expected, encodeVarInt(0xdecafbad)...)
Expect(b.Bytes()).To(Equal(expected))
})
It("has the correct min length", func() {
frame := MaxStreamsFrame{MaxStreams: 0x1337}
Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337)))
})
})
})

View file

@ -110,16 +110,16 @@ func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0)
}
// HandleMaxStreamIDFrame mocks base method
func (m *MockStreamManager) HandleMaxStreamIDFrame(arg0 *wire.MaxStreamIDFrame) error {
ret := m.ctrl.Call(m, "HandleMaxStreamIDFrame", arg0)
// HandleMaxStreamsFrame mocks base method
func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) error {
ret := m.ctrl.Call(m, "HandleMaxStreamsFrame", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// HandleMaxStreamIDFrame indicates an expected call of HandleMaxStreamIDFrame
func (mr *MockStreamManagerMockRecorder) HandleMaxStreamIDFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamIDFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamIDFrame), arg0)
// HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame
func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0)
}
// OpenStream mocks base method

View file

@ -40,7 +40,7 @@ type streamManager interface {
AcceptUniStream() (ReceiveStream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*handshake.TransportParameters)
HandleMaxStreamIDFrame(*wire.MaxStreamIDFrame) error
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
CloseWithError(error)
}
@ -563,8 +563,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
s.handleMaxDataFrame(frame)
case *wire.MaxStreamDataFrame:
err = s.handleMaxStreamDataFrame(frame)
case *wire.MaxStreamIDFrame:
err = s.handleMaxStreamIDFrame(frame)
case *wire.MaxStreamsFrame:
err = s.handleMaxStreamsFrame(frame)
case *wire.DataBlockedFrame:
case *wire.StreamDataBlockedFrame:
case *wire.StreamIDBlockedFrame:
@ -641,8 +641,8 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error
return nil
}
func (s *session) handleMaxStreamIDFrame(frame *wire.MaxStreamIDFrame) error {
return s.streamsMap.HandleMaxStreamIDFrame(frame)
func (s *session) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) error {
return s.streamsMap.HandleMaxStreamsFrame(frame)
}
func (s *session) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {

View file

@ -249,17 +249,20 @@ var _ = Describe("Session", func() {
Context("handling MAX_STREAM_ID frames", func() {
It("passes the frame to the streamsMap", func() {
f := &wire.MaxStreamIDFrame{StreamID: 10}
streamManager.EXPECT().HandleMaxStreamIDFrame(f)
err := sess.handleMaxStreamIDFrame(f)
f := &wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: 10,
}
streamManager.EXPECT().HandleMaxStreamsFrame(f)
err := sess.handleMaxStreamsFrame(f)
Expect(err).ToNot(HaveOccurred())
})
It("returns errors", func() {
f := &wire.MaxStreamIDFrame{StreamID: 10}
f := &wire.MaxStreamsFrame{MaxStreams: 10}
testErr := errors.New("test error")
streamManager.EXPECT().HandleMaxStreamIDFrame(f).Return(testErr)
err := sess.handleMaxStreamIDFrame(f)
streamManager.EXPECT().HandleMaxStreamsFrame(f).Return(testErr)
err := sess.handleMaxStreamsFrame(f)
Expect(err).To(MatchError(testErr))
})
})

View file

@ -146,15 +146,13 @@ func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, err
panic("")
}
func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
id := f.StreamID
if id.InitiatedBy() != m.perspective {
return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id)
}
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error {
id := protocol.MaxStreamID(f.Type, f.MaxStreams, m.perspective)
switch id.Type() {
case protocol.StreamTypeUni:
m.outgoingUniStreams.SetMaxStream(id)
case protocol.StreamTypeBidi:
fmt.Printf("")
m.outgoingBidiStreams.SetMaxStream(id)
}
return nil

View file

@ -1,6 +1,10 @@
package quic
import "github.com/cheekybits/genny/generic"
import (
"github.com/cheekybits/genny/generic"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// In the auto-generated streams maps, we need to be able to close the streams.
// Therefore, extend the generic.Type with the stream close method.
@ -9,3 +13,5 @@ type item interface {
generic.Type
closeForShutdown(error)
}
const streamTypeGeneric protocol.StreamType = protocol.StreamTypeUni

View file

@ -24,7 +24,7 @@ type incomingBidiStreamsMap struct {
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamID) streamI
queueMaxStreamID func(*wire.MaxStreamIDFrame)
queueMaxStreamID func(*wire.MaxStreamsFrame)
closeErr error
}
@ -43,7 +43,7 @@ func newIncomingBidiStreamsMap(
maxStream: initialMaxStreamID,
maxNumStreams: maxNumStreams,
newStream: newStream,
queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) },
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
}
m.cond.L = &m.mutex
return m
@ -111,7 +111,10 @@ func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error {
if m.maxNumStreams > uint64(len(m.streams)) {
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream})
m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: m.maxStream.StreamNum(),
})
}
return nil
}

View file

@ -8,8 +8,8 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream"
//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream"
//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi"
//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni"
type incomingItemsMap struct {
mutex sync.RWMutex
cond sync.Cond
@ -22,7 +22,7 @@ type incomingItemsMap struct {
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamID) item
queueMaxStreamID func(*wire.MaxStreamIDFrame)
queueMaxStreamID func(*wire.MaxStreamsFrame)
closeErr error
}
@ -41,7 +41,7 @@ func newIncomingItemsMap(
maxStream: initialMaxStreamID,
maxNumStreams: maxNumStreams,
newStream: newStream,
queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) },
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
}
m.cond.L = &m.mutex
return m
@ -109,7 +109,10 @@ func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error {
if m.maxNumStreams > uint64(len(m.streams)) {
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream})
m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: streamTypeGeneric,
MaxStreams: m.maxStream.StreamNum(),
})
}
return nil
}

View file

@ -26,8 +26,8 @@ func (s *mockGenericStream) closeForShutdown(err error) {
var _ = Describe("Streams Map (incoming)", func() {
const (
firstNewStream protocol.StreamID = 20
maxNumStreams uint64 = 10
firstNewStream protocol.StreamID = 2
maxNumStreams uint64 = 5
initialMaxStream protocol.StreamID = firstNewStream + 4*protocol.StreamID(maxNumStreams-1)
)
@ -49,9 +49,9 @@ var _ = Describe("Streams Map (incoming)", func() {
})
It("opens all streams up to the id on GetOrOpenStream", func() {
_, err := m.GetOrOpenStream(firstNewStream + 4*5)
_, err := m.GetOrOpenStream(firstNewStream + 4*4)
Expect(err).ToNot(HaveOccurred())
Expect(newItemCounter).To(Equal(6))
Expect(newItemCounter).To(Equal(5))
})
It("starts opening streams at the right position", func() {
@ -59,9 +59,9 @@ var _ = Describe("Streams Map (incoming)", func() {
_, err := m.GetOrOpenStream(firstNewStream + 4)
Expect(err).ToNot(HaveOccurred())
Expect(newItemCounter).To(Equal(2))
_, err = m.GetOrOpenStream(firstNewStream + 4*5)
_, err = m.GetOrOpenStream(firstNewStream + 4*4)
Expect(err).ToNot(HaveOccurred())
Expect(newItemCounter).To(Equal(6))
Expect(newItemCounter).To(Equal(5))
})
It("accepts streams in the right order", func() {
@ -143,9 +143,9 @@ var _ = Describe("Streams Map (incoming)", func() {
})
It("closes all streams when CloseWithError is called", func() {
str1, err := m.GetOrOpenStream(20)
str1, err := m.GetOrOpenStream(firstNewStream)
Expect(err).ToNot(HaveOccurred())
str2, err := m.GetOrOpenStream(20 + 8)
str2, err := m.GetOrOpenStream(firstNewStream + 8)
Expect(err).ToNot(HaveOccurred())
testErr := errors.New("test err")
m.CloseWithError(testErr)
@ -157,11 +157,11 @@ var _ = Describe("Streams Map (incoming)", func() {
It("deletes streams", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
_, err := m.GetOrOpenStream(20)
_, err := m.GetOrOpenStream(initialMaxStream)
Expect(err).ToNot(HaveOccurred())
err = m.DeleteStream(20)
err = m.DeleteStream(initialMaxStream)
Expect(err).ToNot(HaveOccurred())
str, err := m.GetOrOpenStream(20)
str, err := m.GetOrOpenStream(initialMaxStream)
Expect(err).ToNot(HaveOccurred())
Expect(str).To(BeNil())
})
@ -171,13 +171,17 @@ var _ = Describe("Streams Map (incoming)", func() {
Expect(err).To(MatchError("Tried to delete unknown stream 1337"))
})
It("sends MAX_STREAM_ID frames when streams are deleted", func() {
It("sends MAX_STREAMS frames when streams are deleted", func() {
// open a bunch of streams
_, err := m.GetOrOpenStream(firstNewStream + 4*4)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 4})
Expect(m.DeleteStream(firstNewStream + 4)).To(Succeed())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 8})
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.MaxStreamsFrame).MaxStreams).To(Equal(maxNumStreams + 1))
})
Expect(m.DeleteStream(firstNewStream + 2*4)).To(Succeed())
mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) {
Expect(f.(*wire.MaxStreamsFrame).MaxStreams).To(Equal(maxNumStreams + 2))
})
Expect(m.DeleteStream(firstNewStream + 3*4)).To(Succeed())
})
})

View file

@ -24,7 +24,7 @@ type incomingUniStreamsMap struct {
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamID) receiveStreamI
queueMaxStreamID func(*wire.MaxStreamIDFrame)
queueMaxStreamID func(*wire.MaxStreamsFrame)
closeErr error
}
@ -43,7 +43,7 @@ func newIncomingUniStreamsMap(
maxStream: initialMaxStreamID,
maxNumStreams: maxNumStreams,
newStream: newStream,
queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) },
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
}
m.cond.L = &m.mutex
return m
@ -111,7 +111,10 @@ func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error {
if m.maxNumStreams > uint64(len(m.streams)) {
numNewStreams := m.maxNumStreams - uint64(len(m.streams))
m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4)
m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream})
m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: m.maxStream.StreamNum(),
})
}
return nil
}

View file

@ -292,7 +292,7 @@ var _ = Describe("Streams Map", func() {
})
})
Context("handling MAX_STREAM_ID frames", func() {
Context("handling MAX_STREAMS frames", func() {
BeforeEach(func() {
mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
})
@ -300,49 +300,49 @@ var _ = Describe("Streams Map", func() {
It("processes IDs for outgoing bidirectional streams", func() {
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingBidiStream})
Expect(err).ToNot(HaveOccurred())
Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: 1,
})).To(Succeed())
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
_, err = m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("processes IDs for outgoing bidirectional streams", func() {
It("processes IDs for outgoing unidirectional streams", func() {
_, err := m.OpenUniStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingUniStream})
Expect(err).ToNot(HaveOccurred())
Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: 1,
})).To(Succeed())
str, err := m.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
})
It("rejects IDs for incoming bidirectional streams", func() {
err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingBidiStream})
Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingBidiStream)))
})
It("rejects IDs for incoming unidirectional streams", func() {
err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingUniStream})
Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingUniStream)))
_, err = m.OpenUniStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
})
Context("sending MAX_STREAM_ID frames", func() {
It("sends MAX_STREAM_ID frames for bidirectional streams", func() {
Context("sending MAX_STREAMS frames", func() {
It("sends a MAX_STREAMS frame for bidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{
StreamID: ids.firstIncomingBidiStream + 4*maxBidiStreams,
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeBidi,
MaxStreams: maxBidiStreams + 1,
})
Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed())
})
It("sends MAX_STREAM_ID frames for unidirectional streams", func() {
It("sends a MAX_STREAMS frame for unidirectional streams", func() {
_, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
Expect(err).ToNot(HaveOccurred())
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{
StreamID: ids.firstIncomingUniStream + 4*maxUniStreams,
mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{
Type: protocol.StreamTypeUni,
MaxStreams: maxUniStreams + 1,
})
Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed())
})