move the transport parameter stream limit check to the parser

This commit is contained in:
Marten Seemann 2020-12-18 13:00:33 +07:00
parent 6217ec55fd
commit c7d3aeea3d
6 changed files with 50 additions and 38 deletions

View file

@ -15,9 +15,17 @@ import (
)
var _ = Describe("Transport Parameters", func() {
getRandomValue := func() uint64 {
getRandomValueUpTo := func(max int64) uint64 {
maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4}
return uint64(rand.Int63n(maxVals[int(rand.Int31n(4))]))
m := maxVals[int(rand.Int31n(4))]
if m > max {
m = max
}
return uint64(rand.Int63n(m))
}
getRandomValue := func() uint64 {
return getRandomValueUpTo(math.MaxInt64)
}
BeforeEach(func() {
@ -79,8 +87,8 @@ var _ = Describe("Transport Parameters", func() {
InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()),
InitialMaxData: protocol.ByteCount(getRandomValue()),
MaxIdleTimeout: 0xcafe * time.Second,
MaxBidiStreamNum: protocol.StreamNum(getRandomValue()),
MaxUniStreamNum: protocol.StreamNum(getRandomValue()),
MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))),
MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))),
DisableActiveMigration: true,
StatelessResetToken: &token,
OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
@ -252,6 +260,28 @@ var _ = Describe("Transport Parameters", func() {
Expect(err.Error()).To(ContainSubstring("TRANSPORT_PARAMETER_ERROR: inconsistent transport parameter length"))
})
It("errors if initial_max_streams_bidi is too large", func() {
b := &bytes.Buffer{}
utils.WriteVarInt(b, uint64(initialMaxStreamsBidiParameterID))
utils.WriteVarInt(b, uint64(utils.VarIntLen(uint64(protocol.MaxStreamCount+1))))
utils.WriteVarInt(b, uint64(protocol.MaxStreamCount+1))
addInitialSourceConnectionID(b)
err := (&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("TRANSPORT_PARAMETER_ERROR: initial_max_streams_bidi too large: 1152921504606846977 (maximum 1152921504606846976)"))
})
It("errors if initial_max_streams_uni is too large", func() {
b := &bytes.Buffer{}
utils.WriteVarInt(b, uint64(initialMaxStreamsUniParameterID))
utils.WriteVarInt(b, uint64(utils.VarIntLen(uint64(protocol.MaxStreamCount+1))))
utils.WriteVarInt(b, uint64(protocol.MaxStreamCount+1))
addInitialSourceConnectionID(b)
err := (&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("TRANSPORT_PARAMETER_ERROR: initial_max_streams_uni too large: 1152921504606846977 (maximum 1152921504606846976)"))
})
It("handles huge max_ack_delay values", func() {
b := &bytes.Buffer{}
val := uint64(math.MaxUint64) / 5
@ -416,8 +446,8 @@ var _ = Describe("Transport Parameters", func() {
InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()),
InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()),
InitialMaxData: protocol.ByteCount(getRandomValue()),
MaxBidiStreamNum: protocol.StreamNum(getRandomValue()),
MaxUniStreamNum: protocol.StreamNum(getRandomValue()),
MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))),
MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))),
ActiveConnectionIDLimit: getRandomValue(),
}
Expect(params.ValidFor0RTT(params)).To(BeTrue())

View file

@ -279,8 +279,14 @@ func (p *TransportParameters) readNumericTransportParameter(
p.InitialMaxData = protocol.ByteCount(val)
case initialMaxStreamsBidiParameterID:
p.MaxBidiStreamNum = protocol.StreamNum(val)
if p.MaxBidiStreamNum > protocol.MaxStreamCount {
return fmt.Errorf("initial_max_streams_bidi too large: %d (maximum %d)", p.MaxBidiStreamNum, protocol.MaxStreamCount)
}
case initialMaxStreamsUniParameterID:
p.MaxUniStreamNum = protocol.StreamNum(val)
if p.MaxUniStreamNum > protocol.MaxStreamCount {
return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount)
}
case maxIdleTimeoutParameterID:
p.MaxIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond)
case maxUDPPayloadSizeParameterID:

View file

@ -197,11 +197,9 @@ func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *go
}
// UpdateLimits mocks base method
func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) error {
func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateLimits", arg0)
ret0, _ := ret[0].(error)
return ret0
m.ctrl.Call(m, "UpdateLimits", arg0)
}
// UpdateLimits indicates an expected call of UpdateLimits

View file

@ -42,7 +42,7 @@ type streamManager interface {
AcceptStream(context.Context) (Stream, error)
AcceptUniStream(context.Context) (ReceiveStream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*wire.TransportParameters) error
UpdateLimits(*wire.TransportParameters)
HandleMaxStreamsFrame(*wire.MaxStreamsFrame) error
CloseWithError(error)
}
@ -1389,10 +1389,7 @@ func (s *session) restoreTransportParameters(params *wire.TransportParameters) {
s.peerParams = params
s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit)
s.connFlowController.UpdateSendWindow(params.InitialMaxData)
if err := s.streamsMap.UpdateLimits(params); err != nil {
s.closeLocal(err)
return
}
s.streamsMap.UpdateLimits(params)
}
func (s *session) processTransportParameters(params *wire.TransportParameters) {
@ -1435,9 +1432,7 @@ func (s *session) processTransportParametersImpl(params *wire.TransportParameter
// Our local idle timeout will always be > 0.
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
s.keepAliveInterval = utils.MinDuration(s.idleTimeout/2, protocol.MaxKeepAliveInterval)
if err := s.streamsMap.UpdateLimits(params); err != nil {
return err
}
s.streamsMap.UpdateLimits(params)
s.packer.HandleTransportParameters(params)
s.frameParser.SetAckDelayExponent(params.AckDelayExponent)
s.connFlowController.UpdateSendWindow(params.InitialMaxData)

View file

@ -222,13 +222,9 @@ func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error {
return nil
}
func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) error {
if p.MaxBidiStreamNum > protocol.MaxStreamCount || p.MaxUniStreamNum > protocol.MaxStreamCount {
return qerr.StreamLimitError
}
func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
return nil
}
func (m *streamsMap) CloseWithError(err error) {

View file

@ -10,7 +10,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
@ -330,10 +329,10 @@ var _ = Describe("Streams Map", func() {
m.perspective = pers
_, err := m.OpenStream()
expectTooManyStreamsError(err)
Expect(m.UpdateLimits(&wire.TransportParameters{
m.UpdateLimits(&wire.TransportParameters{
MaxBidiStreamNum: 5,
MaxUniStreamNum: 8,
})).To(Succeed())
})
mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2)
// test we can only 5 bidirectional streams
@ -354,18 +353,6 @@ var _ = Describe("Streams Map", func() {
expectTooManyStreamsError(err)
})
}
It("rejects parameters with too large unidirectional stream counts", func() {
Expect(m.UpdateLimits(&wire.TransportParameters{
MaxUniStreamNum: protocol.MaxStreamCount + 1,
})).To(MatchError(qerr.StreamLimitError))
})
It("rejects parameters with too large unidirectional stream counts", func() {
Expect(m.UpdateLimits(&wire.TransportParameters{
MaxBidiStreamNum: protocol.MaxStreamCount + 1,
})).To(MatchError(qerr.StreamLimitError))
})
})
Context("handling MAX_STREAMS frames", func() {