diff --git a/session.go b/session.go index 6357c5d0..8f117d93 100644 --- a/session.go +++ b/session.go @@ -688,18 +688,13 @@ func (s *session) handleCloseError(closeErr closeError) error { func (s *session) processTransportParameters(params *handshake.TransportParameters) { s.peerParams = params - s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams) + s.streamsMap.UpdateLimits(params) if params.OmitConnectionID { s.packer.SetOmitConnectionID() } s.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) - // increase the flow control windows of all streams by sending them a fake MAX_STREAM_DATA frame - s.streamsMap.Range(func(str streamI) { - str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ - StreamID: str.StreamID(), - ByteOffset: params.StreamFlowControlWindow, - }) - }) + // the crypto stream is the only open stream at this moment + // so we don't need to update stream flow control windows } func (s *session) sendPacket() error { diff --git a/session_test.go b/session_test.go index 82b2d88a..da4b1b55 100644 --- a/session_test.go +++ b/session_test.go @@ -201,7 +201,7 @@ var _ = Describe("Session", func() { Context("handling STREAM frames", func() { BeforeEach(func() { - sess.streamsMap.UpdateMaxStreamLimit(100) + sess.streamsMap.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) }) It("makes new streams", func() { @@ -494,9 +494,9 @@ var _ = Describe("Session", func() { }() _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - sess.streamsMap.Range(func(s streamI) { + for _, s := range sess.streamsMap.streams { s.(*MockStreamI).EXPECT().closeForShutdown(gomock.Any()) - }) + } err = sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) Eventually(sess.Context().Done()).Should(BeClosed()) diff --git a/streams_map.go b/streams_map.go index 956b725d..40556d67 100644 --- a/streams_map.go +++ b/streams_map.go @@ -5,8 +5,10 @@ import ( "fmt" "sync" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -225,16 +227,6 @@ func (m *streamsMap) DeleteStream(id protocol.StreamID) error { return nil } -// Range executes a callback for all streams, in pseudo-random order -func (m *streamsMap) Range(cb func(s streamI)) { - m.mutex.RLock() - defer m.mutex.RUnlock() - - for _, s := range m.streams { - cb(s) - } -} - func (m *streamsMap) putStream(s streamI) error { id := s.StreamID() if _, ok := m.streams[id]; ok { @@ -255,9 +247,16 @@ func (m *streamsMap) CloseWithError(err error) { } } -func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) { +// TODO(#952): this won't be needed when gQUIC supports stateless handshakes +func (m *streamsMap) UpdateLimits(params *handshake.TransportParameters) { m.mutex.Lock() - defer m.mutex.Unlock() - m.maxOutgoingStreams = limit + m.maxOutgoingStreams = params.MaxStreams + for id, str := range m.streams { + str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: id, + ByteOffset: params.StreamFlowControlWindow, + }) + } + m.mutex.Unlock() m.openStreamOrErrCond.Broadcast() } diff --git a/streams_map_test.go b/streams_map_test.go index 6bb7e71c..0c90d627 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -2,9 +2,10 @@ package quic import ( "errors" - "sort" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -126,7 +127,7 @@ var _ = Describe("Streams Map", func() { }) It("opens a stream 2 first", func() { - m.UpdateMaxStreamLimit(100) + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) s, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) @@ -143,7 +144,7 @@ var _ = Describe("Streams Map", func() { }) It("doesn't reopen an already closed stream", func() { - m.UpdateMaxStreamLimit(100) + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) @@ -158,7 +159,8 @@ var _ = Describe("Streams Map", func() { const maxOutgoingStreams = 50 BeforeEach(func() { - m.UpdateMaxStreamLimit(maxOutgoingStreams) + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: maxOutgoingStreams}) + }) It("errors when too many streams are opened", func() { @@ -194,7 +196,7 @@ var _ = Describe("Streams Map", func() { const maxOutgoingStreams = 10 BeforeEach(func() { - m.UpdateMaxStreamLimit(maxOutgoingStreams) + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: maxOutgoingStreams}) }) openMaxNumStreams := func() { @@ -398,7 +400,7 @@ var _ = Describe("Streams Map", func() { Context("as a client", func() { BeforeEach(func() { setNewStreamsMap(protocol.PerspectiveClient, versionGQUICFrames) - m.UpdateMaxStreamLimit(100) + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) }) Context("server-side streams", func() { @@ -448,7 +450,7 @@ var _ = Describe("Streams Map", func() { Context("client-side streams", func() { It("starts with stream 1, if the crypto stream is stream 0", func() { setNewStreamsMap(protocol.PerspectiveClient, versionIETFFrames) - m.UpdateMaxStreamLimit(100) + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) s, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) @@ -503,32 +505,6 @@ var _ = Describe("Streams Map", func() { }) }) - Context("Ranging", func() { - It("ranges over all open streams", func() { - setNewStreamsMap(protocol.PerspectiveServer, protocol.VersionWhatever) - var callbackCalledForStream []protocol.StreamID - callback := func(str streamI) { - callbackCalledForStream = append(callbackCalledForStream, str.StreamID()) - sort.Slice(callbackCalledForStream, func(i, j int) bool { - return callbackCalledForStream[i] < callbackCalledForStream[j] - }) - } - - Expect(m.streams).To(BeEmpty()) - // create 5 streams, ids 4 to 8 - callbackCalledForStream = callbackCalledForStream[:0] - for i := 4; i <= 8; i++ { - str := NewMockStreamI(mockCtrl) - str.EXPECT().StreamID().Return(protocol.StreamID(i)).AnyTimes() - err := m.putStream(str) - Expect(err).NotTo(HaveOccurred()) - } - // execute the callback for all streams - m.Range(callback) - Expect(callbackCalledForStream).To(Equal([]protocol.StreamID{4, 5, 6, 7, 8})) - }) - }) - Context("deleting streams", func() { BeforeEach(func() { setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames) @@ -546,7 +522,7 @@ var _ = Describe("Streams Map", func() { }) It("deletes an outgoing stream", func() { - m.UpdateMaxStreamLimit(10000) + m.UpdateLimits(&handshake.TransportParameters{MaxStreams: 10000}) _, err := m.OpenStream() // open stream 2 Expect(err).ToNot(HaveOccurred()) _, err = m.OpenStream() @@ -562,4 +538,19 @@ var _ = Describe("Streams Map", func() { Expect(err).To(MatchError(errMapAccess)) }) }) + + It("sets the flow control limit", func() { + setNewStreamsMap(protocol.PerspectiveServer, versionGQUICFrames) + _, err := m.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + m.streams[3].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 3, + ByteOffset: 321, + }) + m.streams[5].(*MockStreamI).EXPECT().handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ + StreamID: 5, + ByteOffset: 321, + }) + m.UpdateLimits(&handshake.TransportParameters{StreamFlowControlWindow: 321}) + }) })