fix handling of multiple handshake messages in the case of errors

When receiving a handshake message after another handshake messages that
doesn't cause any action from the TLS stack (i.e. Certificate and
CertificateVerify), the handshake would run into a deadlock if the first
of these handshake messages caused an error in the TLS stack.

We need to make sure that we wait until a message has been fully
processed before proceeding with the handshake.
This commit is contained in:
Marten Seemann 2020-09-13 20:13:57 +07:00
parent 8bf5c782e3
commit c9bfde9ac0
2 changed files with 75 additions and 156 deletions

View file

@ -87,7 +87,9 @@ type cryptoSetup struct {
extraConf *qtls.ExtraConfig
conn *qtls.Conn
messageChan chan []byte
messageChan chan []byte
isReadingHandshakeMessage chan struct{}
readFirstHandshakeMessage bool
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
@ -105,15 +107,6 @@ type cryptoSetup struct {
clientHelloWritten bool
clientHelloWrittenChan chan *wire.TransportParameters
receivedWriteKey chan struct{}
receivedReadKey chan struct{}
// WriteRecord does a non-blocking send on this channel.
// This way, handleMessage can see if qtls tries to write a message.
// This is necessary:
// for servers: to see if a HelloRetryRequest should be sent in response to a ClientHello
// for clients: to see if a ServerHello is a HelloRetryRequest
writeRecord chan struct{}
rttStats *utils.RTTStats
tracer logging.ConnectionTracer
@ -231,29 +224,27 @@ func newCryptoSetup(
}
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
cs := &cryptoSetup{
tlsConf: tlsConf,
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
aead: newUpdatableAEAD(rttStats, tracer, logger),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
messageChan: make(chan []byte, 100),
receivedReadKey: make(chan struct{}),
receivedWriteKey: make(chan struct{}),
writeRecord: make(chan struct{}, 1),
closeChan: make(chan struct{}),
tlsConf: tlsConf,
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
aead: newUpdatableAEAD(rttStats, tracer, logger),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
messageChan: make(chan []byte, 100),
isReadingHandshakeMessage: make(chan struct{}),
closeChan: make(chan struct{}),
}
var maxEarlyData uint32
if enable0RTT {
@ -344,20 +335,25 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev
h.messageChan <- data
if encLevel == protocol.Encryption1RTT {
h.handlePostHandshakeMessage()
return false
}
var strFinished bool
switch h.perspective {
case protocol.PerspectiveClient:
strFinished = h.handleMessageForClient(msgType)
case protocol.PerspectiveServer:
strFinished = h.handleMessageForServer(msgType)
default:
panic("")
readLoop:
for {
select {
case data := <-h.paramsChan:
h.handleTransportParameters(data)
case <-h.isReadingHandshakeMessage:
break readLoop
case <-h.handshakeDone:
break readLoop
}
}
if strFinished {
h.logger.Debugf("Done with encryption level %s.", encLevel)
}
return strFinished
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
// but only if a handshake opener and sealer was created.
// Otherwise, a HelloRetryRequest was performed.
// We're done with the Handshake encryption level after processing the Finished message.
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
msgType == typeFinished
}
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
@ -383,108 +379,6 @@ func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protoco
return nil
}
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
switch msgType {
case typeClientHello:
select {
case <-h.writeRecord:
// If qtls sends a HelloRetryRequest, it will only write the record.
// If it accepts the ClientHello, it will first read the transport parameters.
h.logger.Debugf("Sending HelloRetryRequest")
return false
case data := <-h.paramsChan:
h.handleTransportParameters(data)
case <-h.handshakeDone:
return false
}
// get the handshake read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
// get the 1-RTT write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
return true
case typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
return true
default:
// unexpected message
return false
}
}
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
switch msgType {
case typeServerHello:
// get the handshake write key
select {
case <-h.writeRecord:
// If qtls writes in response to a ServerHello, this means that this ServerHello
// is a HelloRetryRequest.
// Otherwise, we'd just wait for the Certificate message.
h.logger.Debugf("ServerHello is a HelloRetryRequest")
return false
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
// get the handshake read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
return true
case typeEncryptedExtensions:
select {
case data := <-h.paramsChan:
h.handleTransportParameters(data)
case <-h.handshakeDone:
return false
}
return false
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
return true
default:
return false
}
}
func (h *cryptoSetup) handleTransportParameters(data []byte) {
var tp wire.TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
@ -591,6 +485,7 @@ func (h *cryptoSetup) handlePostHandshakeMessage() {
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
alertChan := make(chan uint8, 1)
go func() {
<-h.isReadingHandshakeMessage
select {
case alert := <-h.alertChan:
alertChan <- alert
@ -606,6 +501,11 @@ func (h *cryptoSetup) handlePostHandshakeMessage() {
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
if !h.readFirstHandshakeMessage {
h.readFirstHandshakeMessage = true
} else {
h.isReadingHandshakeMessage <- struct{}{}
}
msg, ok := <-h.messageChan
if !ok {
return nil, errors.New("error while handling the handshake message")
@ -651,7 +551,6 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
}
h.receivedReadKey <- struct{}{}
}
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
@ -696,7 +595,6 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
}
h.receivedWriteKey <- struct{}{}
}
// WriteRecord is called when TLS writes data
@ -717,11 +615,6 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
h.logger.Debugf("Not doing 0-RTT.")
h.clientHelloWrittenChan <- nil
}
} else {
// We need additional signaling to properly detect HelloRetryRequests.
// For servers: when the ServerHello is written.
// For clients: when a reply is sent in response to a ServerHello.
h.writeRecord <- struct{}{}
}
return n, err
case protocol.EncryptionHandshake:

View file

@ -1,6 +1,7 @@
package handshake
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@ -22,6 +23,13 @@ import (
. "github.com/onsi/gomega"
)
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
}
type chunk struct {
data []byte
encLevel protocol.EncryptionLevel
@ -257,9 +265,27 @@ var _ = Describe("Crypto Setup TLS", func() {
for {
select {
case c := <-cChunkChan:
server.HandleMessage(c.data, c.encLevel)
msgType := messageType(c.data[0])
finished := server.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeClientHello {
// If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys.
_, err := server.GetHandshakeOpener()
Expect(finished).To(Equal(err == nil))
} else {
Expect(finished).To(BeFalse())
}
case c := <-sChunkChan:
client.HandleMessage(c.data, c.encLevel)
msgType := messageType(c.data[0])
finished := client.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeServerHello {
Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom)))
} else {
Expect(finished).To(BeFalse())
}
case <-done: // handshake complete
return
}