From c95f2054a874ef33cc882f34018b746c40bd3783 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 25 Sep 2017 19:50:26 +0700 Subject: [PATCH] rename the ConnectionParametersManager to ParamsNegotiator --- internal/flowcontrol/flow_control_manager.go | 12 +- .../flowcontrol/flow_control_manager_test.go | 12 +- internal/flowcontrol/flow_controller.go | 20 +-- internal/flowcontrol/flow_controller_test.go | 46 +++--- internal/handshake/crypto_setup_client.go | 14 +- .../handshake/crypto_setup_client_test.go | 8 +- internal/handshake/crypto_setup_server.go | 39 +++-- internal/handshake/crypto_setup_tls.go | 4 +- ...s_manager.go => params_negotiator_base.go} | 36 ++-- ..._manager.go => params_negotiator_gquic.go} | 18 +- ...est.go => params_negotiator_gquic_test.go} | 134 +++++++-------- internal/mocks/cpm.go | 155 ------------------ internal/mocks/gen.go | 2 +- internal/mocks/params_negotiator.go | 155 ++++++++++++++++++ packet_packer.go | 8 +- packet_packer_test.go | 6 +- session.go | 18 +- session_test.go | 70 ++++---- streams_map.go | 20 +-- streams_map_test.go | 12 +- 20 files changed, 393 insertions(+), 396 deletions(-) rename internal/handshake/{base_connection_parameters_manager.go => params_negotiator_base.go} (74%) rename internal/handshake/{gquic_connection_parameters_manager.go => params_negotiator_gquic.go} (85%) rename internal/handshake/{gquic_connection_parameters_manager_test.go => params_negotiator_gquic_test.go} (65%) delete mode 100644 internal/mocks/cpm.go create mode 100644 internal/mocks/params_negotiator.go diff --git a/internal/flowcontrol/flow_control_manager.go b/internal/flowcontrol/flow_control_manager.go index f974a849..e1f4b507 100644 --- a/internal/flowcontrol/flow_control_manager.go +++ b/internal/flowcontrol/flow_control_manager.go @@ -13,8 +13,8 @@ import ( ) type flowControlManager struct { - connectionParameters handshake.ConnectionParametersManager - rttStats *congestion.RTTStats + connParams handshake.ParamsNegotiator + rttStats *congestion.RTTStats streamFlowController map[protocol.StreamID]*flowController connFlowController *flowController @@ -26,12 +26,12 @@ var _ FlowControlManager = &flowControlManager{} var errMapAccess = errors.New("Error accessing the flowController map.") // NewFlowControlManager creates a new flow control manager -func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager { +func NewFlowControlManager(connParams handshake.ParamsNegotiator, rttStats *congestion.RTTStats) FlowControlManager { return &flowControlManager{ - connectionParameters: connectionParameters, + connParams: connParams, rttStats: rttStats, streamFlowController: make(map[protocol.StreamID]*flowController), - connFlowController: newFlowController(0, false, connectionParameters, rttStats), + connFlowController: newFlowController(0, false, connParams, rttStats), } } @@ -45,7 +45,7 @@ func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesTo return } - f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats) + f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connParams, f.rttStats) } // RemoveStream removes a closed stream from flow control diff --git a/internal/flowcontrol/flow_control_manager_test.go b/internal/flowcontrol/flow_control_manager_test.go index e8cfa583..02677a2a 100644 --- a/internal/flowcontrol/flow_control_manager_test.go +++ b/internal/flowcontrol/flow_control_manager_test.go @@ -15,12 +15,12 @@ var _ = Describe("Flow Control Manager", func() { var fcm *flowControlManager BeforeEach(func() { - mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) - mockCpm.EXPECT().GetReceiveStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(100)) - mockCpm.EXPECT().GetReceiveConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(200)) - mockCpm.EXPECT().GetMaxReceiveStreamFlowControlWindow().AnyTimes().Return(protocol.MaxByteCount) - mockCpm.EXPECT().GetMaxReceiveConnectionFlowControlWindow().AnyTimes().Return(protocol.MaxByteCount) - fcm = NewFlowControlManager(mockCpm, &congestion.RTTStats{}).(*flowControlManager) + mockPn := mocks.NewMockParamsNegotiator(mockCtrl) + mockPn.EXPECT().GetReceiveStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(100)) + mockPn.EXPECT().GetReceiveConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(200)) + mockPn.EXPECT().GetMaxReceiveStreamFlowControlWindow().AnyTimes().Return(protocol.MaxByteCount) + mockPn.EXPECT().GetMaxReceiveConnectionFlowControlWindow().AnyTimes().Return(protocol.MaxByteCount) + fcm = NewFlowControlManager(mockPn, &congestion.RTTStats{}).(*flowControlManager) }) It("creates a connection level flow controller", func() { diff --git a/internal/flowcontrol/flow_controller.go b/internal/flowcontrol/flow_controller.go index 00c0b4f3..11903cd1 100644 --- a/internal/flowcontrol/flow_controller.go +++ b/internal/flowcontrol/flow_controller.go @@ -14,8 +14,8 @@ type flowController struct { streamID protocol.StreamID contributesToConnection bool // does the stream contribute to connection level flow control - connectionParameters handshake.ConnectionParametersManager - rttStats *congestion.RTTStats + connParams handshake.ParamsNegotiator + rttStats *congestion.RTTStats bytesSent protocol.ByteCount sendWindow protocol.ByteCount @@ -33,22 +33,22 @@ type flowController struct { var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset") // newFlowController gets a new flow controller -func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController { +func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connParams handshake.ParamsNegotiator, rttStats *congestion.RTTStats) *flowController { fc := flowController{ streamID: streamID, contributesToConnection: contributesToConnection, - connectionParameters: connectionParameters, + connParams: connParams, rttStats: rttStats, } if streamID == 0 { - fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow() + fc.receiveWindow = connParams.GetReceiveConnectionFlowControlWindow() fc.receiveWindowIncrement = fc.receiveWindow - fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow() + fc.maxReceiveWindowIncrement = connParams.GetMaxReceiveConnectionFlowControlWindow() } else { - fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow() + fc.receiveWindow = connParams.GetReceiveStreamFlowControlWindow() fc.receiveWindowIncrement = fc.receiveWindow - fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow() + fc.maxReceiveWindowIncrement = connParams.GetMaxReceiveStreamFlowControlWindow() } return &fc @@ -61,9 +61,9 @@ func (c *flowController) ContributesToConnection() bool { func (c *flowController) getSendWindow() protocol.ByteCount { if c.sendWindow == 0 { if c.streamID == 0 { - return c.connectionParameters.GetSendConnectionFlowControlWindow() + return c.connParams.GetSendConnectionFlowControlWindow() } - return c.connectionParameters.GetSendStreamFlowControlWindow() + return c.connParams.GetSendStreamFlowControlWindow() } return c.sendWindow } diff --git a/internal/flowcontrol/flow_controller_test.go b/internal/flowcontrol/flow_controller_test.go index 3d324ae9..de38b481 100644 --- a/internal/flowcontrol/flow_controller_test.go +++ b/internal/flowcontrol/flow_controller_test.go @@ -20,57 +20,57 @@ var _ = Describe("Flow controller", func() { Context("Constructor", func() { var rttStats *congestion.RTTStats - var mockCpm *mocks.MockConnectionParametersManager + var mockPn *mocks.MockParamsNegotiator BeforeEach(func() { - mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) - mockCpm.EXPECT().GetSendStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(1000)) - mockCpm.EXPECT().GetReceiveStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(2000)) - mockCpm.EXPECT().GetSendConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(3000)) - mockCpm.EXPECT().GetReceiveConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(4000)) - mockCpm.EXPECT().GetMaxReceiveStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(8000)) - mockCpm.EXPECT().GetMaxReceiveConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(9000)) + mockPn = mocks.NewMockParamsNegotiator(mockCtrl) + mockPn.EXPECT().GetSendStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(1000)) + mockPn.EXPECT().GetReceiveStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(2000)) + mockPn.EXPECT().GetSendConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(3000)) + mockPn.EXPECT().GetReceiveConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(4000)) + mockPn.EXPECT().GetMaxReceiveStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(8000)) + mockPn.EXPECT().GetMaxReceiveConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(9000)) rttStats = &congestion.RTTStats{} }) It("reads the stream send and receive windows when acting as stream-level flow controller", func() { - fc := newFlowController(5, true, mockCpm, rttStats) + fc := newFlowController(5, true, mockPn, rttStats) Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.receiveWindow).To(Equal(protocol.ByteCount(2000))) - Expect(fc.maxReceiveWindowIncrement).To(Equal(mockCpm.GetMaxReceiveStreamFlowControlWindow())) + Expect(fc.maxReceiveWindowIncrement).To(Equal(mockPn.GetMaxReceiveStreamFlowControlWindow())) }) It("reads the stream send and receive windows when acting as connection-level flow controller", func() { - fc := newFlowController(0, false, mockCpm, rttStats) + fc := newFlowController(0, false, mockPn, rttStats) Expect(fc.streamID).To(Equal(protocol.StreamID(0))) Expect(fc.receiveWindow).To(Equal(protocol.ByteCount(4000))) - Expect(fc.maxReceiveWindowIncrement).To(Equal(mockCpm.GetMaxReceiveConnectionFlowControlWindow())) + Expect(fc.maxReceiveWindowIncrement).To(Equal(mockPn.GetMaxReceiveConnectionFlowControlWindow())) }) It("does not set the stream flow control windows for sending", func() { - fc := newFlowController(5, true, mockCpm, rttStats) + fc := newFlowController(5, true, mockPn, rttStats) Expect(fc.sendWindow).To(BeZero()) }) It("does not set the connection flow control windows for sending", func() { - fc := newFlowController(0, false, mockCpm, rttStats) + fc := newFlowController(0, false, mockPn, rttStats) Expect(fc.sendWindow).To(BeZero()) }) It("says if it contributes to connection-level flow control", func() { - fc := newFlowController(1, false, mockCpm, rttStats) + fc := newFlowController(1, false, mockPn, rttStats) Expect(fc.ContributesToConnection()).To(BeFalse()) - fc = newFlowController(5, true, mockCpm, rttStats) + fc = newFlowController(5, true, mockPn, rttStats) Expect(fc.ContributesToConnection()).To(BeTrue()) }) }) Context("send flow control", func() { - var mockCpm *mocks.MockConnectionParametersManager + var mockPn *mocks.MockParamsNegotiator BeforeEach(func() { - mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) - controller.connectionParameters = mockCpm + mockPn = mocks.NewMockParamsNegotiator(mockCtrl) + controller.connParams = mockPn }) It("adds bytes sent", func() { @@ -110,10 +110,10 @@ var _ = Describe("Flow controller", func() { It("asks the ConnectionParametersManager for the stream flow control window size", func() { controller.streamID = 5 - mockCpm.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(1000)) + mockPn.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(1000)) Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(1000))) // make sure the value is not cached - mockCpm.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(2000)) + mockPn.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(2000)) Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(2000))) }) @@ -125,10 +125,10 @@ var _ = Describe("Flow controller", func() { It("asks the ConnectionParametersManager for the connection flow control window size", func() { controller.streamID = 0 - mockCpm.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(3000)) + mockPn.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(3000)) Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(3000))) // make sure the value is not cached - mockCpm.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(5000)) + mockPn.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(5000)) Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(5000))) }) diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index dfb7f897..b1f37bc9 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -52,7 +52,7 @@ type cryptoSetupClient struct { aeadChanged chan<- protocol.EncryptionLevel requestConnIDTruncation bool - connectionParameters *gquicConnectionParametersManager + params *paramsNegotiatorGQUIC } var _ CryptoSetup = &cryptoSetupClient{} @@ -72,14 +72,14 @@ func NewCryptoSetupClient( params *TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, negotiatedVersions []protocol.VersionNumber, -) (CryptoSetup, ConnectionParametersManager, error) { - cpm := newGQUICConnectionParamatersManager(protocol.PerspectiveClient, version, params) +) (CryptoSetup, ParamsNegotiator, error) { + pn := newParamsNegotiatorGQUIC(protocol.PerspectiveClient, version, params) return &cryptoSetupClient{ hostname: hostname, connID: connID, version: version, certManager: crypto.NewCertManager(tlsConfig), - connectionParameters: cpm, + params: pn, requestConnIDTruncation: params.RequestConnectionIDTruncation, keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyExchange: getEphermalKEX, @@ -87,7 +87,7 @@ func NewCryptoSetupClient( aeadChanged: aeadChanged, negotiatedVersions: negotiatedVersions, divNonceChan: make(chan []byte), - }, cpm, nil + }, pn, nil } func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error { @@ -264,7 +264,7 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { return err } - err = h.connectionParameters.SetFromMap(cryptoData) + err = h.params.SetFromMap(cryptoData) if err != nil { return qerr.InvalidCryptoMessageParameter } @@ -405,7 +405,7 @@ func (h *cryptoSetupClient) sendCHLO() error { } func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { - tags, err := h.connectionParameters.GetHelloMap() + tags, err := h.params.GetHelloMap() if err != nil { return nil, err } diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index dc54b74f..75501b86 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -440,7 +440,7 @@ var _ = Describe("Client Crypto Setup", func() { shloMap[TagICSL] = []byte{3, 0, 0, 0} // 3 seconds err := cs.handleSHLOMessage(shloMap) Expect(err).ToNot(HaveOccurred()) - Expect(cs.connectionParameters.GetIdleConnectionStateLifetime()).To(Equal(3 * time.Second)) + Expect(cs.params.GetIdleConnectionStateLifetime()).To(Equal(3 * time.Second)) }) It("errors if it can't read a connection parameter", func() { @@ -499,12 +499,12 @@ var _ = Describe("Client Crypto Setup", func() { }) It("adds the tags returned from the connectionParametersManager to the CHLO", func() { - cpmTags, err := cs.connectionParameters.GetHelloMap() + pnTags, err := cs.params.GetHelloMap() Expect(err).ToNot(HaveOccurred()) - Expect(cpmTags).ToNot(BeEmpty()) + Expect(pnTags).ToNot(BeEmpty()) tags, err := cs.getTags() Expect(err).ToNot(HaveOccurred()) - for t := range cpmTags { + for t := range pnTags { Expect(tags).To(HaveKey(t)) } }) diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index c3238052..f2dfa0b9 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -47,7 +47,7 @@ type cryptoSetupServer struct { cryptoStream io.ReadWriter - connectionParameters *gquicConnectionParametersManager + params *paramsNegotiatorGQUIC mutex sync.RWMutex } @@ -73,28 +73,28 @@ func NewCryptoSetup( supportedVersions []protocol.VersionNumber, acceptSTK func(net.Addr, *Cookie) bool, aeadChanged chan<- protocol.EncryptionLevel, -) (CryptoSetup, ConnectionParametersManager, error) { +) (CryptoSetup, ParamsNegotiator, error) { stkGenerator, err := NewCookieGenerator() if err != nil { return nil, nil, err } - cpm := newGQUICConnectionParamatersManager(protocol.PerspectiveServer, version, params) + pn := newParamsNegotiatorGQUIC(protocol.PerspectiveServer, version, params) return &cryptoSetupServer{ - connID: connID, - remoteAddr: remoteAddr, - version: version, - supportedVersions: supportedVersions, - scfg: scfg, - stkGenerator: stkGenerator, - keyDerivation: crypto.DeriveQuicCryptoAESKeys, - keyExchange: getEphermalKEX, - nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), - connectionParameters: cpm, - acceptSTKCallback: acceptSTK, - sentSHLO: make(chan struct{}), - aeadChanged: aeadChanged, - }, cpm, nil + connID: connID, + remoteAddr: remoteAddr, + version: version, + supportedVersions: supportedVersions, + scfg: scfg, + stkGenerator: stkGenerator, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + keyExchange: getEphermalKEX, + nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), + params: pn, + acceptSTKCallback: acceptSTK, + sentSHLO: make(chan struct{}), + aeadChanged: aeadChanged, + }, pn, nil } // HandleCryptoStream reads and writes messages on the crypto stream @@ -418,12 +418,11 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T return nil, err } - err = h.connectionParameters.SetFromMap(cryptoData) - if err != nil { + if err := h.params.SetFromMap(cryptoData); err != nil { return nil, err } - replyMap, err := h.connectionParameters.GetHelloMap() + replyMap, err := h.params.GetHelloMap() if err != nil { return nil, err } diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 0601443b..0471d38b 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -40,7 +40,7 @@ func NewCryptoSetupTLS( version protocol.VersionNumber, tlsConfig *tls.Config, aeadChanged chan<- protocol.EncryptionLevel, -) (CryptoSetup, ConnectionParametersManager, error) { +) (CryptoSetup, ParamsNegotiator, error) { mintConf, err := tlsToMintConfig(tlsConfig, perspective) if err != nil { return nil, nil, err @@ -54,7 +54,7 @@ func NewCryptoSetupTLS( nullAEAD: crypto.NewNullAEAD(perspective, version), keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged, - }, newGQUICConnectionParamatersManager(perspective, version, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}), nil + }, newParamsNegotiatorGQUIC(perspective, version, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}), nil } func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error { diff --git a/internal/handshake/base_connection_parameters_manager.go b/internal/handshake/params_negotiator_base.go similarity index 74% rename from internal/handshake/base_connection_parameters_manager.go rename to internal/handshake/params_negotiator_base.go index fc970e54..425a5f79 100644 --- a/internal/handshake/base_connection_parameters_manager.go +++ b/internal/handshake/params_negotiator_base.go @@ -8,9 +8,9 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -// ConnectionParametersManager negotiates and stores the connection parameters. -// A ConnectionParametersManager can be used for a server as well as a client. -type ConnectionParametersManager interface { +// The ParamsNegotiator negotiates and stores the connection parameters. +// It can be used for a server as well as a client. +type ParamsNegotiator interface { GetSendStreamFlowControlWindow() protocol.ByteCount GetSendConnectionFlowControlWindow() protocol.ByteCount GetReceiveStreamFlowControlWindow() protocol.ByteCount @@ -31,7 +31,7 @@ type ConnectionParametersManager interface { // For the client: // 1. call GetHelloMap to get the values to send in a CHLO // 2. call SetFromMap with the values received in the SHLO -type baseConnectionParametersManager struct { +type paramsNegotiatorBase struct { mutex sync.RWMutex version protocol.VersionNumber @@ -51,7 +51,7 @@ type baseConnectionParametersManager struct { maxReceiveConnectionFlowControlWindow protocol.ByteCount } -func (h *baseConnectionParametersManager) init(params *TransportParameters) { +func (h *paramsNegotiatorBase) init(params *TransportParameters) { h.sendStreamFlowControlWindow = protocol.InitialStreamFlowControlWindow // can only be changed by the client h.sendConnectionFlowControlWindow = protocol.InitialConnectionFlowControlWindow // can only be changed by the client h.receiveStreamFlowControlWindow = protocol.ReceiveStreamFlowControlWindow @@ -69,62 +69,62 @@ func (h *baseConnectionParametersManager) init(params *TransportParameters) { } } -func (h *baseConnectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { +func (h *paramsNegotiatorBase) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection) } -func (h *baseConnectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 { +func (h *paramsNegotiatorBase) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 { return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection) } -func (h *baseConnectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration { +func (h *paramsNegotiatorBase) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration { return utils.MinDuration(clientValue, h.idleConnectionStateLifetime) } // GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *baseConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { +func (h *paramsNegotiatorBase) GetSendStreamFlowControlWindow() protocol.ByteCount { h.mutex.RLock() defer h.mutex.RUnlock() return h.sendStreamFlowControlWindow } // GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *baseConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { +func (h *paramsNegotiatorBase) GetSendConnectionFlowControlWindow() protocol.ByteCount { h.mutex.RLock() defer h.mutex.RUnlock() return h.sendConnectionFlowControlWindow } -func (h *baseConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { +func (h *paramsNegotiatorBase) GetReceiveStreamFlowControlWindow() protocol.ByteCount { h.mutex.RLock() defer h.mutex.RUnlock() return h.receiveStreamFlowControlWindow } // GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data -func (h *baseConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { +func (h *paramsNegotiatorBase) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { return h.maxReceiveStreamFlowControlWindow } // GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data -func (h *baseConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { +func (h *paramsNegotiatorBase) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { h.mutex.RLock() defer h.mutex.RUnlock() return h.receiveConnectionFlowControlWindow } -func (h *baseConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { +func (h *paramsNegotiatorBase) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { return h.maxReceiveConnectionFlowControlWindow } -func (h *baseConnectionParametersManager) GetMaxOutgoingStreams() uint32 { +func (h *paramsNegotiatorBase) GetMaxOutgoingStreams() uint32 { h.mutex.RLock() defer h.mutex.RUnlock() return h.maxIncomingDynamicStreamsPerConnection } -func (h *baseConnectionParametersManager) GetMaxIncomingStreams() uint32 { +func (h *paramsNegotiatorBase) GetMaxIncomingStreams() uint32 { h.mutex.RLock() defer h.mutex.RUnlock() @@ -132,13 +132,13 @@ func (h *baseConnectionParametersManager) GetMaxIncomingStreams() uint32 { return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier)) } -func (h *baseConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { +func (h *paramsNegotiatorBase) GetIdleConnectionStateLifetime() time.Duration { h.mutex.RLock() defer h.mutex.RUnlock() return h.idleConnectionStateLifetime } -func (h *baseConnectionParametersManager) TruncateConnectionID() bool { +func (h *paramsNegotiatorBase) TruncateConnectionID() bool { if h.perspective == protocol.PerspectiveClient { return false } diff --git a/internal/handshake/gquic_connection_parameters_manager.go b/internal/handshake/params_negotiator_gquic.go similarity index 85% rename from internal/handshake/gquic_connection_parameters_manager.go rename to internal/handshake/params_negotiator_gquic.go index 68757075..b7d2932e 100644 --- a/internal/handshake/gquic_connection_parameters_manager.go +++ b/internal/handshake/params_negotiator_gquic.go @@ -9,21 +9,21 @@ import ( "github.com/lucas-clemente/quic-go/qerr" ) -var _ ConnectionParametersManager = &baseConnectionParametersManager{} - // errMalformedTag is returned when the tag value cannot be read var ( errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value") errFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported") ) -type gquicConnectionParametersManager struct { - baseConnectionParametersManager +type paramsNegotiatorGQUIC struct { + paramsNegotiatorBase } -// newConnectionParamatersManager creates a new connection parameters manager -func newGQUICConnectionParamatersManager(pers protocol.Perspective, v protocol.VersionNumber, params *TransportParameters) *gquicConnectionParametersManager { - h := &gquicConnectionParametersManager{} +var _ ParamsNegotiator = ¶msNegotiatorGQUIC{} + +// newParamsNegotiatorGQUIC creates a new connection parameters manager +func newParamsNegotiatorGQUIC(pers protocol.Perspective, v protocol.VersionNumber, params *TransportParameters) *paramsNegotiatorGQUIC { + h := ¶msNegotiatorGQUIC{} h.perspective = pers h.version = v h.init(params) @@ -31,7 +31,7 @@ func newGQUICConnectionParamatersManager(pers protocol.Perspective, v protocol.V } // SetFromMap reads all params. -func (h *gquicConnectionParametersManager) SetFromMap(params map[Tag][]byte) error { +func (h *paramsNegotiatorGQUIC) SetFromMap(params map[Tag][]byte) error { h.mutex.Lock() defer h.mutex.Unlock() @@ -94,7 +94,7 @@ func (h *gquicConnectionParametersManager) SetFromMap(params map[Tag][]byte) err } // GetHelloMap gets all parameters needed for the Hello message. -func (h *gquicConnectionParametersManager) GetHelloMap() (map[Tag][]byte, error) { +func (h *paramsNegotiatorGQUIC) GetHelloMap() (map[Tag][]byte, error) { sfcw := bytes.NewBuffer([]byte{}) utils.LittleEndian.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow())) cfcw := bytes.NewBuffer([]byte{}) diff --git a/internal/handshake/gquic_connection_parameters_manager_test.go b/internal/handshake/params_negotiator_gquic_test.go similarity index 65% rename from internal/handshake/gquic_connection_parameters_manager_test.go rename to internal/handshake/params_negotiator_gquic_test.go index d83291e4..ca7eecbe 100644 --- a/internal/handshake/gquic_connection_parameters_manager_test.go +++ b/internal/handshake/params_negotiator_gquic_test.go @@ -10,9 +10,9 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("ConnectionsParameterManager", func() { - var cpm *gquicConnectionParametersManager // a connectionParametersManager for a server - var cpmClient *gquicConnectionParametersManager +var _ = Describe("Params Negotiator (for gQUIC)", func() { + var pn *paramsNegotiatorGQUIC // a connectionParametersManager for a server + var pnClient *paramsNegotiatorGQUIC const MB = 1 << 20 maxReceiveStreamFlowControlWindowServer := protocol.ByteCount(math.Floor(1.1 * MB)) // default is 1 MB maxReceiveConnectionFlowControlWindowServer := protocol.ByteCount(math.Floor(1.5 * MB)) // default is 1.5 MB @@ -20,7 +20,7 @@ var _ = Describe("ConnectionsParameterManager", func() { maxReceiveConnectionFlowControlWindowClient := protocol.ByteCount(math.Floor(13 * MB)) // default is 15 MB idleTimeout := 42 * time.Second BeforeEach(func() { - cpm = newGQUICConnectionParamatersManager( + pn = newParamsNegotiatorGQUIC( protocol.PerspectiveServer, protocol.VersionWhatever, &TransportParameters{ @@ -29,7 +29,7 @@ var _ = Describe("ConnectionsParameterManager", func() { IdleTimeout: idleTimeout, }, ) - cpmClient = newGQUICConnectionParamatersManager( + pnClient = newParamsNegotiatorGQUIC( protocol.PerspectiveClient, protocol.VersionWhatever, &TransportParameters{ @@ -43,11 +43,11 @@ var _ = Describe("ConnectionsParameterManager", func() { Context("SHLO", func() { BeforeEach(func() { // these tests should only use the server connectionParametersManager. Make them panic if they don't - cpmClient = nil + pnClient = nil }) It("returns all parameters necessary for the SHLO", func() { - entryMap, err := cpm.GetHelloMap() + entryMap, err := pn.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagICSL)) Expect(entryMap).To(HaveKey(TagMSPC)) @@ -55,24 +55,24 @@ var _ = Describe("ConnectionsParameterManager", func() { }) It("sets the stream-level flow control windows in SHLO", func() { - cpm.receiveStreamFlowControlWindow = 0xDEADBEEF - entryMap, err := cpm.GetHelloMap() + pn.receiveStreamFlowControlWindow = 0xDEADBEEF + entryMap, err := pn.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagSFCW)) Expect(entryMap[TagSFCW]).To(Equal([]byte{0xEF, 0xBE, 0xAD, 0xDE})) }) It("sets the connection-level flow control windows in SHLO", func() { - cpm.receiveConnectionFlowControlWindow = 0xDECAFBAD - entryMap, err := cpm.GetHelloMap() + pn.receiveConnectionFlowControlWindow = 0xDECAFBAD + entryMap, err := pn.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagCFCW)) Expect(entryMap[TagCFCW]).To(Equal([]byte{0xAD, 0xFB, 0xCA, 0xDE})) }) It("sets the connection-level flow control windows in SHLO", func() { - cpm.idleConnectionStateLifetime = 0xDECAFBAD * time.Second - entryMap, err := cpm.GetHelloMap() + pn.idleConnectionStateLifetime = 0xDECAFBAD * time.Second + entryMap, err := pn.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagICSL)) Expect(entryMap[TagICSL]).To(Equal([]byte{0xAD, 0xFB, 0xCA, 0xDE})) @@ -81,17 +81,17 @@ var _ = Describe("ConnectionsParameterManager", func() { It("sets the negotiated value for maximum streams in the SHLO", func() { val := 50 Expect(val).To(BeNumerically("<", protocol.MaxStreamsPerConnection)) - err := cpm.SetFromMap(map[Tag][]byte{TagMSPC: []byte{byte(val), 0, 0, 0}}) + err := pn.SetFromMap(map[Tag][]byte{TagMSPC: []byte{byte(val), 0, 0, 0}}) Expect(err).ToNot(HaveOccurred()) - entryMap, err := cpm.GetHelloMap() + entryMap, err := pn.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap[TagMSPC]).To(Equal([]byte{byte(val), 0, 0, 0})) }) It("always sends its own value for the maximum incoming dynamic streams in the SHLO", func() { - err := cpm.SetFromMap(map[Tag][]byte{TagMIDS: []byte{5, 0, 0, 0}}) + err := pn.SetFromMap(map[Tag][]byte{TagMIDS: []byte{5, 0, 0, 0}}) Expect(err).ToNot(HaveOccurred()) - entryMap, err := cpm.GetHelloMap() + entryMap, err := pn.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap[TagMIDS]).To(Equal([]byte{byte(protocol.MaxIncomingDynamicStreamsPerConnection), 0, 0, 0})) }) @@ -100,11 +100,11 @@ var _ = Describe("ConnectionsParameterManager", func() { Context("CHLO", func() { BeforeEach(func() { // these tests should only use the client connectionParametersManager. Make them panic if they don't - cpm = nil + pn = nil }) It("has the right values", func() { - entryMap, err := cpmClient.GetHelloMap() + entryMap, err := pnClient.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagICSL)) Expect(binary.LittleEndian.Uint32(entryMap[TagICSL])).To(BeEquivalentTo(idleTimeout / time.Second)) @@ -121,76 +121,76 @@ var _ = Describe("ConnectionsParameterManager", func() { Context("Truncated connection IDs", func() { It("does not send truncated connection IDs if the TCID tag is missing", func() { - Expect(cpm.TruncateConnectionID()).To(BeFalse()) + Expect(pn.TruncateConnectionID()).To(BeFalse()) }) It("reads the tag for truncated connection IDs", func() { values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}} - cpm.SetFromMap(values) - Expect(cpm.TruncateConnectionID()).To(BeTrue()) + pn.SetFromMap(values) + Expect(pn.TruncateConnectionID()).To(BeTrue()) }) It("ignores the TCID tag, as a client", func() { values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}} - cpmClient.SetFromMap(values) - Expect(cpmClient.TruncateConnectionID()).To(BeFalse()) + pnClient.SetFromMap(values) + Expect(pnClient.TruncateConnectionID()).To(BeFalse()) }) It("errors when given an invalid value", func() { values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).To(MatchError(errMalformedTag)) }) }) Context("flow control", func() { It("has the correct default flow control windows for sending", func() { - Expect(cpm.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow)) - Expect(cpm.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow)) - Expect(cpmClient.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow)) - Expect(cpmClient.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow)) + Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow)) + Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow)) + Expect(pnClient.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow)) + Expect(pnClient.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow)) }) It("has the correct default flow control windows for receiving", func() { - Expect(cpm.GetReceiveStreamFlowControlWindow()).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow)) - Expect(cpm.GetReceiveConnectionFlowControlWindow()).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow)) - Expect(cpmClient.GetReceiveStreamFlowControlWindow()).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow)) - Expect(cpmClient.GetReceiveConnectionFlowControlWindow()).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow)) + Expect(pn.GetReceiveStreamFlowControlWindow()).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow)) + Expect(pn.GetReceiveConnectionFlowControlWindow()).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow)) + Expect(pnClient.GetReceiveStreamFlowControlWindow()).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow)) + Expect(pnClient.GetReceiveConnectionFlowControlWindow()).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow)) }) It("has the correct maximum flow control windows", func() { - Expect(cpm.GetMaxReceiveStreamFlowControlWindow()).To(Equal(maxReceiveStreamFlowControlWindowServer)) - Expect(cpm.GetMaxReceiveConnectionFlowControlWindow()).To(Equal(maxReceiveConnectionFlowControlWindowServer)) - Expect(cpmClient.GetMaxReceiveStreamFlowControlWindow()).To(Equal(maxReceiveStreamFlowControlWindowClient)) - Expect(cpmClient.GetMaxReceiveConnectionFlowControlWindow()).To(Equal(maxReceiveConnectionFlowControlWindowClient)) + Expect(pn.GetMaxReceiveStreamFlowControlWindow()).To(Equal(maxReceiveStreamFlowControlWindowServer)) + Expect(pn.GetMaxReceiveConnectionFlowControlWindow()).To(Equal(maxReceiveConnectionFlowControlWindowServer)) + Expect(pnClient.GetMaxReceiveStreamFlowControlWindow()).To(Equal(maxReceiveStreamFlowControlWindowClient)) + Expect(pnClient.GetMaxReceiveConnectionFlowControlWindow()).To(Equal(maxReceiveConnectionFlowControlWindowClient)) }) It("sets a new stream-level flow control window for sending", func() { values := map[Tag][]byte{TagSFCW: {0xDE, 0xAD, 0xBE, 0xEF}} - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).ToNot(HaveOccurred()) - Expect(cpm.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) + Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) }) It("does not change the stream-level flow control window when given an invalid value", func() { values := map[Tag][]byte{TagSFCW: {0xDE, 0xAD, 0xBE}} // 1 byte too short - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).To(MatchError(errMalformedTag)) - Expect(cpm.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow)) + Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow)) }) It("sets a new connection-level flow control window for sending", func() { values := map[Tag][]byte{TagCFCW: {0xDE, 0xAD, 0xBE, 0xEF}} - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).ToNot(HaveOccurred()) - Expect(cpm.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) + Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) }) It("does not change the connection-level flow control window when given an invalid value", func() { values := map[Tag][]byte{TagCFCW: {0xDE, 0xAD, 0xBE}} // 1 byte too short - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).To(MatchError(errMalformedTag)) - Expect(cpm.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow)) + Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow)) }) It("does not allow renegotiation of flow control parameters", func() { @@ -198,30 +198,30 @@ var _ = Describe("ConnectionsParameterManager", func() { TagCFCW: {0xDE, 0xAD, 0xBE, 0xEF}, TagSFCW: {0xDE, 0xAD, 0xBE, 0xEF}, } - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).ToNot(HaveOccurred()) values = map[Tag][]byte{ TagCFCW: {0x13, 0x37, 0x13, 0x37}, TagSFCW: {0x13, 0x37, 0x13, 0x37}, } - err = cpm.SetFromMap(values) + err = pn.SetFromMap(values) Expect(err).To(MatchError(errFlowControlRenegotiationNotSupported)) - Expect(cpm.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) - Expect(cpm.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) + Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) + Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) }) }) Context("idle connection state lifetime", func() { It("has initial idle connection state lifetime", func() { - Expect(cpm.GetIdleConnectionStateLifetime()).To(Equal(idleTimeout)) + Expect(pn.GetIdleConnectionStateLifetime()).To(Equal(idleTimeout)) }) It("negotiates correctly when the peer wants a longer lifetime", func() { - Expect(cpm.negotiateIdleConnectionStateLifetime(idleTimeout + 10*time.Second)).To(Equal(idleTimeout)) + Expect(pn.negotiateIdleConnectionStateLifetime(idleTimeout + 10*time.Second)).To(Equal(idleTimeout)) }) It("negotiates correctly when the peer wants a shorter lifetime", func() { - Expect(cpm.negotiateIdleConnectionStateLifetime(idleTimeout - 3*time.Second)).To(Equal(idleTimeout - 3*time.Second)) + Expect(pn.negotiateIdleConnectionStateLifetime(idleTimeout - 3*time.Second)).To(Equal(idleTimeout - 3*time.Second)) }) It("sets the negotiated lifetime", func() { @@ -229,29 +229,29 @@ var _ = Describe("ConnectionsParameterManager", func() { values := map[Tag][]byte{ TagICSL: {10, 0, 0, 0}, } - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).ToNot(HaveOccurred()) - Expect(cpm.GetIdleConnectionStateLifetime()).To(Equal(10 * time.Second)) + Expect(pn.GetIdleConnectionStateLifetime()).To(Equal(10 * time.Second)) }) It("does not change the idle connection state lifetime when given an invalid value", func() { values := map[Tag][]byte{ TagSFCW: {0xDE, 0xAD, 0xBE}, // 1 byte too short } - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).To(MatchError(errMalformedTag)) - Expect(cpm.GetIdleConnectionStateLifetime()).To(Equal(idleTimeout)) + Expect(pn.GetIdleConnectionStateLifetime()).To(Equal(idleTimeout)) }) It("gets idle connection state lifetime", func() { value := 0xDECAFBAD * time.Second - cpm.idleConnectionStateLifetime = value - Expect(cpm.GetIdleConnectionStateLifetime()).To(Equal(value)) + pn.idleConnectionStateLifetime = value + Expect(pn.GetIdleConnectionStateLifetime()).To(Equal(value)) }) It("errors when given an invalid value", func() { values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).To(MatchError(errMalformedTag)) }) }) @@ -259,42 +259,42 @@ var _ = Describe("ConnectionsParameterManager", func() { Context("max streams per connection", func() { It("errors when given an invalid max streams per connection value", func() { values := map[Tag][]byte{TagMSPC: {2, 0, 0}} // 1 byte too short - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).To(MatchError(errMalformedTag)) }) It("errors when given an invalid max dynamic incoming streams per connection value", func() { values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short - err := cpm.SetFromMap(values) + err := pn.SetFromMap(values) Expect(err).To(MatchError(errMalformedTag)) }) Context("outgoing connections", func() { It("sets the negotiated max streams per connection value", func() { // this test only works if the value given here is smaller than protocol.MaxStreamsPerConnection - err := cpm.SetFromMap(map[Tag][]byte{ + err := pn.SetFromMap(map[Tag][]byte{ TagMIDS: {2, 0, 0, 0}, TagMSPC: {1, 0, 0, 0}, }) Expect(err).ToNot(HaveOccurred()) - Expect(cpm.GetMaxOutgoingStreams()).To(Equal(uint32(2))) + Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(2))) }) It("uses the the MSPC value, if no MIDS is given", func() { - err := cpm.SetFromMap(map[Tag][]byte{TagMIDS: {3, 0, 0, 0}}) + err := pn.SetFromMap(map[Tag][]byte{TagMIDS: {3, 0, 0, 0}}) Expect(err).ToNot(HaveOccurred()) - Expect(cpm.GetMaxOutgoingStreams()).To(Equal(uint32(3))) + Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(3))) }) }) Context("incoming connections", func() { It("always uses the constant value, no matter what the client sent", func() { - err := cpm.SetFromMap(map[Tag][]byte{ + err := pn.SetFromMap(map[Tag][]byte{ TagMSPC: {3, 0, 0, 0}, TagMIDS: {3, 0, 0, 0}, }) Expect(err).ToNot(HaveOccurred()) - Expect(cpm.GetMaxIncomingStreams()).To(BeNumerically(">", protocol.MaxStreamsPerConnection)) + Expect(pn.GetMaxIncomingStreams()).To(BeNumerically(">", protocol.MaxStreamsPerConnection)) }) }) }) diff --git a/internal/mocks/cpm.go b/internal/mocks/cpm.go deleted file mode 100644 index 0492abee..00000000 --- a/internal/mocks/cpm.go +++ /dev/null @@ -1,155 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ../handshake/base_connection_parameters_manager.go - -package mocks - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// MockConnectionParametersManager is a mock of ConnectionParametersManager interface -type MockConnectionParametersManager struct { - ctrl *gomock.Controller - recorder *MockConnectionParametersManagerMockRecorder -} - -// MockConnectionParametersManagerMockRecorder is the mock recorder for MockConnectionParametersManager -type MockConnectionParametersManagerMockRecorder struct { - mock *MockConnectionParametersManager -} - -// NewMockConnectionParametersManager creates a new mock instance -func NewMockConnectionParametersManager(ctrl *gomock.Controller) *MockConnectionParametersManager { - mock := &MockConnectionParametersManager{ctrl: ctrl} - mock.recorder = &MockConnectionParametersManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (_m *MockConnectionParametersManager) EXPECT() *MockConnectionParametersManagerMockRecorder { - return _m.recorder -} - -// GetSendStreamFlowControlWindow mocks base method -func (_m *MockConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetSendStreamFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetSendStreamFlowControlWindow indicates an expected call of GetSendStreamFlowControlWindow -func (_mr *MockConnectionParametersManagerMockRecorder) GetSendStreamFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetSendStreamFlowControlWindow", reflect.TypeOf((*MockConnectionParametersManager)(nil).GetSendStreamFlowControlWindow)) -} - -// GetSendConnectionFlowControlWindow mocks base method -func (_m *MockConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetSendConnectionFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetSendConnectionFlowControlWindow indicates an expected call of GetSendConnectionFlowControlWindow -func (_mr *MockConnectionParametersManagerMockRecorder) GetSendConnectionFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetSendConnectionFlowControlWindow", reflect.TypeOf((*MockConnectionParametersManager)(nil).GetSendConnectionFlowControlWindow)) -} - -// GetReceiveStreamFlowControlWindow mocks base method -func (_m *MockConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetReceiveStreamFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetReceiveStreamFlowControlWindow indicates an expected call of GetReceiveStreamFlowControlWindow -func (_mr *MockConnectionParametersManagerMockRecorder) GetReceiveStreamFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetReceiveStreamFlowControlWindow", reflect.TypeOf((*MockConnectionParametersManager)(nil).GetReceiveStreamFlowControlWindow)) -} - -// GetMaxReceiveStreamFlowControlWindow mocks base method -func (_m *MockConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetMaxReceiveStreamFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetMaxReceiveStreamFlowControlWindow indicates an expected call of GetMaxReceiveStreamFlowControlWindow -func (_mr *MockConnectionParametersManagerMockRecorder) GetMaxReceiveStreamFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetMaxReceiveStreamFlowControlWindow", reflect.TypeOf((*MockConnectionParametersManager)(nil).GetMaxReceiveStreamFlowControlWindow)) -} - -// GetReceiveConnectionFlowControlWindow mocks base method -func (_m *MockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetReceiveConnectionFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetReceiveConnectionFlowControlWindow indicates an expected call of GetReceiveConnectionFlowControlWindow -func (_mr *MockConnectionParametersManagerMockRecorder) GetReceiveConnectionFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetReceiveConnectionFlowControlWindow", reflect.TypeOf((*MockConnectionParametersManager)(nil).GetReceiveConnectionFlowControlWindow)) -} - -// GetMaxReceiveConnectionFlowControlWindow mocks base method -func (_m *MockConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { - ret := _m.ctrl.Call(_m, "GetMaxReceiveConnectionFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetMaxReceiveConnectionFlowControlWindow indicates an expected call of GetMaxReceiveConnectionFlowControlWindow -func (_mr *MockConnectionParametersManagerMockRecorder) GetMaxReceiveConnectionFlowControlWindow() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetMaxReceiveConnectionFlowControlWindow", reflect.TypeOf((*MockConnectionParametersManager)(nil).GetMaxReceiveConnectionFlowControlWindow)) -} - -// GetMaxOutgoingStreams mocks base method -func (_m *MockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { - ret := _m.ctrl.Call(_m, "GetMaxOutgoingStreams") - ret0, _ := ret[0].(uint32) - return ret0 -} - -// GetMaxOutgoingStreams indicates an expected call of GetMaxOutgoingStreams -func (_mr *MockConnectionParametersManagerMockRecorder) GetMaxOutgoingStreams() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetMaxOutgoingStreams", reflect.TypeOf((*MockConnectionParametersManager)(nil).GetMaxOutgoingStreams)) -} - -// GetMaxIncomingStreams mocks base method -func (_m *MockConnectionParametersManager) GetMaxIncomingStreams() uint32 { - ret := _m.ctrl.Call(_m, "GetMaxIncomingStreams") - ret0, _ := ret[0].(uint32) - return ret0 -} - -// GetMaxIncomingStreams indicates an expected call of GetMaxIncomingStreams -func (_mr *MockConnectionParametersManagerMockRecorder) GetMaxIncomingStreams() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetMaxIncomingStreams", reflect.TypeOf((*MockConnectionParametersManager)(nil).GetMaxIncomingStreams)) -} - -// GetIdleConnectionStateLifetime mocks base method -func (_m *MockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { - ret := _m.ctrl.Call(_m, "GetIdleConnectionStateLifetime") - ret0, _ := ret[0].(time.Duration) - return ret0 -} - -// GetIdleConnectionStateLifetime indicates an expected call of GetIdleConnectionStateLifetime -func (_mr *MockConnectionParametersManagerMockRecorder) GetIdleConnectionStateLifetime() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetIdleConnectionStateLifetime", reflect.TypeOf((*MockConnectionParametersManager)(nil).GetIdleConnectionStateLifetime)) -} - -// TruncateConnectionID mocks base method -func (_m *MockConnectionParametersManager) TruncateConnectionID() bool { - ret := _m.ctrl.Call(_m, "TruncateConnectionID") - ret0, _ := ret[0].(bool) - return ret0 -} - -// TruncateConnectionID indicates an expected call of TruncateConnectionID -func (_mr *MockConnectionParametersManagerMockRecorder) TruncateConnectionID() *gomock.Call { - return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "TruncateConnectionID", reflect.TypeOf((*MockConnectionParametersManager)(nil).TruncateConnectionID)) -} diff --git a/internal/mocks/gen.go b/internal/mocks/gen.go index 0b40addc..064a0ad5 100644 --- a/internal/mocks/gen.go +++ b/internal/mocks/gen.go @@ -4,5 +4,5 @@ package mocks // so we have to use sed to correct for that //go:generate sh -c "mockgen -package mocks_fc -source ../flowcontrol/interface.go | sed \"s/\\[\\]WindowUpdate/[]flowcontrol.WindowUpdate/g\" > mocks_fc/flow_control_manager.go" -//go:generate sh -c "mockgen -package mocks -source ../handshake/base_connection_parameters_manager.go > cpm.go" +//go:generate sh -c "mockgen -package mocks -source ../handshake/params_negotiator_base.go > params_negotiator.go" //go:generate sh -c "goimports -w ." diff --git a/internal/mocks/params_negotiator.go b/internal/mocks/params_negotiator.go new file mode 100644 index 00000000..dd058979 --- /dev/null +++ b/internal/mocks/params_negotiator.go @@ -0,0 +1,155 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../handshake/params_negotiator_base.go + +package mocks + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockParamsNegotiator is a mock of ParamsNegotiator interface +type MockParamsNegotiator struct { + ctrl *gomock.Controller + recorder *MockParamsNegotiatorMockRecorder +} + +// MockParamsNegotiatorMockRecorder is the mock recorder for MockParamsNegotiator +type MockParamsNegotiatorMockRecorder struct { + mock *MockParamsNegotiator +} + +// NewMockParamsNegotiator creates a new mock instance +func NewMockParamsNegotiator(ctrl *gomock.Controller) *MockParamsNegotiator { + mock := &MockParamsNegotiator{ctrl: ctrl} + mock.recorder = &MockParamsNegotiatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockParamsNegotiator) EXPECT() *MockParamsNegotiatorMockRecorder { + return _m.recorder +} + +// GetSendStreamFlowControlWindow mocks base method +func (_m *MockParamsNegotiator) GetSendStreamFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetSendStreamFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetSendStreamFlowControlWindow indicates an expected call of GetSendStreamFlowControlWindow +func (_mr *MockParamsNegotiatorMockRecorder) GetSendStreamFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetSendStreamFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetSendStreamFlowControlWindow)) +} + +// GetSendConnectionFlowControlWindow mocks base method +func (_m *MockParamsNegotiator) GetSendConnectionFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetSendConnectionFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetSendConnectionFlowControlWindow indicates an expected call of GetSendConnectionFlowControlWindow +func (_mr *MockParamsNegotiatorMockRecorder) GetSendConnectionFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetSendConnectionFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetSendConnectionFlowControlWindow)) +} + +// GetReceiveStreamFlowControlWindow mocks base method +func (_m *MockParamsNegotiator) GetReceiveStreamFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetReceiveStreamFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetReceiveStreamFlowControlWindow indicates an expected call of GetReceiveStreamFlowControlWindow +func (_mr *MockParamsNegotiatorMockRecorder) GetReceiveStreamFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetReceiveStreamFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetReceiveStreamFlowControlWindow)) +} + +// GetMaxReceiveStreamFlowControlWindow mocks base method +func (_m *MockParamsNegotiator) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetMaxReceiveStreamFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetMaxReceiveStreamFlowControlWindow indicates an expected call of GetMaxReceiveStreamFlowControlWindow +func (_mr *MockParamsNegotiatorMockRecorder) GetMaxReceiveStreamFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetMaxReceiveStreamFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxReceiveStreamFlowControlWindow)) +} + +// GetReceiveConnectionFlowControlWindow mocks base method +func (_m *MockParamsNegotiator) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetReceiveConnectionFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetReceiveConnectionFlowControlWindow indicates an expected call of GetReceiveConnectionFlowControlWindow +func (_mr *MockParamsNegotiatorMockRecorder) GetReceiveConnectionFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetReceiveConnectionFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetReceiveConnectionFlowControlWindow)) +} + +// GetMaxReceiveConnectionFlowControlWindow mocks base method +func (_m *MockParamsNegotiator) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "GetMaxReceiveConnectionFlowControlWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// GetMaxReceiveConnectionFlowControlWindow indicates an expected call of GetMaxReceiveConnectionFlowControlWindow +func (_mr *MockParamsNegotiatorMockRecorder) GetMaxReceiveConnectionFlowControlWindow() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetMaxReceiveConnectionFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxReceiveConnectionFlowControlWindow)) +} + +// GetMaxOutgoingStreams mocks base method +func (_m *MockParamsNegotiator) GetMaxOutgoingStreams() uint32 { + ret := _m.ctrl.Call(_m, "GetMaxOutgoingStreams") + ret0, _ := ret[0].(uint32) + return ret0 +} + +// GetMaxOutgoingStreams indicates an expected call of GetMaxOutgoingStreams +func (_mr *MockParamsNegotiatorMockRecorder) GetMaxOutgoingStreams() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetMaxOutgoingStreams", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxOutgoingStreams)) +} + +// GetMaxIncomingStreams mocks base method +func (_m *MockParamsNegotiator) GetMaxIncomingStreams() uint32 { + ret := _m.ctrl.Call(_m, "GetMaxIncomingStreams") + ret0, _ := ret[0].(uint32) + return ret0 +} + +// GetMaxIncomingStreams indicates an expected call of GetMaxIncomingStreams +func (_mr *MockParamsNegotiatorMockRecorder) GetMaxIncomingStreams() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetMaxIncomingStreams", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxIncomingStreams)) +} + +// GetIdleConnectionStateLifetime mocks base method +func (_m *MockParamsNegotiator) GetIdleConnectionStateLifetime() time.Duration { + ret := _m.ctrl.Call(_m, "GetIdleConnectionStateLifetime") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// GetIdleConnectionStateLifetime indicates an expected call of GetIdleConnectionStateLifetime +func (_mr *MockParamsNegotiatorMockRecorder) GetIdleConnectionStateLifetime() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetIdleConnectionStateLifetime", reflect.TypeOf((*MockParamsNegotiator)(nil).GetIdleConnectionStateLifetime)) +} + +// TruncateConnectionID mocks base method +func (_m *MockParamsNegotiator) TruncateConnectionID() bool { + ret := _m.ctrl.Call(_m, "TruncateConnectionID") + ret0, _ := ret[0].(bool) + return ret0 +} + +// TruncateConnectionID indicates an expected call of TruncateConnectionID +func (_mr *MockParamsNegotiatorMockRecorder) TruncateConnectionID() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "TruncateConnectionID", reflect.TypeOf((*MockParamsNegotiator)(nil).TruncateConnectionID)) +} diff --git a/packet_packer.go b/packet_packer.go index 89a99a6f..d289ff67 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -25,7 +25,7 @@ type packetPacker struct { cryptoSetup handshake.CryptoSetup packetNumberGenerator *packetNumberGenerator - connectionParameters handshake.ConnectionParametersManager + connParams handshake.ParamsNegotiator streamFramer *streamFramer controlFrames []wire.Frame @@ -36,7 +36,7 @@ type packetPacker struct { func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.CryptoSetup, - connectionParameters handshake.ConnectionParametersManager, + connParams handshake.ParamsNegotiator, streamFramer *streamFramer, perspective protocol.Perspective, version protocol.VersionNumber, @@ -44,7 +44,7 @@ func newPacketPacker(connectionID protocol.ConnectionID, return &packetPacker{ cryptoSetup: cryptoSetup, connectionID: connectionID, - connectionParameters: connectionParameters, + connParams: connParams, perspective: perspective, version: version, streamFramer: streamFramer, @@ -271,7 +271,7 @@ func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *wire. ConnectionID: p.connectionID, PacketNumber: pnum, PacketNumberLen: packetNumberLen, - TruncateConnectionID: p.connectionParameters.TruncateConnectionID(), + TruncateConnectionID: p.connParams.TruncateConnectionID(), } if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { diff --git a/packet_packer_test.go b/packet_packer_test.go index aafa4549..b6ddc6aa 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -61,8 +61,8 @@ var _ = Describe("Packet packer", func() { ) BeforeEach(func() { - mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) - mockCpm.EXPECT().TruncateConnectionID().Return(false).AnyTimes() + mockPn := mocks.NewMockParamsNegotiator(mockCtrl) + mockPn.EXPECT().TruncateConnectionID().Return(false).AnyTimes() cryptoStream = &stream{} @@ -73,7 +73,7 @@ var _ = Describe("Packet packer", func() { packer = &packetPacker{ cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, - connectionParameters: mockCpm, + connParams: mockPn, connectionID: 0x1337, packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), streamFramer: streamFramer, diff --git a/session.go b/session.go index d1f8d35f..478a7d62 100644 --- a/session.go +++ b/session.go @@ -100,7 +100,7 @@ type session struct { // it receives at most 3 handshake events: 2 when the encryption level changes, and one error handshakeChan chan<- handshakeEvent - connectionParameters handshake.ConnectionParametersManager + connParams handshake.ParamsNegotiator lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire @@ -194,7 +194,7 @@ func (s *session) setup( return s.config.AcceptCookie(clientAddr, cookie) } if s.version.UsesTLS() { - s.cryptoSetup, s.connectionParameters, err = handshake.NewCryptoSetupTLS( + s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLS( "", s.perspective, s.version, @@ -202,7 +202,7 @@ func (s *session) setup( aeadChanged, ) } else { - s.cryptoSetup, s.connectionParameters, err = newCryptoSetup( + s.cryptoSetup, s.connParams, err = newCryptoSetup( s.connectionID, s.conn.RemoteAddr(), s.version, @@ -215,7 +215,7 @@ func (s *session) setup( } } else { if s.version.UsesTLS() { - s.cryptoSetup, s.connectionParameters, err = handshake.NewCryptoSetupTLS( + s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLS( hostname, s.perspective, s.version, @@ -224,7 +224,7 @@ func (s *session) setup( ) } else { transportParams.RequestConnectionIDTruncation = s.config.RequestConnectionIDTruncation - s.cryptoSetup, s.connectionParameters, err = newCryptoSetupClient( + s.cryptoSetup, s.connParams, err = newCryptoSetupClient( hostname, s.connectionID, s.version, @@ -239,12 +239,12 @@ func (s *session) setup( return nil, nil, err } - s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) - s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) + s.flowControlManager = flowcontrol.NewFlowControlManager(s.connParams, s.rttStats) + s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connParams) s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) s.packer = newPacketPacker(s.connectionID, s.cryptoSetup, - s.connectionParameters, + s.connParams, s.streamFramer, s.perspective, s.version, @@ -389,7 +389,7 @@ func (s *session) maybeResetTimer() { } func (s *session) idleTimeout() time.Duration { - return s.connectionParameters.GetIdleConnectionStateLifetime() + return s.connParams.GetIdleConnectionStateLifetime() } func (s *session) handlePacketImpl(p *receivedPacket) error { diff --git a/session_test.go b/session_test.go index a43d0786..5ef4ab04 100644 --- a/session_test.go +++ b/session_test.go @@ -143,36 +143,34 @@ func areSessionsRunning() bool { return strings.Contains(b.String(), "quic-go.(*session).run") } -type mockConnectionParametersManager struct{} +type mockParamsNegotiator struct{} -var _ handshake.ConnectionParametersManager = &mockConnectionParametersManager{} +var _ handshake.ParamsNegotiator = &mockParamsNegotiator{} -func (m *mockConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { +func (m *mockParamsNegotiator) GetSendStreamFlowControlWindow() protocol.ByteCount { return protocol.InitialStreamFlowControlWindow } -func (m *mockConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { +func (m *mockParamsNegotiator) GetSendConnectionFlowControlWindow() protocol.ByteCount { return protocol.InitialConnectionFlowControlWindow } -func (m *mockConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { +func (m *mockParamsNegotiator) GetReceiveStreamFlowControlWindow() protocol.ByteCount { return protocol.ReceiveStreamFlowControlWindow } -func (m *mockConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { +func (m *mockParamsNegotiator) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { return protocol.DefaultMaxReceiveStreamFlowControlWindowServer } -func (m *mockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { +func (m *mockParamsNegotiator) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { return protocol.ReceiveConnectionFlowControlWindow } -func (m *mockConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { +func (m *mockParamsNegotiator) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { return protocol.DefaultMaxReceiveConnectionFlowControlWindowServer } -func (m *mockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { return 100 } -func (m *mockConnectionParametersManager) GetMaxIncomingStreams() uint32 { return 100 } -func (m *mockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { +func (m *mockParamsNegotiator) GetMaxOutgoingStreams() uint32 { return 100 } +func (m *mockParamsNegotiator) GetMaxIncomingStreams() uint32 { return 100 } +func (m *mockParamsNegotiator) GetIdleConnectionStateLifetime() time.Duration { return time.Hour } -func (m *mockConnectionParametersManager) TruncateConnectionID() bool { return false } - -var _ handshake.ConnectionParametersManager = &mockConnectionParametersManager{} +func (m *mockParamsNegotiator) TruncateConnectionID() bool { return false } var _ = Describe("Session", func() { var ( @@ -197,9 +195,9 @@ var _ = Describe("Session", func() { _ []protocol.VersionNumber, _ func(net.Addr, *Cookie) bool, aeadChangedP chan<- protocol.EncryptionLevel, - ) (handshake.CryptoSetup, handshake.ConnectionParametersManager, error) { + ) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) { aeadChanged = aeadChangedP - return cryptoSetup, &mockConnectionParametersManager{}, nil + return cryptoSetup, &mockParamsNegotiator{}, nil } mconn = newMockConnection() @@ -221,7 +219,7 @@ var _ = Describe("Session", func() { sess = pSess.(*session) Expect(sess.streamsMap.openStreams).To(BeEmpty()) // the crypto stream is opened in session.run() - sess.connectionParameters = &mockConnectionParametersManager{} + sess.connParams = &mockParamsNegotiator{} }) AfterEach(func() { @@ -247,9 +245,9 @@ var _ = Describe("Session", func() { _ []protocol.VersionNumber, cookieFunc func(net.Addr, *Cookie) bool, _ chan<- protocol.EncryptionLevel, - ) (handshake.CryptoSetup, handshake.ConnectionParametersManager, error) { + ) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) { cookieVerify = cookieFunc - return cryptoSetup, &mockConnectionParametersManager{}, nil + return cryptoSetup, &mockParamsNegotiator{}, nil } conf := populateServerConfig(&Config{}) @@ -1559,11 +1557,11 @@ var _ = Describe("Session", func() { It("does not use ICSL before handshake", func() { defer sess.Close(nil) sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) - mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) - mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(9999 * time.Second).AnyTimes() - mockCpm.EXPECT().TruncateConnectionID().Return(false).AnyTimes() - sess.connectionParameters = mockCpm - sess.packer.connectionParameters = mockCpm + mockPn := mocks.NewMockParamsNegotiator(mockCtrl) + mockPn.EXPECT().GetIdleConnectionStateLifetime().Return(9999 * time.Second).AnyTimes() + mockPn.EXPECT().TruncateConnectionID().Return(false).AnyTimes() + sess.connParams = mockPn + sess.packer.connParams = mockPn // the handshake timeout is irrelevant here, since it depends on the time the session was created, // and not on the last network activity done := make(chan struct{}) @@ -1576,12 +1574,12 @@ var _ = Describe("Session", func() { It("uses ICSL after handshake", func(done Done) { close(aeadChanged) - mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) - mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(0 * time.Second) - mockCpm.EXPECT().TruncateConnectionID().Return(false).AnyTimes() - sess.connectionParameters = mockCpm - sess.packer.connectionParameters = mockCpm - mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(0 * time.Second).AnyTimes() + mockPn := mocks.NewMockParamsNegotiator(mockCtrl) + mockPn.EXPECT().GetIdleConnectionStateLifetime().Return(0 * time.Second) + mockPn.EXPECT().TruncateConnectionID().Return(false).AnyTimes() + sess.connParams = mockPn + sess.packer.connParams = mockPn + mockPn.EXPECT().GetIdleConnectionStateLifetime().Return(0 * time.Second).AnyTimes() err := sess.run() // Would normally not return Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity."))) @@ -1630,8 +1628,8 @@ var _ = Describe("Session", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { - mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) - mockCpm.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() + mockPn := mocks.NewMockParamsNegotiator(mockCtrl) + mockPn.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() for i := 0; i < 10; i++ { _, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) @@ -1641,8 +1639,8 @@ var _ = Describe("Session", func() { }) It("does not error when many streams are opened and closed", func() { - mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) - mockCpm.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() + mockPn := mocks.NewMockParamsNegotiator(mockCtrl) + mockPn.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() for i := 2; i <= 1000; i++ { s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) @@ -1732,9 +1730,9 @@ var _ = Describe("Client Session", func() { _ *handshake.TransportParameters, aeadChangedP chan<- protocol.EncryptionLevel, _ []protocol.VersionNumber, - ) (handshake.CryptoSetup, handshake.ConnectionParametersManager, error) { + ) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) { aeadChanged = aeadChangedP - return cryptoSetup, &mockConnectionParametersManager{}, nil + return cryptoSetup, &mockParamsNegotiator{}, nil } mconn = newMockConnection() diff --git a/streams_map.go b/streams_map.go index 4638e8ce..269669f5 100644 --- a/streams_map.go +++ b/streams_map.go @@ -13,8 +13,8 @@ import ( type streamsMap struct { mutex sync.RWMutex - perspective protocol.Perspective - connectionParameters handshake.ConnectionParametersManager + connParams handshake.ParamsNegotiator + perspective protocol.Perspective streams map[protocol.StreamID]*stream // needed for round-robin scheduling @@ -42,13 +42,13 @@ var ( errMapAccess = errors.New("streamsMap: Error accessing the streams map") ) -func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connectionParameters handshake.ConnectionParametersManager) *streamsMap { +func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, connParams handshake.ParamsNegotiator) *streamsMap { sm := streamsMap{ - perspective: pers, - streams: map[protocol.StreamID]*stream{}, - openStreams: make([]protocol.StreamID, 0), - newStream: newStream, - connectionParameters: connectionParameters, + perspective: pers, + streams: map[protocol.StreamID]*stream{}, + openStreams: make([]protocol.StreamID, 0), + newStream: newStream, + connParams: connParams, } sm.nextStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex @@ -125,7 +125,7 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) { } func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { - if m.numIncomingStreams >= m.connectionParameters.GetMaxIncomingStreams() { + if m.numIncomingStreams >= m.connParams.GetMaxIncomingStreams() { return nil, qerr.TooManyOpenStreams } if id+protocol.MaxNewStreamIDDelta < m.highestStreamOpenedByPeer { @@ -149,7 +149,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { func (m *streamsMap) openStreamImpl() (*stream, error) { id := m.nextStream - if m.numOutgoingStreams >= m.connectionParameters.GetMaxOutgoingStreams() { + if m.numOutgoingStreams >= m.connParams.GetMaxOutgoingStreams() { return nil, qerr.TooManyOpenStreams } diff --git a/streams_map_test.go b/streams_map_test.go index 3c7f02fb..c46d4935 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -17,17 +17,17 @@ var _ = Describe("Streams Map", func() { ) var ( - m *streamsMap - mockCpm *mocks.MockConnectionParametersManager + m *streamsMap + mockPn *mocks.MockParamsNegotiator ) setNewStreamsMap := func(p protocol.Perspective) { - mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) + mockPn = mocks.NewMockParamsNegotiator(mockCtrl) - mockCpm.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams)) - mockCpm.EXPECT().GetMaxIncomingStreams().AnyTimes().Return(uint32(maxIncomingStreams)) + mockPn.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams)) + mockPn.EXPECT().GetMaxIncomingStreams().AnyTimes().Return(uint32(maxIncomingStreams)) - m = newStreamsMap(nil, p, mockCpm) + m = newStreamsMap(nil, p, mockPn) m.newStream = func(id protocol.StreamID) *stream { return newStream(id, nil, nil, nil) }