diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 807ebd18..54c39f32 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -87,7 +87,9 @@ type cryptoSetup struct { extraConf *qtls.ExtraConfig conn *qtls.Conn - messageChan chan []byte + messageChan chan []byte + isReadingHandshakeMessage chan struct{} + readFirstHandshakeMessage bool ourParams *wire.TransportParameters peerParams *wire.TransportParameters @@ -105,15 +107,6 @@ type cryptoSetup struct { clientHelloWritten bool clientHelloWrittenChan chan *wire.TransportParameters - receivedWriteKey chan struct{} - receivedReadKey chan struct{} - // WriteRecord does a non-blocking send on this channel. - // This way, handleMessage can see if qtls tries to write a message. - // This is necessary: - // for servers: to see if a HelloRetryRequest should be sent in response to a ClientHello - // for clients: to see if a ServerHello is a HelloRetryRequest - writeRecord chan struct{} - rttStats *utils.RTTStats tracer logging.ConnectionTracer @@ -231,29 +224,27 @@ func newCryptoSetup( } extHandler := newExtensionHandler(tp.Marshal(perspective), perspective) cs := &cryptoSetup{ - tlsConf: tlsConf, - initialStream: initialStream, - initialSealer: initialSealer, - initialOpener: initialOpener, - handshakeStream: handshakeStream, - aead: newUpdatableAEAD(rttStats, tracer, logger), - readEncLevel: protocol.EncryptionInitial, - writeEncLevel: protocol.EncryptionInitial, - runner: runner, - ourParams: tp, - paramsChan: extHandler.TransportParameters(), - rttStats: rttStats, - tracer: tracer, - logger: logger, - perspective: perspective, - handshakeDone: make(chan struct{}), - alertChan: make(chan uint8), - clientHelloWrittenChan: make(chan *wire.TransportParameters, 1), - messageChan: make(chan []byte, 100), - receivedReadKey: make(chan struct{}), - receivedWriteKey: make(chan struct{}), - writeRecord: make(chan struct{}, 1), - closeChan: make(chan struct{}), + tlsConf: tlsConf, + initialStream: initialStream, + initialSealer: initialSealer, + initialOpener: initialOpener, + handshakeStream: handshakeStream, + aead: newUpdatableAEAD(rttStats, tracer, logger), + readEncLevel: protocol.EncryptionInitial, + writeEncLevel: protocol.EncryptionInitial, + runner: runner, + ourParams: tp, + paramsChan: extHandler.TransportParameters(), + rttStats: rttStats, + tracer: tracer, + logger: logger, + perspective: perspective, + handshakeDone: make(chan struct{}), + alertChan: make(chan uint8), + clientHelloWrittenChan: make(chan *wire.TransportParameters, 1), + messageChan: make(chan []byte, 100), + isReadingHandshakeMessage: make(chan struct{}), + closeChan: make(chan struct{}), } var maxEarlyData uint32 if enable0RTT { @@ -344,20 +335,25 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev h.messageChan <- data if encLevel == protocol.Encryption1RTT { h.handlePostHandshakeMessage() + return false } - var strFinished bool - switch h.perspective { - case protocol.PerspectiveClient: - strFinished = h.handleMessageForClient(msgType) - case protocol.PerspectiveServer: - strFinished = h.handleMessageForServer(msgType) - default: - panic("") +readLoop: + for { + select { + case data := <-h.paramsChan: + h.handleTransportParameters(data) + case <-h.isReadingHandshakeMessage: + break readLoop + case <-h.handshakeDone: + break readLoop + } } - if strFinished { - h.logger.Debugf("Done with encryption level %s.", encLevel) - } - return strFinished + // We're done with the Initial encryption level after processing a ClientHello / ServerHello, + // but only if a handshake opener and sealer was created. + // Otherwise, a HelloRetryRequest was performed. + // We're done with the Handshake encryption level after processing the Finished message. + return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) || + msgType == typeFinished } func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { @@ -383,108 +379,6 @@ func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protoco return nil } -func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool { - switch msgType { - case typeClientHello: - select { - case <-h.writeRecord: - // If qtls sends a HelloRetryRequest, it will only write the record. - // If it accepts the ClientHello, it will first read the transport parameters. - h.logger.Debugf("Sending HelloRetryRequest") - return false - case data := <-h.paramsChan: - h.handleTransportParameters(data) - case <-h.handshakeDone: - return false - } - // get the handshake read key - select { - case <-h.receivedReadKey: - case <-h.handshakeDone: - return false - } - // get the handshake write key - select { - case <-h.receivedWriteKey: - case <-h.handshakeDone: - return false - } - // get the 1-RTT write key - select { - case <-h.receivedWriteKey: - case <-h.handshakeDone: - return false - } - return true - case typeCertificate, typeCertificateVerify: - // nothing to do - return false - case typeFinished: - // get the 1-RTT read key - select { - case <-h.receivedReadKey: - case <-h.handshakeDone: - return false - } - return true - default: - // unexpected message - return false - } -} - -func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool { - switch msgType { - case typeServerHello: - // get the handshake write key - select { - case <-h.writeRecord: - // If qtls writes in response to a ServerHello, this means that this ServerHello - // is a HelloRetryRequest. - // Otherwise, we'd just wait for the Certificate message. - h.logger.Debugf("ServerHello is a HelloRetryRequest") - return false - case <-h.receivedWriteKey: - case <-h.handshakeDone: - return false - } - // get the handshake read key - select { - case <-h.receivedReadKey: - case <-h.handshakeDone: - return false - } - return true - case typeEncryptedExtensions: - select { - case data := <-h.paramsChan: - h.handleTransportParameters(data) - case <-h.handshakeDone: - return false - } - return false - case typeCertificateRequest, typeCertificate, typeCertificateVerify: - // nothing to do - return false - case typeFinished: - // get the 1-RTT read key - select { - case <-h.receivedReadKey: - case <-h.handshakeDone: - return false - } - // get the handshake write key - select { - case <-h.receivedWriteKey: - case <-h.handshakeDone: - return false - } - return true - default: - return false - } -} - func (h *cryptoSetup) handleTransportParameters(data []byte) { var tp wire.TransportParameters if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { @@ -591,6 +485,7 @@ func (h *cryptoSetup) handlePostHandshakeMessage() { // Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock. alertChan := make(chan uint8, 1) go func() { + <-h.isReadingHandshakeMessage select { case alert := <-h.alertChan: alertChan <- alert @@ -606,6 +501,11 @@ func (h *cryptoSetup) handlePostHandshakeMessage() { // ReadHandshakeMessage is called by TLS. // It blocks until a new handshake message is available. func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { + if !h.readFirstHandshakeMessage { + h.readFirstHandshakeMessage = true + } else { + h.isReadingHandshakeMessage <- struct{}{} + } msg, ok := <-h.messageChan if !ok { return nil, errors.New("error while handling the handshake message") @@ -651,7 +551,6 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph if h.tracer != nil { h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite()) } - h.receivedReadKey <- struct{}{} } func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { @@ -696,7 +595,6 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip if h.tracer != nil { h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective) } - h.receivedWriteKey <- struct{}{} } // WriteRecord is called when TLS writes data @@ -717,11 +615,6 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { h.logger.Debugf("Not doing 0-RTT.") h.clientHelloWrittenChan <- nil } - } else { - // We need additional signaling to properly detect HelloRetryRequests. - // For servers: when the ServerHello is written. - // For clients: when a reply is sent in response to a ServerHello. - h.writeRecord <- struct{}{} } return n, err case protocol.EncryptionHandshake: diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index df06f007..f138710c 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -1,6 +1,7 @@ package handshake import ( + "bytes" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -22,6 +23,13 @@ import ( . "github.com/onsi/gomega" ) +var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +} + type chunk struct { data []byte encLevel protocol.EncryptionLevel @@ -257,9 +265,27 @@ var _ = Describe("Crypto Setup TLS", func() { for { select { case c := <-cChunkChan: - server.HandleMessage(c.data, c.encLevel) + msgType := messageType(c.data[0]) + finished := server.HandleMessage(c.data, c.encLevel) + if msgType == typeFinished { + Expect(finished).To(BeTrue()) + } else if msgType == typeClientHello { + // If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys. + _, err := server.GetHandshakeOpener() + Expect(finished).To(Equal(err == nil)) + } else { + Expect(finished).To(BeFalse()) + } case c := <-sChunkChan: - client.HandleMessage(c.data, c.encLevel) + msgType := messageType(c.data[0]) + finished := client.HandleMessage(c.data, c.encLevel) + if msgType == typeFinished { + Expect(finished).To(BeTrue()) + } else if msgType == typeServerHello { + Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom))) + } else { + Expect(finished).To(BeFalse()) + } case <-done: // handshake complete return }