try decrypting undecryptable packets when the encryption level changes

There's no need to do this asynchronously any more when using TLS.
This commit is contained in:
Marten Seemann 2018-10-20 11:40:33 +09:00
parent 387c28d707
commit b63c81f0bf
5 changed files with 32 additions and 38 deletions

View file

@ -30,7 +30,7 @@ func newCryptoStreamManager(
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
var str cryptoStream
switch encLevel {
case protocol.EncryptionInitial:
@ -38,18 +38,18 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
case protocol.EncryptionHandshake:
str = m.handshakeStream
default:
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return err
return false, err
}
for {
data := str.GetCryptoData()
if data == nil {
return nil
return false, nil
}
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
return str.Finish()
return true, str.Finish()
}
}
}

View file

@ -33,7 +33,9 @@ var _ = Describe("Crypto Stream Manager", func() {
initialStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
initialStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial)
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed())
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
})
It("passes messages to the handshake stream", func() {
@ -42,7 +44,9 @@ var _ = Describe("Crypto Stream Manager", func() {
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
handshakeStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake)
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
})
It("doesn't call the message handler, if there's no message", func() {
@ -50,7 +54,9 @@ var _ = Describe("Crypto Stream Manager", func() {
handshakeStream.EXPECT().HandleCryptoFrame(cf)
handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle
// don't EXPECT any calls to HandleMessage()
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
})
It("processes all messages", func() {
@ -61,7 +67,9 @@ var _ = Describe("Crypto Stream Manager", func() {
handshakeStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake)
cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake)
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
})
It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() {
@ -72,7 +80,9 @@ var _ = Describe("Crypto Stream Manager", func() {
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
handshakeStream.EXPECT().Finish(),
)
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeTrue())
})
It("returns errors that occur when finishing a stream", func() {
@ -84,11 +94,12 @@ var _ = Describe("Crypto Stream Manager", func() {
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
handshakeStream.EXPECT().Finish().Return(testErr),
)
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(MatchError(testErr))
_, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).To(MatchError(err))
})
It("errors for unknown encryption levels", func() {
err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT)
_, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT)
Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT"))
})
})

View file

@ -63,8 +63,6 @@ type cryptoSetupTLS struct {
handshakeErrChan chan struct{}
// HandleData() sends errors on the messageErrChan
messageErrChan chan error
// handshakeEvent signals a change of encryption level to the session
handshakeEvent chan<- struct{}
// handshakeComplete is closed when the handshake completes
handshakeComplete chan<- struct{}
// transport parameters are sent on the receivedTransportParams, as soon as they are received
@ -108,7 +106,6 @@ func NewCryptoSetupTLSClient(
connID protocol.ConnectionID,
params *TransportParameters,
handleParams func(*TransportParameters),
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
tlsConf *tls.Config,
initialVersion protocol.VersionNumber,
@ -123,7 +120,6 @@ func NewCryptoSetupTLSClient(
connID,
params,
handleParams,
handshakeEvent,
handshakeComplete,
tlsConf,
versionInfo{
@ -143,7 +139,6 @@ func NewCryptoSetupTLSServer(
connID protocol.ConnectionID,
params *TransportParameters,
handleParams func(*TransportParameters),
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
tlsConf *tls.Config,
supportedVersions []protocol.VersionNumber,
@ -157,7 +152,6 @@ func NewCryptoSetupTLSServer(
connID,
params,
handleParams,
handshakeEvent,
handshakeComplete,
tlsConf,
versionInfo{
@ -176,7 +170,6 @@ func newCryptoSetupTLS(
connID protocol.ConnectionID,
params *TransportParameters,
handleParams func(*TransportParameters),
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
tlsConf *tls.Config,
versionInfo versionInfo,
@ -194,7 +187,6 @@ func newCryptoSetupTLS(
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
handleParamsCallback: handleParams,
handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
logger: logger,
perspective: perspective,
@ -339,7 +331,6 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
case <-h.handshakeErrChan:
return false
}
h.handshakeEvent <- struct{}{}
return true
case typeCertificate, typeCertificateVerify:
// nothing to do
@ -351,7 +342,6 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
case <-h.handshakeErrChan:
return false
}
h.handshakeEvent <- struct{}{}
return true
default:
panic("unexpected handshake message")
@ -367,7 +357,6 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
case <-h.handshakeErrChan:
return false
}
h.handshakeEvent <- struct{}{}
return true
case typeEncryptedExtensions:
select {
@ -401,7 +390,6 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
case <-h.handshakeErrChan:
return false
}
h.handshakeEvent <- struct{}{}
return true
default:
panic("unexpected handshake message: ")

View file

@ -63,7 +63,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},
@ -95,7 +94,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},
@ -178,7 +176,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
clientConf,
protocol.VersionTLS,
@ -196,7 +193,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
serverConf,
[]protocol.VersionNumber{protocol.VersionTLS},
@ -237,7 +233,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
&tls.Config{InsecureSkipVerify: true},
protocol.VersionTLS,
@ -278,7 +273,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
cTransportParameters,
func(p *TransportParameters) { sTransportParametersRcvd = p },
make(chan struct{}, 100),
make(chan struct{}),
&tls.Config{ServerName: "quic.clemente.io"},
protocol.VersionTLS,
@ -300,7 +294,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
sTransportParameters,
func(p *TransportParameters) { cTransportParametersRcvd = p },
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},

View file

@ -120,6 +120,7 @@ type session struct {
paramsChan <-chan handshake.TransportParameters
// the handshakeEvent channel is passed to the CryptoSetup.
// It receives when it makes sense to try decrypting undecryptable packets.
// Only used for gQUIC.
handshakeEvent <-chan struct{}
handshakeCompleteChan <-chan struct{} // is closed when the handshake completes
handshakeComplete bool
@ -325,7 +326,6 @@ func newTLSServerSession(
logger utils.Logger,
v protocol.VersionNumber,
) (quicSession, error) {
handshakeEvent := make(chan struct{}, 2) // TODO: explain cap
handshakeCompleteChan := make(chan struct{})
s := &session{
conn: conn,
@ -334,7 +334,6 @@ func newTLSServerSession(
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveServer,
handshakeEvent: handshakeEvent,
handshakeCompleteChan: handshakeCompleteChan,
logger: logger,
version: v,
@ -350,7 +349,6 @@ func newTLSServerSession(
origConnID,
params,
s.processTransportParameters,
handshakeEvent,
handshakeCompleteChan,
tlsConf,
conf.Versions,
@ -403,7 +401,6 @@ var newTLSClientSession = func(
logger utils.Logger,
v protocol.VersionNumber,
) (quicSession, error) {
handshakeEvent := make(chan struct{}, 2) // TODO: explain cap
handshakeCompleteChan := make(chan struct{})
s := &session{
conn: conn,
@ -412,7 +409,6 @@ var newTLSClientSession = func(
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveClient,
handshakeEvent: handshakeEvent,
handshakeCompleteChan: handshakeCompleteChan,
logger: logger,
version: v,
@ -426,7 +422,6 @@ var newTLSClientSession = func(
s.destConnID,
params,
s.processTransportParameters,
handshakeEvent,
handshakeCompleteChan,
tlsConf,
initialVersion,
@ -804,7 +799,14 @@ func (s *session) handlePacket(p *receivedPacket) {
}
func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
if err != nil {
return err
}
if encLevelChanged {
s.tryDecryptingQueuedPackets()
}
return nil
}
func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error {