mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
try decrypting undecryptable packets when the encryption level changes
There's no need to do this asynchronously any more when using TLS.
This commit is contained in:
parent
387c28d707
commit
b63c81f0bf
5 changed files with 32 additions and 38 deletions
|
@ -30,7 +30,7 @@ func newCryptoStreamManager(
|
|||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
|
||||
var str cryptoStream
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
|
@ -38,18 +38,18 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
|
|||
case protocol.EncryptionHandshake:
|
||||
str = m.handshakeStream
|
||||
default:
|
||||
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
}
|
||||
if err := str.HandleCryptoFrame(frame); err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
for {
|
||||
data := str.GetCryptoData()
|
||||
if data == nil {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
|
||||
return str.Finish()
|
||||
return true, str.Finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,7 +33,9 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
initialStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
|
||||
initialStream.EXPECT().GetCryptoData()
|
||||
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial)
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed())
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeFalse())
|
||||
})
|
||||
|
||||
It("passes messages to the handshake stream", func() {
|
||||
|
@ -42,7 +44,9 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
|
||||
handshakeStream.EXPECT().GetCryptoData()
|
||||
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake)
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeFalse())
|
||||
})
|
||||
|
||||
It("doesn't call the message handler, if there's no message", func() {
|
||||
|
@ -50,7 +54,9 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
handshakeStream.EXPECT().HandleCryptoFrame(cf)
|
||||
handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle
|
||||
// don't EXPECT any calls to HandleMessage()
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeFalse())
|
||||
})
|
||||
|
||||
It("processes all messages", func() {
|
||||
|
@ -61,7 +67,9 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
handshakeStream.EXPECT().GetCryptoData()
|
||||
cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake)
|
||||
cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake)
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeFalse())
|
||||
})
|
||||
|
||||
It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() {
|
||||
|
@ -72,7 +80,9 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
|
||||
handshakeStream.EXPECT().Finish(),
|
||||
)
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
|
||||
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(encLevelChanged).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns errors that occur when finishing a stream", func() {
|
||||
|
@ -84,11 +94,12 @@ var _ = Describe("Crypto Stream Manager", func() {
|
|||
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
|
||||
handshakeStream.EXPECT().Finish().Return(testErr),
|
||||
)
|
||||
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(MatchError(testErr))
|
||||
_, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
|
||||
Expect(err).To(MatchError(err))
|
||||
})
|
||||
|
||||
It("errors for unknown encryption levels", func() {
|
||||
err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT)
|
||||
_, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT)
|
||||
Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT"))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -63,8 +63,6 @@ type cryptoSetupTLS struct {
|
|||
handshakeErrChan chan struct{}
|
||||
// HandleData() sends errors on the messageErrChan
|
||||
messageErrChan chan error
|
||||
// handshakeEvent signals a change of encryption level to the session
|
||||
handshakeEvent chan<- struct{}
|
||||
// handshakeComplete is closed when the handshake completes
|
||||
handshakeComplete chan<- struct{}
|
||||
// transport parameters are sent on the receivedTransportParams, as soon as they are received
|
||||
|
@ -108,7 +106,6 @@ func NewCryptoSetupTLSClient(
|
|||
connID protocol.ConnectionID,
|
||||
params *TransportParameters,
|
||||
handleParams func(*TransportParameters),
|
||||
handshakeEvent chan<- struct{},
|
||||
handshakeComplete chan<- struct{},
|
||||
tlsConf *tls.Config,
|
||||
initialVersion protocol.VersionNumber,
|
||||
|
@ -123,7 +120,6 @@ func NewCryptoSetupTLSClient(
|
|||
connID,
|
||||
params,
|
||||
handleParams,
|
||||
handshakeEvent,
|
||||
handshakeComplete,
|
||||
tlsConf,
|
||||
versionInfo{
|
||||
|
@ -143,7 +139,6 @@ func NewCryptoSetupTLSServer(
|
|||
connID protocol.ConnectionID,
|
||||
params *TransportParameters,
|
||||
handleParams func(*TransportParameters),
|
||||
handshakeEvent chan<- struct{},
|
||||
handshakeComplete chan<- struct{},
|
||||
tlsConf *tls.Config,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
|
@ -157,7 +152,6 @@ func NewCryptoSetupTLSServer(
|
|||
connID,
|
||||
params,
|
||||
handleParams,
|
||||
handshakeEvent,
|
||||
handshakeComplete,
|
||||
tlsConf,
|
||||
versionInfo{
|
||||
|
@ -176,7 +170,6 @@ func newCryptoSetupTLS(
|
|||
connID protocol.ConnectionID,
|
||||
params *TransportParameters,
|
||||
handleParams func(*TransportParameters),
|
||||
handshakeEvent chan<- struct{},
|
||||
handshakeComplete chan<- struct{},
|
||||
tlsConf *tls.Config,
|
||||
versionInfo versionInfo,
|
||||
|
@ -194,7 +187,6 @@ func newCryptoSetupTLS(
|
|||
readEncLevel: protocol.EncryptionInitial,
|
||||
writeEncLevel: protocol.EncryptionInitial,
|
||||
handleParamsCallback: handleParams,
|
||||
handshakeEvent: handshakeEvent,
|
||||
handshakeComplete: handshakeComplete,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
|
@ -339,7 +331,6 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
|
|||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
h.handshakeEvent <- struct{}{}
|
||||
return true
|
||||
case typeCertificate, typeCertificateVerify:
|
||||
// nothing to do
|
||||
|
@ -351,7 +342,6 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
|
|||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
h.handshakeEvent <- struct{}{}
|
||||
return true
|
||||
default:
|
||||
panic("unexpected handshake message")
|
||||
|
@ -367,7 +357,6 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
|
|||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
h.handshakeEvent <- struct{}{}
|
||||
return true
|
||||
case typeEncryptedExtensions:
|
||||
select {
|
||||
|
@ -401,7 +390,6 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
|
|||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
h.handshakeEvent <- struct{}{}
|
||||
return true
|
||||
default:
|
||||
panic("unexpected handshake message: ")
|
||||
|
|
|
@ -63,7 +63,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
testdata.GetTLSConfig(),
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
|
@ -95,7 +94,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
testdata.GetTLSConfig(),
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
|
@ -178,7 +176,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
clientConf,
|
||||
protocol.VersionTLS,
|
||||
|
@ -196,7 +193,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.ConnectionID{},
|
||||
&TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
serverConf,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
|
@ -237,7 +233,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.ConnectionID{},
|
||||
&TransportParameters{},
|
||||
func(p *TransportParameters) {},
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
protocol.VersionTLS,
|
||||
|
@ -278,7 +273,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.ConnectionID{},
|
||||
cTransportParameters,
|
||||
func(p *TransportParameters) { sTransportParametersRcvd = p },
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
&tls.Config{ServerName: "quic.clemente.io"},
|
||||
protocol.VersionTLS,
|
||||
|
@ -300,7 +294,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
protocol.ConnectionID{},
|
||||
sTransportParameters,
|
||||
func(p *TransportParameters) { cTransportParametersRcvd = p },
|
||||
make(chan struct{}, 100),
|
||||
make(chan struct{}),
|
||||
testdata.GetTLSConfig(),
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
|
|
16
session.go
16
session.go
|
@ -120,6 +120,7 @@ type session struct {
|
|||
paramsChan <-chan handshake.TransportParameters
|
||||
// the handshakeEvent channel is passed to the CryptoSetup.
|
||||
// It receives when it makes sense to try decrypting undecryptable packets.
|
||||
// Only used for gQUIC.
|
||||
handshakeEvent <-chan struct{}
|
||||
handshakeCompleteChan <-chan struct{} // is closed when the handshake completes
|
||||
handshakeComplete bool
|
||||
|
@ -325,7 +326,6 @@ func newTLSServerSession(
|
|||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
handshakeEvent := make(chan struct{}, 2) // TODO: explain cap
|
||||
handshakeCompleteChan := make(chan struct{})
|
||||
s := &session{
|
||||
conn: conn,
|
||||
|
@ -334,7 +334,6 @@ func newTLSServerSession(
|
|||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
handshakeEvent: handshakeEvent,
|
||||
handshakeCompleteChan: handshakeCompleteChan,
|
||||
logger: logger,
|
||||
version: v,
|
||||
|
@ -350,7 +349,6 @@ func newTLSServerSession(
|
|||
origConnID,
|
||||
params,
|
||||
s.processTransportParameters,
|
||||
handshakeEvent,
|
||||
handshakeCompleteChan,
|
||||
tlsConf,
|
||||
conf.Versions,
|
||||
|
@ -403,7 +401,6 @@ var newTLSClientSession = func(
|
|||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
handshakeEvent := make(chan struct{}, 2) // TODO: explain cap
|
||||
handshakeCompleteChan := make(chan struct{})
|
||||
s := &session{
|
||||
conn: conn,
|
||||
|
@ -412,7 +409,6 @@ var newTLSClientSession = func(
|
|||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
handshakeEvent: handshakeEvent,
|
||||
handshakeCompleteChan: handshakeCompleteChan,
|
||||
logger: logger,
|
||||
version: v,
|
||||
|
@ -426,7 +422,6 @@ var newTLSClientSession = func(
|
|||
s.destConnID,
|
||||
params,
|
||||
s.processTransportParameters,
|
||||
handshakeEvent,
|
||||
handshakeCompleteChan,
|
||||
tlsConf,
|
||||
initialVersion,
|
||||
|
@ -804,7 +799,14 @@ func (s *session) handlePacket(p *receivedPacket) {
|
|||
}
|
||||
|
||||
func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
|
||||
return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
|
||||
encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if encLevelChanged {
|
||||
s.tryDecryptingQueuedPackets()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue