From c7d3aeea3d7e49fdb34e97365f60e04e7668c60f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 18 Dec 2020 13:00:33 +0700 Subject: [PATCH] move the transport parameter stream limit check to the parser --- internal/wire/transport_parameter_test.go | 42 +++++++++++++++++++---- internal/wire/transport_parameters.go | 6 ++++ mock_stream_manager_test.go | 6 ++-- session.go | 11 ++---- streams_map.go | 6 +--- streams_map_test.go | 17 ++------- 6 files changed, 50 insertions(+), 38 deletions(-) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 45929717..fbfdee24 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -15,9 +15,17 @@ import ( ) var _ = Describe("Transport Parameters", func() { - getRandomValue := func() uint64 { + getRandomValueUpTo := func(max int64) uint64 { maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} - return uint64(rand.Int63n(maxVals[int(rand.Int31n(4))])) + m := maxVals[int(rand.Int31n(4))] + if m > max { + m = max + } + return uint64(rand.Int63n(m)) + } + + getRandomValue := func() uint64 { + return getRandomValueUpTo(math.MaxInt64) } BeforeEach(func() { @@ -79,8 +87,8 @@ var _ = Describe("Transport Parameters", func() { InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), InitialMaxData: protocol.ByteCount(getRandomValue()), MaxIdleTimeout: 0xcafe * time.Second, - MaxBidiStreamNum: protocol.StreamNum(getRandomValue()), - MaxUniStreamNum: protocol.StreamNum(getRandomValue()), + MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), DisableActiveMigration: true, StatelessResetToken: &token, OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, @@ -252,6 +260,28 @@ var _ = Describe("Transport Parameters", func() { Expect(err.Error()).To(ContainSubstring("TRANSPORT_PARAMETER_ERROR: inconsistent transport parameter length")) }) + It("errors if initial_max_streams_bidi is too large", func() { + b := &bytes.Buffer{} + utils.WriteVarInt(b, uint64(initialMaxStreamsBidiParameterID)) + utils.WriteVarInt(b, uint64(utils.VarIntLen(uint64(protocol.MaxStreamCount+1)))) + utils.WriteVarInt(b, uint64(protocol.MaxStreamCount+1)) + addInitialSourceConnectionID(b) + err := (&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("TRANSPORT_PARAMETER_ERROR: initial_max_streams_bidi too large: 1152921504606846977 (maximum 1152921504606846976)")) + }) + + It("errors if initial_max_streams_uni is too large", func() { + b := &bytes.Buffer{} + utils.WriteVarInt(b, uint64(initialMaxStreamsUniParameterID)) + utils.WriteVarInt(b, uint64(utils.VarIntLen(uint64(protocol.MaxStreamCount+1)))) + utils.WriteVarInt(b, uint64(protocol.MaxStreamCount+1)) + addInitialSourceConnectionID(b) + err := (&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("TRANSPORT_PARAMETER_ERROR: initial_max_streams_uni too large: 1152921504606846977 (maximum 1152921504606846976)")) + }) + It("handles huge max_ack_delay values", func() { b := &bytes.Buffer{} val := uint64(math.MaxUint64) / 5 @@ -416,8 +446,8 @@ var _ = Describe("Transport Parameters", func() { InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), InitialMaxData: protocol.ByteCount(getRandomValue()), - MaxBidiStreamNum: protocol.StreamNum(getRandomValue()), - MaxUniStreamNum: protocol.StreamNum(getRandomValue()), + MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), ActiveConnectionIDLimit: getRandomValue(), } Expect(params.ValidFor0RTT(params)).To(BeTrue()) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index fbb10d6c..5775f7f2 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -279,8 +279,14 @@ func (p *TransportParameters) readNumericTransportParameter( p.InitialMaxData = protocol.ByteCount(val) case initialMaxStreamsBidiParameterID: p.MaxBidiStreamNum = protocol.StreamNum(val) + if p.MaxBidiStreamNum > protocol.MaxStreamCount { + return fmt.Errorf("initial_max_streams_bidi too large: %d (maximum %d)", p.MaxBidiStreamNum, protocol.MaxStreamCount) + } case initialMaxStreamsUniParameterID: p.MaxUniStreamNum = protocol.StreamNum(val) + if p.MaxUniStreamNum > protocol.MaxStreamCount { + return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) + } case maxIdleTimeoutParameterID: p.MaxIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) case maxUDPPayloadSizeParameterID: diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 238a945b..cfe5b2be 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -197,11 +197,9 @@ func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *go } // UpdateLimits mocks base method -func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) error { +func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateLimits", arg0) - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.Call(m, "UpdateLimits", arg0) } // UpdateLimits indicates an expected call of UpdateLimits diff --git a/session.go b/session.go index a1994b1a..f400e76d 100644 --- a/session.go +++ b/session.go @@ -42,7 +42,7 @@ type streamManager interface { AcceptStream(context.Context) (Stream, error) AcceptUniStream(context.Context) (ReceiveStream, error) DeleteStream(protocol.StreamID) error - UpdateLimits(*wire.TransportParameters) error + UpdateLimits(*wire.TransportParameters) HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error CloseWithError(error) } @@ -1389,10 +1389,7 @@ func (s *session) restoreTransportParameters(params *wire.TransportParameters) { s.peerParams = params s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) s.connFlowController.UpdateSendWindow(params.InitialMaxData) - if err := s.streamsMap.UpdateLimits(params); err != nil { - s.closeLocal(err) - return - } + s.streamsMap.UpdateLimits(params) } func (s *session) processTransportParameters(params *wire.TransportParameters) { @@ -1435,9 +1432,7 @@ func (s *session) processTransportParametersImpl(params *wire.TransportParameter // Our local idle timeout will always be > 0. s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) s.keepAliveInterval = utils.MinDuration(s.idleTimeout/2, protocol.MaxKeepAliveInterval) - if err := s.streamsMap.UpdateLimits(params); err != nil { - return err - } + s.streamsMap.UpdateLimits(params) s.packer.HandleTransportParameters(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) diff --git a/streams_map.go b/streams_map.go index 0abfa9ae..72f68adc 100644 --- a/streams_map.go +++ b/streams_map.go @@ -222,13 +222,9 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error { return nil } -func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) error { - if p.MaxBidiStreamNum > protocol.MaxStreamCount || p.MaxUniStreamNum > protocol.MaxStreamCount { - return qerr.StreamLimitError - } +func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) - return nil } func (m *streamsMap) CloseWithError(err error) { diff --git a/streams_map_test.go b/streams_map_test.go index e716319b..27f4703a 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -10,7 +10,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -330,10 +329,10 @@ var _ = Describe("Streams Map", func() { m.perspective = pers _, err := m.OpenStream() expectTooManyStreamsError(err) - Expect(m.UpdateLimits(&wire.TransportParameters{ + m.UpdateLimits(&wire.TransportParameters{ MaxBidiStreamNum: 5, MaxUniStreamNum: 8, - })).To(Succeed()) + }) mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2) // test we can only 5 bidirectional streams @@ -354,18 +353,6 @@ var _ = Describe("Streams Map", func() { expectTooManyStreamsError(err) }) } - - It("rejects parameters with too large unidirectional stream counts", func() { - Expect(m.UpdateLimits(&wire.TransportParameters{ - MaxUniStreamNum: protocol.MaxStreamCount + 1, - })).To(MatchError(qerr.StreamLimitError)) - }) - - It("rejects parameters with too large unidirectional stream counts", func() { - Expect(m.UpdateLimits(&wire.TransportParameters{ - MaxBidiStreamNum: protocol.MaxStreamCount + 1, - })).To(MatchError(qerr.StreamLimitError)) - }) }) Context("handling MAX_STREAMS frames", func() {