diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 747b280a..0c1d4694 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -8,7 +8,7 @@ import ( ) type cryptoDataHandler interface { - HandleData([]byte, protocol.EncryptionLevel) error + HandleData([]byte, protocol.EncryptionLevel) } type cryptoStreamManager struct { @@ -48,8 +48,6 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve if data == nil { return nil } - if err := m.cryptoHandler.HandleData(data, encLevel); err != nil { - return err - } + m.cryptoHandler.HandleData(data, encLevel) } } diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index bc281505..a7b777c5 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -1,8 +1,6 @@ package quic import ( - "errors" - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -54,12 +52,4 @@ var _ = Describe("Crypto Stream Manager", func() { cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake) Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed()) }) - - It("returns the error if handling crypto data fails", func() { - testErr := errors.New("test error") - f := &wire.CryptoFrame{Data: []byte("foobar")} - cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake).Return(testErr) - err := csm.HandleCryptoFrame(f, protocol.EncryptionHandshake) - Expect(err).To(MatchError(testErr)) - }) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 4a375412..3b0970d5 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -55,9 +55,20 @@ type cryptoSetupTLS struct { readEncLevel protocol.EncryptionLevel writeEncLevel protocol.EncryptionLevel - handleParamsCallback func(*TransportParameters) - handshakeEvent chan<- struct{} - handshakeComplete chan<- struct{} + handleParamsCallback func(*TransportParameters) + + // There are two ways that an error can occur during the handshake: + // 1. as a return value from qtls.Handshake() + // 2. when new data is passed to the crypto setup via HandleData() + // handshakeErrChan is closed when qtls.Handshake() errors + handshakeErrChan chan struct{} + // HandleData() sends errors on the messageErrChan + messageErrChan chan error + // handshakeEvent signals a change of encryption level to the session + handshakeEvent chan<- struct{} + // handshakeComplete is closed when the handshake completes + handshakeComplete chan<- struct{} + // transport parameters are sent on the receivedTransportParams, as soon as they are received receivedTransportParams <-chan TransportParameters clientHelloWritten bool @@ -190,6 +201,8 @@ func newCryptoSetupTLS( handshakeComplete: handshakeComplete, logger: logger, perspective: perspective, + handshakeErrChan: make(chan struct{}), + messageErrChan: make(chan error, 1), clientHelloWrittenChan: make(chan struct{}), messageChan: make(chan []byte, 100), receivedReadKey: make(chan struct{}), @@ -229,16 +242,37 @@ func (h *cryptoSetupTLS) RunHandshake() error { case protocol.PerspectiveServer: conn = qtls.Server(nil, h.tlsConf) } - if err := conn.Handshake(); err != nil { - close(h.receivedReadKey) - close(h.receivedWriteKey) + // Handle errors that might occur when HandleData() is called. + handshakeErrChan := make(chan error, 1) + handshakeComplete := make(chan struct{}) + go func() { + if err := conn.Handshake(); err != nil { + handshakeErrChan <- err + return + } + close(handshakeComplete) + }() + + select { + case <-handshakeComplete: // return when the handshake is done + close(h.handshakeComplete) + return nil + case err := <-handshakeErrChan: + // if handleMessageFor{server,client} are waiting for some qtls action, make them return + close(h.handshakeErrChan) + return err + case err := <-h.messageErrChan: + // If the handshake errored because of an error that occurred during HandleData(), + // that error message will be more useful than the error message generated by Handshake(). + // Close the message chan that qtls is receiving messages from. + // This will make qtls.Handshake() return. + // Thereby the go routine running qtls.Handshake() will return. + close(h.messageChan) return err } - close(h.handshakeComplete) - return nil } -func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) error { +func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) { var buf *bytes.Buffer switch encLevel { case protocol.EncryptionInitial: @@ -246,7 +280,8 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev case protocol.EncryptionHandshake: buf = &h.handshakeReadBuf default: - return fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel) + h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel) + return } buf.Write(data) for buf.Len() >= 4 { @@ -254,15 +289,14 @@ func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLev // read the TLS message length length := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) if buf.Len() < 4+length { // message not yet complete - return nil + return } msg := make([]byte, length+4) buf.Read(msg) if err := h.handleMessage(msg, encLevel); err != nil { - return err + h.messageErrChan <- err } } - return nil } // handleMessage handles a TLS handshake message. @@ -276,12 +310,13 @@ func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.Encryption h.messageChan <- data switch h.perspective { case protocol.PerspectiveClient: - return h.handleMessageForClient(msgType) + h.handleMessageForClient(msgType) case protocol.PerspectiveServer: - return h.handleMessageForServer(msgType) + h.handleMessageForServer(msgType) default: panic("") } + return nil } func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { @@ -300,65 +335,114 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot return fmt.Errorf("unexpected handshake message: %d", msgType) } if encLevel != expected { - return fmt.Errorf("expected handshake message %d to have encryption level %s, has %s", msgType, expected, encLevel) + return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel) } return nil } -func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) error { +func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) { switch msgType { case typeClientHello: - params := <-h.receivedTransportParams - h.handleParamsCallback(¶ms) - <-h.receivedWriteKey // get the handshake write key - <-h.receivedWriteKey // get the 1-RTT write key - <-h.receivedReadKey // get the handshake read key - h.handshakeEvent <- struct{}{} + select { + case params := <-h.receivedTransportParams: + h.handleParamsCallback(¶ms) + case <-h.handshakeErrChan: + return + } + // get the handshake write key + select { + case <-h.receivedWriteKey: + case <-h.handshakeErrChan: + return + } + // get the 1-RTT write key + select { + case <-h.receivedWriteKey: + case <-h.handshakeErrChan: + return + } + // get the handshake read key // TODO: check that the initial stream doesn't have any more data + select { + case <-h.receivedReadKey: + case <-h.handshakeErrChan: + return + } + h.handshakeEvent <- struct{}{} case typeCertificate, typeCertificateVerify: // nothing to do case typeFinished: - <-h.receivedReadKey // get the 1-RTT read key - h.handshakeEvent <- struct{}{} + // get the 1-RTT read key // TODO: check that the handshake stream doesn't have any more data + select { + case <-h.receivedReadKey: + case <-h.handshakeErrChan: + return + } + h.handshakeEvent <- struct{}{} default: - // TODO: think about what to do with unknown message types - return fmt.Errorf("Received unknown handshake message: %d", msgType) + panic("unexpected handshake message") } - return nil } -func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) error { +func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) { switch msgType { case typeServerHello: - <-h.receivedReadKey // get the handshake read key + // get the handshake read key + // TODO: check that the initial stream doesn't have any more data + select { + case <-h.receivedReadKey: + case <-h.handshakeErrChan: + return + } h.handshakeEvent <- struct{}{} case typeEncryptedExtensions: - params := <-h.receivedTransportParams - h.handleParamsCallback(¶ms) + select { + case params := <-h.receivedTransportParams: + h.handleParamsCallback(¶ms) + case <-h.handshakeErrChan: + return + } case typeCertificateRequest, typeCertificate, typeCertificateVerify: // nothing to do case typeFinished: - <-h.receivedWriteKey // get the handshake write key + // get the handshake write key // TODO: check that the initial stream doesn't have any more data + select { + case <-h.receivedWriteKey: + case <-h.handshakeErrChan: + return + } // While the order of these two is not defined by the TLS spec, // we have to do it on the same order as our TLS library does it. - <-h.receivedWriteKey // get the handshake write key - <-h.receivedReadKey // get the 1-RTT read key + // get the handshake write key + select { + case <-h.receivedWriteKey: + case <-h.handshakeErrChan: + return + } + // get the 1-RTT read key + select { + case <-h.receivedReadKey: + case <-h.handshakeErrChan: + return + } // TODO: check that the handshake stream doesn't have any more data h.handshakeEvent <- struct{}{} default: - // TODO: think about what to do with unknown extensions - return fmt.Errorf("Received unknown handshake message: %d", msgType) + panic("unexpected handshake message: ") } - return nil } // ReadHandshakeMessage is called by TLS. // It blocks until a new handshake message is available. func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) { // TODO: add some error handling here (when the session is closed) - return <-h.messageChan, nil + msg, ok := <-h.messageChan + if !ok { + return nil, errors.New("error while handling the handshake message") + } + return msg, nil } func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) { diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 5d98e09e..4ff37d04 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -48,25 +48,6 @@ func (s *stream) Write(b []byte) (int, error) { } var _ = Describe("Crypto Setup TLS", func() { - generateCert := func() tls.Certificate { - priv, err := rsa.GenerateKey(rand.Reader, 2048) - Expect(err).ToNot(HaveOccurred()) - tmpl := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{}, - SignatureAlgorithm: x509.SHA256WithRSA, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour), // valid for an hour - BasicConstraintsValid: true, - } - certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) - Expect(err).ToNot(HaveOccurred()) - return tls.Certificate{ - PrivateKey: priv, - Certificate: [][]byte{certDER}, - } - } - initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { chunkChan := make(chan chunk, 100) initialStream := newStream(chunkChan, protocol.EncryptionInitial) @@ -74,172 +55,16 @@ var _ = Describe("Crypto Setup TLS", func() { return chunkChan, initialStream, handshakeStream } - handshake := func( - client CryptoSetupTLS, - cChunkChan <-chan chunk, - server CryptoSetupTLS, - sChunkChan <-chan chunk) (error /* client error */, error /* server error */) { - done := make(chan struct{}) - defer close(done) - go func() { - defer GinkgoRecover() - for { - select { - case c := <-cChunkChan: - err := server.HandleData(c.data, c.encLevel) - Expect(err).ToNot(HaveOccurred()) - case c := <-sChunkChan: - err := client.HandleData(c.data, c.encLevel) - Expect(err).ToNot(HaveOccurred()) - case <-done: // handshake complete - } - } - }() - - serverErrChan := make(chan error) - go func() { - defer GinkgoRecover() - serverErrChan <- server.RunHandshake() - }() - - clientErr := client.RunHandshake() - var serverErr error - Eventually(serverErrChan).Should(Receive(&serverErr)) - return clientErr, serverErr - } - - handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - client, _, err := NewCryptoSetupTLSClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - &TransportParameters{}, - func(p *TransportParameters) {}, - make(chan struct{}, 100), - make(chan struct{}), - clientConf, - protocol.VersionTLS, - []protocol.VersionNumber{protocol.VersionTLS}, - protocol.VersionTLS, - utils.DefaultLogger.WithPrefix("client"), - protocol.PerspectiveClient, - ) - Expect(err).ToNot(HaveOccurred()) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() + It("returns Handshake() when an error occurs", func() { + _, sInitialStream, sHandshakeStream := initStreams() server, err := NewCryptoSetupTLSServer( sInitialStream, sHandshakeStream, protocol.ConnectionID{}, - &TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)}, - func(p *TransportParameters) {}, - make(chan struct{}, 100), - make(chan struct{}), - serverConf, - []protocol.VersionNumber{protocol.VersionTLS}, - protocol.VersionTLS, - utils.DefaultLogger.WithPrefix("server"), - protocol.PerspectiveServer, - ) - Expect(err).ToNot(HaveOccurred()) - - return handshake(client, cChunkChan, server, sChunkChan) - } - - It("handshakes", func() { - clientConf := &tls.Config{ServerName: "quic.clemente.io"} - serverConf := testdata.GetTLSConfig() - clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("handshakes with client auth", func() { - clientConf := &tls.Config{ - ServerName: "quic.clemente.io", - Certificates: []tls.Certificate{generateCert()}, - } - serverConf := testdata.GetTLSConfig() - serverConf.ClientAuth = qtls.RequireAnyClientCert - clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - }) - - It("signals when it has written the ClientHello", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - client, chChan, err := NewCryptoSetupTLSClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, make(chan struct{}, 100), make(chan struct{}), - &tls.Config{InsecureSkipVerify: true}, - protocol.VersionTLS, - []protocol.VersionNumber{protocol.VersionTLS}, - protocol.VersionTLS, - utils.DefaultLogger.WithPrefix("client"), - protocol.PerspectiveClient, - ) - Expect(err).ToNot(HaveOccurred()) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - client.RunHandshake() - close(done) - }() - var ch chunk - Eventually(cChunkChan).Should(Receive(&ch)) - Eventually(chChan).Should(BeClosed()) - // make sure the whole ClientHello was written - Expect(len(ch.data)).To(BeNumerically(">=", 4)) - Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) - length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) - Expect(len(ch.data) - 4).To(Equal(length)) - - // make the go routine return - client.HandleData([]byte{1, 0, 0, 1, 0}, protocol.EncryptionInitial) - Eventually(done).Should(BeClosed()) - }) - - It("receives transport parameters", func() { - var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} - client, _, err := NewCryptoSetupTLSClient( - cInitialStream, - cHandshakeStream, - protocol.ConnectionID{}, - cTransportParameters, - func(p *TransportParameters) { sTransportParametersRcvd = p }, - make(chan struct{}, 100), - make(chan struct{}), - &tls.Config{ServerName: "quic.clemente.io"}, - protocol.VersionTLS, - []protocol.VersionNumber{protocol.VersionTLS}, - protocol.VersionTLS, - utils.DefaultLogger.WithPrefix("client"), - protocol.PerspectiveClient, - ) - Expect(err).ToNot(HaveOccurred()) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sTransportParameters := &TransportParameters{ - IdleTimeout: 0x1337 * time.Second, - StatelessResetToken: bytes.Repeat([]byte{42}, 16), - } - server, err := NewCryptoSetupTLSServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - sTransportParameters, - func(p *TransportParameters) { cTransportParametersRcvd = p }, - make(chan struct{}, 100), - make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, protocol.VersionTLS, @@ -251,15 +76,253 @@ var _ = Describe("Crypto Setup TLS", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) + err := server.RunHandshake() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("received unexpected handshake message")) close(done) }() + + fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) + server.HandleData(fakeCH, protocol.EncryptionInitial) Eventually(done).Should(BeClosed()) - Expect(cTransportParametersRcvd).ToNot(BeNil()) - Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout)) - Expect(sTransportParametersRcvd).ToNot(BeNil()) - Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) + }) + + It("returns Handshake() when handling a message fails", func() { + _, sInitialStream, sHandshakeStream := initStreams() + server, err := NewCryptoSetupTLSServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + testdata.GetTLSConfig(), + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("server"), + protocol.PerspectiveServer, + ) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := server.RunHandshake() + Expect(err).To(MatchError("expected handshake message ClientHello to have encryption level Initial, has Handshake")) + close(done) + }() + + fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) + server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level + Eventually(done).Should(BeClosed()) + }) + + Context("doing the handshake", func() { + generateCert := func() tls.Certificate { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{}, + SignatureAlgorithm: x509.SHA256WithRSA, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), // valid for an hour + BasicConstraintsValid: true, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) + Expect(err).ToNot(HaveOccurred()) + return tls.Certificate{ + PrivateKey: priv, + Certificate: [][]byte{certDER}, + } + } + + handshake := func( + client CryptoSetupTLS, + cChunkChan <-chan chunk, + server CryptoSetupTLS, + sChunkChan <-chan chunk) (error /* client error */, error /* server error */) { + done := make(chan struct{}) + defer close(done) + go func() { + defer GinkgoRecover() + for { + select { + case c := <-cChunkChan: + server.HandleData(c.data, c.encLevel) + case c := <-sChunkChan: + client.HandleData(c.data, c.encLevel) + case <-done: // handshake complete + } + } + }() + + serverErrChan := make(chan error) + go func() { + defer GinkgoRecover() + serverErrChan <- server.RunHandshake() + }() + + clientErr := client.RunHandshake() + var serverErr error + Eventually(serverErrChan).Should(Receive(&serverErr)) + return clientErr, serverErr + } + + handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + client, _, err := NewCryptoSetupTLSClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + clientConf, + protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("client"), + protocol.PerspectiveClient, + ) + Expect(err).ToNot(HaveOccurred()) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + server, err := NewCryptoSetupTLSServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + serverConf, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("server"), + protocol.PerspectiveServer, + ) + Expect(err).ToNot(HaveOccurred()) + + return handshake(client, cChunkChan, server, sChunkChan) + } + + It("handshakes", func() { + clientConf := &tls.Config{ServerName: "quic.clemente.io"} + serverConf := testdata.GetTLSConfig() + clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("handshakes with client auth", func() { + clientConf := &tls.Config{ + ServerName: "quic.clemente.io", + Certificates: []tls.Certificate{generateCert()}, + } + serverConf := testdata.GetTLSConfig() + serverConf.ClientAuth = qtls.RequireAnyClientCert + clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) + + It("signals when it has written the ClientHello", func() { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + client, chChan, err := NewCryptoSetupTLSClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + &tls.Config{InsecureSkipVerify: true}, + protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("client"), + protocol.PerspectiveClient, + ) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + client.RunHandshake() + close(done) + }() + var ch chunk + Eventually(cChunkChan).Should(Receive(&ch)) + Eventually(chChan).Should(BeClosed()) + // make sure the whole ClientHello was written + Expect(len(ch.data)).To(BeNumerically(">=", 4)) + Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) + length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) + Expect(len(ch.data) - 4).To(Equal(length)) + + // make the go routine return + client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial) + Eventually(done).Should(BeClosed()) + }) + + It("receives transport parameters", func() { + var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} + client, _, err := NewCryptoSetupTLSClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + cTransportParameters, + func(p *TransportParameters) { sTransportParametersRcvd = p }, + make(chan struct{}, 100), + make(chan struct{}), + &tls.Config{ServerName: "quic.clemente.io"}, + protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("client"), + protocol.PerspectiveClient, + ) + Expect(err).ToNot(HaveOccurred()) + + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sTransportParameters := &TransportParameters{ + IdleTimeout: 0x1337 * time.Second, + StatelessResetToken: bytes.Repeat([]byte{42}, 16), + } + server, err := NewCryptoSetupTLSServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + sTransportParameters, + func(p *TransportParameters) { cTransportParametersRcvd = p }, + make(chan struct{}, 100), + make(chan struct{}), + testdata.GetTLSConfig(), + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("server"), + protocol.PerspectiveServer, + ) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + close(done) + }() + Eventually(done).Should(BeClosed()) + Expect(cTransportParametersRcvd).ToNot(BeNil()) + Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout)) + Expect(sTransportParametersRcvd).ToNot(BeNil()) + Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) + }) }) }) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index a9d71934..7acb6a13 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -44,7 +44,7 @@ type CryptoSetup interface { type CryptoSetupTLS interface { baseCryptoSetup - HandleData([]byte, protocol.EncryptionLevel) error + HandleData([]byte, protocol.EncryptionLevel) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) diff --git a/mock_crypto_data_handler.go b/mock_crypto_data_handler.go index 58efb560..b789ca8b 100644 --- a/mock_crypto_data_handler.go +++ b/mock_crypto_data_handler.go @@ -35,10 +35,8 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { } // HandleData mocks base method -func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) error { - ret := m.ctrl.Call(m, "HandleData", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 +func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) { + m.ctrl.Call(m, "HandleData", arg0, arg1) } // HandleData indicates an expected call of HandleData