diff --git a/client.go b/client.go index 3446799e..955c908e 100644 --- a/client.go +++ b/client.go @@ -169,9 +169,11 @@ func (c *client) dialTLS() error { params := &handshake.TransportParameters{ StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - MaxStreams: protocol.MaxIncomingStreams, IdleTimeout: c.config.IdleTimeout, OmitConnectionID: c.config.RequestConnectionIDOmission, + // TODO(#523): make these values configurable + MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient), + MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient), } csc := handshake.NewCryptoStreamConn(nil) extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version) diff --git a/internal/handshake/tls_extension_handler_server.go b/internal/handshake/tls_extension_handler_server.go index 313751c9..3e7e2705 100644 --- a/internal/handshake/tls_extension_handler_server.go +++ b/internal/handshake/tls_extension_handler_server.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "math" "github.com/lucas-clemente/quic-go/qerr" @@ -105,8 +104,6 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte if err != nil { return err } - // TODO(#878): remove this when implementing the MAX_STREAM_ID frame - params.MaxStreams = math.MaxUint32 h.paramsChan <- *params return nil } diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index 17f6c26c..2de372f8 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -127,6 +127,8 @@ var _ = Describe("Transport Parameters", func() { Expect(err).ToNot(HaveOccurred()) Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344))) Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455))) + Expect(params.MaxBidiStreamID).To(Equal(protocol.StreamID(0x33445566))) + Expect(params.MaxUniStreamID).To(Equal(protocol.StreamID(0x44556677))) Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second)) Expect(params.OmitConnectionID).To(BeFalse()) }) @@ -224,18 +226,20 @@ var _ = Describe("Transport Parameters", func() { StreamFlowControlWindow: 0xdeadbeef, ConnectionFlowControlWindow: 0xdecafbad, IdleTimeout: 0xcafe * time.Second, + MaxBidiStreamID: 0xbadf000d, + MaxUniStreamID: 0xface, } }) It("creates the parameters list", func() { values := paramsListToMap(params.getTransportParameters()) - Expect(values).To(HaveLen(5)) + Expect(values).To(HaveLen(6)) Expect(values).To(HaveKeyWithValue(initialMaxStreamDataParameterID, []byte{0xde, 0xad, 0xbe, 0xef})) Expect(values).To(HaveKeyWithValue(initialMaxDataParameterID, []byte{0xde, 0xca, 0xfb, 0xad})) - Expect(values).To(HaveKeyWithValue(initialMaxStreamIDBiDiParameterID, []byte{0xff, 0xff, 0xff, 0xff})) + Expect(values).To(HaveKeyWithValue(initialMaxStreamIDBiDiParameterID, []byte{0xba, 0xdf, 0x00, 0x0d})) + Expect(values).To(HaveKeyWithValue(initialMaxStreamIDUniParameterID, []byte{0x0, 0x0, 0xfa, 0xce})) Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0xca, 0xfe})) Expect(values).To(HaveKeyWithValue(maxPacketSizeParameterID, []byte{0x5, 0xac})) // 1452 = 0x5ac - Expect(values).ToNot(HaveKey(initialMaxStreamIDUniParameterID)) }) It("request ommision of the connection ID", func() { diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index e2d6df72..a0283582 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "errors" "fmt" - "math" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -21,7 +20,9 @@ type TransportParameters struct { StreamFlowControlWindow protocol.ByteCount ConnectionFlowControlWindow protocol.ByteCount - MaxStreams uint32 + MaxBidiStreamID protocol.StreamID // only used for IETF QUIC + MaxUniStreamID protocol.StreamID // only used for IETF QUIC + MaxStreams uint32 // only used for gQUIC OmitConnectionID bool IdleTimeout time.Duration @@ -117,12 +118,14 @@ func readTransportParamters(paramsList []transportParameter) (*TransportParamete if len(p.Value) != 4 { return nil, fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 4)", len(p.Value)) } - // TODO: handle this value + // TODO(#1154): validate the stream ID + params.MaxBidiStreamID = protocol.StreamID(binary.BigEndian.Uint32(p.Value)) case initialMaxStreamIDUniParameterID: if len(p.Value) != 4 { return nil, fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 4)", len(p.Value)) } - // TODO: handle this value + // TODO(#1154): validate the stream ID + params.MaxUniStreamID = protocol.StreamID(binary.BigEndian.Uint32(p.Value)) case idleTimeoutParameterID: foundIdleTimeout = true if len(p.Value) != 2 { @@ -150,9 +153,10 @@ func (p *TransportParameters) getTransportParameters() []transportParameter { binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow)) initialMaxData := make([]byte, 4) binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow)) - initialMaxStreamIDBiDi := make([]byte, 4) - // TODO: use a reasonable value here - binary.BigEndian.PutUint32(initialMaxStreamIDBiDi, math.MaxUint32) + initialMaxBidiStreamID := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxBidiStreamID, uint32(p.MaxBidiStreamID)) + initialMaxUniStreamID := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxUniStreamID, uint32(p.MaxUniStreamID)) idleTimeout := make([]byte, 2) binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second)) maxPacketSize := make([]byte, 2) @@ -160,7 +164,8 @@ func (p *TransportParameters) getTransportParameters() []transportParameter { params := []transportParameter{ {initialMaxStreamDataParameterID, initialMaxStreamData}, {initialMaxDataParameterID, initialMaxData}, - {initialMaxStreamIDBiDiParameterID, initialMaxStreamIDBiDi}, + {initialMaxStreamIDBiDiParameterID, initialMaxBidiStreamID}, + {initialMaxStreamIDUniParameterID, initialMaxUniStreamID}, {idleTimeoutParameterID, idleTimeout}, {maxPacketSizeParameterID, maxPacketSize}, } diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 16229836..0901b19e 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -55,9 +55,6 @@ func (t PacketType) String() string { // A ConnectionID in QUIC type ConnectionID uint64 -// A StreamID in QUIC -type StreamID uint64 - // A ByteCount in QUIC type ByteCount uint64 diff --git a/internal/protocol/stream_id.go b/internal/protocol/stream_id.go new file mode 100644 index 00000000..a0dced0c --- /dev/null +++ b/internal/protocol/stream_id.go @@ -0,0 +1,36 @@ +package protocol + +// A StreamID in QUIC +type StreamID uint64 + +// MaxBidiStreamID is the highest stream ID that the peer is allowed to open, +// when it is allowed to open numStreams bidirectional streams. +// It is only valid for IETF QUIC. +func MaxBidiStreamID(numStreams int, pers Perspective) StreamID { + if numStreams == 0 { + return 0 + } + var first StreamID + if pers == PerspectiveClient { + first = 1 + } else { + first = 4 + } + return first + 4*StreamID(numStreams-1) +} + +// MaxUniStreamID is the highest stream ID that the peer is allowed to open, +// when it is allowed to open numStreams unidirectional streams. +// It is only valid for IETF QUIC. +func MaxUniStreamID(numStreams int, pers Perspective) StreamID { + if numStreams == 0 { + return 0 + } + var first StreamID + if pers == PerspectiveClient { + first = 3 + } else { + first = 2 + } + return first + 4*StreamID(numStreams-1) +} diff --git a/internal/protocol/stream_id_test.go b/internal/protocol/stream_id_test.go new file mode 100644 index 00000000..cca4f928 --- /dev/null +++ b/internal/protocol/stream_id_test.go @@ -0,0 +1,42 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stream ID", func() { + Context("bidirectional streams", func() { + It("doesn't allow any", func() { + Expect(MaxBidiStreamID(0, PerspectiveClient)).To(Equal(StreamID(0))) + Expect(MaxBidiStreamID(0, PerspectiveServer)).To(Equal(StreamID(0))) + }) + + It("allows one", func() { + Expect(MaxBidiStreamID(1, PerspectiveClient)).To(Equal(StreamID(1))) + Expect(MaxBidiStreamID(1, PerspectiveServer)).To(Equal(StreamID(4))) + }) + + It("allows many", func() { + Expect(MaxBidiStreamID(100, PerspectiveClient)).To(Equal(StreamID(397))) + Expect(MaxBidiStreamID(100, PerspectiveServer)).To(Equal(StreamID(400))) + }) + }) + + Context("unidirectional streams", func() { + It("doesn't allow any", func() { + Expect(MaxUniStreamID(0, PerspectiveClient)).To(Equal(StreamID(0))) + Expect(MaxUniStreamID(0, PerspectiveServer)).To(Equal(StreamID(0))) + }) + + It("allows one", func() { + Expect(MaxUniStreamID(1, PerspectiveClient)).To(Equal(StreamID(3))) + Expect(MaxUniStreamID(1, PerspectiveServer)).To(Equal(StreamID(2))) + }) + + It("allows many", func() { + Expect(MaxUniStreamID(100, PerspectiveClient)).To(Equal(StreamID(399))) + Expect(MaxUniStreamID(100, PerspectiveServer)).To(Equal(StreamID(398))) + }) + }) +}) diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index ddce5555..bbf32a5d 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -10,6 +10,7 @@ import ( gomock "github.com/golang/mock/gomock" handshake "github.com/lucas-clemente/quic-go/internal/handshake" protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" ) // MockStreamManager is a mock of StreamManager interface @@ -96,6 +97,18 @@ func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) } +// HandleMaxStreamIDFrame mocks base method +func (m *MockStreamManager) HandleMaxStreamIDFrame(arg0 *wire.MaxStreamIDFrame) error { + ret := m.ctrl.Call(m, "HandleMaxStreamIDFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleMaxStreamIDFrame indicates an expected call of HandleMaxStreamIDFrame +func (mr *MockStreamManagerMockRecorder) HandleMaxStreamIDFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamIDFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamIDFrame), arg0) +} + // OpenStream mocks base method func (m *MockStreamManager) OpenStream() (Stream, error) { ret := m.ctrl.Call(m, "OpenStream") diff --git a/server_tls.go b/server_tls.go index ff766d06..c45de35b 100644 --- a/server_tls.go +++ b/server_tls.go @@ -66,8 +66,10 @@ func newServerTLS( params: &handshake.TransportParameters{ StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - MaxStreams: protocol.MaxIncomingStreams, IdleTimeout: config.IdleTimeout, + // TODO(#523): make these values configurable + MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveServer), + MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveServer), }, } s.newMintConn = s.newMintConnImpl diff --git a/session.go b/session.go index 4f606641..f93bf012 100644 --- a/session.go +++ b/session.go @@ -37,6 +37,7 @@ type streamManager interface { AcceptStream() (Stream, error) DeleteStream(protocol.StreamID) error UpdateLimits(*handshake.TransportParameters) + HandleMaxStreamIDFrame(*wire.MaxStreamIDFrame) error CloseWithError(error) } @@ -563,6 +564,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve s.handleMaxDataFrame(frame) case *wire.MaxStreamDataFrame: err = s.handleMaxStreamDataFrame(frame) + case *wire.MaxStreamIDFrame: + err = s.handleMaxStreamIDFrame(frame) case *wire.BlockedFrame: case *wire.StreamBlockedFrame: case *wire.StopSendingFrame: @@ -634,6 +637,10 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error return nil } +func (s *session) handleMaxStreamIDFrame(frame *wire.MaxStreamIDFrame) error { + return s.streamsMap.HandleMaxStreamIDFrame(frame) +} + func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { if frame.StreamID == s.version.CryptoStreamID() { return errors.New("Received RST_STREAM frame for the crypto stream") diff --git a/session_test.go b/session_test.go index 4f7b804a..53c27ed0 100644 --- a/session_test.go +++ b/session_test.go @@ -357,6 +357,23 @@ var _ = Describe("Session", func() { }) }) + Context("handling MAX_STREAM_ID frames", func() { + It("passes the frame to the streamsMap", func() { + f := &wire.MaxStreamIDFrame{StreamID: 10} + streamManager.EXPECT().HandleMaxStreamIDFrame(f) + err := sess.handleMaxStreamIDFrame(f) + Expect(err).ToNot(HaveOccurred()) + }) + + It("returns errors", func() { + f := &wire.MaxStreamIDFrame{StreamID: 10} + testErr := errors.New("test error") + streamManager.EXPECT().HandleMaxStreamIDFrame(f).Return(testErr) + err := sess.handleMaxStreamIDFrame(f) + Expect(err).To(MatchError(testErr)) + }) + }) + Context("handling STOP_SENDING frames", func() { It("passes the frame to the stream", func() { f := &wire.StopSendingFrame{ diff --git a/streams_map.go b/streams_map.go index 9b19dc9d..8287c83b 100644 --- a/streams_map.go +++ b/streams_map.go @@ -6,6 +6,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) type streamType int @@ -64,9 +65,23 @@ func newStreamsMap( return newReceiveStream(id, m.sender, m.newFlowController(id), version) } m.outgoingBidiStreams = newOutgoingBidiStreamsMap(firstOutgoingBidiStream, newBidiStream) - m.incomingBidiStreams = newIncomingBidiStreamsMap(firstIncomingBidiStream, newBidiStream) + // TODO(#523): make these values configurable + m.incomingBidiStreams = newIncomingBidiStreamsMap( + firstIncomingBidiStream, + protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, perspective), + protocol.MaxIncomingStreams, + sender.queueControlFrame, + newBidiStream, + ) m.outgoingUniStreams = newOutgoingUniStreamsMap(firstOutgoingUniStream, newUniSendStream) - m.incomingUniStreams = newIncomingUniStreamsMap(firstIncomingUniStream, newUniReceiveStream) + // TODO(#523): make these values configurable + m.incomingUniStreams = newIncomingUniStreamsMap( + firstIncomingUniStream, + protocol.MaxUniStreamID(protocol.MaxIncomingStreams, perspective), + protocol.MaxIncomingStreams, + sender.queueControlFrame, + newUniReceiveStream, + ) return m } @@ -102,8 +117,7 @@ func (m *streamsMap) OpenStream() (Stream, error) { } func (m *streamsMap) OpenStreamSync() (Stream, error) { - // TODO(#1150): implement limits - return m.OpenStream() + return m.outgoingBidiStreams.OpenStreamSync() } func (m *streamsMap) OpenUniStream() (SendStream, error) { @@ -111,8 +125,7 @@ func (m *streamsMap) OpenUniStream() (SendStream, error) { } func (m *streamsMap) OpenUniStreamSync() (SendStream, error) { - // TODO(#1150): implement limits - return m.outgoingUniStreams.OpenStream() + return m.outgoingUniStreams.OpenStreamSync() } func (m *streamsMap) AcceptStream() (Stream, error) { @@ -170,8 +183,23 @@ func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, err } } -func (m *streamsMap) UpdateLimits(*handshake.TransportParameters) { - // TODO(#1150): implement limits +func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { + id := f.StreamID + switch m.getStreamType(id) { + case streamTypeOutgoingBidi: + m.outgoingBidiStreams.SetMaxStream(id) + return nil + case streamTypeOutgoingUni: + m.outgoingUniStreams.SetMaxStream(id) + return nil + default: + return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id) + } +} + +func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) { + m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamID) + m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamID) } func (m *streamsMap) CloseWithError(err error) { diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index 774bf1a6..8a35f044 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) type incomingBidiStreamsMap struct { @@ -17,18 +18,31 @@ type incomingBidiStreamsMap struct { streams map[protocol.StreamID]streamI - nextStream protocol.StreamID - highestStream protocol.StreamID - newStream func(protocol.StreamID) streamI + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams int // maximum number of streams + + newStream func(protocol.StreamID) streamI + queueMaxStreamID func(*wire.MaxStreamIDFrame) closeErr error } -func newIncomingBidiStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) streamI) *incomingBidiStreamsMap { +func newIncomingBidiStreamsMap( + nextStream protocol.StreamID, + initialMaxStreamID protocol.StreamID, + maxNumStreams int, + queueControlFrame func(wire.Frame), + newStream func(protocol.StreamID) streamI, +) *incomingBidiStreamsMap { m := &incomingBidiStreamsMap{ - streams: make(map[protocol.StreamID]streamI), - nextStream: nextStream, - newStream: newStream, + streams: make(map[protocol.StreamID]streamI), + nextStream: nextStream, + maxStream: initialMaxStreamID, + maxNumStreams: maxNumStreams, + newStream: newStream, + queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, } m.cond.L = &m.mutex return m @@ -55,6 +69,9 @@ func (m *incomingBidiStreamsMap) AcceptStream() (streamI, error) { } func (m *incomingBidiStreamsMap) GetOrOpenStream(id protocol.StreamID) (streamI, error) { + if id > m.maxStream { + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } // if the id is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil @@ -90,6 +107,11 @@ func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { return fmt.Errorf("Tried to delete unknown stream %d", id) } delete(m.streams, id) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 { + m.maxStream = m.highestStream + protocol.StreamID(numNewStreams*4) + m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + } return nil } diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index e03311c6..830b690d 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) //go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream" @@ -15,18 +16,31 @@ type incomingItemsMap struct { streams map[protocol.StreamID]item - nextStream protocol.StreamID - highestStream protocol.StreamID - newStream func(protocol.StreamID) item + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams int // maximum number of streams + + newStream func(protocol.StreamID) item + queueMaxStreamID func(*wire.MaxStreamIDFrame) closeErr error } -func newIncomingItemsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) item) *incomingItemsMap { +func newIncomingItemsMap( + nextStream protocol.StreamID, + initialMaxStreamID protocol.StreamID, + maxNumStreams int, + queueControlFrame func(wire.Frame), + newStream func(protocol.StreamID) item, +) *incomingItemsMap { m := &incomingItemsMap{ - streams: make(map[protocol.StreamID]item), - nextStream: nextStream, - newStream: newStream, + streams: make(map[protocol.StreamID]item), + nextStream: nextStream, + maxStream: initialMaxStreamID, + maxNumStreams: maxNumStreams, + newStream: newStream, + queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, } m.cond.L = &m.mutex return m @@ -53,6 +67,9 @@ func (m *incomingItemsMap) AcceptStream() (item, error) { } func (m *incomingItemsMap) GetOrOpenStream(id protocol.StreamID) (item, error) { + if id > m.maxStream { + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } // if the id is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil @@ -88,6 +105,11 @@ func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error { return fmt.Errorf("Tried to delete unknown stream %d", id) } delete(m.streams, id) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 { + m.maxStream = m.highestStream + protocol.StreamID(numNewStreams*4) + m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + } return nil } diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 6b2b9b43..5d66fcd1 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -2,19 +2,28 @@ package quic import ( "errors" + "fmt" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Streams Map (outgoing)", func() { - const firstNewStream protocol.StreamID = 20 + const ( + firstNewStream protocol.StreamID = 20 + maxNumStreams int = 10 + initialMaxStream protocol.StreamID = firstNewStream + 4*protocol.StreamID(maxNumStreams-1) + ) + var ( m *incomingItemsMap newItem func(id protocol.StreamID) item newItemCounter int + mockSender *MockStreamSender ) BeforeEach(func() { @@ -23,7 +32,8 @@ var _ = Describe("Streams Map (outgoing)", func() { newItemCounter++ return id } - m = newIncomingItemsMap(firstNewStream, newItem) + mockSender = NewMockStreamSender(mockCtrl) + m = newIncomingItemsMap(firstNewStream, initialMaxStream, maxNumStreams, mockSender.queueControlFrame, newItem) }) It("opens all streams up to the id on GetOrOpenStream", func() { @@ -53,6 +63,17 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(str).To(Equal(firstNewStream + 4)) }) + It("allows opening the maximum stream ID", func() { + str, err := m.GetOrOpenStream(initialMaxStream) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(initialMaxStream)) + }) + + It("errors when trying to get a stream ID higher than the maximum", func() { + _, err := m.GetOrOpenStream(initialMaxStream + 4) + Expect(err).To(MatchError(fmt.Errorf("peer tried to open stream %d (current limit: %d)", initialMaxStream+4, initialMaxStream))) + }) + It("blocks AcceptStream until a new stream is available", func() { strChan := make(chan item) go func() { @@ -90,6 +111,7 @@ var _ = Describe("Streams Map (outgoing)", func() { }) It("deletes streams", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) _, err := m.GetOrOpenStream(20) Expect(err).ToNot(HaveOccurred()) err = m.DeleteStream(20) @@ -103,4 +125,14 @@ var _ = Describe("Streams Map (outgoing)", func() { err := m.DeleteStream(1337) Expect(err).To(MatchError("Tried to delete unknown stream 1337")) }) + + It("sends MAX_STREAM_ID frames when streams are deleted", func() { + // open a bunch of streams + _, err := m.GetOrOpenStream(firstNewStream + 4*4) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 4}) + Expect(m.DeleteStream(firstNewStream + 4)).To(Succeed()) + mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{StreamID: initialMaxStream + 8}) + Expect(m.DeleteStream(firstNewStream + 3*4)).To(Succeed()) + }) }) diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index 7cf57afb..9091d635 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" ) type incomingUniStreamsMap struct { @@ -17,18 +18,31 @@ type incomingUniStreamsMap struct { streams map[protocol.StreamID]receiveStreamI - nextStream protocol.StreamID - highestStream protocol.StreamID - newStream func(protocol.StreamID) receiveStreamI + nextStream protocol.StreamID // the next stream that will be returned by AcceptStream() + highestStream protocol.StreamID // the highest stream that the peer openend + maxStream protocol.StreamID // the highest stream that the peer is allowed to open + maxNumStreams int // maximum number of streams + + newStream func(protocol.StreamID) receiveStreamI + queueMaxStreamID func(*wire.MaxStreamIDFrame) closeErr error } -func newIncomingUniStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) receiveStreamI) *incomingUniStreamsMap { +func newIncomingUniStreamsMap( + nextStream protocol.StreamID, + initialMaxStreamID protocol.StreamID, + maxNumStreams int, + queueControlFrame func(wire.Frame), + newStream func(protocol.StreamID) receiveStreamI, +) *incomingUniStreamsMap { m := &incomingUniStreamsMap{ - streams: make(map[protocol.StreamID]receiveStreamI), - nextStream: nextStream, - newStream: newStream, + streams: make(map[protocol.StreamID]receiveStreamI), + nextStream: nextStream, + maxStream: initialMaxStreamID, + maxNumStreams: maxNumStreams, + newStream: newStream, + queueMaxStreamID: func(f *wire.MaxStreamIDFrame) { queueControlFrame(f) }, } m.cond.L = &m.mutex return m @@ -55,6 +69,9 @@ func (m *incomingUniStreamsMap) AcceptStream() (receiveStreamI, error) { } func (m *incomingUniStreamsMap) GetOrOpenStream(id protocol.StreamID) (receiveStreamI, error) { + if id > m.maxStream { + return nil, fmt.Errorf("peer tried to open stream %d (current limit: %d)", id, m.maxStream) + } // if the id is smaller than the highest we accepted // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil @@ -90,6 +107,11 @@ func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error { return fmt.Errorf("Tried to delete unknown stream %d", id) } delete(m.streams, id) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if numNewStreams := m.maxNumStreams - len(m.streams); numNewStreams > 0 { + m.maxStream = m.highestStream + protocol.StreamID(numNewStreams*4) + m.queueMaxStreamID(&wire.MaxStreamIDFrame{StreamID: m.maxStream}) + } return nil } diff --git a/streams_map_legacy.go b/streams_map_legacy.go index abd06574..152f20e5 100644 --- a/streams_map_legacy.go +++ b/streams_map_legacy.go @@ -256,3 +256,8 @@ func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) { m.mutex.Unlock() m.openStreamOrErrCond.Broadcast() } + +// should never be called, since MAX_STREAM_ID frames can only be unpacked for IETF QUIC +func (m *streamsMapLegacy) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error { + return errors.New("gQUIC doesn't have MAX_STREAM_ID frames") +} diff --git a/streams_map_legacy_test.go b/streams_map_legacy_test.go index 72852f1d..404de4a3 100644 --- a/streams_map_legacy_test.go +++ b/streams_map_legacy_test.go @@ -546,4 +546,8 @@ var _ = Describe("Streams Map (for gQUIC)", func() { }) m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321}) }) + + It("doesn't accept MAX_STREAM_ID frames", func() { + Expect(m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{})).ToNot(Succeed()) + }) }) diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go index 8bad5324..d6951a23 100644 --- a/streams_map_outgoing_bidi.go +++ b/streams_map_outgoing_bidi.go @@ -14,30 +14,57 @@ import ( type outgoingBidiStreamsMap struct { mutex sync.RWMutex + cond sync.Cond streams map[protocol.StreamID]streamI nextStream protocol.StreamID + maxStream protocol.StreamID newStream func(protocol.StreamID) streamI closeErr error } func newOutgoingBidiStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) streamI) *outgoingBidiStreamsMap { - return &outgoingBidiStreamsMap{ + m := &outgoingBidiStreamsMap{ streams: make(map[protocol.StreamID]streamI), nextStream: nextStream, newStream: newStream, } + m.cond.L = &m.mutex + return m } func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { m.mutex.Lock() defer m.mutex.Unlock() + return m.openStreamImpl() +} + +func (m *outgoingBidiStreamsMap) OpenStreamSync() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.cond.Wait() + } +} + +func (m *outgoingBidiStreamsMap) openStreamImpl() (streamI, error) { if m.closeErr != nil { return nil, m.closeErr } + if m.nextStream > m.maxStream { + return nil, qerr.TooManyOpenStreams + } s := m.newStream(m.nextStream) m.streams[m.nextStream] = s m.nextStream += 4 @@ -65,8 +92,18 @@ func (m *outgoingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { return nil } +func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) { + m.mutex.Lock() + if id > m.maxStream { + m.maxStream = id + m.cond.Broadcast() + } + m.mutex.Unlock() +} + func (m *outgoingBidiStreamsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err + m.cond.Broadcast() m.mutex.Unlock() } diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go index 93f82d5a..49ead874 100644 --- a/streams_map_outgoing_generic.go +++ b/streams_map_outgoing_generic.go @@ -15,30 +15,57 @@ type item generic.Type //go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream" type outgoingItemsMap struct { mutex sync.RWMutex + cond sync.Cond streams map[protocol.StreamID]item nextStream protocol.StreamID + maxStream protocol.StreamID newStream func(protocol.StreamID) item closeErr error } func newOutgoingItemsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) item) *outgoingItemsMap { - return &outgoingItemsMap{ + m := &outgoingItemsMap{ streams: make(map[protocol.StreamID]item), nextStream: nextStream, newStream: newStream, } + m.cond.L = &m.mutex + return m } func (m *outgoingItemsMap) OpenStream() (item, error) { m.mutex.Lock() defer m.mutex.Unlock() + return m.openStreamImpl() +} + +func (m *outgoingItemsMap) OpenStreamSync() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.cond.Wait() + } +} + +func (m *outgoingItemsMap) openStreamImpl() (item, error) { if m.closeErr != nil { return nil, m.closeErr } + if m.nextStream > m.maxStream { + return nil, qerr.TooManyOpenStreams + } s := m.newStream(m.nextStream) m.streams[m.nextStream] = s m.nextStream += 4 @@ -66,8 +93,18 @@ func (m *outgoingItemsMap) DeleteStream(id protocol.StreamID) error { return nil } +func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) { + m.mutex.Lock() + if id > m.maxStream { + m.maxStream = id + m.cond.Broadcast() + } + m.mutex.Unlock() +} + func (m *outgoingItemsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err + m.cond.Broadcast() m.mutex.Unlock() } diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_generic_test.go index 4c010231..a4cf22f1 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_generic_test.go @@ -23,56 +23,107 @@ var _ = Describe("Streams Map (outgoing)", func() { m = newOutgoingItemsMap(firstNewStream, newItem) }) - It("opens streams", func() { - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream)) - str, err = m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream + 4)) + Context("no stream ID limit", func() { + BeforeEach(func() { + m.SetMaxStream(0xffffffff) + }) + + It("opens streams", func() { + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(firstNewStream)) + str, err = m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(firstNewStream + 4)) + }) + + It("doesn't open streams after it has been closed", func() { + testErr := errors.New("close") + m.CloseWithError(testErr) + _, err := m.OpenStream() + Expect(err).To(MatchError(testErr)) + }) + + It("gets streams", func() { + _, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetStream(firstNewStream) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(firstNewStream)) + }) + + It("errors when trying to get a stream that has not yet been opened", func() { + _, err := m.GetStream(10) + Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamID, "peer attempted to open stream 10"))) + }) + + It("deletes streams", func() { + _, err := m.OpenStream() // opens stream 10 + Expect(err).ToNot(HaveOccurred()) + err = m.DeleteStream(10) + Expect(err).ToNot(HaveOccurred()) + str, err := m.GetStream(10) + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(BeNil()) + }) + + It("errors when deleting a non-existing stream", func() { + err := m.DeleteStream(1337) + Expect(err).To(MatchError("Tried to delete unknown stream 1337")) + }) + + It("errors when deleting a stream twice", func() { + _, err := m.OpenStream() // opens stream 10 + Expect(err).ToNot(HaveOccurred()) + err = m.DeleteStream(10) + Expect(err).ToNot(HaveOccurred()) + err = m.DeleteStream(10) + Expect(err).To(MatchError("Tried to delete unknown stream 10")) + }) }) - It("doesn't open streams after it has been closed", func() { - testErr := errors.New("close") - m.CloseWithError(testErr) - _, err := m.OpenStream() - Expect(err).To(MatchError(testErr)) - }) + Context("with stream ID limits", func() { + It("errors when no stream can be opened immediately", func() { + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) - It("gets streams", func() { - _, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - str, err := m.GetStream(firstNewStream) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal(firstNewStream)) - }) + It("blocks until a stream can be opened synchronously", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + str, err := m.OpenStreamSync() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(firstNewStream)) + close(done) + }() - It("errors when trying to get a stream that has not yet been opened", func() { - _, err := m.GetStream(10) - Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamID, "peer attempted to open stream 10"))) - }) + Consistently(done).ShouldNot(BeClosed()) + m.SetMaxStream(firstNewStream) + Eventually(done).Should(BeClosed()) + }) - It("deletes streams", func() { - _, err := m.OpenStream() // opens stream 10 - Expect(err).ToNot(HaveOccurred()) - err = m.DeleteStream(10) - Expect(err).ToNot(HaveOccurred()) - str, err := m.GetStream(10) - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(BeNil()) - }) + It("stops opening synchronously when it is closed", func() { + testErr := errors.New("test error") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := m.OpenStreamSync() + Expect(err).To(MatchError(testErr)) + close(done) + }() - It("errors when deleting a non-existing stream", func() { - err := m.DeleteStream(1337) - Expect(err).To(MatchError("Tried to delete unknown stream 1337")) - }) + Consistently(done).ShouldNot(BeClosed()) + m.CloseWithError(testErr) + Eventually(done).Should(BeClosed()) + }) - It("errors when deleting a stream twice", func() { - _, err := m.OpenStream() // opens stream 10 - Expect(err).ToNot(HaveOccurred()) - err = m.DeleteStream(10) - Expect(err).ToNot(HaveOccurred()) - err = m.DeleteStream(10) - Expect(err).To(MatchError("Tried to delete unknown stream 10")) + It("doesn't reduce the stream limit", func() { + m.SetMaxStream(firstNewStream) + m.SetMaxStream(firstNewStream - 4) + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal(firstNewStream)) + }) }) }) diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing_uni.go index 913f70c6..2f0e6bc0 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing_uni.go @@ -14,30 +14,57 @@ import ( type outgoingUniStreamsMap struct { mutex sync.RWMutex + cond sync.Cond streams map[protocol.StreamID]sendStreamI nextStream protocol.StreamID + maxStream protocol.StreamID newStream func(protocol.StreamID) sendStreamI closeErr error } func newOutgoingUniStreamsMap(nextStream protocol.StreamID, newStream func(protocol.StreamID) sendStreamI) *outgoingUniStreamsMap { - return &outgoingUniStreamsMap{ + m := &outgoingUniStreamsMap{ streams: make(map[protocol.StreamID]sendStreamI), nextStream: nextStream, newStream: newStream, } + m.cond.L = &m.mutex + return m } func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { m.mutex.Lock() defer m.mutex.Unlock() + return m.openStreamImpl() +} + +func (m *outgoingUniStreamsMap) OpenStreamSync() (sendStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + for { + str, err := m.openStreamImpl() + if err == nil { + return str, err + } + if err != nil && err != qerr.TooManyOpenStreams { + return nil, err + } + m.cond.Wait() + } +} + +func (m *outgoingUniStreamsMap) openStreamImpl() (sendStreamI, error) { if m.closeErr != nil { return nil, m.closeErr } + if m.nextStream > m.maxStream { + return nil, qerr.TooManyOpenStreams + } s := m.newStream(m.nextStream) m.streams[m.nextStream] = s m.nextStream += 4 @@ -65,8 +92,18 @@ func (m *outgoingUniStreamsMap) DeleteStream(id protocol.StreamID) error { return nil } +func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) { + m.mutex.Lock() + if id > m.maxStream { + m.maxStream = id + m.cond.Broadcast() + } + m.mutex.Unlock() +} + func (m *outgoingUniStreamsMap) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err + m.cond.Broadcast() m.mutex.Unlock() } diff --git a/streams_map_test.go b/streams_map_test.go index 6193f7da..a9901d8c 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -4,9 +4,12 @@ import ( "errors" "fmt" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -48,14 +51,26 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() { } Context(perspective.String(), func() { - var m *streamsMap + var ( + m *streamsMap + mockSender *MockStreamSender + ) + + allowUnlimitedStreams := func() { + m.UpdateLimits(&handshake.TransportParameters{ + MaxBidiStreamID: 0xffffffff, + MaxUniStreamID: 0xffffffff, + }) + } BeforeEach(func() { - m = newStreamsMap(nil, newFlowController, perspective, versionIETFFrames).(*streamsMap) + mockSender = NewMockStreamSender(mockCtrl) + m = newStreamsMap(mockSender, newFlowController, perspective, versionIETFFrames).(*streamsMap) }) Context("opening", func() { It("opens bidirectional streams", func() { + allowUnlimitedStreams() str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeAssignableToTypeOf(&stream{})) @@ -67,6 +82,7 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() { }) It("opens unidirectional streams", func() { + allowUnlimitedStreams() str, err := m.OpenUniStream() Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeAssignableToTypeOf(&sendStream{})) @@ -99,6 +115,11 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() { }) Context("deleting", func() { + BeforeEach(func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() + allowUnlimitedStreams() + }) + It("deletes outgoing bidirectional streams", func() { id := ids.firstOutgoingBidiStream str, err := m.OpenStream() @@ -145,6 +166,10 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() { }) Context("getting streams", func() { + BeforeEach(func() { + allowUnlimitedStreams() + }) + Context("send streams", func() { It("gets an outgoing bidirectional stream", func() { // need to open the stream ourselves first @@ -231,6 +256,82 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() { }) }) + Context("updating stream ID limits", func() { + It("processes the parameter for outgoing bidirectional streams", func() { + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + m.UpdateLimits(&handshake.TransportParameters{ + MaxBidiStreamID: ids.firstOutgoingBidiStream, + }) + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) + }) + + It("processes the parameter for outgoing bidirectional streams", func() { + _, err := m.OpenUniStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + m.UpdateLimits(&handshake.TransportParameters{ + MaxUniStreamID: ids.firstOutgoingUniStream, + }) + str, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) + }) + }) + + Context("handling MAX_STREAM_ID frames", func() { + It("processes IDs for outgoing bidirectional streams", func() { + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingBidiStream}) + Expect(err).ToNot(HaveOccurred()) + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream)) + }) + + It("processes IDs for outgoing bidirectional streams", func() { + _, err := m.OpenUniStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingUniStream}) + Expect(err).ToNot(HaveOccurred()) + str, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream)) + }) + + It("rejects IDs for incoming bidirectional streams", func() { + err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingBidiStream}) + Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingBidiStream))) + }) + + It("rejects IDs for incoming unidirectional streams", func() { + err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingUniStream}) + Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingUniStream))) + }) + }) + + Context("sending MAX_STREAM_ID frames", func() { + It("sends MAX_STREAM_ID frames for bidirectional streams", func() { + _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream + 4*10) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{ + StreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, perspective) + 4, + }) + Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed()) + }) + + It("sends MAX_STREAM_ID frames for unidirectional streams", func() { + _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream + 4*10) + Expect(err).ToNot(HaveOccurred()) + mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{ + StreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, perspective) + 4, + }) + Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed()) + }) + }) + It("closes", func() { testErr := errors.New("test error") m.CloseWithError(testErr)