From 2dbc29a5bd7ad7e86b43e086b20e07153a092c77 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 18 Oct 2018 22:55:02 +0100 Subject: [PATCH] fix error handling in the TLS crypto setup There are two ways that an error can occur during the handshake: 1. as a return value from qtls.Handshake() 2. when new data is passed to the crypto setup via HandleData() We need to make sure that the RunHandshake() as well as HandleData() both return if an error occurs at any step during the handshake. --- crypto_stream_manager.go | 6 +- crypto_stream_manager_test.go | 10 - internal/handshake/crypto_setup_tls.go | 162 ++++++-- internal/handshake/crypto_setup_tls_test.go | 431 +++++++++++--------- internal/handshake/interface.go | 2 +- mock_crypto_data_handler.go | 6 +- 6 files changed, 375 insertions(+), 242 deletions(-) diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 747b280a..0c1d4694 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -8,7 +8,7 @@ import ( ) type cryptoDataHandler interface { - HandleData([]byte, protocol.EncryptionLevel) error + HandleData([]byte, protocol.EncryptionLevel) } type cryptoStreamManager struct { @@ -48,8 +48,6 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve if data == nil { return nil } - if err := m.cryptoHandler.HandleData(data, encLevel); err != nil { - return err - } + m.cryptoHandler.HandleData(data, encLevel) } } diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index bc281505..a7b777c5 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -1,8 +1,6 @@ package quic import ( - "errors" - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -54,12 +52,4 @@ var _ = Describe("Crypto Stream Manager", func() { cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake) Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed()) }) - - It("returns the error if handling crypto data fails", func() { - testErr := errors.New("test error") - f := &wire.CryptoFrame{Data: []byte("foobar")} - cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake).Return(testErr) - err := csm.HandleCryptoFrame(f, protocol.EncryptionHandshake) - Expect(err).To(MatchError(testErr)) - }) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 4a375412..3b0970d5 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -55,9 +55,20 @@ type cryptoSetupTLS struct { readEncLevel protocol.EncryptionLevel writeEncLevel protocol.EncryptionLevel - handleParamsCallback func(*TransportParameters) - handshakeEvent chan<- struct{} - handshakeComplete chan<- struct{} + handleParamsCallback func(*TransportParameters) + + // There are two ways that an error can occur during the handshake: + // 1. as a return value from qtls.Handshake() + // 2. when new data is passed to the crypto setup via HandleData() + // handshakeErrChan is closed when qtls.Handshake() errors + handshakeErrChan chan struct{} + // HandleData() sends errors on the messageErrChan + messageErrChan chan error + // handshakeEvent signals a change of encryption level to the session + handshakeEvent chan<- struct{} + // handshakeComplete is closed when the handshake completes + handshakeComplete chan<- struct{} + // transport parameters are sent on the receivedTransportParams, as soon as they are received receivedTransportParams <-chan TransportParameters clientHelloWritten bool @@ -190,6 +201,8 @@ func newCryptoSetupTLS( handshakeComplete: handshakeComplete, logger: logger, perspective: perspective, + handshakeErrChan: make(chan struct{}), + messageErrChan: make(chan error, 1), clientHelloWrittenChan: make(chan struct{}), messageChan: make(chan []byte, 100), receivedReadKey: make(chan struct{}), @@ -229,16 +242,37 @@ func (h *cryptoSetupTLS) RunHandshake() error { case protocol.PerspectiveServer: conn = qtls.Server(nil, h.tlsConf) } - if err := conn.Handshake(); err != nil { - close(h.receivedReadKey) - close(h.receivedWriteKey) + // Handle errors that might occur when HandleData() is called. + handshakeErrChan := make(chan error, 1) + handshakeComplete := make(chan struct{}) + go func() { + if err := conn.Handshake(); err != nil { + handshakeErrChan <- err + return + } + close(handshakeComplete) + }() + + select { + case <-handshakeComplete: // return when the handshake is done + close(h.handshakeComplete) + return nil + case err := <-handshakeErrChan: + // if handleMessageFor{server,client} are waiting for some qtls action, make them return + close(h.handshakeErrChan) + return err + case err := <-h.messageErrChan: + // If the handshake errored because of an error that occurred during HandleData(), + // that error message will be more useful than the error message generated by Handshake(). + // Close the message chan that qtls is receiving messages from. + // This will make qtls.Handshake() return. + // Thereby the go routine running qtls.Handshake() will return. + close(h.messageChan) return err } - close(h.handshakeComplete) - return nil } -func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) error { +func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) { var buf *bytes.Buffer switch encLevel { case protocol.EncryptionInitial: @@ -246,7 +280,8 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev case protocol.EncryptionHandshake: buf = &h.handshakeReadBuf default: - return fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel) + h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel) + return } buf.Write(data) for buf.Len() >= 4 { @@ -254,15 +289,14 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev // read the TLS message length length := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) if buf.Len() < 4+length { // message not yet complete - return nil + return } msg := make([]byte, length+4) buf.Read(msg) if err := h.handleMessage(msg, encLevel); err != nil { - return err + h.messageErrChan <- err } } - return nil } // handleMessage handles a TLS handshake message. @@ -276,12 +310,13 @@ func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.Encryption h.messageChan <- data switch h.perspective { case protocol.PerspectiveClient: - return h.handleMessageForClient(msgType) + h.handleMessageForClient(msgType) case protocol.PerspectiveServer: - return h.handleMessageForServer(msgType) + h.handleMessageForServer(msgType) default: panic("") } + return nil } func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { @@ -300,65 +335,114 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot return fmt.Errorf("unexpected handshake message: %d", msgType) } if encLevel != expected { - return fmt.Errorf("expected handshake message %d to have encryption level %s, has %s", msgType, expected, encLevel) + return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel) } return nil } -func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) error { +func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) { switch msgType { case typeClientHello: - params := <-h.receivedTransportParams - h.handleParamsCallback(¶ms) - <-h.receivedWriteKey // get the handshake write key - <-h.receivedWriteKey // get the 1-RTT write key - <-h.receivedReadKey // get the handshake read key - h.handshakeEvent <- struct{}{} + select { + case params := <-h.receivedTransportParams: + h.handleParamsCallback(¶ms) + case <-h.handshakeErrChan: + return + } + // get the handshake write key + select { + case <-h.receivedWriteKey: + case <-h.handshakeErrChan: + return + } + // get the 1-RTT write key + select { + case <-h.receivedWriteKey: + case <-h.handshakeErrChan: + return + } + // get the handshake read key // TODO: check that the initial stream doesn't have any more data + select { + case <-h.receivedReadKey: + case <-h.handshakeErrChan: + return + } + h.handshakeEvent <- struct{}{} case typeCertificate, typeCertificateVerify: // nothing to do case typeFinished: - <-h.receivedReadKey // get the 1-RTT read key - h.handshakeEvent <- struct{}{} + // get the 1-RTT read key // TODO: check that the handshake stream doesn't have any more data + select { + case <-h.receivedReadKey: + case <-h.handshakeErrChan: + return + } + h.handshakeEvent <- struct{}{} default: - // TODO: think about what to do with unknown message types - return fmt.Errorf("Received unknown handshake message: %d", msgType) + panic("unexpected handshake message") } - return nil } -func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) error { +func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) { switch msgType { case typeServerHello: - <-h.receivedReadKey // get the handshake read key + // get the handshake read key + // TODO: check that the initial stream doesn't have any more data + select { + case <-h.receivedReadKey: + case <-h.handshakeErrChan: + return + } h.handshakeEvent <- struct{}{} case typeEncryptedExtensions: - params := <-h.receivedTransportParams - h.handleParamsCallback(¶ms) + select { + case params := <-h.receivedTransportParams: + h.handleParamsCallback(¶ms) + case <-h.handshakeErrChan: + return + } case typeCertificateRequest, typeCertificate, typeCertificateVerify: // nothing to do case typeFinished: - <-h.receivedWriteKey // get the handshake write key + // get the handshake write key // TODO: check that the initial stream doesn't have any more data + select { + case <-h.receivedWriteKey: + case <-h.handshakeErrChan: + return + } // While the order of these two is not defined by the TLS spec, // we have to do it on the same order as our TLS library does it. - <-h.receivedWriteKey // get the handshake write key - <-h.receivedReadKey // get the 1-RTT read key + // get the handshake write key + select { + case <-h.receivedWriteKey: + case <-h.handshakeErrChan: + return + } + // get the 1-RTT read key + select { + case <-h.receivedReadKey: + case <-h.handshakeErrChan: + return + } // TODO: check that the handshake stream doesn't have any more data h.handshakeEvent <- struct{}{} default: - // TODO: think about what to do with unknown extensions - return fmt.Errorf("Received unknown handshake message: %d", msgType) + panic("unexpected handshake message: ") } - return nil } // ReadHandshakeMessage is called by TLS. // It blocks until a new handshake message is available. func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) { // TODO: add some error handling here (when the session is closed) - return <-h.messageChan, nil + msg, ok := <-h.messageChan + if !ok { + return nil, errors.New("error while handling the handshake message") + } + return msg, nil } func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) { diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 5d98e09e..4ff37d04 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -48,25 +48,6 @@ func (s *stream) Write(b []byte) (int, error) { } var _ = Describe("Crypto Setup TLS", func() { - generateCert := func() tls.Certificate { - priv, err := rsa.GenerateKey(rand.Reader, 2048) - Expect(err).ToNot(HaveOccurred()) - tmpl := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{}, - SignatureAlgorithm: x509.SHA256WithRSA, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour), // valid for an hour - BasicConstraintsValid: true, - } - certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) - Expect(err).ToNot(HaveOccurred()) - return tls.Certificate{ - PrivateKey: priv, - Certificate: [][]byte{certDER}, - } - } - initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { chunkChan := make(chan chunk, 100) initialStream := newStream(chunkChan, protocol.EncryptionInitial) @@ -74,172 +55,16 @@ var _ = Describe("Crypto Setup TLS", func() { return chunkChan, initialStream, handshakeStream } - handshake := func( - client CryptoSetupTLS, - cChunkChan <-chan chunk, - server CryptoSetupTLS, - sChunkChan <-chan chunk) (error /* client error */, error /* server error */) { - done := make(chan struct{}) - defer close(done) - go func() { - defer GinkgoRecover() - for { - select { - case c := <-cChunkChan: - err := server.HandleData(c.data, c.encLevel) - Expect(err).ToNot(HaveOccurred()) - case c := <-sChunkChan: - err := client.HandleData(c.data, c.encLevel) - Expect(err).ToNot(HaveOccurred()) - case <-done: // handshake complete - } - } - }() - - serverErrChan := make(chan error) - go func() { - defer GinkgoRecover() - serverErrChan <- server.RunHandshake() - }() - - clientErr := client.RunHandshake() - var serverErr error - Eventually(serverErrChan).Should(Receive(&serverErr)) - return clientErr, serverErr - } - - handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - client, _, err := NewCryptoSetupTLSClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - &TransportParameters{}, - func(p *TransportParameters) {}, - make(chan struct{}, 100), - make(chan struct{}), - clientConf, - protocol.VersionTLS, - []protocol.VersionNumber{protocol.VersionTLS}, - protocol.VersionTLS, - utils.DefaultLogger.WithPrefix("client"), - protocol.PerspectiveClient, - ) - Expect(err).ToNot(HaveOccurred()) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() + It("returns Handshake() when an error occurs", func() { + _, sInitialStream, sHandshakeStream := initStreams() server, err := NewCryptoSetupTLSServer( sInitialStream, sHandshakeStream, protocol.ConnectionID{}, - &TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)}, - func(p *TransportParameters) {}, - make(chan struct{}, 100), - make(chan struct{}), - serverConf, - []protocol.VersionNumber{protocol.VersionTLS}, - protocol.VersionTLS, - utils.DefaultLogger.WithPrefix("server"), - protocol.PerspectiveServer, - ) - Expect(err).ToNot(HaveOccurred()) - - return handshake(client, cChunkChan, server, sChunkChan) - } - - It("handshakes", func() { - clientConf := &tls.Config{ServerName: "quic.clemente.io"} - serverConf := testdata.GetTLSConfig() - clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("handshakes with client auth", func() { - clientConf := &tls.Config{ - ServerName: "quic.clemente.io", - Certificates: []tls.Certificate{generateCert()}, - } - serverConf := testdata.GetTLSConfig() - serverConf.ClientAuth = qtls.RequireAnyClientCert - clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("signals when it has written the ClientHello", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - client, chChan, err := NewCryptoSetupTLSClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, make(chan struct{}, 100), make(chan struct{}), - &tls.Config{InsecureSkipVerify: true}, - protocol.VersionTLS, - []protocol.VersionNumber{protocol.VersionTLS}, - protocol.VersionTLS, - utils.DefaultLogger.WithPrefix("client"), - protocol.PerspectiveClient, - ) - Expect(err).ToNot(HaveOccurred()) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - client.RunHandshake() - close(done) - }() - var ch chunk - Eventually(cChunkChan).Should(Receive(&ch)) - Eventually(chChan).Should(BeClosed()) - // make sure the whole ClientHello was written - Expect(len(ch.data)).To(BeNumerically(">=", 4)) - Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) - length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) - Expect(len(ch.data) - 4).To(Equal(length)) - - // make the go routine return - client.HandleData([]byte{1, 0, 0, 1, 0}, protocol.EncryptionInitial) - Eventually(done).Should(BeClosed()) - }) - - It("receives transport parameters", func() { - var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} - client, _, err := NewCryptoSetupTLSClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - cTransportParameters, - func(p *TransportParameters) { sTransportParametersRcvd = p }, - make(chan struct{}, 100), - make(chan struct{}), - &tls.Config{ServerName: "quic.clemente.io"}, - protocol.VersionTLS, - []protocol.VersionNumber{protocol.VersionTLS}, - protocol.VersionTLS, - utils.DefaultLogger.WithPrefix("client"), - protocol.PerspectiveClient, - ) - Expect(err).ToNot(HaveOccurred()) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sTransportParameters := &TransportParameters{ - IdleTimeout: 0x1337 * time.Second, - StatelessResetToken: bytes.Repeat([]byte{42}, 16), - } - server, err := NewCryptoSetupTLSServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - sTransportParameters, - func(p *TransportParameters) { cTransportParametersRcvd = p }, - make(chan struct{}, 100), - make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, protocol.VersionTLS, @@ -251,15 +76,253 @@ var _ = Describe("Crypto Setup TLS", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) + err := server.RunHandshake() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("received unexpected handshake message")) close(done) }() + + fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) + server.HandleData(fakeCH, protocol.EncryptionInitial) Eventually(done).Should(BeClosed()) - Expect(cTransportParametersRcvd).ToNot(BeNil()) - Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout)) - Expect(sTransportParametersRcvd).ToNot(BeNil()) - Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) + }) + + It("returns Handshake() when handling a message fails", func() { + _, sInitialStream, sHandshakeStream := initStreams() + server, err := NewCryptoSetupTLSServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + testdata.GetTLSConfig(), + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("server"), + protocol.PerspectiveServer, + ) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := server.RunHandshake() + Expect(err).To(MatchError("expected handshake message ClientHello to have encryption level Initial, has Handshake")) + close(done) + }() + + fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) + server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level + Eventually(done).Should(BeClosed()) + }) + + Context("doing the handshake", func() { + generateCert := func() tls.Certificate { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{}, + SignatureAlgorithm: x509.SHA256WithRSA, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), // valid for an hour + BasicConstraintsValid: true, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) + Expect(err).ToNot(HaveOccurred()) + return tls.Certificate{ + PrivateKey: priv, + Certificate: [][]byte{certDER}, + } + } + + handshake := func( + client CryptoSetupTLS, + cChunkChan <-chan chunk, + server CryptoSetupTLS, + sChunkChan <-chan chunk) (error /* client error */, error /* server error */) { + done := make(chan struct{}) + defer close(done) + go func() { + defer GinkgoRecover() + for { + select { + case c := <-cChunkChan: + server.HandleData(c.data, c.encLevel) + case c := <-sChunkChan: + client.HandleData(c.data, c.encLevel) + case <-done: // handshake complete + } + } + }() + + serverErrChan := make(chan error) + go func() { + defer GinkgoRecover() + serverErrChan <- server.RunHandshake() + }() + + clientErr := client.RunHandshake() + var serverErr error + Eventually(serverErrChan).Should(Receive(&serverErr)) + return clientErr, serverErr + } + + handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + client, _, err := NewCryptoSetupTLSClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + clientConf, + protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("client"), + protocol.PerspectiveClient, + ) + Expect(err).ToNot(HaveOccurred()) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + server, err := NewCryptoSetupTLSServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + serverConf, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("server"), + protocol.PerspectiveServer, + ) + Expect(err).ToNot(HaveOccurred()) + + return handshake(client, cChunkChan, server, sChunkChan) + } + + It("handshakes", func() { + clientConf := &tls.Config{ServerName: "quic.clemente.io"} + serverConf := testdata.GetTLSConfig() + clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("handshakes with client auth", func() { + clientConf := &tls.Config{ + ServerName: "quic.clemente.io", + Certificates: []tls.Certificate{generateCert()}, + } + serverConf := testdata.GetTLSConfig() + serverConf.ClientAuth = qtls.RequireAnyClientCert + clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("signals when it has written the ClientHello", func() { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + client, chChan, err := NewCryptoSetupTLSClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + &tls.Config{InsecureSkipVerify: true}, + protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("client"), + protocol.PerspectiveClient, + ) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + client.RunHandshake() + close(done) + }() + var ch chunk + Eventually(cChunkChan).Should(Receive(&ch)) + Eventually(chChan).Should(BeClosed()) + // make sure the whole ClientHello was written + Expect(len(ch.data)).To(BeNumerically(">=", 4)) + Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) + length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) + Expect(len(ch.data) - 4).To(Equal(length)) + + // make the go routine return + client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial) + Eventually(done).Should(BeClosed()) + }) + + It("receives transport parameters", func() { + var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} + client, _, err := NewCryptoSetupTLSClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + cTransportParameters, + func(p *TransportParameters) { sTransportParametersRcvd = p }, + make(chan struct{}, 100), + make(chan struct{}), + &tls.Config{ServerName: "quic.clemente.io"}, + protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("client"), + protocol.PerspectiveClient, + ) + Expect(err).ToNot(HaveOccurred()) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sTransportParameters := &TransportParameters{ + IdleTimeout: 0x1337 * time.Second, + StatelessResetToken: bytes.Repeat([]byte{42}, 16), + } + server, err := NewCryptoSetupTLSServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + sTransportParameters, + func(p *TransportParameters) { cTransportParametersRcvd = p }, + make(chan struct{}, 100), + make(chan struct{}), + testdata.GetTLSConfig(), + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("server"), + protocol.PerspectiveServer, + ) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + close(done) + }() + Eventually(done).Should(BeClosed()) + Expect(cTransportParametersRcvd).ToNot(BeNil()) + Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout)) + Expect(sTransportParametersRcvd).ToNot(BeNil()) + Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) + }) }) }) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index a9d71934..7acb6a13 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -44,7 +44,7 @@ type CryptoSetup interface { type CryptoSetupTLS interface { baseCryptoSetup - HandleData([]byte, protocol.EncryptionLevel) error + HandleData([]byte, protocol.EncryptionLevel) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) diff --git a/mock_crypto_data_handler.go b/mock_crypto_data_handler.go index 58efb560..b789ca8b 100644 --- a/mock_crypto_data_handler.go +++ b/mock_crypto_data_handler.go @@ -35,10 +35,8 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { } // HandleData mocks base method -func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) error { - ret := m.ctrl.Call(m, "HandleData", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 +func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) { + m.ctrl.Call(m, "HandleData", arg0, arg1) } // HandleData indicates an expected call of HandleData