mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
fix handling of multiple handshake messages in the case of errors
When receiving a handshake message after another handshake messages that doesn't cause any action from the TLS stack (i.e. Certificate and CertificateVerify), the handshake would run into a deadlock if the first of these handshake messages caused an error in the TLS stack. We need to make sure that we wait until a message has been fully processed before proceeding with the handshake.
This commit is contained in:
parent
8bf5c782e3
commit
c9bfde9ac0
2 changed files with 75 additions and 156 deletions
|
@ -87,7 +87,9 @@ type cryptoSetup struct {
|
||||||
extraConf *qtls.ExtraConfig
|
extraConf *qtls.ExtraConfig
|
||||||
conn *qtls.Conn
|
conn *qtls.Conn
|
||||||
|
|
||||||
messageChan chan []byte
|
messageChan chan []byte
|
||||||
|
isReadingHandshakeMessage chan struct{}
|
||||||
|
readFirstHandshakeMessage bool
|
||||||
|
|
||||||
ourParams *wire.TransportParameters
|
ourParams *wire.TransportParameters
|
||||||
peerParams *wire.TransportParameters
|
peerParams *wire.TransportParameters
|
||||||
|
@ -105,15 +107,6 @@ type cryptoSetup struct {
|
||||||
clientHelloWritten bool
|
clientHelloWritten bool
|
||||||
clientHelloWrittenChan chan *wire.TransportParameters
|
clientHelloWrittenChan chan *wire.TransportParameters
|
||||||
|
|
||||||
receivedWriteKey chan struct{}
|
|
||||||
receivedReadKey chan struct{}
|
|
||||||
// WriteRecord does a non-blocking send on this channel.
|
|
||||||
// This way, handleMessage can see if qtls tries to write a message.
|
|
||||||
// This is necessary:
|
|
||||||
// for servers: to see if a HelloRetryRequest should be sent in response to a ClientHello
|
|
||||||
// for clients: to see if a ServerHello is a HelloRetryRequest
|
|
||||||
writeRecord chan struct{}
|
|
||||||
|
|
||||||
rttStats *utils.RTTStats
|
rttStats *utils.RTTStats
|
||||||
|
|
||||||
tracer logging.ConnectionTracer
|
tracer logging.ConnectionTracer
|
||||||
|
@ -231,29 +224,27 @@ func newCryptoSetup(
|
||||||
}
|
}
|
||||||
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
|
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
|
||||||
cs := &cryptoSetup{
|
cs := &cryptoSetup{
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
initialStream: initialStream,
|
initialStream: initialStream,
|
||||||
initialSealer: initialSealer,
|
initialSealer: initialSealer,
|
||||||
initialOpener: initialOpener,
|
initialOpener: initialOpener,
|
||||||
handshakeStream: handshakeStream,
|
handshakeStream: handshakeStream,
|
||||||
aead: newUpdatableAEAD(rttStats, tracer, logger),
|
aead: newUpdatableAEAD(rttStats, tracer, logger),
|
||||||
readEncLevel: protocol.EncryptionInitial,
|
readEncLevel: protocol.EncryptionInitial,
|
||||||
writeEncLevel: protocol.EncryptionInitial,
|
writeEncLevel: protocol.EncryptionInitial,
|
||||||
runner: runner,
|
runner: runner,
|
||||||
ourParams: tp,
|
ourParams: tp,
|
||||||
paramsChan: extHandler.TransportParameters(),
|
paramsChan: extHandler.TransportParameters(),
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
perspective: perspective,
|
perspective: perspective,
|
||||||
handshakeDone: make(chan struct{}),
|
handshakeDone: make(chan struct{}),
|
||||||
alertChan: make(chan uint8),
|
alertChan: make(chan uint8),
|
||||||
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
|
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
|
||||||
messageChan: make(chan []byte, 100),
|
messageChan: make(chan []byte, 100),
|
||||||
receivedReadKey: make(chan struct{}),
|
isReadingHandshakeMessage: make(chan struct{}),
|
||||||
receivedWriteKey: make(chan struct{}),
|
closeChan: make(chan struct{}),
|
||||||
writeRecord: make(chan struct{}, 1),
|
|
||||||
closeChan: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
var maxEarlyData uint32
|
var maxEarlyData uint32
|
||||||
if enable0RTT {
|
if enable0RTT {
|
||||||
|
@ -344,20 +335,25 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev
|
||||||
h.messageChan <- data
|
h.messageChan <- data
|
||||||
if encLevel == protocol.Encryption1RTT {
|
if encLevel == protocol.Encryption1RTT {
|
||||||
h.handlePostHandshakeMessage()
|
h.handlePostHandshakeMessage()
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
var strFinished bool
|
readLoop:
|
||||||
switch h.perspective {
|
for {
|
||||||
case protocol.PerspectiveClient:
|
select {
|
||||||
strFinished = h.handleMessageForClient(msgType)
|
case data := <-h.paramsChan:
|
||||||
case protocol.PerspectiveServer:
|
h.handleTransportParameters(data)
|
||||||
strFinished = h.handleMessageForServer(msgType)
|
case <-h.isReadingHandshakeMessage:
|
||||||
default:
|
break readLoop
|
||||||
panic("")
|
case <-h.handshakeDone:
|
||||||
|
break readLoop
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if strFinished {
|
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
|
||||||
h.logger.Debugf("Done with encryption level %s.", encLevel)
|
// but only if a handshake opener and sealer was created.
|
||||||
}
|
// Otherwise, a HelloRetryRequest was performed.
|
||||||
return strFinished
|
// We're done with the Handshake encryption level after processing the Finished message.
|
||||||
|
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
|
||||||
|
msgType == typeFinished
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||||
|
@ -383,108 +379,6 @@ func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protoco
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
|
|
||||||
switch msgType {
|
|
||||||
case typeClientHello:
|
|
||||||
select {
|
|
||||||
case <-h.writeRecord:
|
|
||||||
// If qtls sends a HelloRetryRequest, it will only write the record.
|
|
||||||
// If it accepts the ClientHello, it will first read the transport parameters.
|
|
||||||
h.logger.Debugf("Sending HelloRetryRequest")
|
|
||||||
return false
|
|
||||||
case data := <-h.paramsChan:
|
|
||||||
h.handleTransportParameters(data)
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// get the handshake read key
|
|
||||||
select {
|
|
||||||
case <-h.receivedReadKey:
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// get the handshake write key
|
|
||||||
select {
|
|
||||||
case <-h.receivedWriteKey:
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// get the 1-RTT write key
|
|
||||||
select {
|
|
||||||
case <-h.receivedWriteKey:
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
case typeCertificate, typeCertificateVerify:
|
|
||||||
// nothing to do
|
|
||||||
return false
|
|
||||||
case typeFinished:
|
|
||||||
// get the 1-RTT read key
|
|
||||||
select {
|
|
||||||
case <-h.receivedReadKey:
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
// unexpected message
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
|
|
||||||
switch msgType {
|
|
||||||
case typeServerHello:
|
|
||||||
// get the handshake write key
|
|
||||||
select {
|
|
||||||
case <-h.writeRecord:
|
|
||||||
// If qtls writes in response to a ServerHello, this means that this ServerHello
|
|
||||||
// is a HelloRetryRequest.
|
|
||||||
// Otherwise, we'd just wait for the Certificate message.
|
|
||||||
h.logger.Debugf("ServerHello is a HelloRetryRequest")
|
|
||||||
return false
|
|
||||||
case <-h.receivedWriteKey:
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// get the handshake read key
|
|
||||||
select {
|
|
||||||
case <-h.receivedReadKey:
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
case typeEncryptedExtensions:
|
|
||||||
select {
|
|
||||||
case data := <-h.paramsChan:
|
|
||||||
h.handleTransportParameters(data)
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
|
|
||||||
// nothing to do
|
|
||||||
return false
|
|
||||||
case typeFinished:
|
|
||||||
// get the 1-RTT read key
|
|
||||||
select {
|
|
||||||
case <-h.receivedReadKey:
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// get the handshake write key
|
|
||||||
select {
|
|
||||||
case <-h.receivedWriteKey:
|
|
||||||
case <-h.handshakeDone:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetup) handleTransportParameters(data []byte) {
|
func (h *cryptoSetup) handleTransportParameters(data []byte) {
|
||||||
var tp wire.TransportParameters
|
var tp wire.TransportParameters
|
||||||
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
|
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
|
||||||
|
@ -591,6 +485,7 @@ func (h *cryptoSetup) handlePostHandshakeMessage() {
|
||||||
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
|
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
|
||||||
alertChan := make(chan uint8, 1)
|
alertChan := make(chan uint8, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
<-h.isReadingHandshakeMessage
|
||||||
select {
|
select {
|
||||||
case alert := <-h.alertChan:
|
case alert := <-h.alertChan:
|
||||||
alertChan <- alert
|
alertChan <- alert
|
||||||
|
@ -606,6 +501,11 @@ func (h *cryptoSetup) handlePostHandshakeMessage() {
|
||||||
// ReadHandshakeMessage is called by TLS.
|
// ReadHandshakeMessage is called by TLS.
|
||||||
// It blocks until a new handshake message is available.
|
// It blocks until a new handshake message is available.
|
||||||
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
||||||
|
if !h.readFirstHandshakeMessage {
|
||||||
|
h.readFirstHandshakeMessage = true
|
||||||
|
} else {
|
||||||
|
h.isReadingHandshakeMessage <- struct{}{}
|
||||||
|
}
|
||||||
msg, ok := <-h.messageChan
|
msg, ok := <-h.messageChan
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("error while handling the handshake message")
|
return nil, errors.New("error while handling the handshake message")
|
||||||
|
@ -651,7 +551,6 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
|
||||||
if h.tracer != nil {
|
if h.tracer != nil {
|
||||||
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
|
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
|
||||||
}
|
}
|
||||||
h.receivedReadKey <- struct{}{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||||
|
@ -696,7 +595,6 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
|
||||||
if h.tracer != nil {
|
if h.tracer != nil {
|
||||||
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
|
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
|
||||||
}
|
}
|
||||||
h.receivedWriteKey <- struct{}{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteRecord is called when TLS writes data
|
// WriteRecord is called when TLS writes data
|
||||||
|
@ -717,11 +615,6 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
|
||||||
h.logger.Debugf("Not doing 0-RTT.")
|
h.logger.Debugf("Not doing 0-RTT.")
|
||||||
h.clientHelloWrittenChan <- nil
|
h.clientHelloWrittenChan <- nil
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// We need additional signaling to properly detect HelloRetryRequests.
|
|
||||||
// For servers: when the ServerHello is written.
|
|
||||||
// For clients: when a reply is sent in response to a ServerHello.
|
|
||||||
h.writeRecord <- struct{}{}
|
|
||||||
}
|
}
|
||||||
return n, err
|
return n, err
|
||||||
case protocol.EncryptionHandshake:
|
case protocol.EncryptionHandshake:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -22,6 +23,13 @@ import (
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
|
||||||
|
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
|
||||||
|
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
|
||||||
|
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
|
||||||
|
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
|
||||||
|
}
|
||||||
|
|
||||||
type chunk struct {
|
type chunk struct {
|
||||||
data []byte
|
data []byte
|
||||||
encLevel protocol.EncryptionLevel
|
encLevel protocol.EncryptionLevel
|
||||||
|
@ -257,9 +265,27 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case c := <-cChunkChan:
|
case c := <-cChunkChan:
|
||||||
server.HandleMessage(c.data, c.encLevel)
|
msgType := messageType(c.data[0])
|
||||||
|
finished := server.HandleMessage(c.data, c.encLevel)
|
||||||
|
if msgType == typeFinished {
|
||||||
|
Expect(finished).To(BeTrue())
|
||||||
|
} else if msgType == typeClientHello {
|
||||||
|
// If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys.
|
||||||
|
_, err := server.GetHandshakeOpener()
|
||||||
|
Expect(finished).To(Equal(err == nil))
|
||||||
|
} else {
|
||||||
|
Expect(finished).To(BeFalse())
|
||||||
|
}
|
||||||
case c := <-sChunkChan:
|
case c := <-sChunkChan:
|
||||||
client.HandleMessage(c.data, c.encLevel)
|
msgType := messageType(c.data[0])
|
||||||
|
finished := client.HandleMessage(c.data, c.encLevel)
|
||||||
|
if msgType == typeFinished {
|
||||||
|
Expect(finished).To(BeTrue())
|
||||||
|
} else if msgType == typeServerHello {
|
||||||
|
Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom)))
|
||||||
|
} else {
|
||||||
|
Expect(finished).To(BeFalse())
|
||||||
|
}
|
||||||
case <-done: // handshake complete
|
case <-done: // handshake complete
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue