diff --git a/internal/flowcontrol/flow_control_manager.go b/internal/flowcontrol/flow_control_manager.go index 82606286..902eaa79 100644 --- a/internal/flowcontrol/flow_control_manager.go +++ b/internal/flowcontrol/flow_control_manager.go @@ -13,32 +13,31 @@ import ( ) type flowControlManager struct { - connParams handshake.ParamsNegotiator rttStats *congestion.RTTStats maxReceiveStreamWindow protocol.ByteCount streamFlowController map[protocol.StreamID]*flowController connFlowController *flowController mutex sync.RWMutex + + initialStreamSendWindow protocol.ByteCount } var _ FlowControlManager = &flowControlManager{} -var errMapAccess = errors.New("Error accessing the flowController map.") +var errMapAccess = errors.New("Error accessing the flowController map") // NewFlowControlManager creates a new flow control manager func NewFlowControlManager( - connParams handshake.ParamsNegotiator, maxReceiveStreamWindow protocol.ByteCount, maxReceiveConnectionWindow protocol.ByteCount, rttStats *congestion.RTTStats, ) FlowControlManager { return &flowControlManager{ - connParams: connParams, rttStats: rttStats, maxReceiveStreamWindow: maxReceiveStreamWindow, streamFlowController: make(map[protocol.StreamID]*flowController), - connFlowController: newFlowController(0, false, connParams, protocol.ReceiveConnectionFlowControlWindow, maxReceiveConnectionWindow, rttStats), + connFlowController: newFlowController(0, false, protocol.ReceiveConnectionFlowControlWindow, maxReceiveConnectionWindow, 0, rttStats), } } @@ -51,7 +50,7 @@ func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesTo if _, ok := f.streamFlowController[streamID]; ok { return } - f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connParams, protocol.ReceiveStreamFlowControlWindow, f.maxReceiveStreamWindow, f.rttStats) + f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, protocol.ReceiveStreamFlowControlWindow, f.maxReceiveStreamWindow, f.initialStreamSendWindow, f.rttStats) } // RemoveStream removes a closed stream from flow control @@ -61,6 +60,17 @@ func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { f.mutex.Unlock() } +func (f *flowControlManager) UpdateTransportParameters(params *handshake.TransportParameters) { + f.mutex.Lock() + defer f.mutex.Unlock() + + f.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow) + f.initialStreamSendWindow = params.StreamFlowControlWindow + for _, fc := range f.streamFlowController { + fc.UpdateSendWindow(params.StreamFlowControlWindow) + } +} + // ResetStream should be called when receiving a RstStreamFrame // it updates the byte offset to the value in the RstStreamFrame // streamID must not be 0 here @@ -233,7 +243,6 @@ func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset pro return false, err } } - return fc.UpdateSendWindow(offset), nil } diff --git a/internal/flowcontrol/flow_control_manager_test.go b/internal/flowcontrol/flow_control_manager_test.go index ea1d3706..09f2f64f 100644 --- a/internal/flowcontrol/flow_control_manager_test.go +++ b/internal/flowcontrol/flow_control_manager_test.go @@ -3,8 +3,9 @@ package flowcontrol import ( "time" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -15,13 +16,18 @@ var _ = Describe("Flow Control Manager", func() { var fcm *flowControlManager BeforeEach(func() { - mockPn := mocks.NewMockParamsNegotiator(mockCtrl) - fcm = NewFlowControlManager(mockPn, protocol.MaxByteCount, protocol.MaxByteCount, &congestion.RTTStats{}).(*flowControlManager) + fcm = NewFlowControlManager( + 0x2000, // maxReceiveStreamWindow + 0x4000, // maxReceiveConnectionWindow + &congestion.RTTStats{}, + ).(*flowControlManager) }) It("creates a connection level flow controller", func() { - Expect(fcm.streamFlowController).ToNot(HaveKey(protocol.StreamID(0))) + Expect(fcm.streamFlowController).To(BeEmpty()) Expect(fcm.connFlowController.ContributesToConnection()).To(BeFalse()) + Expect(fcm.connFlowController.sendWindow).To(BeZero()) + Expect(fcm.connFlowController.maxReceiveWindowIncrement).To(Equal(protocol.ByteCount(0x4000))) }) Context("creating new streams", func() { @@ -31,6 +37,19 @@ var _ = Describe("Flow Control Manager", func() { fc := fcm.streamFlowController[5] Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.ContributesToConnection()).To(BeFalse()) + // the transport parameters have not yet been received. Start with a window of size 0 + Expect(fc.sendWindow).To(BeZero()) + Expect(fc.maxReceiveWindowIncrement).To(Equal(protocol.ByteCount(0x2000))) + }) + + It("creates a new stream after it has received transport parameters", func() { + fcm.UpdateTransportParameters(&handshake.TransportParameters{ + StreamFlowControlWindow: 0x3000, + }) + fcm.NewStream(5, false) + Expect(fcm.streamFlowController).To(HaveKey(protocol.StreamID(5))) + fc := fcm.streamFlowController[5] + Expect(fc.sendWindow).To(Equal(protocol.ByteCount(0x3000))) }) It("doesn't create a new flow controller if called for an existing stream", func() { @@ -51,6 +70,16 @@ var _ = Describe("Flow Control Manager", func() { Expect(fcm.streamFlowController).ToNot(HaveKey(protocol.StreamID(5))) }) + It("updates the send windows for existing streams when receiveing the transport parameters", func() { + fcm.NewStream(5, false) + fcm.UpdateTransportParameters(&handshake.TransportParameters{ + StreamFlowControlWindow: 0x3000, + ConnectionFlowControlWindow: 0x6000, + }) + Expect(fcm.connFlowController.sendWindow).To(Equal(protocol.ByteCount(0x6000))) + Expect(fcm.streamFlowController[5].sendWindow).To(Equal(protocol.ByteCount(0x3000))) + }) + Context("receiving data", func() { BeforeEach(func() { fcm.NewStream(1, false) diff --git a/internal/flowcontrol/flow_controller.go b/internal/flowcontrol/flow_controller.go index ae318d59..8bd42492 100644 --- a/internal/flowcontrol/flow_controller.go +++ b/internal/flowcontrol/flow_controller.go @@ -5,7 +5,6 @@ import ( "time" "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) @@ -14,8 +13,7 @@ type flowController struct { streamID protocol.StreamID contributesToConnection bool // does the stream contribute to connection level flow control - connParams handshake.ParamsNegotiator - rttStats *congestion.RTTStats + rttStats *congestion.RTTStats bytesSent protocol.ByteCount sendWindow protocol.ByteCount @@ -36,19 +34,19 @@ var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset") func newFlowController( streamID protocol.StreamID, contributesToConnection bool, - connParams handshake.ParamsNegotiator, receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, + initialSendWindow protocol.ByteCount, rttStats *congestion.RTTStats, ) *flowController { return &flowController{ streamID: streamID, contributesToConnection: contributesToConnection, - connParams: connParams, rttStats: rttStats, receiveWindow: receiveWindow, receiveWindowIncrement: receiveWindow, maxReceiveWindowIncrement: maxReceiveWindow, + sendWindow: initialSendWindow, } } @@ -56,16 +54,6 @@ func (c *flowController) ContributesToConnection() bool { return c.contributesToConnection } -func (c *flowController) getSendWindow() protocol.ByteCount { - if c.sendWindow == 0 { - if c.streamID == 0 { - return c.connParams.GetSendConnectionFlowControlWindow() - } - return c.connParams.GetSendStreamFlowControlWindow() - } - return c.sendWindow -} - func (c *flowController) AddBytesSent(n protocol.ByteCount) { c.bytesSent += n } @@ -81,16 +69,11 @@ func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { } func (c *flowController) SendWindowSize() protocol.ByteCount { - sendWindow := c.getSendWindow() - - if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here + // this only happens during connection establishment, when data is sent before we receive the peer's transport parameters + if c.bytesSent > c.sendWindow { return 0 } - return sendWindow - c.bytesSent -} - -func (c *flowController) SendWindowOffset() protocol.ByteCount { - return c.getSendWindow() + return c.sendWindow - c.bytesSent } // UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher diff --git a/internal/flowcontrol/flow_controller_test.go b/internal/flowcontrol/flow_controller_test.go index a6c70b0f..161dafc1 100644 --- a/internal/flowcontrol/flow_controller_test.go +++ b/internal/flowcontrol/flow_controller_test.go @@ -4,7 +4,6 @@ import ( "time" "github.com/lucas-clemente/quic-go/congestion" - "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -19,61 +18,28 @@ var _ = Describe("Flow controller", func() { }) Context("Constructor", func() { - var rttStats *congestion.RTTStats - var mockPn *mocks.MockParamsNegotiator + rttStats := &congestion.RTTStats{} - receiveStreamWindow := protocol.ByteCount(2000) - receiveConnectionWindow := protocol.ByteCount(4000) - maxReceiveStreamWindow := protocol.ByteCount(8000) - maxReceiveConnectionWindow := protocol.ByteCount(9000) - - BeforeEach(func() { - mockPn = mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetSendStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(1000)) - mockPn.EXPECT().GetSendConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(3000)) - rttStats = &congestion.RTTStats{} - }) - - It("reads the stream send and receive windows when acting as stream-level flow controller", func() { - fc := newFlowController(5, true, mockPn, receiveStreamWindow, maxReceiveStreamWindow, rttStats) + It("sets the send and receive windows", func() { + receiveWindow := protocol.ByteCount(2000) + maxReceiveWindow := protocol.ByteCount(3000) + sendWindow := protocol.ByteCount(4000) + fc := newFlowController(5, true, receiveWindow, maxReceiveWindow, sendWindow, rttStats) Expect(fc.streamID).To(Equal(protocol.StreamID(5))) - Expect(fc.receiveWindow).To(Equal(receiveStreamWindow)) - Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveStreamWindow)) - }) - - It("reads the stream send and receive windows when acting as connection-level flow controller", func() { - fc := newFlowController(0, false, mockPn, receiveConnectionWindow, maxReceiveConnectionWindow, rttStats) - Expect(fc.streamID).To(Equal(protocol.StreamID(0))) - Expect(fc.receiveWindow).To(Equal(receiveConnectionWindow)) - Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveConnectionWindow)) - }) - - It("does not set the stream flow control windows for sending", func() { - fc := newFlowController(5, true, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats) - Expect(fc.sendWindow).To(BeZero()) - }) - - It("does not set the connection flow control windows for sending", func() { - fc := newFlowController(0, false, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats) - Expect(fc.sendWindow).To(BeZero()) + Expect(fc.receiveWindow).To(Equal(receiveWindow)) + Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow)) + Expect(fc.sendWindow).To(Equal(sendWindow)) }) It("says if it contributes to connection-level flow control", func() { - fc := newFlowController(1, false, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats) + fc := newFlowController(1, false, protocol.MaxByteCount, protocol.MaxByteCount, protocol.MaxByteCount, rttStats) Expect(fc.ContributesToConnection()).To(BeFalse()) - fc = newFlowController(5, true, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats) + fc = newFlowController(5, true, protocol.MaxByteCount, protocol.MaxByteCount, protocol.MaxByteCount, rttStats) Expect(fc.ContributesToConnection()).To(BeTrue()) }) }) Context("send flow control", func() { - var mockPn *mocks.MockParamsNegotiator - - BeforeEach(func() { - mockPn = mocks.NewMockParamsNegotiator(mockCtrl) - controller.connParams = mockPn - }) - It("adds bytes sent", func() { controller.bytesSent = 5 controller.AddBytesSent(6) @@ -89,14 +55,14 @@ var _ = Describe("Flow controller", func() { It("gets the offset of the flow control window", func() { controller.bytesSent = 5 controller.sendWindow = 12 - Expect(controller.SendWindowOffset()).To(Equal(protocol.ByteCount(12))) + Expect(controller.sendWindow).To(Equal(protocol.ByteCount(12))) }) It("updates the size of the flow control window", func() { controller.bytesSent = 5 updateSuccessful := controller.UpdateSendWindow(15) Expect(updateSuccessful).To(BeTrue()) - Expect(controller.SendWindowOffset()).To(Equal(protocol.ByteCount(15))) + Expect(controller.sendWindow).To(Equal(protocol.ByteCount(15))) Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(15 - 5))) }) @@ -108,36 +74,6 @@ var _ = Describe("Flow controller", func() { Expect(updateSuccessful).To(BeFalse()) Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(20))) }) - - It("asks the ConnectionParametersManager for the stream flow control window size", func() { - controller.streamID = 5 - mockPn.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(1000)) - Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(1000))) - // make sure the value is not cached - mockPn.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(2000)) - Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(2000))) - }) - - It("stops asking the ConnectionParametersManager for the flow control stream window size once a window update has arrived", func() { - controller.streamID = 5 - Expect(controller.UpdateSendWindow(8000)) - Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(8000))) - }) - - It("asks the ConnectionParametersManager for the connection flow control window size", func() { - controller.streamID = 0 - mockPn.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(3000)) - Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(3000))) - // make sure the value is not cached - mockPn.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(5000)) - Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(5000))) - }) - - It("stops asking the ConnectionParametersManager for the connection flow control window size once a window update has arrived", func() { - controller.streamID = 0 - Expect(controller.UpdateSendWindow(7000)) - Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(7000))) - }) }) Context("receive flow control", func() { diff --git a/internal/flowcontrol/interface.go b/internal/flowcontrol/interface.go index 1b29bd78..d62ba79a 100644 --- a/internal/flowcontrol/interface.go +++ b/internal/flowcontrol/interface.go @@ -1,6 +1,7 @@ package flowcontrol import "github.com/lucas-clemente/quic-go/internal/protocol" +import "github.com/lucas-clemente/quic-go/internal/handshake" // WindowUpdate provides the data for WindowUpdateFrames. type WindowUpdate struct { @@ -12,6 +13,7 @@ type WindowUpdate struct { type FlowControlManager interface { NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) RemoveStream(streamID protocol.StreamID) + UpdateTransportParameters(*handshake.TransportParameters) // methods needed for receiving data ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index f6eb6e5e..76c1fb35 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -49,10 +49,11 @@ type cryptoSetupClient struct { nullAEAD crypto.AEAD secureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD - aeadChanged chan<- protocol.EncryptionLevel - requestConnIDOmission bool - params *paramsNegotiatorGQUIC + paramsChan chan<- TransportParameters + aeadChanged chan<- protocol.EncryptionLevel + + params *TransportParameters } var _ CryptoSetup = &cryptoSetupClient{} @@ -70,24 +71,24 @@ func NewCryptoSetupClient( version protocol.VersionNumber, tlsConfig *tls.Config, params *TransportParameters, + paramsChan chan<- TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, negotiatedVersions []protocol.VersionNumber, -) (CryptoSetup, ParamsNegotiator, error) { - pn := newParamsNegotiatorGQUIC(protocol.PerspectiveClient, version, params) +) (CryptoSetup, error) { return &cryptoSetupClient{ - hostname: hostname, - connID: connID, - version: version, - certManager: crypto.NewCertManager(tlsConfig), - params: pn, - requestConnIDOmission: params.RequestConnectionIDOmission, - keyDerivation: crypto.DeriveQuicCryptoAESKeys, - keyExchange: getEphermalKEX, - nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), - aeadChanged: aeadChanged, - negotiatedVersions: negotiatedVersions, - divNonceChan: make(chan []byte), - }, pn, nil + hostname: hostname, + connID: connID, + version: version, + certManager: crypto.NewCertManager(tlsConfig), + params: params, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + keyExchange: getEphermalKEX, + nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), + paramsChan: paramsChan, + aeadChanged: aeadChanged, + negotiatedVersions: negotiatedVersions, + divNonceChan: make(chan []byte), + }, nil } func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error { @@ -141,15 +142,21 @@ func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error { utils.Debugf("Got %s", message) switch message.Tag { case TagREJ: - err = h.handleREJMessage(message.Data) + if err := h.handleREJMessage(message.Data); err != nil { + return err + } case TagSHLO: - err = h.handleSHLOMessage(message.Data) + params, err := h.handleSHLOMessage(message.Data) + if err != nil { + return err + } + // blocks until the session has received the parameters + h.paramsChan <- *params + h.aeadChanged <- protocol.EncryptionForwardSecure + close(h.aeadChanged) default: return qerr.InvalidCryptoMessageType } - if err != nil { - return err - } } } @@ -215,12 +222,12 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { return nil } -func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { +func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) { h.mutex.Lock() defer h.mutex.Unlock() if !h.receivedSecurePacket { - return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message") + return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message") } if sno, ok := cryptoData[TagSNO]; ok { @@ -229,22 +236,22 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { serverPubs, ok := cryptoData[TagPUBS] if !ok { - return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") + return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") } verTag, ok := cryptoData[TagVER] if !ok { - return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") + return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") } if !h.validateVersionList(verTag) { - return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") + return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") } nonce := append(h.nonc, h.sno...) ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs) if err != nil { - return err + return nil, err } leafCert := h.certManager.GetLeafCert() @@ -261,18 +268,14 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { protocol.PerspectiveClient, ) if err != nil { - return err + return nil, err } - err = h.params.SetFromMap(cryptoData) + params, err := readHelloMap(cryptoData) if err != nil { - return qerr.InvalidCryptoMessageParameter + return nil, qerr.InvalidCryptoMessageParameter } - - h.aeadChanged <- protocol.EncryptionForwardSecure - close(h.aeadChanged) - - return nil + return params, nil } func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool { @@ -405,10 +408,7 @@ func (h *cryptoSetupClient) sendCHLO() error { } func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { - tags, err := h.params.GetHelloMap() - if err != nil { - return nil, err - } + tags := h.params.getHelloMap() tags[TagSNI] = []byte(h.hostname) tags[TagPDMD] = []byte("X509") @@ -421,9 +421,6 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version)) tags[TagVER] = versionTag - if h.requestConnIDOmission { - tags[TagTCID] = []byte{0, 0, 0, 0} - } if len(h.stk) > 0 { tags[TagSTK] = h.stk } diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index b90d7814..00a9426a 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -79,6 +79,7 @@ var _ = Describe("Client Crypto Setup", func() { keyDerivationCalledWith *keyDerivationValues shloMap map[Tag][]byte aeadChanged chan protocol.EncryptionLevel + paramsChan chan TransportParameters ) BeforeEach(func() { @@ -108,13 +109,16 @@ var _ = Describe("Client Crypto Setup", func() { stream = newMockStream() certManager = &mockCertManager{} version := protocol.Version37 + // use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking + paramsChan = make(chan TransportParameters, 1) aeadChanged = make(chan protocol.EncryptionLevel, 2) - csInt, _, err := NewCryptoSetupClient( + csInt, err := NewCryptoSetupClient( "hostname", 0, version, nil, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, + paramsChan, aeadChanged, nil, ) @@ -222,7 +226,7 @@ var _ = Describe("Client Crypto Setup", func() { It("returns the right error when detecting a downgrade attack", func() { cs.negotiatedVersions = []protocol.VersionNumber{protocol.VersionWhatever} cs.receivedSecurePacket = true - err := cs.handleSHLOMessage(map[Tag][]byte{ + _, err := cs.handleSHLOMessage(map[Tag][]byte{ TagPUBS: []byte{0}, TagVER: []byte{0, 1}, }) @@ -385,7 +389,7 @@ var _ = Describe("Client Crypto Setup", func() { It("rejects unencrypted SHLOs", func() { cs.receivedSecurePacket = false - err := cs.handleSHLOMessage(shloMap) + _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message"))) Expect(aeadChanged).ToNot(Receive()) Expect(aeadChanged).ToNot(BeClosed()) @@ -393,14 +397,14 @@ var _ = Describe("Client Crypto Setup", func() { It("rejects SHLOs without a PUBS", func() { delete(shloMap, TagPUBS) - err := cs.handleSHLOMessage(shloMap) + _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS"))) Expect(aeadChanged).ToNot(BeClosed()) }) It("rejects SHLOs without a version list", func() { delete(shloMap, TagVER) - err := cs.handleSHLOMessage(shloMap) + _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list"))) Expect(aeadChanged).ToNot(BeClosed()) }) @@ -412,36 +416,58 @@ var _ = Describe("Client Crypto Setup", func() { b := &bytes.Buffer{} utils.LittleEndian.WriteUint32(b, protocol.VersionNumberToTag(ver)) shloMap[TagVER] = b.Bytes() - err := cs.handleSHLOMessage(shloMap) + _, err := cs.handleSHLOMessage(shloMap) Expect(err).ToNot(HaveOccurred()) }) It("reads the server nonce, if set", func() { shloMap[TagSNO] = []byte("server nonce") - err := cs.handleSHLOMessage(shloMap) + _, err := cs.handleSHLOMessage(shloMap) Expect(err).ToNot(HaveOccurred()) Expect(cs.sno).To(Equal(shloMap[TagSNO])) }) It("creates a forwardSecureAEAD", func() { shloMap[TagSNO] = []byte("server nonce") - err := cs.handleSHLOMessage(shloMap) + _, err := cs.handleSHLOMessage(shloMap) Expect(err).ToNot(HaveOccurred()) Expect(cs.forwardSecureAEAD).ToNot(BeNil()) - Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) - Expect(aeadChanged).To(BeClosed()) }) It("reads the connection paramaters", func() { shloMap[TagICSL] = []byte{13, 0, 0, 0} // 13 seconds - err := cs.handleSHLOMessage(shloMap) + params, err := cs.handleSHLOMessage(shloMap) Expect(err).ToNot(HaveOccurred()) - Expect(cs.params.GetRemoteIdleTimeout()).To(Equal(13 * time.Second)) + Expect(params.IdleTimeout).To(Equal(13 * time.Second)) + }) + + It("closes the aeadChanged when receiving an SHLO", func() { + HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream(stream) + Expect(err).ToNot(HaveOccurred()) + }() + Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure))) + Eventually(aeadChanged).Should(BeClosed()) + }) + + It("passes the transport parameters on the channel", func() { + shloMap[TagSFCW] = []byte{0x0d, 0x00, 0xdf, 0xba} + HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) + go func() { + defer GinkgoRecover() + err := cs.HandleCryptoStream(stream) + Expect(err).ToNot(HaveOccurred()) + }() + var params TransportParameters + Eventually(paramsChan).Should(Receive(¶ms)) + Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xbadf000d))) }) It("errors if it can't read a connection parameter", func() { shloMap[TagICSL] = []byte{3, 0, 0} // 1 byte too short - err := cs.handleSHLOMessage(shloMap) + _, err := cs.handleSHLOMessage(shloMap) Expect(err).To(MatchError(qerr.InvalidCryptoMessageParameter)) }) }) @@ -488,15 +514,14 @@ var _ = Describe("Client Crypto Setup", func() { }) It("requests to omit the connection ID", func() { - cs.requestConnIDOmission = true + cs.params.OmitConnectionID = true tags, err := cs.getTags() Expect(err).ToNot(HaveOccurred()) Expect(tags).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0})) }) It("adds the tags returned from the connectionParametersManager to the CHLO", func() { - pnTags, err := cs.params.GetHelloMap() - Expect(err).ToNot(HaveOccurred()) + pnTags := cs.params.getHelloMap() Expect(pnTags).ToNot(BeEmpty()) tags, err := cs.getTags() Expect(err).ToNot(HaveOccurred()) @@ -588,7 +613,7 @@ var _ = Describe("Client Crypto Setup", func() { doSHLO := func() { cs.receivedSecurePacket = true - err := cs.handleSHLOMessage(shloMap) + _, err := cs.handleSHLOMessage(shloMap) Expect(err).ToNot(HaveOccurred()) } diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index f2dfa0b9..ab1a30ba 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -40,14 +40,17 @@ type cryptoSetupServer struct { receivedForwardSecurePacket bool receivedSecurePacket bool sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written - aeadChanged chan<- protocol.EncryptionLevel + + receivedParams bool + paramsChan chan<- TransportParameters + aeadChanged chan<- protocol.EncryptionLevel keyDerivation QuicCryptoKeyDerivationFunction keyExchange KeyExchangeFunction cryptoStream io.ReadWriter - params *paramsNegotiatorGQUIC + params *TransportParameters mutex sync.RWMutex } @@ -72,14 +75,14 @@ func NewCryptoSetup( params *TransportParameters, supportedVersions []protocol.VersionNumber, acceptSTK func(net.Addr, *Cookie) bool, + paramsChan chan<- TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, -) (CryptoSetup, ParamsNegotiator, error) { +) (CryptoSetup, error) { stkGenerator, err := NewCookieGenerator() if err != nil { - return nil, nil, err + return nil, err } - pn := newParamsNegotiatorGQUIC(protocol.PerspectiveServer, version, params) return &cryptoSetupServer{ connID: connID, remoteAddr: remoteAddr, @@ -90,11 +93,12 @@ func NewCryptoSetup( keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyExchange: getEphermalKEX, nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), - params: pn, + params: params, acceptSTKCallback: acceptSTK, sentSHLO: make(chan struct{}), + paramsChan: paramsChan, aeadChanged: aeadChanged, - }, pn, nil + }, nil } // HandleCryptoStream reads and writes messages on the crypto stream @@ -163,6 +167,16 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] return false, err } + params, err := readHelloMap(cryptoData) + if err != nil { + return false, err + } + // blocks until the session has received the parameters + if !h.receivedParams { + h.receivedParams = true + h.paramsChan <- *params + } + if !h.isInchoateCHLO(cryptoData, certUncompressed) { // We have a CHLO with a proper server config ID, do a 0-RTT handshake reply, err = h.handleCHLO(sni, chloData, cryptoData) @@ -418,14 +432,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T return nil, err } - if err := h.params.SetFromMap(cryptoData); err != nil { - return nil, err - } - - replyMap, err := h.params.GetHelloMap() - if err != nil { - return nil, err - } + replyMap := h.params.getHelloMap() // add crypto parameters verTag := &bytes.Buffer{} for _, v := range h.supportedVersions { diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index 018d7b84..0819e1b2 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "net" + "time" "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -167,6 +168,7 @@ var _ = Describe("Server Crypto Setup", func() { scfg *ServerConfig cs *cryptoSetupServer stream *mockStream + paramsChan chan TransportParameters aeadChanged chan protocol.EncryptionLevel nonce32 []byte versionTag []byte @@ -183,6 +185,8 @@ var _ = Describe("Server Crypto Setup", func() { remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} expectedInitialNonceLen = 32 expectedFSNonceLen = 64 + // use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking + paramsChan = make(chan TransportParameters, 1) aeadChanged = make(chan protocol.EncryptionLevel, 2) stream = newMockStream() kex = &mockKEX{} @@ -197,7 +201,7 @@ var _ = Describe("Server Crypto Setup", func() { Expect(err).NotTo(HaveOccurred()) version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1] supportedVersions = []protocol.VersionNumber{version, 98, 99} - csInt, _, err := NewCryptoSetup( + csInt, err := NewCryptoSetup( protocol.ConnectionID(42), remoteAddr, version, @@ -205,6 +209,7 @@ var _ = Describe("Server Crypto Setup", func() { &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}, supportedVersions, nil, + paramsChan, aeadChanged, ) Expect(err).NotTo(HaveOccurred()) @@ -285,6 +290,16 @@ var _ = Describe("Server Crypto Setup", func() { Expect(err).To(MatchError(ErrNSTPExperiment)) }) + It("reads the transport parameters sent by the client", func() { + sourceAddrValid = true + fullCHLO[TagICSL] = []byte{0x37, 0x13, 0, 0} + _, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), fullCHLO) + Expect(err).ToNot(HaveOccurred()) + var params TransportParameters + Expect(paramsChan).To(Receive(¶ms)) + Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second)) + }) + It("generates REJ messages", func() { sourceAddrValid = false response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index ceda3af9..2293ae04 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -38,52 +38,52 @@ var newMintController = func(conn *mint.Conn) crypto.MintController { // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server func NewCryptoSetupTLSServer( tlsConfig *tls.Config, - transportParams *TransportParameters, + params *TransportParameters, + paramsChan chan<- TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, -) (CryptoSetup, ParamsNegotiator, error) { +) (CryptoSetup, error) { mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer) if err != nil { - return nil, nil, err + return nil, err } - params := newParamsNegotiator(protocol.PerspectiveServer, version, transportParams) return &cryptoSetupTLS{ perspective: protocol.PerspectiveServer, mintConf: mintConf, nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged, - extensionHandler: newExtensionHandlerServer(params, supportedVersions, version), - }, params, nil + extensionHandler: newExtensionHandlerServer(params, paramsChan, supportedVersions, version), + }, nil } // NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client func NewCryptoSetupTLSClient( hostname string, // only needed for the client tlsConfig *tls.Config, - transportParams *TransportParameters, + params *TransportParameters, + paramsChan chan<- TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, -) (CryptoSetup, ParamsNegotiator, error) { +) (CryptoSetup, error) { mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient) if err != nil { - return nil, nil, err + return nil, err } mintConf.ServerName = hostname - params := newParamsNegotiator(protocol.PerspectiveClient, version, transportParams) return &cryptoSetupTLS{ perspective: protocol.PerspectiveClient, mintConf: mintConf, nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged, - extensionHandler: newExtensionHandlerClient(params, initialVersion, supportedVersions, version), - }, params, nil + extensionHandler: newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version), + }, nil } func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error { diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 583fdc6e..547e4991 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -33,16 +33,19 @@ func mockKeyDerivation(crypto.MintController, protocol.Perspective) (crypto.AEAD var _ = Describe("TLS Crypto Setup", func() { var ( cs *cryptoSetupTLS + paramsChan chan TransportParameters aeadChanged chan protocol.EncryptionLevel mintControllerConstructor = newMintController ) BeforeEach(func() { + paramsChan = make(chan TransportParameters) aeadChanged = make(chan protocol.EncryptionLevel, 2) - csInt, _, err := NewCryptoSetupTLSServer( + csInt, err := NewCryptoSetupTLSServer( testdata.GetTLSConfig(), &TransportParameters{}, + paramsChan, aeadChanged, nil, protocol.VersionTLS, diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 39780400..a3ba7edd 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -2,7 +2,6 @@ package handshake import ( "io" - "time" "github.com/lucas-clemente/quic-go/internal/protocol" ) @@ -25,9 +24,3 @@ type CryptoSetup interface { GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) } - -// TransportParameters are parameters sent to the peer during the handshake -type TransportParameters struct { - RequestConnectionIDOmission bool - IdleTimeout time.Duration -} diff --git a/internal/handshake/params_negotiator.go b/internal/handshake/params_negotiator.go deleted file mode 100644 index be8e26a3..00000000 --- a/internal/handshake/params_negotiator.go +++ /dev/null @@ -1,111 +0,0 @@ -package handshake - -import ( - "encoding/binary" - "errors" - "fmt" - "math" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" -) - -type paramsNegotiator struct { - paramsNegotiatorBase -} - -var _ ParamsNegotiator = ¶msNegotiator{} - -// newParamsNegotiator creates a new connection parameters manager -func newParamsNegotiator(pers protocol.Perspective, v protocol.VersionNumber, params *TransportParameters) *paramsNegotiator { - h := ¶msNegotiator{} - h.perspective = pers - h.version = v - h.init(params) - return h -} - -func (h *paramsNegotiator) SetFromTransportParameters(params []transportParameter) error { - h.mutex.Lock() - defer h.mutex.Unlock() - - var foundInitialMaxStreamData bool - var foundInitialMaxData bool - var foundInitialMaxStreamID bool - var foundIdleTimeout bool - - for _, p := range params { - switch p.Parameter { - case initialMaxStreamDataParameterID: - foundInitialMaxStreamData = true - if len(p.Value) != 4 { - return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value)) - } - h.sendStreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) - utils.Debugf("h.sendStreamFlowControlWindow: %#x", h.sendStreamFlowControlWindow) - case initialMaxDataParameterID: - foundInitialMaxData = true - if len(p.Value) != 4 { - return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value)) - } - h.sendConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) - utils.Debugf("h.sendConnectionFlowControlWindow: %#x", h.sendConnectionFlowControlWindow) - case initialMaxStreamIDParameterID: - foundInitialMaxStreamID = true - if len(p.Value) != 4 { - return fmt.Errorf("wrong length for initial_max_stream_id: %d (expected 4)", len(p.Value)) - } - // TODO: handle this value - case idleTimeoutParameterID: - foundIdleTimeout = true - if len(p.Value) != 2 { - return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value)) - } - h.setRemoteIdleTimeout(time.Duration(binary.BigEndian.Uint16(p.Value)) * time.Second) - case omitConnectionIDParameterID: - if len(p.Value) != 0 { - return fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value)) - } - h.omitConnectionID = true - } - } - - if !(foundInitialMaxStreamData && foundInitialMaxData && foundInitialMaxStreamID && foundIdleTimeout) { - return errors.New("missing parameter") - } - return nil -} - -func (h *paramsNegotiator) GetTransportParameters() []transportParameter { - initialMaxStreamData := make([]byte, 4) - binary.BigEndian.PutUint32(initialMaxStreamData, uint32(protocol.ReceiveStreamFlowControlWindow)) - initialMaxData := make([]byte, 4) - binary.BigEndian.PutUint32(initialMaxData, uint32(protocol.ReceiveConnectionFlowControlWindow)) - initialMaxStreamID := make([]byte, 4) - // TODO: use a reasonable value here - binary.BigEndian.PutUint32(initialMaxStreamID, math.MaxUint32) - idleTimeout := make([]byte, 2) - binary.BigEndian.PutUint16(idleTimeout, uint16(h.idleTimeout)) - maxPacketSize := make([]byte, 2) - binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize)) - params := []transportParameter{ - {initialMaxStreamDataParameterID, initialMaxStreamData}, - {initialMaxDataParameterID, initialMaxData}, - {initialMaxStreamIDParameterID, initialMaxStreamID}, - {idleTimeoutParameterID, idleTimeout}, - {maxPacketSizeParameterID, maxPacketSize}, - } - h.mutex.RLock() - defer h.mutex.RUnlock() - if h.omitConnectionID { - params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}}) - } - return params -} - -func (h *paramsNegotiator) OmitConnectionID() bool { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.omitConnectionID -} diff --git a/internal/handshake/params_negotiator_base.go b/internal/handshake/params_negotiator_base.go deleted file mode 100644 index c2562e99..00000000 --- a/internal/handshake/params_negotiator_base.go +++ /dev/null @@ -1,85 +0,0 @@ -package handshake - -import ( - "sync" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" -) - -// The ParamsNegotiator negotiates and stores the connection parameters. -// It can be used for a server as well as a client. -type ParamsNegotiator interface { - GetSendStreamFlowControlWindow() protocol.ByteCount - GetSendConnectionFlowControlWindow() protocol.ByteCount - GetMaxOutgoingStreams() uint32 - // get the idle timeout that was sent by the peer - GetRemoteIdleTimeout() time.Duration - // determines if the client requests omission of connection IDs. - OmitConnectionID() bool -} - -// For the server: -// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation -// 2. call GetHelloMap to get the values to send in the SHLO -// For the client: -// 1. call GetHelloMap to get the values to send in a CHLO -// 2. call SetFromMap with the values received in the SHLO -type paramsNegotiatorBase struct { - mutex sync.RWMutex - - version protocol.VersionNumber - perspective protocol.Perspective - - flowControlNegotiated bool - - omitConnectionID bool - requestConnectionIDOmission bool - - maxOutgoingStreams uint32 - idleTimeout time.Duration - remoteIdleTimeout time.Duration - sendStreamFlowControlWindow protocol.ByteCount - sendConnectionFlowControlWindow protocol.ByteCount -} - -func (h *paramsNegotiatorBase) init(params *TransportParameters) { - h.sendStreamFlowControlWindow = protocol.InitialStreamFlowControlWindow // can only be changed by the client - h.sendConnectionFlowControlWindow = protocol.InitialConnectionFlowControlWindow // can only be changed by the client - h.requestConnectionIDOmission = params.RequestConnectionIDOmission - - h.idleTimeout = params.IdleTimeout - // use this as a default value. As soon as the client sends its value, this gets updated - h.maxOutgoingStreams = protocol.MaxIncomingStreams -} - -// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *paramsNegotiatorBase) GetSendStreamFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.sendStreamFlowControlWindow -} - -// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data -func (h *paramsNegotiatorBase) GetSendConnectionFlowControlWindow() protocol.ByteCount { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.sendConnectionFlowControlWindow -} - -func (h *paramsNegotiatorBase) GetMaxOutgoingStreams() uint32 { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.maxOutgoingStreams -} - -func (h *paramsNegotiatorBase) setRemoteIdleTimeout(t time.Duration) { - h.remoteIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, t) -} - -func (h *paramsNegotiatorBase) GetRemoteIdleTimeout() time.Duration { - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.remoteIdleTimeout -} diff --git a/internal/handshake/params_negotiator_gquic.go b/internal/handshake/params_negotiator_gquic.go deleted file mode 100644 index ce41be94..00000000 --- a/internal/handshake/params_negotiator_gquic.go +++ /dev/null @@ -1,116 +0,0 @@ -package handshake - -import ( - "bytes" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/qerr" -) - -// errMalformedTag is returned when the tag value cannot be read -var ( - errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value") - errFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported") -) - -type paramsNegotiatorGQUIC struct { - paramsNegotiatorBase -} - -var _ ParamsNegotiator = ¶msNegotiatorGQUIC{} - -// newParamsNegotiatorGQUIC creates a new connection parameters manager -func newParamsNegotiatorGQUIC(pers protocol.Perspective, v protocol.VersionNumber, params *TransportParameters) *paramsNegotiatorGQUIC { - h := ¶msNegotiatorGQUIC{} - h.perspective = pers - h.version = v - h.init(params) - return h -} - -// SetFromMap reads all params. -func (h *paramsNegotiatorGQUIC) SetFromMap(params map[Tag][]byte) error { - h.mutex.Lock() - defer h.mutex.Unlock() - - if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer { - clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return errMalformedTag - } - h.omitConnectionID = (clientValue == 0) - } - if value, ok := params[TagMIDS]; ok { - clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return errMalformedTag - } - h.maxOutgoingStreams = clientValue - } - if value, ok := params[TagICSL]; ok { - clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return errMalformedTag - } - h.setRemoteIdleTimeout(time.Duration(clientValue) * time.Second) - } - if value, ok := params[TagSFCW]; ok { - if h.flowControlNegotiated { - return errFlowControlRenegotiationNotSupported - } - sendStreamFlowControlWindow, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return errMalformedTag - } - h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow) - } - if value, ok := params[TagCFCW]; ok { - if h.flowControlNegotiated { - return errFlowControlRenegotiationNotSupported - } - sendConnectionFlowControlWindow, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) - if err != nil { - return errMalformedTag - } - h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow) - } - - _, containsSFCW := params[TagSFCW] - _, containsCFCW := params[TagCFCW] - if containsCFCW || containsSFCW { - h.flowControlNegotiated = true - } - - return nil -} - -// GetHelloMap gets all parameters needed for the Hello message. -func (h *paramsNegotiatorGQUIC) GetHelloMap() (map[Tag][]byte, error) { - sfcw := bytes.NewBuffer([]byte{}) - utils.LittleEndian.WriteUint32(sfcw, uint32(protocol.ReceiveStreamFlowControlWindow)) - cfcw := bytes.NewBuffer([]byte{}) - utils.LittleEndian.WriteUint32(cfcw, uint32(protocol.ReceiveConnectionFlowControlWindow)) - mids := bytes.NewBuffer([]byte{}) - utils.LittleEndian.WriteUint32(mids, protocol.MaxIncomingStreams) - icsl := bytes.NewBuffer([]byte{}) - utils.LittleEndian.WriteUint32(icsl, uint32(h.idleTimeout/time.Second)) - - return map[Tag][]byte{ - TagICSL: icsl.Bytes(), - TagMIDS: mids.Bytes(), - TagCFCW: cfcw.Bytes(), - TagSFCW: sfcw.Bytes(), - }, nil -} - -func (h *paramsNegotiatorGQUIC) OmitConnectionID() bool { - if h.perspective == protocol.PerspectiveClient { - return false - } - - h.mutex.RLock() - defer h.mutex.RUnlock() - return h.omitConnectionID -} diff --git a/internal/handshake/params_negotiator_gquic_test.go b/internal/handshake/params_negotiator_gquic_test.go deleted file mode 100644 index 6967b703..00000000 --- a/internal/handshake/params_negotiator_gquic_test.go +++ /dev/null @@ -1,231 +0,0 @@ -package handshake - -import ( - "encoding/binary" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Params Negotiator (for gQUIC)", func() { - var pn *paramsNegotiatorGQUIC // a connectionParametersManager for a server - var pnClient *paramsNegotiatorGQUIC - idleTimeout := 42 * time.Second - BeforeEach(func() { - pn = newParamsNegotiatorGQUIC( - protocol.PerspectiveServer, - protocol.VersionWhatever, - &TransportParameters{ - IdleTimeout: idleTimeout, - }, - ) - pnClient = newParamsNegotiatorGQUIC( - protocol.PerspectiveClient, - protocol.VersionWhatever, - &TransportParameters{ - IdleTimeout: idleTimeout, - }, - ) - }) - - Context("SHLO", func() { - BeforeEach(func() { - // these tests should only use the server connectionParametersManager. Make them panic if they don't - pnClient = nil - }) - - It("returns all parameters necessary for the SHLO", func() { - entryMap, err := pn.GetHelloMap() - Expect(err).ToNot(HaveOccurred()) - Expect(entryMap).To(HaveKey(TagICSL)) - Expect(entryMap).To(HaveKey(TagMIDS)) - }) - - It("sets the stream-level flow control windows in SHLO", func() { - entryMap, err := pn.GetHelloMap() - Expect(err).ToNot(HaveOccurred()) - expected := make([]byte, 4) - binary.LittleEndian.PutUint32(expected, uint32(protocol.ReceiveStreamFlowControlWindow)) - Expect(entryMap).To(HaveKeyWithValue(TagSFCW, expected)) - }) - - It("sets the connection-level flow control windows in SHLO", func() { - entryMap, err := pn.GetHelloMap() - Expect(err).ToNot(HaveOccurred()) - expected := make([]byte, 4) - binary.LittleEndian.PutUint32(expected, uint32(protocol.ReceiveConnectionFlowControlWindow)) - Expect(entryMap).To(HaveKeyWithValue(TagCFCW, expected)) - }) - - It("sets the connection-level flow control windows in SHLO", func() { - pn.idleTimeout = 0xdecafbad * time.Second - entryMap, err := pn.GetHelloMap() - Expect(err).ToNot(HaveOccurred()) - Expect(entryMap).To(HaveKey(TagICSL)) - Expect(entryMap[TagICSL]).To(Equal([]byte{0xad, 0xfb, 0xca, 0xde})) - }) - - It("always sends its own value for the maximum incoming dynamic streams in the SHLO", func() { - err := pn.SetFromMap(map[Tag][]byte{TagMIDS: []byte{5, 0, 0, 0}}) - Expect(err).ToNot(HaveOccurred()) - entryMap, err := pn.GetHelloMap() - Expect(err).ToNot(HaveOccurred()) - Expect(entryMap[TagMIDS]).To(Equal([]byte{byte(protocol.MaxIncomingStreams), 0, 0, 0})) - }) - }) - - Context("CHLO", func() { - BeforeEach(func() { - // these tests should only use the client connectionParametersManager. Make them panic if they don't - pn = nil - }) - - It("has the right values", func() { - entryMap, err := pnClient.GetHelloMap() - Expect(err).ToNot(HaveOccurred()) - Expect(entryMap).To(HaveKey(TagICSL)) - Expect(binary.LittleEndian.Uint32(entryMap[TagICSL])).To(BeEquivalentTo(idleTimeout / time.Second)) - Expect(entryMap).To(HaveKey(TagMIDS)) - Expect(binary.LittleEndian.Uint32(entryMap[TagMIDS])).To(BeEquivalentTo(protocol.MaxIncomingStreams)) - Expect(entryMap).To(HaveKey(TagSFCW)) - Expect(binary.LittleEndian.Uint32(entryMap[TagSFCW])).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow)) - Expect(entryMap).To(HaveKey(TagCFCW)) - Expect(binary.LittleEndian.Uint32(entryMap[TagCFCW])).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow)) - }) - }) - - Context("Omitted connection IDs", func() { - It("does not send omitted connection IDs if the TCID tag is missing", func() { - Expect(pn.OmitConnectionID()).To(BeFalse()) - }) - - It("reads the tag for omitted connection IDs", func() { - values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}} - pn.SetFromMap(values) - Expect(pn.OmitConnectionID()).To(BeTrue()) - }) - - It("ignores the TCID tag, as a client", func() { - values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}} - pnClient.SetFromMap(values) - Expect(pnClient.OmitConnectionID()).To(BeFalse()) - }) - - It("errors when given an invalid value", func() { - values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short - err := pn.SetFromMap(values) - Expect(err).To(MatchError(errMalformedTag)) - }) - }) - - Context("flow control", func() { - It("has the correct default flow control windows for sending", func() { - Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow)) - Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow)) - Expect(pnClient.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow)) - Expect(pnClient.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow)) - }) - - It("sets a new stream-level flow control window for sending", func() { - values := map[Tag][]byte{TagSFCW: {0xDE, 0xAD, 0xBE, 0xEF}} - err := pn.SetFromMap(values) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) - }) - - It("does not change the stream-level flow control window when given an invalid value", func() { - values := map[Tag][]byte{TagSFCW: {0xDE, 0xAD, 0xBE}} // 1 byte too short - err := pn.SetFromMap(values) - Expect(err).To(MatchError(errMalformedTag)) - Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow)) - }) - - It("sets a new connection-level flow control window for sending", func() { - values := map[Tag][]byte{TagCFCW: {0xDE, 0xAD, 0xBE, 0xEF}} - err := pn.SetFromMap(values) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) - }) - - It("does not change the connection-level flow control window when given an invalid value", func() { - values := map[Tag][]byte{TagCFCW: {0xDE, 0xAD, 0xBE}} // 1 byte too short - err := pn.SetFromMap(values) - Expect(err).To(MatchError(errMalformedTag)) - Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow)) - }) - - It("does not allow renegotiation of flow control parameters", func() { - values := map[Tag][]byte{ - TagCFCW: {0xDE, 0xAD, 0xBE, 0xEF}, - TagSFCW: {0xDE, 0xAD, 0xBE, 0xEF}, - } - err := pn.SetFromMap(values) - Expect(err).ToNot(HaveOccurred()) - values = map[Tag][]byte{ - TagCFCW: {0x13, 0x37, 0x13, 0x37}, - TagSFCW: {0x13, 0x37, 0x13, 0x37}, - } - err = pn.SetFromMap(values) - Expect(err).To(MatchError(errFlowControlRenegotiationNotSupported)) - Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) - Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE))) - }) - }) - - Context("idle timeout", func() { - It("sets the remote idle timeout", func() { - values := map[Tag][]byte{ - TagICSL: {10, 0, 0, 0}, - } - err := pn.SetFromMap(values) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetRemoteIdleTimeout()).To(Equal(10 * time.Second)) - }) - - It("doesn't allow values below the minimum remote idle timeout", func() { - t := 2 * time.Second - Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout)) - values := map[Tag][]byte{ - TagICSL: {uint8(t.Seconds()), 0, 0, 0}, - } - err := pn.SetFromMap(values) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetRemoteIdleTimeout()).To(Equal(protocol.MinRemoteIdleTimeout)) - }) - - It("errors when given an invalid value", func() { - values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short - err := pn.SetFromMap(values) - Expect(err).To(MatchError(errMalformedTag)) - }) - }) - - Context("max streams per connection", func() { - It("errors when given an invalid max dynamic incoming streams per connection value", func() { - values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short - err := pn.SetFromMap(values) - Expect(err).To(MatchError(errMalformedTag)) - }) - - Context("outgoing connections", func() { - It("sets the negotiated max streams per connection value", func() { - // this test only works if the value given here is smaller than protocol.MaxStreamsPerConnection - err := pn.SetFromMap(map[Tag][]byte{ - TagMIDS: {2, 0, 0, 0}, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(2))) - }) - - It("uses the the MSPC value, if no MIDS is given", func() { - err := pn.SetFromMap(map[Tag][]byte{ - TagMIDS: {3, 0, 0, 0}, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(3))) - }) - }) - }) -}) diff --git a/internal/handshake/params_negotiator_test.go b/internal/handshake/params_negotiator_test.go deleted file mode 100644 index a60b94e4..00000000 --- a/internal/handshake/params_negotiator_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package handshake - -import ( - "encoding/binary" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Params Negotiator (for TLS)", func() { - var params map[transportParameterID][]byte - var pn *paramsNegotiator - - paramsMapToList := func(p map[transportParameterID][]byte) []transportParameter { - var list []transportParameter - for id, val := range p { - list = append(list, transportParameter{id, val}) - } - return list - } - - paramsListToMap := func(l []transportParameter) map[transportParameterID][]byte { - p := make(map[transportParameterID][]byte) - for _, v := range l { - p[v.Parameter] = v.Value - } - return p - } - - BeforeEach(func() { - pn = newParamsNegotiator( - protocol.PerspectiveServer, - protocol.VersionWhatever, - &TransportParameters{}, - ) - params = map[transportParameterID][]byte{ - initialMaxStreamDataParameterID: []byte{0x11, 0x22, 0x33, 0x44}, - initialMaxDataParameterID: []byte{0x22, 0x33, 0x44, 0x55}, - initialMaxStreamIDParameterID: []byte{0x33, 0x44, 0x55, 0x66}, - idleTimeoutParameterID: []byte{0x13, 0x37}, - } - }) - - Context("getting", func() { - It("creates the parameters list", func() { - pn.idleTimeout = 0xcafe - buf := make([]byte, 4) - values := paramsListToMap(pn.GetTransportParameters()) - Expect(values).To(HaveLen(5)) - binary.BigEndian.PutUint32(buf, uint32(protocol.ReceiveStreamFlowControlWindow)) - Expect(values).To(HaveKeyWithValue(initialMaxStreamDataParameterID, buf)) - binary.BigEndian.PutUint32(buf, uint32(protocol.ReceiveConnectionFlowControlWindow)) - Expect(values).To(HaveKeyWithValue(initialMaxDataParameterID, buf)) - Expect(values).To(HaveKeyWithValue(initialMaxStreamIDParameterID, []byte{0xff, 0xff, 0xff, 0xff})) - Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0xca, 0xfe})) - Expect(values).To(HaveKeyWithValue(maxPacketSizeParameterID, []byte{0x5, 0xac})) // 1452 = 0x5ac - }) - - It("request ommision of the connection ID", func() { - pn.omitConnectionID = true - values := paramsListToMap(pn.GetTransportParameters()) - Expect(values).To(HaveKeyWithValue(omitConnectionIDParameterID, []byte{})) - }) - }) - - Context("setting", func() { - It("reads parameters", func() { - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0x11223344))) - Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0x22334455))) - Expect(pn.GetRemoteIdleTimeout()).To(Equal(0x1337 * time.Second)) - Expect(pn.OmitConnectionID()).To(BeFalse()) - }) - - It("saves if it should omit the connection ID", func() { - params[omitConnectionIDParameterID] = []byte{} - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.OmitConnectionID()).To(BeTrue()) - }) - - It("rejects the parameters if the initial_max_stream_data is missing", func() { - delete(params, initialMaxStreamDataParameterID) - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).To(MatchError("missing parameter")) - }) - - It("rejects the parameters if the initial_max_data is missing", func() { - delete(params, initialMaxDataParameterID) - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).To(MatchError("missing parameter")) - }) - - It("rejects the parameters if the initial_max_stream_id is missing", func() { - delete(params, initialMaxStreamIDParameterID) - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).To(MatchError("missing parameter")) - }) - - It("rejects the parameters if the idle_timeout is missing", func() { - delete(params, idleTimeoutParameterID) - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).To(MatchError("missing parameter")) - }) - - It("doesn't allow values below the minimum remote idle timeout", func() { - t := 2 * time.Second - Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout)) - params[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())} - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetRemoteIdleTimeout()).To(Equal(protocol.MinRemoteIdleTimeout)) - }) - - It("rejects the parameters if the initial_max_stream_data has the wrong length", func() { - params[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)")) - }) - - It("rejects the parameters if the initial_max_data has the wrong length", func() { - params[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)")) - }) - - It("rejects the parameters if the initial_max_stream_id has the wrong length", func() { - params[initialMaxStreamIDParameterID] = []byte{0x11, 0x22, 0x33, 0x44, 0x55} // should be 4 bytes - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).To(MatchError("wrong length for initial_max_stream_id: 5 (expected 4)")) - }) - - It("rejects the parameters if the initial_idle_timeout has the wrong length", func() { - params[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)")) - }) - - It("rejects the parameters if omit_connection_id is non-empty", func() { - params[omitConnectionIDParameterID] = []byte{0} // should be empty - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).To(MatchError("wrong length for omit_connection_id: 1 (expected empty)")) - }) - - It("ignores unknown parameters", func() { - params[1337] = []byte{42} - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).ToNot(HaveOccurred()) - }) - }) -}) diff --git a/internal/handshake/tls_extension_handler_client.go b/internal/handshake/tls_extension_handler_client.go index 8cf150ba..59f42310 100644 --- a/internal/handshake/tls_extension_handler_client.go +++ b/internal/handshake/tls_extension_handler_client.go @@ -3,6 +3,7 @@ package handshake import ( "errors" "fmt" + "math" "github.com/lucas-clemente/quic-go/qerr" @@ -12,7 +13,8 @@ import ( ) type extensionHandlerClient struct { - params *paramsNegotiator + params *TransportParameters + paramsChan chan<- TransportParameters initialVersion protocol.VersionNumber supportedVersions []protocol.VersionNumber @@ -21,9 +23,16 @@ type extensionHandlerClient struct { var _ mint.AppExtensionHandler = &extensionHandlerClient{} -func newExtensionHandlerClient(params *paramsNegotiator, initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber) *extensionHandlerClient { +func newExtensionHandlerClient( + params *TransportParameters, + paramsChan chan<- TransportParameters, + initialVersion protocol.VersionNumber, + supportedVersions []protocol.VersionNumber, + version protocol.VersionNumber, +) *extensionHandlerClient { return &extensionHandlerClient{ params: params, + paramsChan: paramsChan, initialVersion: initialVersion, supportedVersions: supportedVersions, version: version, @@ -38,7 +47,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi data, err := syntax.Marshal(clientHelloTransportParameters{ NegotiatedVersion: uint32(h.version), InitialVersion: uint32(h.initialVersion), - Parameters: h.params.GetTransportParameters(), + Parameters: h.params.getTransportParameters(), }) if err != nil { return err @@ -99,5 +108,12 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte // TODO: return the right error here return errors.New("server didn't sent stateless_reset_token") } - return h.params.SetFromTransportParameters(eetp.Parameters) + params, err := readTransportParamters(eetp.Parameters) + if err != nil { + return err + } + // TODO(#878): remove this when implementing the MAX_STREAM_ID frame + params.MaxStreams = math.MaxUint32 + h.paramsChan <- *params + return nil } diff --git a/internal/handshake/tls_extension_handler_client_test.go b/internal/handshake/tls_extension_handler_client_test.go index d8561e5c..787ca6df 100644 --- a/internal/handshake/tls_extension_handler_client_test.go +++ b/internal/handshake/tls_extension_handler_client_test.go @@ -12,12 +12,16 @@ import ( ) var _ = Describe("TLS Extension Handler, for the client", func() { - var handler *extensionHandlerClient - var el mint.ExtensionList + var ( + handler *extensionHandlerClient + el mint.ExtensionList + paramsChan chan TransportParameters + ) BeforeEach(func() { - pn := ¶msNegotiator{} - handler = newExtensionHandlerClient(pn, protocol.VersionWhatever, nil, protocol.VersionWhatever) + // use a buffered channel here, so that we don't have to receive concurrently when parsing a message + paramsChan = make(chan TransportParameters, 1) + handler = newExtensionHandlerClient(&TransportParameters{}, paramsChan, protocol.VersionWhatever, nil, protocol.VersionWhatever) el = make(mint.ExtensionList, 0) }) @@ -78,7 +82,9 @@ var _ = Describe("TLS Extension Handler, for the client", func() { addEncryptedExtensionsWithParameters(parameters) err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) Expect(err).ToNot(HaveOccurred()) - Expect(handler.params.GetSendStreamFlowControlWindow()).To(BeEquivalentTo(0x11223344)) + var params TransportParameters + Expect(paramsChan).To(Receive(¶ms)) + Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344)) }) It("errors if the EncryptedExtensions message doesn't contain TransportParameters", func() { diff --git a/internal/handshake/tl_extension_handler_server.go b/internal/handshake/tls_extension_handler_server.go similarity index 81% rename from internal/handshake/tl_extension_handler_server.go rename to internal/handshake/tls_extension_handler_server.go index aeea7eb5..49830d8d 100644 --- a/internal/handshake/tl_extension_handler_server.go +++ b/internal/handshake/tls_extension_handler_server.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "math" "github.com/lucas-clemente/quic-go/qerr" @@ -13,7 +14,8 @@ import ( ) type extensionHandlerServer struct { - params *paramsNegotiator + params *TransportParameters + paramsChan chan<- TransportParameters version protocol.VersionNumber supportedVersions []protocol.VersionNumber @@ -21,9 +23,15 @@ type extensionHandlerServer struct { var _ mint.AppExtensionHandler = &extensionHandlerServer{} -func newExtensionHandlerServer(params *paramsNegotiator, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber) *extensionHandlerServer { +func newExtensionHandlerServer( + params *TransportParameters, + paramsChan chan<- TransportParameters, + supportedVersions []protocol.VersionNumber, + version protocol.VersionNumber, +) *extensionHandlerServer { return &extensionHandlerServer{ params: params, + paramsChan: paramsChan, version: version, supportedVersions: supportedVersions, } @@ -35,7 +43,8 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi } transportParams := append( - h.params.GetTransportParameters(), + h.params.getTransportParameters(), + // TODO(#855): generate a real token transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)}, ) supportedVersions := make([]uint32, len(h.supportedVersions)) @@ -89,5 +98,12 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte return errors.New("client sent a stateless reset token") } } - return h.params.SetFromTransportParameters(chtp.Parameters) + params, err := readTransportParamters(chtp.Parameters) + if err != nil { + return err + } + // TODO(#878): remove this when implementing the MAX_STREAM_ID frame + params.MaxStreams = math.MaxUint32 + h.paramsChan <- *params + return nil } diff --git a/internal/handshake/tls_extension_handler_server_test.go b/internal/handshake/tls_extension_handler_server_test.go index 5f902289..27689536 100644 --- a/internal/handshake/tls_extension_handler_server_test.go +++ b/internal/handshake/tls_extension_handler_server_test.go @@ -19,12 +19,16 @@ func parameterMapToList(paramMap map[transportParameterID][]byte) []transportPar } var _ = Describe("TLS Extension Handler, for the server", func() { - var handler *extensionHandlerServer - var el mint.ExtensionList + var ( + handler *extensionHandlerServer + el mint.ExtensionList + paramsChan chan TransportParameters + ) BeforeEach(func() { - pn := ¶msNegotiator{} - handler = newExtensionHandlerServer(pn, nil, protocol.VersionWhatever) + // use a buffered channel here, so that we don't have to receive concurrently when parsing a message + paramsChan = make(chan TransportParameters, 1) + handler = newExtensionHandlerServer(&TransportParameters{}, paramsChan, nil, protocol.VersionWhatever) el = make(mint.ExtensionList, 0) }) @@ -79,7 +83,9 @@ var _ = Describe("TLS Extension Handler, for the server", func() { addClientHelloWithParameters(parameters) err := handler.Receive(mint.HandshakeTypeClientHello, &el) Expect(err).ToNot(HaveOccurred()) - Expect(handler.params.GetSendStreamFlowControlWindow()).To(BeEquivalentTo(0x11223344)) + var params TransportParameters + Expect(paramsChan).To(Receive(¶ms)) + Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344)) }) It("errors if the ClientHello doesn't contain TransportParameters", func() { diff --git a/internal/handshake/tls_extension_test.go b/internal/handshake/tls_extension_test.go index 8109f393..4388731d 100644 --- a/internal/handshake/tls_extension_test.go +++ b/internal/handshake/tls_extension_test.go @@ -6,25 +6,6 @@ import ( ) var _ = Describe("TLS extension body", func() { - // var server, client mint.AppExtensionHandler - // var el mint.ExtensionList - - // BeforeEach(func() { - // server = &extensionHandler{perspective: protocol.PerspectiveServer} - // client = &extensionHandler{perspective: protocol.PerspectiveClient} - // // el = make(mint.ExtensionList, 0) - // // TODO: initialize el with some dummy extensions - // }) - - // It("writes and reads a ClientHello", func() { - // err := client.Send(mint.HandshakeTypeClientHello, &el) - // Expect(err).ToNot(HaveOccurred()) - // ch := &tlsExtensionBody{} - // found := el.Find(ch) - // Expect(found).To(BeTrue()) - // err = server.Receive(mint.HandshakeTypeClientHello, &el) - // Expect(err).ToNot(HaveOccurred()) - // }) var extBody *tlsExtensionBody BeforeEach(func() { diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go new file mode 100644 index 00000000..1bcfafee --- /dev/null +++ b/internal/handshake/transport_parameter_test.go @@ -0,0 +1,246 @@ +package handshake + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Transport Parameters", func() { + Context("for gQUIC", func() { + Context("parsing", func() { + It("sets all values", func() { + values := map[Tag][]byte{ + TagSFCW: {0xad, 0xfb, 0xca, 0xde}, + TagCFCW: {0xef, 0xbe, 0xad, 0xde}, + TagICSL: {0x0d, 0xf0, 0xad, 0xba}, + TagMIDS: {0xff, 0x10, 0x00, 0xc0}, + } + params, err := readHelloMap(values) + Expect(err).ToNot(HaveOccurred()) + Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xdecafbad))) + Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0xdeadbeef))) + Expect(params.IdleTimeout).To(Equal(time.Duration(0xbaadf00d) * time.Second)) + Expect(params.MaxStreams).To(Equal(uint32(0xc00010ff))) + Expect(params.OmitConnectionID).To(BeFalse()) + }) + + It("reads if the connection ID should be omitted", func() { + values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}} + params, err := readHelloMap(values) + Expect(err).ToNot(HaveOccurred()) + Expect(params.OmitConnectionID).To(BeTrue()) + }) + + It("doesn't allow idle timeouts below the minimum remote idle timeout", func() { + t := 2 * time.Second + Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout)) + values := map[Tag][]byte{ + TagICSL: {uint8(t.Seconds()), 0, 0, 0}, + } + params, err := readHelloMap(values) + Expect(err).ToNot(HaveOccurred()) + Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout)) + }) + + It("errors when given an invalid SFCW value", func() { + values := map[Tag][]byte{TagSFCW: {2, 0, 0}} // 1 byte too short + _, err := readHelloMap(values) + Expect(err).To(MatchError(errMalformedTag)) + }) + + It("errors when given an invalid CFCW value", func() { + values := map[Tag][]byte{TagCFCW: {2, 0, 0}} // 1 byte too short + _, err := readHelloMap(values) + Expect(err).To(MatchError(errMalformedTag)) + }) + + It("errors when given an invalid TCID value", func() { + values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short + _, err := readHelloMap(values) + Expect(err).To(MatchError(errMalformedTag)) + }) + + It("errors when given an invalid ICSL value", func() { + values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short + _, err := readHelloMap(values) + Expect(err).To(MatchError(errMalformedTag)) + }) + + It("errors when given an invalid MIDS value", func() { + values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short + _, err := readHelloMap(values) + Expect(err).To(MatchError(errMalformedTag)) + }) + }) + + Context("writing", func() { + It("returns all necessary parameters ", func() { + params := &TransportParameters{ + StreamFlowControlWindow: 0xdeadbeef, + ConnectionFlowControlWindow: 0xdecafbad, + IdleTimeout: 0xbaaaaaad * time.Second, + MaxStreams: 0x1337, + } + entryMap := params.getHelloMap() + Expect(entryMap).To(HaveLen(4)) + Expect(entryMap).ToNot(HaveKey(TagTCID)) + Expect(entryMap).To(HaveKeyWithValue(TagSFCW, []byte{0xef, 0xbe, 0xad, 0xde})) + Expect(entryMap).To(HaveKeyWithValue(TagCFCW, []byte{0xad, 0xfb, 0xca, 0xde})) + Expect(entryMap).To(HaveKeyWithValue(TagICSL, []byte{0xad, 0xaa, 0xaa, 0xba})) + Expect(entryMap).To(HaveKeyWithValue(TagMIDS, []byte{0x37, 0x13, 0, 0})) + }) + + It("requests omission of the connection ID", func() { + params := &TransportParameters{OmitConnectionID: true} + entryMap := params.getHelloMap() + Expect(entryMap).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0})) + }) + }) + }) + + Context("for TLS", func() { + paramsMapToList := func(p map[transportParameterID][]byte) []transportParameter { + var list []transportParameter + for id, val := range p { + list = append(list, transportParameter{id, val}) + } + return list + } + + Context("parsing", func() { + var parameters map[transportParameterID][]byte + + BeforeEach(func() { + parameters = map[transportParameterID][]byte{ + initialMaxStreamDataParameterID: []byte{0x11, 0x22, 0x33, 0x44}, + initialMaxDataParameterID: []byte{0x22, 0x33, 0x44, 0x55}, + initialMaxStreamIDParameterID: []byte{0x33, 0x44, 0x55, 0x66}, + idleTimeoutParameterID: []byte{0x13, 0x37}, + } + }) + It("reads parameters", func() { + params, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).ToNot(HaveOccurred()) + Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344))) + Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455))) + Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second)) + Expect(params.OmitConnectionID).To(BeFalse()) + }) + + It("saves if it should omit the connection ID", func() { + parameters[omitConnectionIDParameterID] = []byte{} + params, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).ToNot(HaveOccurred()) + Expect(params.OmitConnectionID).To(BeTrue()) + }) + + It("rejects the parameters if the initial_max_stream_data is missing", func() { + delete(parameters, initialMaxStreamDataParameterID) + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).To(MatchError("missing parameter")) + }) + + It("rejects the parameters if the initial_max_data is missing", func() { + delete(parameters, initialMaxDataParameterID) + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).To(MatchError("missing parameter")) + }) + + It("rejects the parameters if the initial_max_stream_id is missing", func() { + delete(parameters, initialMaxStreamIDParameterID) + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).To(MatchError("missing parameter")) + }) + + It("rejects the parameters if the idle_timeout is missing", func() { + delete(parameters, idleTimeoutParameterID) + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).To(MatchError("missing parameter")) + }) + + It("doesn't allow values below the minimum remote idle timeout", func() { + t := 2 * time.Second + Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout)) + parameters[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())} + params, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).ToNot(HaveOccurred()) + Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout)) + }) + + It("rejects the parameters if the initial_max_stream_data has the wrong length", func() { + parameters[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)")) + }) + + It("rejects the parameters if the initial_max_data has the wrong length", func() { + parameters[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)")) + }) + + It("rejects the parameters if the initial_max_stream_id has the wrong length", func() { + parameters[initialMaxStreamIDParameterID] = []byte{0x11, 0x22, 0x33, 0x44, 0x55} // should be 4 bytes + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).To(MatchError("wrong length for initial_max_stream_id: 5 (expected 4)")) + }) + + It("rejects the parameters if the initial_idle_timeout has the wrong length", func() { + parameters[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)")) + }) + + It("rejects the parameters if omit_connection_id is non-empty", func() { + parameters[omitConnectionIDParameterID] = []byte{0} // should be empty + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).To(MatchError("wrong length for omit_connection_id: 1 (expected empty)")) + }) + + It("ignores unknown parameters", func() { + parameters[1337] = []byte{42} + _, err := readTransportParamters(paramsMapToList(parameters)) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Context("writing", func() { + var params *TransportParameters + + paramsListToMap := func(l []transportParameter) map[transportParameterID][]byte { + p := make(map[transportParameterID][]byte) + for _, v := range l { + p[v.Parameter] = v.Value + } + return p + } + + BeforeEach(func() { + params = &TransportParameters{ + StreamFlowControlWindow: 0xdeadbeef, + ConnectionFlowControlWindow: 0xdecafbad, + IdleTimeout: 0xcafe, + } + }) + + It("creates the parameters list", func() { + values := paramsListToMap(params.getTransportParameters()) + Expect(values).To(HaveLen(5)) + Expect(values).To(HaveKeyWithValue(initialMaxStreamDataParameterID, []byte{0xde, 0xad, 0xbe, 0xef})) + Expect(values).To(HaveKeyWithValue(initialMaxDataParameterID, []byte{0xde, 0xca, 0xfb, 0xad})) + Expect(values).To(HaveKeyWithValue(initialMaxStreamIDParameterID, []byte{0xff, 0xff, 0xff, 0xff})) + Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0xca, 0xfe})) + Expect(values).To(HaveKeyWithValue(maxPacketSizeParameterID, []byte{0x5, 0xac})) // 1452 = 0x5ac + }) + + It("request ommision of the connection ID", func() { + params.OmitConnectionID = true + values := paramsListToMap(params.getTransportParameters()) + Expect(values).To(HaveKeyWithValue(omitConnectionIDParameterID, []byte{})) + }) + }) + }) +}) diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go new file mode 100644 index 00000000..8b54d970 --- /dev/null +++ b/internal/handshake/transport_parameters.go @@ -0,0 +1,167 @@ +package handshake + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "math" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" +) + +// errMalformedTag is returned when the tag value cannot be read +var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value") + +// TransportParameters are parameters sent to the peer during the handshake +type TransportParameters struct { + StreamFlowControlWindow protocol.ByteCount + ConnectionFlowControlWindow protocol.ByteCount + + MaxStreams uint32 + + OmitConnectionID bool + IdleTimeout time.Duration +} + +// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message +func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) { + params := &TransportParameters{} + if value, ok := tags[TagTCID]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.OmitConnectionID = (v == 0) + } + if value, ok := tags[TagMIDS]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.MaxStreams = v + } + if value, ok := tags[TagICSL]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second) + } + if value, ok := tags[TagSFCW]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.StreamFlowControlWindow = protocol.ByteCount(v) + } + if value, ok := tags[TagCFCW]; ok { + v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return nil, errMalformedTag + } + params.ConnectionFlowControlWindow = protocol.ByteCount(v) + } + return params, nil +} + +// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake. +func (p *TransportParameters) getHelloMap() map[Tag][]byte { + sfcw := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow)) + cfcw := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow)) + mids := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(mids, p.MaxStreams) + icsl := bytes.NewBuffer([]byte{}) + utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second)) + + tags := map[Tag][]byte{ + TagICSL: icsl.Bytes(), + TagMIDS: mids.Bytes(), + TagCFCW: cfcw.Bytes(), + TagSFCW: sfcw.Bytes(), + } + if p.OmitConnectionID { + tags[TagTCID] = []byte{0, 0, 0, 0} + } + return tags +} + +// readTransportParameters reads the transport parameters sent in the QUIC TLS extension +func readTransportParamters(paramsList []transportParameter) (*TransportParameters, error) { + params := &TransportParameters{} + + var foundInitialMaxStreamData bool + var foundInitialMaxData bool + var foundInitialMaxStreamID bool + var foundIdleTimeout bool + + for _, p := range paramsList { + switch p.Parameter { + case initialMaxStreamDataParameterID: + foundInitialMaxStreamData = true + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value)) + } + params.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) + case initialMaxDataParameterID: + foundInitialMaxData = true + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value)) + } + params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value)) + case initialMaxStreamIDParameterID: + foundInitialMaxStreamID = true + if len(p.Value) != 4 { + return nil, fmt.Errorf("wrong length for initial_max_stream_id: %d (expected 4)", len(p.Value)) + } + // TODO: handle this value + case idleTimeoutParameterID: + foundIdleTimeout = true + if len(p.Value) != 2 { + return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value)) + } + params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second) + case omitConnectionIDParameterID: + if len(p.Value) != 0 { + return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value)) + } + params.OmitConnectionID = true + } + } + + if !(foundInitialMaxStreamData && foundInitialMaxData && foundInitialMaxStreamID && foundIdleTimeout) { + return nil, errors.New("missing parameter") + } + return params, nil +} + +// GetTransportParameters gets the parameters needed for the TLS handshake. +func (p *TransportParameters) getTransportParameters() []transportParameter { + initialMaxStreamData := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow)) + initialMaxData := make([]byte, 4) + binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow)) + initialMaxStreamID := make([]byte, 4) + // TODO: use a reasonable value here + binary.BigEndian.PutUint32(initialMaxStreamID, math.MaxUint32) + idleTimeout := make([]byte, 2) + binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout)) + maxPacketSize := make([]byte, 2) + binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize)) + params := []transportParameter{ + {initialMaxStreamDataParameterID, initialMaxStreamData}, + {initialMaxDataParameterID, initialMaxData}, + {initialMaxStreamIDParameterID, initialMaxStreamID}, + {idleTimeoutParameterID, idleTimeout}, + {maxPacketSizeParameterID, maxPacketSize}, + } + if p.OmitConnectionID { + params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}}) + } + return params +} diff --git a/internal/mocks/flow_control_manager.go b/internal/mocks/flow_control_manager.go new file mode 100644 index 00000000..86a11337 --- /dev/null +++ b/internal/mocks/flow_control_manager.go @@ -0,0 +1,177 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../flowcontrol/interface.go + +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + handshake "github.com/lucas-clemente/quic-go/internal/handshake" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockFlowControlManager is a mock of FlowControlManager interface +type MockFlowControlManager struct { + ctrl *gomock.Controller + recorder *MockFlowControlManagerMockRecorder +} + +// MockFlowControlManagerMockRecorder is the mock recorder for MockFlowControlManager +type MockFlowControlManagerMockRecorder struct { + mock *MockFlowControlManager +} + +// NewMockFlowControlManager creates a new mock instance +func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager { + mock := &MockFlowControlManager{ctrl: ctrl} + mock.recorder = &MockFlowControlManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockFlowControlManager) EXPECT() *MockFlowControlManagerMockRecorder { + return _m.recorder +} + +// NewStream mocks base method +func (_m *MockFlowControlManager) NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) { + _m.ctrl.Call(_m, "NewStream", streamID, contributesToConnectionFlow) +} + +// NewStream indicates an expected call of NewStream +func (_mr *MockFlowControlManagerMockRecorder) NewStream(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "NewStream", reflect.TypeOf((*MockFlowControlManager)(nil).NewStream), arg0, arg1) +} + +// RemoveStream mocks base method +func (_m *MockFlowControlManager) RemoveStream(streamID protocol.StreamID) { + _m.ctrl.Call(_m, "RemoveStream", streamID) +} + +// RemoveStream indicates an expected call of RemoveStream +func (_mr *MockFlowControlManagerMockRecorder) RemoveStream(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RemoveStream", reflect.TypeOf((*MockFlowControlManager)(nil).RemoveStream), arg0) +} + +// UpdateTransportParameters mocks base method +func (_m *MockFlowControlManager) UpdateTransportParameters(_param0 *handshake.TransportParameters) { + _m.ctrl.Call(_m, "UpdateTransportParameters", _param0) +} + +// UpdateTransportParameters indicates an expected call of UpdateTransportParameters +func (_mr *MockFlowControlManagerMockRecorder) UpdateTransportParameters(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateTransportParameters", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateTransportParameters), arg0) +} + +// ResetStream mocks base method +func (_m *MockFlowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { + ret := _m.ctrl.Call(_m, "ResetStream", streamID, byteOffset) + ret0, _ := ret[0].(error) + return ret0 +} + +// ResetStream indicates an expected call of ResetStream +func (_mr *MockFlowControlManagerMockRecorder) ResetStream(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ResetStream", reflect.TypeOf((*MockFlowControlManager)(nil).ResetStream), arg0, arg1) +} + +// UpdateHighestReceived mocks base method +func (_m *MockFlowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { + ret := _m.ctrl.Call(_m, "UpdateHighestReceived", streamID, byteOffset) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateHighestReceived indicates an expected call of UpdateHighestReceived +func (_mr *MockFlowControlManagerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateHighestReceived), arg0, arg1) +} + +// AddBytesRead mocks base method +func (_m *MockFlowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error { + ret := _m.ctrl.Call(_m, "AddBytesRead", streamID, n) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddBytesRead indicates an expected call of AddBytesRead +func (_mr *MockFlowControlManagerMockRecorder) AddBytesRead(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesRead", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesRead), arg0, arg1) +} + +// GetWindowUpdates mocks base method +func (_m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate { + ret := _m.ctrl.Call(_m, "GetWindowUpdates") + ret0, _ := ret[0].([]flowcontrol.WindowUpdate) + return ret0 +} + +// GetWindowUpdates indicates an expected call of GetWindowUpdates +func (_mr *MockFlowControlManagerMockRecorder) GetWindowUpdates() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWindowUpdates", reflect.TypeOf((*MockFlowControlManager)(nil).GetWindowUpdates)) +} + +// GetReceiveWindow mocks base method +func (_m *MockFlowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) { + ret := _m.ctrl.Call(_m, "GetReceiveWindow", streamID) + ret0, _ := ret[0].(protocol.ByteCount) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetReceiveWindow indicates an expected call of GetReceiveWindow +func (_mr *MockFlowControlManagerMockRecorder) GetReceiveWindow(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetReceiveWindow", reflect.TypeOf((*MockFlowControlManager)(nil).GetReceiveWindow), arg0) +} + +// AddBytesSent mocks base method +func (_m *MockFlowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error { + ret := _m.ctrl.Call(_m, "AddBytesSent", streamID, n) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddBytesSent indicates an expected call of AddBytesSent +func (_mr *MockFlowControlManagerMockRecorder) AddBytesSent(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesSent", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesSent), arg0, arg1) +} + +// SendWindowSize mocks base method +func (_m *MockFlowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) { + ret := _m.ctrl.Call(_m, "SendWindowSize", streamID) + ret0, _ := ret[0].(protocol.ByteCount) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SendWindowSize indicates an expected call of SendWindowSize +func (_mr *MockFlowControlManagerMockRecorder) SendWindowSize(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SendWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).SendWindowSize), arg0) +} + +// RemainingConnectionWindowSize mocks base method +func (_m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount { + ret := _m.ctrl.Call(_m, "RemainingConnectionWindowSize") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// RemainingConnectionWindowSize indicates an expected call of RemainingConnectionWindowSize +func (_mr *MockFlowControlManagerMockRecorder) RemainingConnectionWindowSize() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RemainingConnectionWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).RemainingConnectionWindowSize)) +} + +// UpdateWindow mocks base method +func (_m *MockFlowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) { + ret := _m.ctrl.Call(_m, "UpdateWindow", streamID, offset) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateWindow indicates an expected call of UpdateWindow +func (_mr *MockFlowControlManagerMockRecorder) UpdateWindow(arg0, arg1 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateWindow", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateWindow), arg0, arg1) +} diff --git a/internal/mocks/gen.go b/internal/mocks/gen.go index 064a0ad5..21c7c028 100644 --- a/internal/mocks/gen.go +++ b/internal/mocks/gen.go @@ -3,6 +3,5 @@ package mocks // mockgen source mode doesn't properly recognize structs defined in the same package // so we have to use sed to correct for that -//go:generate sh -c "mockgen -package mocks_fc -source ../flowcontrol/interface.go | sed \"s/\\[\\]WindowUpdate/[]flowcontrol.WindowUpdate/g\" > mocks_fc/flow_control_manager.go" -//go:generate sh -c "mockgen -package mocks -source ../handshake/params_negotiator_base.go > params_negotiator.go" +//go:generate sh -c "mockgen -package mocks -source ../flowcontrol/interface.go | sed \"s/\\[\\]WindowUpdate/[]flowcontrol.WindowUpdate/g\" > flow_control_manager.go" //go:generate sh -c "goimports -w ." diff --git a/internal/mocks/mocks_fc/flow_control_manager.go b/internal/mocks/mocks_fc/flow_control_manager.go deleted file mode 100644 index 8089da47..00000000 --- a/internal/mocks/mocks_fc/flow_control_manager.go +++ /dev/null @@ -1,167 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ../flowcontrol/interface.go - -// Package mocks_fc is a generated GoMock package. -package mocks_fc - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - "github.com/lucas-clemente/quic-go/internal/flowcontrol" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// MockFlowControlManager is a mock of FlowControlManager interface -type MockFlowControlManager struct { - ctrl *gomock.Controller - recorder *MockFlowControlManagerMockRecorder -} - -// MockFlowControlManagerMockRecorder is the mock recorder for MockFlowControlManager -type MockFlowControlManagerMockRecorder struct { - mock *MockFlowControlManager -} - -// NewMockFlowControlManager creates a new mock instance -func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager { - mock := &MockFlowControlManager{ctrl: ctrl} - mock.recorder = &MockFlowControlManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockFlowControlManager) EXPECT() *MockFlowControlManagerMockRecorder { - return m.recorder -} - -// NewStream mocks base method -func (m *MockFlowControlManager) NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) { - m.ctrl.Call(m, "NewStream", streamID, contributesToConnectionFlow) -} - -// NewStream indicates an expected call of NewStream -func (mr *MockFlowControlManagerMockRecorder) NewStream(streamID, contributesToConnectionFlow interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewStream", reflect.TypeOf((*MockFlowControlManager)(nil).NewStream), streamID, contributesToConnectionFlow) -} - -// RemoveStream mocks base method -func (m *MockFlowControlManager) RemoveStream(streamID protocol.StreamID) { - m.ctrl.Call(m, "RemoveStream", streamID) -} - -// RemoveStream indicates an expected call of RemoveStream -func (mr *MockFlowControlManagerMockRecorder) RemoveStream(streamID interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveStream", reflect.TypeOf((*MockFlowControlManager)(nil).RemoveStream), streamID) -} - -// ResetStream mocks base method -func (m *MockFlowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - ret := m.ctrl.Call(m, "ResetStream", streamID, byteOffset) - ret0, _ := ret[0].(error) - return ret0 -} - -// ResetStream indicates an expected call of ResetStream -func (mr *MockFlowControlManagerMockRecorder) ResetStream(streamID, byteOffset interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetStream", reflect.TypeOf((*MockFlowControlManager)(nil).ResetStream), streamID, byteOffset) -} - -// UpdateHighestReceived mocks base method -func (m *MockFlowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error { - ret := m.ctrl.Call(m, "UpdateHighestReceived", streamID, byteOffset) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateHighestReceived indicates an expected call of UpdateHighestReceived -func (mr *MockFlowControlManagerMockRecorder) UpdateHighestReceived(streamID, byteOffset interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateHighestReceived), streamID, byteOffset) -} - -// AddBytesRead mocks base method -func (m *MockFlowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error { - ret := m.ctrl.Call(m, "AddBytesRead", streamID, n) - ret0, _ := ret[0].(error) - return ret0 -} - -// AddBytesRead indicates an expected call of AddBytesRead -func (mr *MockFlowControlManagerMockRecorder) AddBytesRead(streamID, n interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesRead), streamID, n) -} - -// GetWindowUpdates mocks base method -func (m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate { - ret := m.ctrl.Call(m, "GetWindowUpdates") - ret0, _ := ret[0].([]flowcontrol.WindowUpdate) - return ret0 -} - -// GetWindowUpdates indicates an expected call of GetWindowUpdates -func (mr *MockFlowControlManagerMockRecorder) GetWindowUpdates() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdates", reflect.TypeOf((*MockFlowControlManager)(nil).GetWindowUpdates)) -} - -// GetReceiveWindow mocks base method -func (m *MockFlowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) { - ret := m.ctrl.Call(m, "GetReceiveWindow", streamID) - ret0, _ := ret[0].(protocol.ByteCount) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetReceiveWindow indicates an expected call of GetReceiveWindow -func (mr *MockFlowControlManagerMockRecorder) GetReceiveWindow(streamID interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReceiveWindow", reflect.TypeOf((*MockFlowControlManager)(nil).GetReceiveWindow), streamID) -} - -// AddBytesSent mocks base method -func (m *MockFlowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error { - ret := m.ctrl.Call(m, "AddBytesSent", streamID, n) - ret0, _ := ret[0].(error) - return ret0 -} - -// AddBytesSent indicates an expected call of AddBytesSent -func (mr *MockFlowControlManagerMockRecorder) AddBytesSent(streamID, n interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesSent), streamID, n) -} - -// SendWindowSize mocks base method -func (m *MockFlowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) { - ret := m.ctrl.Call(m, "SendWindowSize", streamID) - ret0, _ := ret[0].(protocol.ByteCount) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SendWindowSize indicates an expected call of SendWindowSize -func (mr *MockFlowControlManagerMockRecorder) SendWindowSize(streamID interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).SendWindowSize), streamID) -} - -// RemainingConnectionWindowSize mocks base method -func (m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount { - ret := m.ctrl.Call(m, "RemainingConnectionWindowSize") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// RemainingConnectionWindowSize indicates an expected call of RemainingConnectionWindowSize -func (mr *MockFlowControlManagerMockRecorder) RemainingConnectionWindowSize() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemainingConnectionWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).RemainingConnectionWindowSize)) -} - -// UpdateWindow mocks base method -func (m *MockFlowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) { - ret := m.ctrl.Call(m, "UpdateWindow", streamID, offset) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateWindow indicates an expected call of UpdateWindow -func (mr *MockFlowControlManagerMockRecorder) UpdateWindow(streamID, offset interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWindow", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateWindow), streamID, offset) -} diff --git a/internal/mocks/params_negotiator.go b/internal/mocks/params_negotiator.go deleted file mode 100644 index 40e6b3b1..00000000 --- a/internal/mocks/params_negotiator.go +++ /dev/null @@ -1,96 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ../handshake/params_negotiator_base.go - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// MockParamsNegotiator is a mock of ParamsNegotiator interface -type MockParamsNegotiator struct { - ctrl *gomock.Controller - recorder *MockParamsNegotiatorMockRecorder -} - -// MockParamsNegotiatorMockRecorder is the mock recorder for MockParamsNegotiator -type MockParamsNegotiatorMockRecorder struct { - mock *MockParamsNegotiator -} - -// NewMockParamsNegotiator creates a new mock instance -func NewMockParamsNegotiator(ctrl *gomock.Controller) *MockParamsNegotiator { - mock := &MockParamsNegotiator{ctrl: ctrl} - mock.recorder = &MockParamsNegotiatorMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockParamsNegotiator) EXPECT() *MockParamsNegotiatorMockRecorder { - return m.recorder -} - -// GetSendStreamFlowControlWindow mocks base method -func (m *MockParamsNegotiator) GetSendStreamFlowControlWindow() protocol.ByteCount { - ret := m.ctrl.Call(m, "GetSendStreamFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetSendStreamFlowControlWindow indicates an expected call of GetSendStreamFlowControlWindow -func (mr *MockParamsNegotiatorMockRecorder) GetSendStreamFlowControlWindow() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSendStreamFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetSendStreamFlowControlWindow)) -} - -// GetSendConnectionFlowControlWindow mocks base method -func (m *MockParamsNegotiator) GetSendConnectionFlowControlWindow() protocol.ByteCount { - ret := m.ctrl.Call(m, "GetSendConnectionFlowControlWindow") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// GetSendConnectionFlowControlWindow indicates an expected call of GetSendConnectionFlowControlWindow -func (mr *MockParamsNegotiatorMockRecorder) GetSendConnectionFlowControlWindow() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSendConnectionFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetSendConnectionFlowControlWindow)) -} - -// GetMaxOutgoingStreams mocks base method -func (m *MockParamsNegotiator) GetMaxOutgoingStreams() uint32 { - ret := m.ctrl.Call(m, "GetMaxOutgoingStreams") - ret0, _ := ret[0].(uint32) - return ret0 -} - -// GetMaxOutgoingStreams indicates an expected call of GetMaxOutgoingStreams -func (mr *MockParamsNegotiatorMockRecorder) GetMaxOutgoingStreams() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxOutgoingStreams", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxOutgoingStreams)) -} - -// GetRemoteIdleTimeout mocks base method -func (m *MockParamsNegotiator) GetRemoteIdleTimeout() time.Duration { - ret := m.ctrl.Call(m, "GetRemoteIdleTimeout") - ret0, _ := ret[0].(time.Duration) - return ret0 -} - -// GetRemoteIdleTimeout indicates an expected call of GetRemoteIdleTimeout -func (mr *MockParamsNegotiatorMockRecorder) GetRemoteIdleTimeout() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRemoteIdleTimeout", reflect.TypeOf((*MockParamsNegotiator)(nil).GetRemoteIdleTimeout)) -} - -// OmitConnectionID mocks base method -func (m *MockParamsNegotiator) OmitConnectionID() bool { - ret := m.ctrl.Call(m, "OmitConnectionID") - ret0, _ := ret[0].(bool) - return ret0 -} - -// OmitConnectionID indicates an expected call of OmitConnectionID -func (mr *MockParamsNegotiatorMockRecorder) OmitConnectionID() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OmitConnectionID", reflect.TypeOf((*MockParamsNegotiator)(nil).OmitConnectionID)) -} diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index cf9cf056..4459d24a 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -43,12 +43,6 @@ const MaxReceivePacketSize ByteCount = 1452 // Used in QUIC for congestion window computations in bytes. const DefaultTCPMSS ByteCount = 1460 -// InitialStreamFlowControlWindow is the initial stream-level flow control window for sending -const InitialStreamFlowControlWindow ByteCount = (1 << 14) // 16 kB - -// InitialConnectionFlowControlWindow is the initial connection-level flow control window for sending -const InitialConnectionFlowControlWindow ByteCount = (1 << 14) // 16 kB - // ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have. const ClientHelloMinimumSize = 1024 diff --git a/packet_packer.go b/packet_packer.go index 429bb9b6..efc102d6 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -25,18 +25,17 @@ type packetPacker struct { cryptoSetup handshake.CryptoSetup packetNumberGenerator *packetNumberGenerator - connParams handshake.ParamsNegotiator streamFramer *streamFramer - controlFrames []wire.Frame - stopWaiting *wire.StopWaitingFrame - ackFrame *wire.AckFrame - leastUnacked protocol.PacketNumber + controlFrames []wire.Frame + stopWaiting *wire.StopWaitingFrame + ackFrame *wire.AckFrame + leastUnacked protocol.PacketNumber + omitConnectionID bool } func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.CryptoSetup, - connParams handshake.ParamsNegotiator, streamFramer *streamFramer, perspective protocol.Perspective, version protocol.VersionNumber, @@ -44,7 +43,6 @@ func newPacketPacker(connectionID protocol.ConnectionID, return &packetPacker{ cryptoSetup: cryptoSetup, connectionID: connectionID, - connParams: connParams, perspective: perspective, version: version, streamFramer: streamFramer, @@ -268,12 +266,14 @@ func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *wire. pnum := p.packetNumberGenerator.Peek() packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(pnum, p.leastUnacked) publicHeader := &wire.PublicHeader{ - ConnectionID: p.connectionID, - PacketNumber: pnum, - PacketNumberLen: packetNumberLen, - OmitConnectionID: p.connParams.OmitConnectionID(), + ConnectionID: p.connectionID, + PacketNumber: pnum, + PacketNumberLen: packetNumberLen, } + if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure { + publicHeader.OmitConnectionID = true + } if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce() } @@ -329,3 +329,7 @@ func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool { func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) { p.leastUnacked = leastUnacked } + +func (p *packetPacker) SetOmitConnectionID() { + p.omitConnectionID = true +} diff --git a/packet_packer_test.go b/packet_packer_test.go index 6a3b8d92..fea46350 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -7,7 +7,6 @@ import ( "github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/internal/handshake" - "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -61,19 +60,15 @@ var _ = Describe("Packet packer", func() { ) BeforeEach(func() { - mockPn := mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().OmitConnectionID().Return(false).AnyTimes() - cryptoStream = &stream{} - streamsMap := newStreamsMap(nil, nil, protocol.PerspectiveServer, nil) + streamsMap := newStreamsMap(nil, nil, protocol.PerspectiveServer) streamsMap.streams[1] = cryptoStream streamsMap.openStreams = []protocol.StreamID{1} streamFramer = newStreamFramer(streamsMap, nil) packer = &packetPacker{ cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, - connParams: mockPn, connectionID: 0x1337, packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), streamFramer: streamFramer, @@ -234,6 +229,20 @@ var _ = Describe("Packet packer", func() { Expect(p).ToNot(BeNil()) }) + It("it omits the connection ID for forward-secure packets", func() { + ph := packer.getPublicHeader(protocol.EncryptionForwardSecure) + Expect(ph.OmitConnectionID).To(BeFalse()) + packer.SetOmitConnectionID() + ph = packer.getPublicHeader(protocol.EncryptionForwardSecure) + Expect(ph.OmitConnectionID).To(BeTrue()) + }) + + It("doesn't omit the connection ID for non-forware-secure packets", func() { + packer.SetOmitConnectionID() + ph := packer.getPublicHeader(protocol.EncryptionSecure) + Expect(ph.OmitConnectionID).To(BeFalse()) + }) + It("adds the version flag to the public header before the crypto handshake is finished", func() { packer.perspective = protocol.PerspectiveClient packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure diff --git a/session.go b/session.go index c3d27de8..f9cf1939 100644 --- a/session.go +++ b/session.go @@ -88,6 +88,8 @@ type session struct { undecryptablePackets []*receivedPacket receivedTooManyUndecrytablePacketsTime time.Time + // this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them + paramsChan <-chan handshake.TransportParameters // this channel is passed to the CryptoSetup and receives the current encryption level // it is closed as soon as the handshake is complete aeadChanged <-chan protocol.EncryptionLevel @@ -100,8 +102,6 @@ type session struct { // it receives at most 3 handshake events: 2 when the encryption level changes, and one error handshakeChan chan<- handshakeEvent - connParams handshake.ParamsNegotiator - lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire // representation, and sent back in public reset packets @@ -109,6 +109,7 @@ type session struct { sessionCreationTime time.Time lastNetworkActivityTime time.Time + remoteIdleTimeout time.Duration timer *utils.Timer // keepAlivePingSent stores whether a Ping frame was sent to the peer or not @@ -166,7 +167,9 @@ func (s *session) setup( negotiatedVersions []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { aeadChanged := make(chan protocol.EncryptionLevel, 2) + paramsChan := make(chan handshake.TransportParameters) s.aeadChanged = aeadChanged + s.paramsChan = paramsChan handshakeChan := make(chan handshakeEvent, 3) s.handshakeChan = handshakeChan s.handshakeCompleteChan = make(chan error, 1) @@ -183,7 +186,10 @@ func (s *session) setup( s.rttStats = &congestion.RTTStats{} transportParams := &handshake.TransportParameters{ - IdleTimeout: s.config.IdleTimeout, + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + MaxStreams: protocol.MaxIncomingStreams, + IdleTimeout: s.config.IdleTimeout, } s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) @@ -194,15 +200,16 @@ func (s *session) setup( return s.config.AcceptCookie(clientAddr, cookie) } if s.version.UsesTLS() { - s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLSServer( + s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer( tlsConf, transportParams, + paramsChan, aeadChanged, s.config.Versions, s.version, ) } else { - s.cryptoSetup, s.connParams, err = newCryptoSetup( + s.cryptoSetup, err = newCryptoSetup( s.connectionID, s.conn.RemoteAddr(), s.version, @@ -210,28 +217,31 @@ func (s *session) setup( transportParams, s.config.Versions, verifySourceAddr, + paramsChan, aeadChanged, ) } } else { + transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission if s.version.UsesTLS() { - s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLSClient( + s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient( hostname, tlsConf, transportParams, + paramsChan, aeadChanged, initialVersion, s.config.Versions, s.version, ) } else { - transportParams.RequestConnectionIDOmission = s.config.RequestConnectionIDOmission - s.cryptoSetup, s.connParams, err = newCryptoSetupClient( + s.cryptoSetup, err = newCryptoSetupClient( hostname, s.connectionID, s.version, tlsConf, transportParams, + paramsChan, aeadChanged, negotiatedVersions, ) @@ -242,16 +252,14 @@ func (s *session) setup( } s.flowControlManager = flowcontrol.NewFlowControlManager( - s.connParams, protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), s.rttStats, ) - s.streamsMap = newStreamsMap(s.newStream, s.flowControlManager.RemoveStream, s.perspective, s.connParams) + s.streamsMap = newStreamsMap(s.newStream, s.flowControlManager.RemoveStream, s.perspective) s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) s.packer = newPacketPacker(s.connectionID, s.cryptoSetup, - s.connParams, s.streamFramer, s.perspective, s.version, @@ -318,6 +326,8 @@ runLoop: // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. putPacketBuffer(p.publicHeader.Raw) + case p := <-s.paramsChan: + s.processTransportParameters(&p) case l, ok := <-aeadChanged: if !ok { // the aeadChanged chan was closed. This means that the handshake is completed. s.handshakeComplete = true @@ -338,7 +348,7 @@ runLoop: s.sentPacketHandler.OnAlarm() } - if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.connParams.GetRemoteIdleTimeout()/2 { + if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.remoteIdleTimeout/2 { // send the PING frame since there is no activity in the session s.packer.QueueControlFrame(&wire.PingFrame{}) s.keepAlivePingSent = true @@ -379,7 +389,7 @@ func (s *session) Context() context.Context { func (s *session) maybeResetTimer() { var deadline time.Time if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent { - deadline = s.lastNetworkActivityTime.Add(s.connParams.GetRemoteIdleTimeout() / 2) + deadline = s.lastNetworkActivityTime.Add(s.remoteIdleTimeout / 2) } else { deadline = s.lastNetworkActivityTime.Add(s.config.IdleTimeout) } @@ -613,6 +623,15 @@ func (s *session) handleCloseError(closeErr closeError) error { return s.sendConnectionClose(quicErr) } +func (s *session) processTransportParameters(params *handshake.TransportParameters) { + s.remoteIdleTimeout = params.IdleTimeout + s.flowControlManager.UpdateTransportParameters(params) + s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams) + if params.OmitConnectionID { + s.packer.SetOmitConnectionID() + } +} + func (s *session) sendPacket() error { s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) diff --git a/session_test.go b/session_test.go index 99b28554..27b24143 100644 --- a/session_test.go +++ b/session_test.go @@ -18,7 +18,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" - "github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/wire" @@ -142,20 +141,6 @@ func areSessionsRunning() bool { return strings.Contains(b.String(), "quic-go.(*session).run") } -type mockParamsNegotiator struct{} - -var _ handshake.ParamsNegotiator = &mockParamsNegotiator{} - -func (m *mockParamsNegotiator) GetSendStreamFlowControlWindow() protocol.ByteCount { - return protocol.InitialStreamFlowControlWindow -} -func (m *mockParamsNegotiator) GetSendConnectionFlowControlWindow() protocol.ByteCount { - return protocol.InitialConnectionFlowControlWindow -} -func (m *mockParamsNegotiator) GetMaxOutgoingStreams() uint32 { return 100 } -func (m *mockParamsNegotiator) GetRemoteIdleTimeout() time.Duration { return time.Hour } -func (m *mockParamsNegotiator) OmitConnectionID() bool { return false } - var _ = Describe("Session", func() { var ( sess *session @@ -178,10 +163,11 @@ var _ = Describe("Session", func() { _ *handshake.TransportParameters, _ []protocol.VersionNumber, _ func(net.Addr, *Cookie) bool, + _ chan<- handshake.TransportParameters, aeadChangedP chan<- protocol.EncryptionLevel, - ) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) { + ) (handshake.CryptoSetup, error) { aeadChanged = aeadChangedP - return cryptoSetup, &mockParamsNegotiator{}, nil + return cryptoSetup, nil } mconn = newMockConnection() @@ -202,8 +188,6 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // 1 stream: the crypto stream - - sess.connParams = &mockParamsNegotiator{} }) AfterEach(func() { @@ -228,10 +212,11 @@ var _ = Describe("Session", func() { _ *handshake.TransportParameters, _ []protocol.VersionNumber, cookieFunc func(net.Addr, *Cookie) bool, + _ chan<- handshake.TransportParameters, _ chan<- protocol.EncryptionLevel, - ) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) { + ) (handshake.CryptoSetup, error) { cookieVerify = cookieFunc - return cryptoSetup, &mockParamsNegotiator{}, nil + return cryptoSetup, nil } conf := populateServerConfig(&Config{}) @@ -270,6 +255,10 @@ var _ = Describe("Session", func() { }) Context("when handling stream frames", func() { + BeforeEach(func() { + sess.streamsMap.UpdateMaxStreamLimit(100) + }) + It("makes new streams", func() { sess.handleStreamFrame(&wire.StreamFrame{ StreamID: 5, @@ -464,7 +453,7 @@ var _ = Describe("Session", func() { It("passes the byte offset to the flow controller", func() { sess.streamsMap.GetOrOpenStream(5) - fcm := mocks_fc.NewMockFlowControlManager(mockCtrl) + fcm := mocks.NewMockFlowControlManager(mockCtrl) sess.flowControlManager = fcm fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337)) err := sess.handleRstStreamFrame(&wire.RstStreamFrame{ @@ -477,7 +466,7 @@ var _ = Describe("Session", func() { It("returns errors from the flow controller", func() { testErr := errors.New("flow control violation") sess.streamsMap.GetOrOpenStream(5) - fcm := mocks_fc.NewMockFlowControlManager(mockCtrl) + fcm := mocks.NewMockFlowControlManager(mockCtrl) sess.flowControlManager = fcm fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337)).Return(testErr) err := sess.handleRstStreamFrame(&wire.RstStreamFrame{ @@ -525,6 +514,10 @@ var _ = Describe("Session", func() { }) Context("handling WINDOW_UPDATE frames", func() { + BeforeEach(func() { + sess.flowControlManager.UpdateTransportParameters(&handshake.TransportParameters{ConnectionFlowControlWindow: 0x1000}) + }) + It("updates the Flow Control Window of a stream", func() { _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) @@ -1093,7 +1086,7 @@ var _ = Describe("Session", func() { It("retransmits a WindowUpdate if it hasn't already sent a WindowUpdate with a higher ByteOffset", func() { _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - fcm := mocks_fc.NewMockFlowControlManager(mockCtrl) + fcm := mocks.NewMockFlowControlManager(mockCtrl) sess.flowControlManager = fcm fcm.EXPECT().GetWindowUpdates() fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x1000), nil) @@ -1114,7 +1107,7 @@ var _ = Describe("Session", func() { It("doesn't retransmit WindowUpdates if it already sent a WindowUpdate with a higher ByteOffset", func() { _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - fcm := mocks_fc.NewMockFlowControlManager(mockCtrl) + fcm := mocks.NewMockFlowControlManager(mockCtrl) sess.flowControlManager = fcm fcm.EXPECT().GetWindowUpdates() fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x2000), nil) @@ -1140,7 +1133,7 @@ var _ = Describe("Session", func() { err = sess.streamsMap.DeleteClosedStreams() Expect(err).ToNot(HaveOccurred()) _, err = sess.flowControlManager.SendWindowSize(5) - Expect(err).To(MatchError("Error accessing the flowController map.")) + Expect(err).To(MatchError("Error accessing the flowController map")) sph.retransmissionQueue = []*ackhandler.Packet{{ Frames: []wire.Frame{&wire.WindowUpdateFrame{ StreamID: 5, @@ -1183,6 +1176,11 @@ var _ = Describe("Session", func() { Context("scheduling sending", func() { BeforeEach(func() { + sess.processTransportParameters(&handshake.TransportParameters{ + StreamFlowControlWindow: protocol.MaxByteCount, + ConnectionFlowControlWindow: protocol.MaxByteCount, + MaxStreams: 1000, + }) sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} }) @@ -1420,16 +1418,33 @@ var _ = Describe("Session", func() { close(done) }) + It("process transport parameters received from the peer", func() { + paramsChan := make(chan handshake.TransportParameters) + sess.paramsChan = paramsChan + _, err := sess.GetOrOpenStream(5) + Expect(err).ToNot(HaveOccurred()) + go sess.run() + paramsChan <- handshake.TransportParameters{ + MaxStreams: 123, + IdleTimeout: 90 * time.Second, + StreamFlowControlWindow: 0x5000, + ConnectionFlowControlWindow: 0x5000, + OmitConnectionID: true, + } + Eventually(func() time.Duration { return sess.remoteIdleTimeout }).Should(Equal(90 * time.Second)) + Eventually(func() uint32 { return sess.streamsMap.maxOutgoingStreams }).Should(Equal(uint32(123))) + Eventually(func() (protocol.ByteCount, error) { return sess.flowControlManager.SendWindowSize(5) }).Should(Equal(protocol.ByteCount(0x5000))) + Eventually(func() bool { return sess.packer.omitConnectionID }).Should(BeTrue()) + Expect(sess.Close(nil)).To(Succeed()) + }) + Context("keep-alives", func() { - var mockPn *mocks.MockParamsNegotiator // should be shorter than the local timeout for these tests // otherwise we'd send a CONNECTION_CLOSE in the tests where we're testing that no PING is sent remoteIdleTimeout := 20 * time.Second BeforeEach(func() { - mockPn = mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetRemoteIdleTimeout().Return(remoteIdleTimeout).AnyTimes() - sess.connParams = mockPn + sess.remoteIdleTimeout = remoteIdleTimeout }) It("sends a PING", func() { @@ -1523,6 +1538,10 @@ var _ = Describe("Session", func() { }, 0.5) Context("getting streams", func() { + BeforeEach(func() { + sess.processTransportParameters(&handshake.TransportParameters{MaxStreams: 1000}) + }) + It("returns a new stream", func() { str, err := sess.GetOrOpenStream(11) Expect(err).ToNot(HaveOccurred()) @@ -1653,11 +1672,12 @@ var _ = Describe("Client Session", func() { _ protocol.VersionNumber, _ *tls.Config, _ *handshake.TransportParameters, + _ chan<- handshake.TransportParameters, aeadChangedP chan<- protocol.EncryptionLevel, _ []protocol.VersionNumber, - ) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) { + ) (handshake.CryptoSetup, error) { aeadChanged = aeadChangedP - return cryptoSetup, &mockParamsNegotiator{}, nil + return cryptoSetup, nil } mconn = newMockConnection() diff --git a/stream_framer_test.go b/stream_framer_test.go index e8a9f4bf..1036faa7 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -3,7 +3,7 @@ package quic import ( "bytes" - "github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc" + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -21,7 +21,7 @@ var _ = Describe("Stream Framer", func() { framer *streamFramer streamsMap *streamsMap stream1, stream2 *stream - mockFcm *mocks_fc.MockFlowControlManager + mockFcm *mocks.MockFlowControlManager ) BeforeEach(func() { @@ -37,11 +37,11 @@ var _ = Describe("Stream Framer", func() { stream1 = &stream{streamID: id1} stream2 = &stream{streamID: id2} - streamsMap = newStreamsMap(nil, nil, protocol.PerspectiveServer, nil) + streamsMap = newStreamsMap(nil, nil, protocol.PerspectiveServer) streamsMap.putStream(stream1) streamsMap.putStream(stream2) - mockFcm = mocks_fc.NewMockFlowControlManager(mockCtrl) + mockFcm = mocks.NewMockFlowControlManager(mockCtrl) framer = newStreamFramer(streamsMap, mockFcm) }) diff --git a/stream_test.go b/stream_test.go index ad73469f..9d757851 100644 --- a/stream_test.go +++ b/stream_test.go @@ -9,7 +9,7 @@ import ( "os" - "github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc" + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -30,7 +30,7 @@ var _ = Describe("Stream", func() { resetCalledForStream protocol.StreamID resetCalledAtOffset protocol.ByteCount - mockFcm *mocks_fc.MockFlowControlManager + mockFcm *mocks.MockFlowControlManager ) // in the tests for the stream deadlines we set a deadline @@ -58,7 +58,7 @@ var _ = Describe("Stream", func() { BeforeEach(func() { onDataCalled = false resetCalled = false - mockFcm = mocks_fc.NewMockFlowControlManager(mockCtrl) + mockFcm = mocks.NewMockFlowControlManager(mockCtrl) str = newStream(streamID, onData, onReset, mockFcm) timeout := scaleDuration(250 * time.Millisecond) diff --git a/streams_map.go b/streams_map.go index 5bd04d13..69797313 100644 --- a/streams_map.go +++ b/streams_map.go @@ -5,7 +5,6 @@ import ( "fmt" "sync" - "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" @@ -14,7 +13,6 @@ import ( type streamsMap struct { mutex sync.RWMutex - connParams handshake.ParamsNegotiator perspective protocol.Perspective streams map[protocol.StreamID]*stream @@ -36,6 +34,7 @@ type streamsMap struct { numOutgoingStreams uint32 numIncomingStreams uint32 maxIncomingStreams uint32 + maxOutgoingStreams uint32 } type streamLambda func(*stream) (bool, error) @@ -44,7 +43,7 @@ type newStreamLambda func(protocol.StreamID) *stream var errMapAccess = errors.New("streamsMap: Error accessing the streams map") -func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective, connParams handshake.ParamsNegotiator) *streamsMap { +func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective) *streamsMap { // add some tolerance to the maximum incoming streams value maxStreams := uint32(protocol.MaxIncomingStreams) maxIncomingStreams := utils.MaxUint32( @@ -57,7 +56,6 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC openStreams: make([]protocol.StreamID, 0), newStream: newStream, removeStreamCallback: removeStreamCallback, - connParams: connParams, maxIncomingStreams: maxIncomingStreams, } sm.nextStreamOrErrCond.L = &sm.mutex @@ -66,6 +64,8 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC if pers == protocol.PerspectiveClient { sm.nextStream = 1 sm.nextStreamToAccept = 2 + // TODO: find a better solution for opening the crypto stream + sm.maxOutgoingStreams = 1 // allow the crypto stream } else { sm.nextStream = 2 sm.nextStreamToAccept = 1 @@ -159,7 +159,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) { func (m *streamsMap) openStreamImpl() (*stream, error) { id := m.nextStream - if m.numOutgoingStreams >= m.connParams.GetMaxOutgoingStreams() { + if m.numOutgoingStreams >= m.maxOutgoingStreams { return nil, qerr.TooManyOpenStreams } @@ -340,3 +340,9 @@ func (m *streamsMap) CloseWithError(err error) { m.streams[s].Cancel(err) } } + +func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.maxOutgoingStreams = limit +} diff --git a/streams_map_test.go b/streams_map_test.go index dffaa105..f8d1c384 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -3,7 +3,6 @@ package quic import ( "errors" - "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -11,22 +10,16 @@ import ( ) var _ = Describe("Streams Map", func() { - const maxOutgoingStreams = 60 - var ( - m *streamsMap - mockPn *mocks.MockParamsNegotiator + m *streamsMap ) setNewStreamsMap := func(p protocol.Perspective) { - mockPn = mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams)) - newStream := func(id protocol.StreamID) *stream { return newStream(id, func() {}, nil, nil) } removeStreamCallback := func(protocol.StreamID) {} - m = newStreamsMap(newStream, removeStreamCallback, p, mockPn) + m = newStreamsMap(newStream, removeStreamCallback, p) } AfterEach(func() { @@ -132,7 +125,13 @@ var _ = Describe("Streams Map", func() { }) Context("server-side streams", func() { + It("doesn't allow opening streams before receiving the transport parameters", func() { + _, err := m.OpenStream() + Expect(err).To(MatchError(qerr.TooManyOpenStreams)) + }) + It("opens a stream 2 first", func() { + m.UpdateMaxStreamLimit(100) s, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) @@ -149,6 +148,7 @@ var _ = Describe("Streams Map", func() { }) It("doesn't reopen an already closed stream", func() { + m.UpdateMaxStreamLimit(100) str, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(str.StreamID()).To(Equal(protocol.StreamID(2))) @@ -160,6 +160,12 @@ var _ = Describe("Streams Map", func() { }) Context("counting streams", func() { + const maxOutgoingStreams = 50 + + BeforeEach(func() { + m.UpdateMaxStreamLimit(maxOutgoingStreams) + }) + It("errors when too many streams are opened", func() { for i := 1; i <= maxOutgoingStreams; i++ { _, err := m.OpenStream() @@ -190,6 +196,12 @@ var _ = Describe("Streams Map", func() { }) Context("opening streams synchronously", func() { + const maxOutgoingStreams = 10 + + BeforeEach(func() { + m.UpdateMaxStreamLimit(maxOutgoingStreams) + }) + openMaxNumStreams := func() { for i := 1; i <= maxOutgoingStreams; i++ { _, err := m.OpenStream()