diff --git a/internal/handshake/connection_parameters_manager.go b/internal/handshake/connection_parameters_manager.go index a69f10bc..d2098922 100644 --- a/internal/handshake/connection_parameters_manager.go +++ b/internal/handshake/connection_parameters_manager.go @@ -66,9 +66,7 @@ var ( func NewConnectionParamatersManager( pers protocol.Perspective, v protocol.VersionNumber, - maxReceiveStreamFlowControlWindow protocol.ByteCount, - maxReceiveConnectionFlowControlWindow protocol.ByteCount, - idleTimeout time.Duration, + params *TransportParameters, ) ConnectionParametersManager { h := &connectionParametersManager{ perspective: pers, @@ -77,11 +75,11 @@ func NewConnectionParamatersManager( sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, - maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, + maxReceiveStreamFlowControlWindow: params.MaxReceiveStreamFlowControlWindow, + maxReceiveConnectionFlowControlWindow: params.MaxReceiveConnectionFlowControlWindow, } - h.idleConnectionStateLifetime = idleTimeout + h.idleConnectionStateLifetime = params.IdleTimeout if h.perspective == protocol.PerspectiveServer { h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective diff --git a/internal/handshake/connection_parameters_manager_test.go b/internal/handshake/connection_parameters_manager_test.go index e5d02972..d7deec28 100644 --- a/internal/handshake/connection_parameters_manager_test.go +++ b/internal/handshake/connection_parameters_manager_test.go @@ -23,16 +23,20 @@ var _ = Describe("ConnectionsParameterManager", func() { cpm = NewConnectionParamatersManager( protocol.PerspectiveServer, protocol.VersionWhatever, - maxReceiveStreamFlowControlWindowServer, - maxReceiveConnectionFlowControlWindowServer, - idleTimeout, + &TransportParameters{ + MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindowServer, + MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindowServer, + IdleTimeout: idleTimeout, + }, ).(*connectionParametersManager) cpmClient = NewConnectionParamatersManager( protocol.PerspectiveClient, protocol.VersionWhatever, - maxReceiveStreamFlowControlWindowClient, - maxReceiveConnectionFlowControlWindowClient, - idleTimeout, + &TransportParameters{ + MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindowClient, + MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindowClient, + IdleTimeout: idleTimeout, + }, ).(*connectionParametersManager) }) diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index ecb404a4..d1274075 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -51,8 +51,8 @@ type cryptoSetupClient struct { forwardSecureAEAD crypto.AEAD aeadChanged chan<- protocol.EncryptionLevel - params *TransportParameters - connectionParameters ConnectionParametersManager + requestConnIDTruncation bool + connectionParameters ConnectionParametersManager } var _ CryptoSetup = &cryptoSetupClient{} @@ -68,34 +68,34 @@ func NewCryptoSetupClient( hostname string, connID protocol.ConnectionID, version protocol.VersionNumber, - cryptoStream io.ReadWriter, tlsConfig *tls.Config, - connectionParameters ConnectionParametersManager, - aeadChanged chan<- protocol.EncryptionLevel, params *TransportParameters, + aeadChanged chan<- protocol.EncryptionLevel, negotiatedVersions []protocol.VersionNumber, -) (CryptoSetup, error) { +) (CryptoSetup, ConnectionParametersManager, error) { + cpm := NewConnectionParamatersManager(protocol.PerspectiveClient, version, params) return &cryptoSetupClient{ - hostname: hostname, - connID: connID, - version: version, - cryptoStream: cryptoStream, - certManager: crypto.NewCertManager(tlsConfig), - connectionParameters: connectionParameters, - keyDerivation: crypto.DeriveQuicCryptoAESKeys, - keyExchange: getEphermalKEX, - nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), - aeadChanged: aeadChanged, - negotiatedVersions: negotiatedVersions, - divNonceChan: make(chan []byte), - params: params, - }, nil + hostname: hostname, + connID: connID, + version: version, + certManager: crypto.NewCertManager(tlsConfig), + connectionParameters: cpm, + requestConnIDTruncation: params.RequestConnectionIDTruncation, + keyDerivation: crypto.DeriveQuicCryptoAESKeys, + keyExchange: getEphermalKEX, + nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), + aeadChanged: aeadChanged, + negotiatedVersions: negotiatedVersions, + divNonceChan: make(chan []byte), + }, cpm, nil } -func (h *cryptoSetupClient) HandleCryptoStream() error { +func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error { messageChan := make(chan HandshakeMessage) errorChan := make(chan error) + h.cryptoStream = stream + go func() { for { message, err := ParseHandshakeMessage(h.cryptoStream) @@ -401,7 +401,6 @@ func (h *cryptoSetupClient) sendCHLO() error { } h.lastSentCHLO = b.Bytes() - return nil } @@ -422,7 +421,7 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) { binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version)) tags[TagVER] = versionTag - if h.params.RequestConnectionIDTruncation { + if h.requestConnIDTruncation { tags[TagTCID] = []byte{0, 0, 0, 0} } if len(h.stk) > 0 { diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index 181427bc..dc54b74f 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -109,20 +109,17 @@ var _ = Describe("Client Crypto Setup", func() { certManager = &mockCertManager{} version := protocol.Version37 aeadChanged = make(chan protocol.EncryptionLevel, 2) - csInt, err := NewCryptoSetupClient( + csInt, _, err := NewCryptoSetupClient( "hostname", 0, version, - stream, nil, - NewConnectionParamatersManager( - protocol.PerspectiveClient, - version, - protocol.DefaultMaxReceiveStreamFlowControlWindowClient, protocol.DefaultMaxReceiveConnectionFlowControlWindowClient, - protocol.DefaultIdleTimeout, - ), + &TransportParameters{ + MaxReceiveStreamFlowControlWindow: protocol.DefaultMaxReceiveStreamFlowControlWindowClient, + MaxReceiveConnectionFlowControlWindow: protocol.DefaultMaxReceiveConnectionFlowControlWindowClient, + IdleTimeout: protocol.DefaultIdleTimeout, + }, aeadChanged, - &TransportParameters{}, nil, ) Expect(err).ToNot(HaveOccurred()) @@ -131,6 +128,7 @@ var _ = Describe("Client Crypto Setup", func() { cs.keyDerivation = keyDerivation cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } cs.nullAEAD = &mockAEAD{encLevel: protocol.EncryptionUnencrypted} + cs.cryptoStream = stream }) AfterEach(func() { @@ -146,13 +144,13 @@ var _ = Describe("Client Crypto Setup", func() { It("rejects handshake messages with the wrong message tag", func() { HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) }) It("errors on invalid handshake messages", func() { stream.dataToRead.Write([]byte("invalid message")) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(HaveOccurred()) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeFailed)) }) @@ -161,7 +159,7 @@ var _ = Describe("Client Crypto Setup", func() { stk := []byte("foobar") tagMap[TagSTK] = stk HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead) - go cs.HandleCryptoStream() + go cs.HandleCryptoStream(stream) Eventually(func() []byte { return cs.stk }).Should(Equal(stk)) }) @@ -494,7 +492,7 @@ var _ = Describe("Client Crypto Setup", func() { }) It("requests to truncate the connection ID", func() { - cs.params.RequestConnectionIDTruncation = true + cs.requestConnIDTruncation = true tags, err := cs.getTags() Expect(err).ToNot(HaveOccurred()) Expect(tags).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0})) @@ -663,7 +661,7 @@ var _ = Describe("Client Crypto Setup", func() { It("tries to escalate before reading a handshake message", func() { Expect(cs.secureAEAD).To(BeNil()) cs.serverVerified = true - go cs.HandleCryptoStream() + go cs.HandleCryptoStream(stream) Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) Expect(cs.secureAEAD).ToNot(BeNil()) Expect(aeadChanged).ToNot(Receive()) @@ -673,7 +671,7 @@ var _ = Describe("Client Crypto Setup", func() { It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) { go func() { defer GinkgoRecover() - cs.HandleCryptoStream() + cs.HandleCryptoStream(stream) Fail("HandleCryptoStream should not have returned") }() cs.diversificationNonce = nil @@ -836,14 +834,14 @@ var _ = Describe("Client Crypto Setup", func() { Context("Diversification Nonces", func() { It("sets a diversification nonce", func() { - go cs.HandleCryptoStream() + go cs.HandleCryptoStream(stream) nonce := []byte("foobar") cs.SetDiversificationNonce(nonce) Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce)) }) It("doesn't do anything when called multiple times with the same nonce", func(done Done) { - go cs.HandleCryptoStream() + go cs.HandleCryptoStream(stream) nonce := []byte("foobar") cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce) @@ -854,7 +852,7 @@ var _ = Describe("Client Crypto Setup", func() { It("rejects a different diversification nonce", func() { var err error go func() { - err = cs.HandleCryptoStream() + err = cs.HandleCryptoStream(stream) }() nonce1 := []byte("foobar") diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index be55080c..b62e593c 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -69,17 +69,17 @@ func NewCryptoSetup( remoteAddr net.Addr, version protocol.VersionNumber, scfg *ServerConfig, - cryptoStream io.ReadWriter, - connectionParametersManager ConnectionParametersManager, + params *TransportParameters, supportedVersions []protocol.VersionNumber, acceptSTK func(net.Addr, *Cookie) bool, aeadChanged chan<- protocol.EncryptionLevel, -) (CryptoSetup, error) { +) (CryptoSetup, ConnectionParametersManager, error) { stkGenerator, err := NewCookieGenerator() if err != nil { - return nil, err + return nil, nil, err } + cpm := NewConnectionParamatersManager(protocol.PerspectiveServer, version, params) return &cryptoSetupServer{ connID: connID, remoteAddr: remoteAddr, @@ -90,16 +90,17 @@ func NewCryptoSetup( keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyExchange: getEphermalKEX, nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), - cryptoStream: cryptoStream, - connectionParameters: connectionParametersManager, + connectionParameters: cpm, acceptSTKCallback: acceptSTK, sentSHLO: make(chan struct{}), aeadChanged: aeadChanged, - }, nil + }, cpm, nil } // HandleCryptoStream reads and writes messages on the crypto stream -func (h *cryptoSetupServer) HandleCryptoStream() error { +func (h *cryptoSetupServer) HandleCryptoStream(stream io.ReadWriter) error { + h.cryptoStream = stream + for { var chloData bytes.Buffer message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData)) diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index c4f6a91b..cc900e5a 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -167,7 +167,6 @@ var _ = Describe("Server Crypto Setup", func() { scfg *ServerConfig cs *cryptoSetupServer stream *mockStream - cpm ConnectionParametersManager aeadChanged chan protocol.EncryptionLevel nonce32 []byte versionTag []byte @@ -198,19 +197,16 @@ var _ = Describe("Server Crypto Setup", func() { Expect(err).NotTo(HaveOccurred()) version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1] supportedVersions = []protocol.VersionNumber{version, 98, 99} - cpm = NewConnectionParamatersManager( - protocol.PerspectiveServer, - protocol.VersionWhatever, - protocol.DefaultMaxReceiveStreamFlowControlWindowServer, protocol.DefaultMaxReceiveConnectionFlowControlWindowServer, - protocol.DefaultIdleTimeout, - ) - csInt, err := NewCryptoSetup( + csInt, _, err := NewCryptoSetup( protocol.ConnectionID(42), remoteAddr, version, scfg, - stream, - cpm, + &TransportParameters{ + MaxReceiveStreamFlowControlWindow: protocol.DefaultMaxReceiveStreamFlowControlWindowServer, + MaxReceiveConnectionFlowControlWindow: protocol.DefaultMaxReceiveConnectionFlowControlWindowServer, + IdleTimeout: protocol.DefaultIdleTimeout, + }, supportedVersions, nil, aeadChanged, @@ -225,6 +221,7 @@ var _ = Describe("Server Crypto Setup", func() { cs.keyDerivation = mockQuicCryptoKeyDerivation cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } cs.nullAEAD = &mockAEAD{encLevel: protocol.EncryptionUnencrypted} + cs.cryptoStream = stream }) AfterEach(func() { @@ -277,7 +274,7 @@ var _ = Describe("Server Crypto Setup", func() { TagFHL2: []byte("foobar"), }, }.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(ErrHOLExperiment)) }) @@ -288,7 +285,7 @@ var _ = Describe("Server Crypto Setup", func() { TagNSTP: []byte("foobar"), }, }.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(ErrNSTPExperiment)) }) @@ -361,7 +358,7 @@ var _ = Describe("Server Crypto Setup", func() { }, }.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) @@ -373,14 +370,14 @@ var _ = Describe("Server Crypto Setup", func() { It("rejects client nonces that have the wrong length", func() { fullCHLO[TagNONC] = []byte("too short client nonce") HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length"))) }) It("rejects client nonces that have the wrong OBIT value", func() { fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0 HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching"))) }) @@ -388,13 +385,13 @@ var _ = Describe("Server Crypto Setup", func() { testErr := errors.New("test error") kex.sharedKeyError = testErr HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(testErr)) }) It("handles 0-RTT handshake", func() { HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) @@ -451,14 +448,14 @@ var _ = Describe("Server Crypto Setup", func() { TagSNI: []byte("quic.clemente.io"), }, }.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag"))) }) It("rejects CHLOs with a version tag that has the wrong length", func() { fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag"))) }) @@ -471,7 +468,7 @@ var _ = Describe("Server Crypto Setup", func() { binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(lowestSupportedVersion)) fullCHLO[TagVER] = b HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected"))) }) @@ -484,35 +481,35 @@ var _ = Describe("Server Crypto Setup", func() { binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(unsupportedVersion)) fullCHLO[TagVER] = b HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).ToNot(HaveOccurred()) }) It("errors if the AEAD tag is missing", func() { delete(fullCHLO, TagAEAD) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) It("errors if the AEAD tag has the wrong value", func() { fullCHLO[TagAEAD] = []byte("wrong") HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) It("errors if the KEXS tag is missing", func() { delete(fullCHLO, TagKEXS) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) It("errors if the KEXS tag has the wrong value", func() { fullCHLO[TagKEXS] = []byte("wrong") HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) }) @@ -524,7 +521,7 @@ var _ = Describe("Server Crypto Setup", func() { TagSTK: validSTK, }, }.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) }) @@ -536,19 +533,19 @@ var _ = Describe("Server Crypto Setup", func() { TagSNI: nil, }, }.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) }) It("errors with invalid message", func() { stream.dataToRead.Write([]byte("invalid message")) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.HandshakeFailed)) }) It("errors with non-CHLO message", func() { HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(stream) Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 1d940891..c306e67e 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -22,7 +22,6 @@ type cryptoSetupTLS struct { keyDerivation KeyDerivationFunction mintConf *mint.Config - conn crypto.MintController nullAEAD crypto.AEAD aead crypto.AEAD @@ -30,43 +29,47 @@ type cryptoSetupTLS struct { aeadChanged chan<- protocol.EncryptionLevel } +var newMintController = func(conn *mint.Conn) crypto.MintController { + return &mintController{conn} +} + // NewCryptoSetupTLS creates a new CryptoSetup instance for a server func NewCryptoSetupTLS( hostname string, // only needed for the client perspective protocol.Perspective, version protocol.VersionNumber, tlsConfig *tls.Config, - cryptoStream io.ReadWriter, aeadChanged chan<- protocol.EncryptionLevel, -) (CryptoSetup, error) { +) (CryptoSetup, ConnectionParametersManager, error) { mintConf, err := tlsToMintConfig(tlsConfig, perspective) if err != nil { - return nil, err + return nil, nil, err } mintConf.ServerName = hostname - var conn *mint.Conn - if perspective == protocol.PerspectiveServer { - conn = mint.Server(&fakeConn{cryptoStream}, mintConf) - } else { - conn = mint.Client(&fakeConn{cryptoStream}, mintConf) - } + return &cryptoSetupTLS{ perspective: perspective, mintConf: mintConf, - conn: &mintController{conn}, nullAEAD: crypto.NewNullAEAD(perspective, version), keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged, - }, nil + }, NewConnectionParamatersManager(perspective, version, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}), nil } -func (h *cryptoSetupTLS) HandleCryptoStream() error { - alert := h.conn.Handshake() - if alert != mint.AlertNoAlert { +func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error { + var conn *mint.Conn + if h.perspective == protocol.PerspectiveServer { + conn = mint.Server(&fakeConn{cryptoStream}, h.mintConf) + } else { + conn = mint.Client(&fakeConn{cryptoStream}, h.mintConf) + } + mc := newMintController(conn) + + if alert := mc.Handshake(); alert != mint.AlertNoAlert { return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) } - aead, err := h.keyDerivation(h.conn, h.perspective) + aead, err := h.keyDerivation(mc, h.perspective) if err != nil { return err } diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 35a93f53..5cfb3e28 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -34,33 +34,42 @@ var _ = Describe("TLS Crypto Setup", func() { var ( cs *cryptoSetupTLS aeadChanged chan protocol.EncryptionLevel + + mintControllerConstructor = newMintController ) BeforeEach(func() { aeadChanged = make(chan protocol.EncryptionLevel, 2) - csInt, err := NewCryptoSetupTLS( + csInt, _, err := NewCryptoSetupTLS( "", protocol.PerspectiveServer, protocol.VersionTLS, testdata.GetTLSConfig(), - nil, aeadChanged, ) Expect(err).ToNot(HaveOccurred()) cs = csInt.(*cryptoSetupTLS) }) + AfterEach(func() { + newMintController = mintControllerConstructor + }) + It("errors when the handshake fails", func() { alert := mint.AlertBadRecordMAC - cs.conn = &fakeMintController{result: alert} - err := cs.HandleCryptoStream() + newMintController = func(*mint.Conn) crypto.MintController { + return &fakeMintController{result: alert} + } + err := cs.HandleCryptoStream(nil) Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert))) }) It("derives keys", func() { - cs.conn = &fakeMintController{result: mint.AlertNoAlert} + newMintController = func(*mint.Conn) crypto.MintController { + return &fakeMintController{result: mint.AlertNoAlert} + } cs.keyDerivation = mockKeyDerivation - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(nil) Expect(err).ToNot(HaveOccurred()) Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) Expect(aeadChanged).To(BeClosed()) @@ -70,9 +79,11 @@ var _ = Describe("TLS Crypto Setup", func() { var foobarFNVSigned []byte // a "foobar", FNV signed doHandshake := func() { - cs.conn = &fakeMintController{result: mint.AlertNoAlert} + newMintController = func(*mint.Conn) crypto.MintController { + return &fakeMintController{result: mint.AlertNoAlert} + } cs.keyDerivation = mockKeyDerivation - err := cs.HandleCryptoStream() + err := cs.HandleCryptoStream(nil) Expect(err).ToNot(HaveOccurred()) } diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index f797cb4e..0f9e5cf6 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -1,6 +1,11 @@ package handshake -import "github.com/lucas-clemente/quic-go/internal/protocol" +import ( + "io" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) // Sealer seals a packet type Sealer interface { @@ -11,7 +16,7 @@ type Sealer interface { // CryptoSetup is a crypto setup type CryptoSetup interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) - HandleCryptoStream() error + HandleCryptoStream(io.ReadWriter) error // TODO: clean up this interface DiversificationNonce() []byte // only needed for cryptoSetupServer SetDiversificationNonce([]byte) // only needed for cryptoSetupClient @@ -23,5 +28,8 @@ type CryptoSetup interface { // TransportParameters are parameters sent to the peer during the handshake type TransportParameters struct { - RequestConnectionIDTruncation bool + RequestConnectionIDTruncation bool + MaxReceiveStreamFlowControlWindow protocol.ByteCount + MaxReceiveConnectionFlowControlWindow protocol.ByteCount + IdleTimeout time.Duration } diff --git a/packet_packer_test.go b/packet_packer_test.go index 9f7c2a9a..aafa4549 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "io" "math" "github.com/lucas-clemente/quic-go/ackhandler" @@ -32,7 +33,7 @@ type mockCryptoSetup struct { var _ handshake.CryptoSetup = &mockCryptoSetup{} -func (m *mockCryptoSetup) HandleCryptoStream() error { +func (m *mockCryptoSetup) HandleCryptoStream(io.ReadWriter) error { return m.handleErr } func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { diff --git a/session.go b/session.go index 8a5e9c9a..d1f8d35f 100644 --- a/session.go +++ b/session.go @@ -180,69 +180,57 @@ func (s *session) setup( s.sessionCreationTime = now s.rttStats = &congestion.RTTStats{} - s.connectionParameters = handshake.NewConnectionParamatersManager( - s.perspective, - s.version, - protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), - protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), - s.config.IdleTimeout, - ) + transportParams := &handshake.TransportParameters{ + MaxReceiveStreamFlowControlWindow: protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), + MaxReceiveConnectionFlowControlWindow: protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), + IdleTimeout: s.config.IdleTimeout, + } s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) - s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) - s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) - s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) var err error if s.perspective == protocol.PerspectiveServer { - cryptoStream, _ := s.GetOrOpenStream(1) - _, _ = s.AcceptStream() // don't expose the crypto stream verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool { return s.config.AcceptCookie(clientAddr, cookie) } if s.version.UsesTLS() { - s.cryptoSetup, err = handshake.NewCryptoSetupTLS( + s.cryptoSetup, s.connectionParameters, err = handshake.NewCryptoSetupTLS( "", s.perspective, s.version, tlsConf, - cryptoStream, aeadChanged, ) } else { - s.cryptoSetup, err = newCryptoSetup( + s.cryptoSetup, s.connectionParameters, err = newCryptoSetup( s.connectionID, s.conn.RemoteAddr(), s.version, scfg, - cryptoStream, - s.connectionParameters, + transportParams, s.config.Versions, verifySourceAddr, aeadChanged, ) } } else { - cryptoStream, _ := s.OpenStream() if s.version.UsesTLS() { - s.cryptoSetup, err = handshake.NewCryptoSetupTLS( + s.cryptoSetup, s.connectionParameters, err = handshake.NewCryptoSetupTLS( hostname, s.perspective, s.version, tlsConf, - cryptoStream, aeadChanged, ) } else { - s.cryptoSetup, err = newCryptoSetupClient( + transportParams.RequestConnectionIDTruncation = s.config.RequestConnectionIDTruncation + s.cryptoSetup, s.connectionParameters, err = newCryptoSetupClient( hostname, s.connectionID, s.version, - cryptoStream, tlsConf, - s.connectionParameters, + transportParams, aeadChanged, - &handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation}, negotiatedVersions, ) } @@ -251,6 +239,9 @@ func (s *session) setup( return nil, nil, err } + s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) + s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) + s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) s.packer = newPacketPacker(s.connectionID, s.cryptoSetup, s.connectionParameters, @@ -266,8 +257,16 @@ func (s *session) setup( // run the session main loop func (s *session) run() error { // Start the crypto stream handler + var cryptoStream Stream + if s.perspective == protocol.PerspectiveServer { + cryptoStream, _ = s.GetOrOpenStream(1) + _, _ = s.AcceptStream() // don't expose the crypto stream + } else { + cryptoStream, _ = s.OpenStream() + } + go func() { - if err := s.cryptoSetup.HandleCryptoStream(); err != nil { + if err := s.cryptoSetup.HandleCryptoStream(cryptoStream); err != nil { s.Close(err) } }() diff --git a/session_test.go b/session_test.go index 5eb3921e..05ceb3a1 100644 --- a/session_test.go +++ b/session_test.go @@ -143,12 +143,48 @@ func areSessionsRunning() bool { return strings.Contains(b.String(), "quic-go.(*session).run") } +type mockConnectionParametersManager struct { +} + +func (m *mockConnectionParametersManager) SetFromMap(map[handshake.Tag][]byte) error { + panic("not implement") +} +func (m *mockConnectionParametersManager) GetHelloMap() (map[handshake.Tag][]byte, error) { + panic("not implement") +} + +func (m *mockConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount { + return protocol.InitialStreamFlowControlWindow +} +func (m *mockConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount { + return protocol.InitialConnectionFlowControlWindow +} +func (m *mockConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount { + return protocol.ReceiveStreamFlowControlWindow +} +func (m *mockConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { + return protocol.DefaultMaxReceiveStreamFlowControlWindowServer +} +func (m *mockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount { + return protocol.ReceiveConnectionFlowControlWindow +} +func (m *mockConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { + return protocol.DefaultMaxReceiveConnectionFlowControlWindowServer +} +func (m *mockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { return 100 } +func (m *mockConnectionParametersManager) GetMaxIncomingStreams() uint32 { return 100 } +func (m *mockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { + return time.Hour +} +func (m *mockConnectionParametersManager) TruncateConnectionID() bool { return false } + +var _ handshake.ConnectionParametersManager = &mockConnectionParametersManager{} + var _ = Describe("Session", func() { var ( sess *session scfg *handshake.ServerConfig mconn *mockConnection - mockCpm *mocks.MockConnectionParametersManager cryptoSetup *mockCryptoSetup handshakeChan <-chan handshakeEvent aeadChanged chan<- protocol.EncryptionLevel @@ -163,14 +199,13 @@ var _ = Describe("Session", func() { _ net.Addr, _ protocol.VersionNumber, _ *handshake.ServerConfig, - _ io.ReadWriter, - _ handshake.ConnectionParametersManager, + _ *handshake.TransportParameters, _ []protocol.VersionNumber, _ func(net.Addr, *Cookie) bool, aeadChangedP chan<- protocol.EncryptionLevel, - ) (handshake.CryptoSetup, error) { + ) (handshake.CryptoSetup, handshake.ConnectionParametersManager, error) { aeadChanged = aeadChangedP - return cryptoSetup, nil + return cryptoSetup, &mockConnectionParametersManager{}, nil } mconn = newMockConnection() @@ -190,11 +225,9 @@ var _ = Describe("Session", func() { ) Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) - Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream + Expect(sess.streamsMap.openStreams).To(BeEmpty()) // the crypto stream is opened in session.run() - mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) - mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(time.Minute).AnyTimes() - sess.connectionParameters = mockCpm + sess.connectionParameters = &mockConnectionParametersManager{} }) AfterEach(func() { @@ -216,14 +249,13 @@ var _ = Describe("Session", func() { _ net.Addr, _ protocol.VersionNumber, _ *handshake.ServerConfig, - _ io.ReadWriter, - _ handshake.ConnectionParametersManager, + _ *handshake.TransportParameters, _ []protocol.VersionNumber, cookieFunc func(net.Addr, *Cookie) bool, _ chan<- protocol.EncryptionLevel, - ) (handshake.CryptoSetup, error) { + ) (handshake.CryptoSetup, handshake.ConnectionParametersManager, error) { cookieVerify = cookieFunc - return cryptoSetup, nil + return cryptoSetup, &mockConnectionParametersManager{}, nil } conf := populateServerConfig(&Config{}) @@ -730,18 +762,23 @@ var _ = Describe("Session", func() { Context("accepting streams", func() { It("waits for new streams", func() { - var str Stream + strChan := make(chan Stream) go func() { defer GinkgoRecover() - var err error - str, err = sess.AcceptStream() - Expect(err).ToNot(HaveOccurred()) + for { + str, err := sess.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + strChan <- str + } }() - Consistently(func() Stream { return str }).Should(BeNil()) + Consistently(strChan).ShouldNot(Receive()) sess.handleStreamFrame(&wire.StreamFrame{ StreamID: 3, }) - Eventually(func() Stream { return str }).ShouldNot(BeNil()) + var str Stream + Eventually(strChan).Should(Receive(&str)) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + Eventually(strChan).Should(Receive(&str)) Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) }) @@ -944,7 +981,7 @@ var _ = Describe("Session", func() { }) Context("sending packets", func() { - It("sends ack frames", func() { + It("sends ACK frames", func() { packetNumber := protocol.PacketNumber(0x035E) sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) err := sess.sendPacket() @@ -1528,7 +1565,7 @@ var _ = Describe("Session", func() { It("does not use ICSL before handshake", func() { defer sess.Close(nil) sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) - mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) + mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(9999 * time.Second).AnyTimes() mockCpm.EXPECT().TruncateConnectionID().Return(false).AnyTimes() sess.connectionParameters = mockCpm @@ -1545,7 +1582,7 @@ var _ = Describe("Session", func() { It("uses ICSL after handshake", func(done Done) { close(aeadChanged) - mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) + mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(0 * time.Second) mockCpm.EXPECT().TruncateConnectionID().Return(false).AnyTimes() sess.connectionParameters = mockCpm @@ -1599,7 +1636,9 @@ var _ = Describe("Session", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { - for i := 0; i < 110; i++ { + mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) + mockCpm.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() + for i := 0; i < 10; i++ { _, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } @@ -1608,6 +1647,8 @@ var _ = Describe("Session", func() { }) It("does not error when many streams are opened and closed", func() { + mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl) + mockCpm.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes() for i := 2; i <= 1000; i++ { s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) @@ -1641,11 +1682,13 @@ var _ = Describe("Session", func() { Context("window updates", func() { It("gets stream level window updates", func() { - err := sess.flowControlManager.AddBytesRead(1, protocol.ReceiveStreamFlowControlWindow) + _, err := sess.GetOrOpenStream(3) + Expect(err).ToNot(HaveOccurred()) + err = sess.flowControlManager.AddBytesRead(3, protocol.ReceiveStreamFlowControlWindow) Expect(err).NotTo(HaveOccurred()) frames := sess.getWindowUpdateFrames() Expect(frames).To(HaveLen(1)) - Expect(frames[0].StreamID).To(Equal(protocol.StreamID(1))) + Expect(frames[0].StreamID).To(Equal(protocol.StreamID(3))) Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow * 2)) }) @@ -1691,15 +1734,13 @@ var _ = Describe("Client Session", func() { _ string, _ protocol.ConnectionID, _ protocol.VersionNumber, - _ io.ReadWriter, _ *tls.Config, - _ handshake.ConnectionParametersManager, - aeadChangedP chan<- protocol.EncryptionLevel, _ *handshake.TransportParameters, + aeadChangedP chan<- protocol.EncryptionLevel, _ []protocol.VersionNumber, - ) (handshake.CryptoSetup, error) { + ) (handshake.CryptoSetup, handshake.ConnectionParametersManager, error) { aeadChanged = aeadChangedP - return cryptoSetup, nil + return cryptoSetup, &mockConnectionParametersManager{}, nil } mconn = newMockConnection() @@ -1714,7 +1755,7 @@ var _ = Describe("Client Session", func() { ) sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) - Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream + Expect(sess.streamsMap.openStreams).To(BeEmpty()) // the crypto stream is opened in session.run() }) AfterEach(func() {