mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
fix handling of multiple handshake messages in the case of errors
When receiving a handshake message after another handshake messages that doesn't cause any action from the TLS stack (i.e. Certificate and CertificateVerify), the handshake would run into a deadlock if the first of these handshake messages caused an error in the TLS stack. We need to make sure that we wait until a message has been fully processed before proceeding with the handshake.
This commit is contained in:
parent
8bf5c782e3
commit
c9bfde9ac0
2 changed files with 75 additions and 156 deletions
|
@ -87,7 +87,9 @@ type cryptoSetup struct {
|
|||
extraConf *qtls.ExtraConfig
|
||||
conn *qtls.Conn
|
||||
|
||||
messageChan chan []byte
|
||||
messageChan chan []byte
|
||||
isReadingHandshakeMessage chan struct{}
|
||||
readFirstHandshakeMessage bool
|
||||
|
||||
ourParams *wire.TransportParameters
|
||||
peerParams *wire.TransportParameters
|
||||
|
@ -105,15 +107,6 @@ type cryptoSetup struct {
|
|||
clientHelloWritten bool
|
||||
clientHelloWrittenChan chan *wire.TransportParameters
|
||||
|
||||
receivedWriteKey chan struct{}
|
||||
receivedReadKey chan struct{}
|
||||
// WriteRecord does a non-blocking send on this channel.
|
||||
// This way, handleMessage can see if qtls tries to write a message.
|
||||
// This is necessary:
|
||||
// for servers: to see if a HelloRetryRequest should be sent in response to a ClientHello
|
||||
// for clients: to see if a ServerHello is a HelloRetryRequest
|
||||
writeRecord chan struct{}
|
||||
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
tracer logging.ConnectionTracer
|
||||
|
@ -231,29 +224,27 @@ func newCryptoSetup(
|
|||
}
|
||||
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
|
||||
cs := &cryptoSetup{
|
||||
tlsConf: tlsConf,
|
||||
initialStream: initialStream,
|
||||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
handshakeStream: handshakeStream,
|
||||
aead: newUpdatableAEAD(rttStats, tracer, logger),
|
||||
readEncLevel: protocol.EncryptionInitial,
|
||||
writeEncLevel: protocol.EncryptionInitial,
|
||||
runner: runner,
|
||||
ourParams: tp,
|
||||
paramsChan: extHandler.TransportParameters(),
|
||||
rttStats: rttStats,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
handshakeDone: make(chan struct{}),
|
||||
alertChan: make(chan uint8),
|
||||
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
|
||||
messageChan: make(chan []byte, 100),
|
||||
receivedReadKey: make(chan struct{}),
|
||||
receivedWriteKey: make(chan struct{}),
|
||||
writeRecord: make(chan struct{}, 1),
|
||||
closeChan: make(chan struct{}),
|
||||
tlsConf: tlsConf,
|
||||
initialStream: initialStream,
|
||||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
handshakeStream: handshakeStream,
|
||||
aead: newUpdatableAEAD(rttStats, tracer, logger),
|
||||
readEncLevel: protocol.EncryptionInitial,
|
||||
writeEncLevel: protocol.EncryptionInitial,
|
||||
runner: runner,
|
||||
ourParams: tp,
|
||||
paramsChan: extHandler.TransportParameters(),
|
||||
rttStats: rttStats,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
handshakeDone: make(chan struct{}),
|
||||
alertChan: make(chan uint8),
|
||||
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
|
||||
messageChan: make(chan []byte, 100),
|
||||
isReadingHandshakeMessage: make(chan struct{}),
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
var maxEarlyData uint32
|
||||
if enable0RTT {
|
||||
|
@ -344,20 +335,25 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev
|
|||
h.messageChan <- data
|
||||
if encLevel == protocol.Encryption1RTT {
|
||||
h.handlePostHandshakeMessage()
|
||||
return false
|
||||
}
|
||||
var strFinished bool
|
||||
switch h.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
strFinished = h.handleMessageForClient(msgType)
|
||||
case protocol.PerspectiveServer:
|
||||
strFinished = h.handleMessageForServer(msgType)
|
||||
default:
|
||||
panic("")
|
||||
readLoop:
|
||||
for {
|
||||
select {
|
||||
case data := <-h.paramsChan:
|
||||
h.handleTransportParameters(data)
|
||||
case <-h.isReadingHandshakeMessage:
|
||||
break readLoop
|
||||
case <-h.handshakeDone:
|
||||
break readLoop
|
||||
}
|
||||
}
|
||||
if strFinished {
|
||||
h.logger.Debugf("Done with encryption level %s.", encLevel)
|
||||
}
|
||||
return strFinished
|
||||
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
|
||||
// but only if a handshake opener and sealer was created.
|
||||
// Otherwise, a HelloRetryRequest was performed.
|
||||
// We're done with the Handshake encryption level after processing the Finished message.
|
||||
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
|
||||
msgType == typeFinished
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||
|
@ -383,108 +379,6 @@ func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protoco
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
|
||||
switch msgType {
|
||||
case typeClientHello:
|
||||
select {
|
||||
case <-h.writeRecord:
|
||||
// If qtls sends a HelloRetryRequest, it will only write the record.
|
||||
// If it accepts the ClientHello, it will first read the transport parameters.
|
||||
h.logger.Debugf("Sending HelloRetryRequest")
|
||||
return false
|
||||
case data := <-h.paramsChan:
|
||||
h.handleTransportParameters(data)
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
// get the handshake read key
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
// get the handshake write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
// get the 1-RTT write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case typeCertificate, typeCertificateVerify:
|
||||
// nothing to do
|
||||
return false
|
||||
case typeFinished:
|
||||
// get the 1-RTT read key
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
// unexpected message
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
|
||||
switch msgType {
|
||||
case typeServerHello:
|
||||
// get the handshake write key
|
||||
select {
|
||||
case <-h.writeRecord:
|
||||
// If qtls writes in response to a ServerHello, this means that this ServerHello
|
||||
// is a HelloRetryRequest.
|
||||
// Otherwise, we'd just wait for the Certificate message.
|
||||
h.logger.Debugf("ServerHello is a HelloRetryRequest")
|
||||
return false
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
// get the handshake read key
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case typeEncryptedExtensions:
|
||||
select {
|
||||
case data := <-h.paramsChan:
|
||||
h.handleTransportParameters(data)
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
return false
|
||||
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
|
||||
// nothing to do
|
||||
return false
|
||||
case typeFinished:
|
||||
// get the 1-RTT read key
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
// get the handshake write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeDone:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleTransportParameters(data []byte) {
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
|
||||
|
@ -591,6 +485,7 @@ func (h *cryptoSetup) handlePostHandshakeMessage() {
|
|||
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
|
||||
alertChan := make(chan uint8, 1)
|
||||
go func() {
|
||||
<-h.isReadingHandshakeMessage
|
||||
select {
|
||||
case alert := <-h.alertChan:
|
||||
alertChan <- alert
|
||||
|
@ -606,6 +501,11 @@ func (h *cryptoSetup) handlePostHandshakeMessage() {
|
|||
// ReadHandshakeMessage is called by TLS.
|
||||
// It blocks until a new handshake message is available.
|
||||
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
||||
if !h.readFirstHandshakeMessage {
|
||||
h.readFirstHandshakeMessage = true
|
||||
} else {
|
||||
h.isReadingHandshakeMessage <- struct{}{}
|
||||
}
|
||||
msg, ok := <-h.messageChan
|
||||
if !ok {
|
||||
return nil, errors.New("error while handling the handshake message")
|
||||
|
@ -651,7 +551,6 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
|
|||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
|
||||
}
|
||||
h.receivedReadKey <- struct{}{}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
|
@ -696,7 +595,6 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
|
|||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
|
||||
}
|
||||
h.receivedWriteKey <- struct{}{}
|
||||
}
|
||||
|
||||
// WriteRecord is called when TLS writes data
|
||||
|
@ -717,11 +615,6 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
|
|||
h.logger.Debugf("Not doing 0-RTT.")
|
||||
h.clientHelloWrittenChan <- nil
|
||||
}
|
||||
} else {
|
||||
// We need additional signaling to properly detect HelloRetryRequests.
|
||||
// For servers: when the ServerHello is written.
|
||||
// For clients: when a reply is sent in response to a ServerHello.
|
||||
h.writeRecord <- struct{}{}
|
||||
}
|
||||
return n, err
|
||||
case protocol.EncryptionHandshake:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
|
@ -22,6 +23,13 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
|
||||
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
|
||||
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
|
||||
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
|
||||
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
data []byte
|
||||
encLevel protocol.EncryptionLevel
|
||||
|
@ -257,9 +265,27 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
for {
|
||||
select {
|
||||
case c := <-cChunkChan:
|
||||
server.HandleMessage(c.data, c.encLevel)
|
||||
msgType := messageType(c.data[0])
|
||||
finished := server.HandleMessage(c.data, c.encLevel)
|
||||
if msgType == typeFinished {
|
||||
Expect(finished).To(BeTrue())
|
||||
} else if msgType == typeClientHello {
|
||||
// If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys.
|
||||
_, err := server.GetHandshakeOpener()
|
||||
Expect(finished).To(Equal(err == nil))
|
||||
} else {
|
||||
Expect(finished).To(BeFalse())
|
||||
}
|
||||
case c := <-sChunkChan:
|
||||
client.HandleMessage(c.data, c.encLevel)
|
||||
msgType := messageType(c.data[0])
|
||||
finished := client.HandleMessage(c.data, c.encLevel)
|
||||
if msgType == typeFinished {
|
||||
Expect(finished).To(BeTrue())
|
||||
} else if msgType == typeServerHello {
|
||||
Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom)))
|
||||
} else {
|
||||
Expect(finished).To(BeFalse())
|
||||
}
|
||||
case <-done: // handshake complete
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue