diff --git a/internal/protocol/stream_id.go b/internal/protocol/stream_id.go index 61d027b2..b96e0c2b 100644 --- a/internal/protocol/stream_id.go +++ b/internal/protocol/stream_id.go @@ -29,6 +29,12 @@ func (s StreamID) Type() StreamType { return StreamTypeBidi } +// StreamNum returns how many streams in total are below this +// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9) +func (s StreamID) StreamNum() uint64 { + return uint64(s/4) + 1 +} + // MaxStreamID is the highest stream ID that a peer is allowed to open, // when it is allowed to open numStreams. func MaxStreamID(stype StreamType, numStreams uint64, pers Perspective) StreamID { diff --git a/internal/protocol/stream_id_test.go b/internal/protocol/stream_id_test.go index 669343c8..8c992484 100644 --- a/internal/protocol/stream_id_test.go +++ b/internal/protocol/stream_id_test.go @@ -27,6 +27,17 @@ var _ = Describe("Stream ID", func() { Expect(FirstStream(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(3))) }) + It("tells the stream number", func() { + Expect(StreamID(0).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(1).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(2).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(3).StreamNum()).To(BeEquivalentTo(1)) + Expect(StreamID(8).StreamNum()).To(BeEquivalentTo(3)) + Expect(StreamID(9).StreamNum()).To(BeEquivalentTo(3)) + Expect(StreamID(10).StreamNum()).To(BeEquivalentTo(3)) + Expect(StreamID(11).StreamNum()).To(BeEquivalentTo(3)) + }) + Context("maximum stream IDs", func() { It("doesn't allow any", func() { Expect(MaxStreamID(StreamTypeBidi, 0, PerspectiveClient)).To(Equal(StreamID(0))) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index 131c425e..71ba202f 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -55,11 +55,6 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame if err != nil { err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) } - case 0x6: - frame, err = parseMaxStreamIDFrame(r, v) - if err != nil { - err = qerr.Error(qerr.InvalidFrameData, err.Error()) - } case 0x7: frame, err = parsePingFrame(r, v) case 0x8: @@ -97,6 +92,11 @@ func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame if err != nil { err = qerr.Error(qerr.InvalidAckData, err.Error()) } + case 0x1c, 0x1d: + frame, err = parseMaxStreamsFrame(r, v) + if err != nil { + err = qerr.Error(qerr.InvalidFrameData, err.Error()) + } case 0x18: frame, err = parseCryptoFrame(r, v) if err != nil { diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 92e96408..b56bf323 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -97,8 +97,11 @@ var _ = Describe("Frame parsing", func() { Expect(frame).To(Equal(f)) }) - It("unpacks MAX_STREAM_ID frames", func() { - f := &MaxStreamIDFrame{StreamID: 0x1337} + It("unpacks MAX_STREAMS frames", func() { + f := &MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreams: 0x1337, + } buf := &bytes.Buffer{} err := f.Write(buf, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/wire/max_stream_id_frame.go b/internal/wire/max_stream_id_frame.go deleted file mode 100644 index 9f5424da..00000000 --- a/internal/wire/max_stream_id_frame.go +++ /dev/null @@ -1,37 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" -) - -// A MaxStreamIDFrame is a MAX_STREAM_ID frame -type MaxStreamIDFrame struct { - StreamID protocol.StreamID -} - -// parseMaxStreamIDFrame parses a MAX_STREAM_ID frame -func parseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) { - // read the Type byte - if _, err := r.ReadByte(); err != nil { - return nil, err - } - streamID, err := utils.ReadVarInt(r) - if err != nil { - return nil, err - } - return &MaxStreamIDFrame{StreamID: protocol.StreamID(streamID)}, nil -} - -func (f *MaxStreamIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x6) - utils.WriteVarInt(b, uint64(f.StreamID)) - return nil -} - -// Length of a written frame -func (f *MaxStreamIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { - return 1 + utils.VarIntLen(uint64(f.StreamID)) -} diff --git a/internal/wire/max_stream_id_frame_test.go b/internal/wire/max_stream_id_frame_test.go deleted file mode 100644 index 66a27e7b..00000000 --- a/internal/wire/max_stream_id_frame_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("MAX_STREAM_ID frame", func() { - Context("parsing", func() { - It("accepts sample frame", func() { - data := []byte{0x6} - data = append(data, encodeVarInt(0xdecafbad)...) - b := bytes.NewReader(data) - f, err := parseMaxStreamIDFrame(b, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Expect(f.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) - Expect(b.Len()).To(BeZero()) - }) - - It("errors on EOFs", func() { - data := []byte{0x06} - data = append(data, encodeVarInt(0xdeadbeefcafe13)...) - _, err := parseMaxStreamIDFrame(bytes.NewReader(data), protocol.VersionWhatever) - Expect(err).NotTo(HaveOccurred()) - for i := range data { - _, err := parseMaxStreamIDFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) - Expect(err).To(HaveOccurred()) - } - }) - }) - - Context("writing", func() { - It("writes a sample frame", func() { - b := &bytes.Buffer{} - frame := MaxStreamIDFrame{StreamID: 0x12345678} - frame.Write(b, protocol.VersionWhatever) - expected := []byte{0x6} - expected = append(expected, encodeVarInt(0x12345678)...) - Expect(b.Bytes()).To(Equal(expected)) - }) - - It("has the correct min length", func() { - frame := MaxStreamIDFrame{StreamID: 0x1337} - Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337))) - }) - }) -}) diff --git a/internal/wire/max_streams_frame.go b/internal/wire/max_streams_frame.go new file mode 100644 index 00000000..65bc38f3 --- /dev/null +++ b/internal/wire/max_streams_frame.go @@ -0,0 +1,51 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A MaxStreamsFrame is a MAX_STREAMS frame +type MaxStreamsFrame struct { + Type protocol.StreamType + MaxStreams uint64 +} + +func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamsFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &MaxStreamsFrame{} + switch typeByte { + case 0x1c: + f.Type = protocol.StreamTypeBidi + case 0x1d: + f.Type = protocol.StreamTypeUni + } + streamID, err := utils.ReadVarInt(r) + if err != nil { + return nil, err + } + f.MaxStreams = streamID + return f, nil +} + +func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + switch f.Type { + case protocol.StreamTypeBidi: + b.WriteByte(0x1c) + case protocol.StreamTypeUni: + b.WriteByte(0x1d) + } + utils.WriteVarInt(b, f.MaxStreams) + return nil +} + +// Length of a written frame +func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + utils.VarIntLen(f.MaxStreams) +} diff --git a/internal/wire/max_streams_frame_test.go b/internal/wire/max_streams_frame_test.go new file mode 100644 index 00000000..b2c32439 --- /dev/null +++ b/internal/wire/max_streams_frame_test.go @@ -0,0 +1,78 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("MAX_STREAMS frame", func() { + Context("parsing", func() { + It("accepts a frame for a bidirectional stream", func() { + data := []byte{0x1c} + data = append(data, encodeVarInt(0xdecaf)...) + b := bytes.NewReader(data) + f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeBidi)) + Expect(f.MaxStreams).To(BeEquivalentTo(0xdecaf)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts a frame for a bidirectional stream", func() { + data := []byte{0x1d} + data = append(data, encodeVarInt(0xdecaf)...) + b := bytes.NewReader(data) + f, err := parseMaxStreamsFrame(b, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Type).To(Equal(protocol.StreamTypeUni)) + Expect(f.MaxStreams).To(BeEquivalentTo(0xdecaf)) + Expect(b.Len()).To(BeZero()) + }) + + It("errors on EOFs", func() { + data := []byte{0x1d} + data = append(data, encodeVarInt(0xdeadbeefcafe13)...) + _, err := parseMaxStreamsFrame(bytes.NewReader(data), protocol.VersionWhatever) + Expect(err).NotTo(HaveOccurred()) + for i := range data { + _, err := parseMaxStreamsFrame(bytes.NewReader(data[0:i]), protocol.VersionWhatever) + Expect(err).To(HaveOccurred()) + } + }) + }) + + Context("writing", func() { + It("for a bidirectional stream", func() { + f := &MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreams: 0xdeadbeef, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x1c} + expected = append(expected, encodeVarInt(0xdeadbeef)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("for a unidirectional stream", func() { + f := &MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreams: 0xdecafbad, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + expected := []byte{0x1d} + expected = append(expected, encodeVarInt(0xdecafbad)...) + Expect(b.Bytes()).To(Equal(expected)) + }) + + It("has the correct min length", func() { + frame := MaxStreamsFrame{MaxStreams: 0x1337} + Expect(frame.Length(protocol.VersionWhatever)).To(Equal(1 + utils.VarIntLen(0x1337))) + }) + }) +}) diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 0fdd7262..9991d28f 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -110,16 +110,16 @@ func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) } -// HandleMaxStreamIDFrame mocks base method -func (m *MockStreamManager) HandleMaxStreamIDFrame(arg0 *wire.MaxStreamIDFrame) error { - ret := m.ctrl.Call(m, "HandleMaxStreamIDFrame", arg0) +// HandleMaxStreamsFrame mocks base method +func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) error { + ret := m.ctrl.Call(m, "HandleMaxStreamsFrame", arg0) ret0, _ := ret[0].(error) return ret0 } -// HandleMaxStreamIDFrame indicates an expected call of HandleMaxStreamIDFrame -func (mr *MockStreamManagerMockRecorder) HandleMaxStreamIDFrame(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamIDFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamIDFrame), arg0) +// HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame +func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0) } // OpenStream mocks base method diff --git a/session.go b/session.go index ba2c1922..c4394ef9 100644 --- a/session.go +++ b/session.go @@ -40,7 +40,7 @@ type streamManager interface { AcceptUniStream() (ReceiveStream, error) DeleteStream(protocol.StreamID) error UpdateLimits(*handshake.TransportParameters) - HandleMaxStreamIDFrame(*wire.MaxStreamIDFrame) error + HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error CloseWithError(error) } @@ -563,8 +563,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve s.handleMaxDataFrame(frame) case *wire.MaxStreamDataFrame: err = s.handleMaxStreamDataFrame(frame) - case *wire.MaxStreamIDFrame: - err = s.handleMaxStreamIDFrame(frame) + case *wire.MaxStreamsFrame: + err = s.handleMaxStreamsFrame(frame) case *wire.DataBlockedFrame: case *wire.StreamDataBlockedFrame: case *wire.StreamIDBlockedFrame: @@ -641,8 +641,8 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error return nil } -func (s *session) handleMaxStreamIDFrame(frame *wire.MaxStreamIDFrame) error { - return s.streamsMap.HandleMaxStreamIDFrame(frame) +func (s *session) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) error { + return s.streamsMap.HandleMaxStreamsFrame(frame) } func (s *session) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { diff --git a/session_test.go b/session_test.go index ddc3d827..796c0086 100644 --- a/session_test.go +++ b/session_test.go @@ -249,17 +249,20 @@ var _ = Describe("Session", func() { Context("handling MAX_STREAM_ID frames", func() { It("passes the frame to the streamsMap", func() { - f := &wire.MaxStreamIDFrame{StreamID: 10} - streamManager.EXPECT().HandleMaxStreamIDFrame(f) - err := sess.handleMaxStreamIDFrame(f) + f := &wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreams: 10, + } + streamManager.EXPECT().HandleMaxStreamsFrame(f) + err := sess.handleMaxStreamsFrame(f) Expect(err).ToNot(HaveOccurred()) }) It("returns errors", func() { - f := &wire.MaxStreamIDFrame{StreamID: 10} + f := &wire.MaxStreamsFrame{MaxStreams: 10} testErr := errors.New("test error") - streamManager.EXPECT().HandleMaxStreamIDFrame(f).Return(testErr) - err := sess.handleMaxStreamIDFrame(f) + streamManager.EXPECT().HandleMaxStreamsFrame(f).Return(testErr) + err := sess.handleMaxStreamsFrame(f) Expect(err).To(MatchError(testErr)) }) }) diff --git a/streams_map.go b/streams_map.go index a63640ba..4be37bf4 100644 --- a/streams_map.go +++ b/streams_map.go @@ -146,15 +146,13 @@ func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, err panic("") } -func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { - id := f.StreamID - if id.InitiatedBy() != m.perspective { - return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id) - } +func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error { + id := protocol.MaxStreamID(f.Type, f.MaxStreams, m.perspective) switch id.Type() { case protocol.StreamTypeUni: m.outgoingUniStreams.SetMaxStream(id) case protocol.StreamTypeBidi: + fmt.Printf("") m.outgoingBidiStreams.SetMaxStream(id) } return nil diff --git a/streams_map_generic_helper.go b/streams_map_generic_helper.go index f48db212..692f093e 100644 --- a/streams_map_generic_helper.go +++ b/streams_map_generic_helper.go @@ -1,6 +1,10 @@ package quic -import "github.com/cheekybits/genny/generic" +import ( + "github.com/cheekybits/genny/generic" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) // In the auto-generated streams maps, we need to be able to close the streams. // Therefore, extend the generic.Type with the stream close method. @@ -9,3 +13,5 @@ type item interface { generic.Type closeForShutdown(error) } + +const streamTypeGeneric protocol.StreamType = protocol.StreamTypeUni diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index d507d323..8ee8ccdd 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -24,7 +24,7 @@ type incomingBidiStreamsMap struct { maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamID) streamI - queueMaxStreamID func(*wire.MaxStreamIDFrame) + queueMaxStreamID func(*wire.MaxStreamsFrame) closeErr error } @@ -43,7 +43,7 @@ func newIncomingBidiStreamsMap( maxStream: initialMaxStreamID, maxNumStreams: maxNumStreams, newStream: newStream, - queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, + queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, } m.cond.L = &m.mutex return m @@ -111,7 +111,10 @@ func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { if m.maxNumStreams > uint64(len(m.streams)) { numNewStreams := m.maxNumStreams - uint64(len(m.streams)) m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4) - m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreams: m.maxStream.StreamNum(), + }) } return nil } diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index 0ea7ecaf..0dc6d9fc 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -8,8 +8,8 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream" -//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream" +//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" +//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" type incomingItemsMap struct { mutex sync.RWMutex cond sync.Cond @@ -22,7 +22,7 @@ type incomingItemsMap struct { maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamID) item - queueMaxStreamID func(*wire.MaxStreamIDFrame) + queueMaxStreamID func(*wire.MaxStreamsFrame) closeErr error } @@ -41,7 +41,7 @@ func newIncomingItemsMap( maxStream: initialMaxStreamID, maxNumStreams: maxNumStreams, newStream: newStream, - queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, + queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, } m.cond.L = &m.mutex return m @@ -109,7 +109,10 @@ func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error { if m.maxNumStreams > uint64(len(m.streams)) { numNewStreams := m.maxNumStreams - uint64(len(m.streams)) m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4) - m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: streamTypeGeneric, + MaxStreams: m.maxStream.StreamNum(), + }) } return nil } diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index fe970348..b5a33cfc 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -26,8 +26,8 @@ func (s *mockGenericStream) closeForShutdown(err error) { var _ = Describe("Streams Map (incoming)", func() { const ( - firstNewStream protocol.StreamID = 20 - maxNumStreams uint64 = 10 + firstNewStream protocol.StreamID = 2 + maxNumStreams uint64 = 5 initialMaxStream protocol.StreamID = firstNewStream + 4*protocol.StreamID(maxNumStreams-1) ) @@ -49,9 +49,9 @@ var _ = Describe("Streams Map (incoming)", func() { }) It("opens all streams up to the id on GetOrOpenStream", func() { - _, err := m.GetOrOpenStream(firstNewStream + 4*5) + _, err := m.GetOrOpenStream(firstNewStream + 4*4) Expect(err).ToNot(HaveOccurred()) - Expect(newItemCounter).To(Equal(6)) + Expect(newItemCounter).To(Equal(5)) }) It("starts opening streams at the right position", func() { @@ -59,9 +59,9 @@ var _ = Describe("Streams Map (incoming)", func() { _, err := m.GetOrOpenStream(firstNewStream + 4) Expect(err).ToNot(HaveOccurred()) Expect(newItemCounter).To(Equal(2)) - _, err = m.GetOrOpenStream(firstNewStream + 4*5) + _, err = m.GetOrOpenStream(firstNewStream + 4*4) Expect(err).ToNot(HaveOccurred()) - Expect(newItemCounter).To(Equal(6)) + Expect(newItemCounter).To(Equal(5)) }) It("accepts streams in the right order", func() { @@ -143,9 +143,9 @@ var _ = Describe("Streams Map (incoming)", func() { }) It("closes all streams when CloseWithError is called", func() { - str1, err := m.GetOrOpenStream(20) + str1, err := m.GetOrOpenStream(firstNewStream) Expect(err).ToNot(HaveOccurred()) - str2, err := m.GetOrOpenStream(20 + 8) + str2, err := m.GetOrOpenStream(firstNewStream + 8) Expect(err).ToNot(HaveOccurred()) testErr := errors.New("test err") m.CloseWithError(testErr) @@ -157,11 +157,11 @@ var _ = Describe("Streams Map (incoming)", func() { It("deletes streams", func() { mockSender.EXPECT().queueControlFrame(gomock.Any()) - _, err := m.GetOrOpenStream(20) + _, err := m.GetOrOpenStream(initialMaxStream) Expect(err).ToNot(HaveOccurred()) - err = m.DeleteStream(20) + err = m.DeleteStream(initialMaxStream) Expect(err).ToNot(HaveOccurred()) - str, err := m.GetOrOpenStream(20) + str, err := m.GetOrOpenStream(initialMaxStream) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) }) @@ -171,13 +171,17 @@ var _ = Describe("Streams Map (incoming)", func() { Expect(err).To(MatchError("Tried to delete unknown stream 1337")) }) - It("sends MAX_STREAM_ID frames when streams are deleted", func() { + It("sends MAX_STREAMS frames when streams are deleted", func() { // open a bunch of streams _, err := m.GetOrOpenStream(firstNewStream + 4*4) Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 4}) - Expect(m.DeleteStream(firstNewStream + 4)).To(Succeed()) - mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 8}) + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.MaxStreamsFrame).MaxStreams).To(Equal(maxNumStreams + 1)) + }) + Expect(m.DeleteStream(firstNewStream + 2*4)).To(Succeed()) + mockSender.EXPECT().queueControlFrame(gomock.Any()).Do(func(f wire.Frame) { + Expect(f.(*wire.MaxStreamsFrame).MaxStreams).To(Equal(maxNumStreams + 2)) + }) Expect(m.DeleteStream(firstNewStream + 3*4)).To(Succeed()) }) }) diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index d4019ca1..a1ceadd2 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -24,7 +24,7 @@ type incomingUniStreamsMap struct { maxNumStreams uint64 // maximum number of streams newStream func(protocol.StreamID) receiveStreamI - queueMaxStreamID func(*wire.MaxStreamIDFrame) + queueMaxStreamID func(*wire.MaxStreamsFrame) closeErr error } @@ -43,7 +43,7 @@ func newIncomingUniStreamsMap( maxStream: initialMaxStreamID, maxNumStreams: maxNumStreams, newStream: newStream, - queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, + queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, } m.cond.L = &m.mutex return m @@ -111,7 +111,10 @@ func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error { if m.maxNumStreams > uint64(len(m.streams)) { numNewStreams := m.maxNumStreams - uint64(len(m.streams)) m.maxStream = m.nextStreamToOpen + protocol.StreamID((numNewStreams-1)*4) - m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreams: m.maxStream.StreamNum(), + }) } return nil } diff --git a/streams_map_test.go b/streams_map_test.go index 62dc7709..bda250d1 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -292,7 +292,7 @@ var _ = Describe("Streams Map", func() { }) }) - Context("handling MAX_STREAM_ID frames", func() { + Context("handling MAX_STREAMS frames", func() { BeforeEach(func() { mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() }) @@ -300,49 +300,49 @@ var _ = Describe("Streams Map", func() { It("processes IDs for outgoing bidirectional streams", func() { _, err := m.OpenStream() Expect(err).To(MatchError(qerr.TooManyOpenStreams)) - err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingBidiStream}) - Expect(err).ToNot(HaveOccurred()) + Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreams: 1, + })).To(Succeed()) str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) + _, err = m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) - It("processes IDs for outgoing bidirectional streams", func() { + It("processes IDs for outgoing unidirectional streams", func() { _, err := m.OpenUniStream() Expect(err).To(MatchError(qerr.TooManyOpenStreams)) - err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingUniStream}) - Expect(err).ToNot(HaveOccurred()) + Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreams: 1, + })).To(Succeed()) str, err := m.OpenUniStream() Expect(err).ToNot(HaveOccurred()) Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) - }) - - It("rejects IDs for incoming bidirectional streams", func() { - err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingBidiStream}) - Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingBidiStream))) - }) - - It("rejects IDs for incoming unidirectional streams", func() { - err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingUniStream}) - Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingUniStream))) + _, err = m.OpenUniStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) }) - Context("sending MAX_STREAM_ID frames", func() { - It("sends MAX_STREAM_ID frames for bidirectional streams", func() { + Context("sending MAX_STREAMS frames", func() { + It("sends a MAX_STREAMS frame for bidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream) Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{ - StreamID: ids.firstIncomingBidiStream + 4*maxBidiStreams, + mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreams: maxBidiStreams + 1, }) Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed()) }) - It("sends MAX_STREAM_ID frames for unidirectional streams", func() { + It("sends a MAX_STREAMS frame for unidirectional streams", func() { _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream) Expect(err).ToNot(HaveOccurred()) - mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{ - StreamID: ids.firstIncomingUniStream + 4*maxUniStreams, + mockSender.EXPECT().queueControlFrame(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreams: maxUniStreams + 1, }) Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed()) })