From b63c81f0bf5485a84714e38fa81a00573199c8db Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 20 Oct 2018 11:40:33 +0900 Subject: [PATCH] try decrypting undecryptable packets when the encryption level changes There's no need to do this asynchronously any more when using TLS. --- crypto_stream_manager.go | 10 ++++----- crypto_stream_manager_test.go | 25 +++++++++++++++------ internal/handshake/crypto_setup_tls.go | 12 ---------- internal/handshake/crypto_setup_tls_test.go | 7 ------ session.go | 16 +++++++------ 5 files changed, 32 insertions(+), 38 deletions(-) diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 0498b516..330b26da 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -30,7 +30,7 @@ func newCryptoStreamManager( } } -func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { +func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) { var str cryptoStream switch encLevel { case protocol.EncryptionInitial: @@ -38,18 +38,18 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve case protocol.EncryptionHandshake: str = m.handshakeStream default: - return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) + return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) } if err := str.HandleCryptoFrame(frame); err != nil { - return err + return false, err } for { data := str.GetCryptoData() if data == nil { - return nil + return false, nil } if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished { - return str.Finish() + return true, str.Finish() } } } diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index aada3197..b57a0299 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -33,7 +33,9 @@ var _ = Describe("Crypto Stream Manager", func() { initialStream.EXPECT().GetCryptoData().Return([]byte("foobar")) initialStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) }) It("passes messages to the handshake stream", func() { @@ -42,7 +44,9 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")) handshakeStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) }) It("doesn't call the message handler, if there's no message", func() { @@ -50,7 +54,9 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().HandleCryptoFrame(cf) handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle // don't EXPECT any calls to HandleMessage() - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) }) It("processes all messages", func() { @@ -61,7 +67,9 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake) cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) }) It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() { @@ -72,7 +80,9 @@ var _ = Describe("Crypto Stream Manager", func() { cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), handshakeStream.EXPECT().Finish(), ) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeTrue()) }) It("returns errors that occur when finishing a stream", func() { @@ -84,11 +94,12 @@ var _ = Describe("Crypto Stream Manager", func() { cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), handshakeStream.EXPECT().Finish().Return(testErr), ) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(MatchError(testErr)) + _, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).To(MatchError(err)) }) It("errors for unknown encryption levels", func() { - err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) + _, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT")) }) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index e8a06ea8..1ced3b47 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -63,8 +63,6 @@ type cryptoSetupTLS struct { handshakeErrChan chan struct{} // HandleData() sends errors on the messageErrChan messageErrChan chan error - // handshakeEvent signals a change of encryption level to the session - handshakeEvent chan<- struct{} // handshakeComplete is closed when the handshake completes handshakeComplete chan<- struct{} // transport parameters are sent on the receivedTransportParams, as soon as they are received @@ -108,7 +106,6 @@ func NewCryptoSetupTLSClient( connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), - handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, tlsConf *tls.Config, initialVersion protocol.VersionNumber, @@ -123,7 +120,6 @@ func NewCryptoSetupTLSClient( connID, params, handleParams, - handshakeEvent, handshakeComplete, tlsConf, versionInfo{ @@ -143,7 +139,6 @@ func NewCryptoSetupTLSServer( connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), - handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, tlsConf *tls.Config, supportedVersions []protocol.VersionNumber, @@ -157,7 +152,6 @@ func NewCryptoSetupTLSServer( connID, params, handleParams, - handshakeEvent, handshakeComplete, tlsConf, versionInfo{ @@ -176,7 +170,6 @@ func newCryptoSetupTLS( connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), - handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, tlsConf *tls.Config, versionInfo versionInfo, @@ -194,7 +187,6 @@ func newCryptoSetupTLS( readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, handleParamsCallback: handleParams, - handshakeEvent: handshakeEvent, handshakeComplete: handshakeComplete, logger: logger, perspective: perspective, @@ -339,7 +331,6 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool { case <-h.handshakeErrChan: return false } - h.handshakeEvent <- struct{}{} return true case typeCertificate, typeCertificateVerify: // nothing to do @@ -351,7 +342,6 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool { case <-h.handshakeErrChan: return false } - h.handshakeEvent <- struct{}{} return true default: panic("unexpected handshake message") @@ -367,7 +357,6 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool { case <-h.handshakeErrChan: return false } - h.handshakeEvent <- struct{}{} return true case typeEncryptedExtensions: select { @@ -401,7 +390,6 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool { case <-h.handshakeErrChan: return false } - h.handshakeEvent <- struct{}{} return true default: panic("unexpected handshake message: ") diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 0f520677..c0e2cb16 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -63,7 +63,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, @@ -95,7 +94,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, @@ -178,7 +176,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), clientConf, protocol.VersionTLS, @@ -196,7 +193,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), serverConf, []protocol.VersionNumber{protocol.VersionTLS}, @@ -237,7 +233,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), &tls.Config{InsecureSkipVerify: true}, protocol.VersionTLS, @@ -278,7 +273,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, cTransportParameters, func(p *TransportParameters) { sTransportParametersRcvd = p }, - make(chan struct{}, 100), make(chan struct{}), &tls.Config{ServerName: "quic.clemente.io"}, protocol.VersionTLS, @@ -300,7 +294,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, sTransportParameters, func(p *TransportParameters) { cTransportParametersRcvd = p }, - make(chan struct{}, 100), make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, diff --git a/session.go b/session.go index c8ca516a..ab9ce584 100644 --- a/session.go +++ b/session.go @@ -120,6 +120,7 @@ type session struct { paramsChan <-chan handshake.TransportParameters // the handshakeEvent channel is passed to the CryptoSetup. // It receives when it makes sense to try decrypting undecryptable packets. + // Only used for gQUIC. handshakeEvent <-chan struct{} handshakeCompleteChan <-chan struct{} // is closed when the handshake completes handshakeComplete bool @@ -325,7 +326,6 @@ func newTLSServerSession( logger utils.Logger, v protocol.VersionNumber, ) (quicSession, error) { - handshakeEvent := make(chan struct{}, 2) // TODO: explain cap handshakeCompleteChan := make(chan struct{}) s := &session{ conn: conn, @@ -334,7 +334,6 @@ func newTLSServerSession( srcConnID: srcConnID, destConnID: destConnID, perspective: protocol.PerspectiveServer, - handshakeEvent: handshakeEvent, handshakeCompleteChan: handshakeCompleteChan, logger: logger, version: v, @@ -350,7 +349,6 @@ func newTLSServerSession( origConnID, params, s.processTransportParameters, - handshakeEvent, handshakeCompleteChan, tlsConf, conf.Versions, @@ -403,7 +401,6 @@ var newTLSClientSession = func( logger utils.Logger, v protocol.VersionNumber, ) (quicSession, error) { - handshakeEvent := make(chan struct{}, 2) // TODO: explain cap handshakeCompleteChan := make(chan struct{}) s := &session{ conn: conn, @@ -412,7 +409,6 @@ var newTLSClientSession = func( srcConnID: srcConnID, destConnID: destConnID, perspective: protocol.PerspectiveClient, - handshakeEvent: handshakeEvent, handshakeCompleteChan: handshakeCompleteChan, logger: logger, version: v, @@ -426,7 +422,6 @@ var newTLSClientSession = func( s.destConnID, params, s.processTransportParameters, - handshakeEvent, handshakeCompleteChan, tlsConf, initialVersion, @@ -804,7 +799,14 @@ func (s *session) handlePacket(p *receivedPacket) { } func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { - return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) + encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) + if err != nil { + return err + } + if encLevelChanged { + s.tryDecryptingQueuedPackets() + } + return nil } func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error {