diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index e0261fac..d0d51151 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -67,6 +67,7 @@ type cryptoSetup struct { messageChan chan []byte ourParams *TransportParameters + peerParams *TransportParameters paramsChan <-chan []byte runner handshakeRunner @@ -77,8 +78,9 @@ type cryptoSetup struct { // is closed when Close() is called closeChan chan struct{} + zeroRTTParameters *TransportParameters clientHelloWritten bool - clientHelloWrittenChan chan struct{} + clientHelloWrittenChan chan *TransportParameters receivedWriteKey chan struct{} receivedReadKey chan struct{} @@ -131,7 +133,7 @@ func NewCryptoSetupClient( enable0RTT bool, rttStats *congestion.RTTStats, logger utils.Logger, -) (CryptoSetup, <-chan struct{} /* ClientHello written */) { +) (CryptoSetup, <-chan *TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { cs, clientHelloWritten := newCryptoSetup( initialStream, handshakeStream, @@ -192,7 +194,7 @@ func newCryptoSetup( rttStats *congestion.RTTStats, logger utils.Logger, perspective protocol.Perspective, -) (*cryptoSetup, <-chan struct{} /* ClientHello written */) { +) (*cryptoSetup, <-chan *TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { initialSealer, initialOpener := NewInitialAEAD(connID, perspective) extHandler := newExtensionHandler(tp.Marshal(), perspective) cs := &cryptoSetup{ @@ -211,14 +213,14 @@ func newCryptoSetup( perspective: perspective, handshakeDone: make(chan struct{}), alertChan: make(chan uint8), - clientHelloWrittenChan: make(chan struct{}), + clientHelloWrittenChan: make(chan *TransportParameters, 1), messageChan: make(chan []byte, 100), receivedReadKey: make(chan struct{}), receivedWriteKey: make(chan struct{}), writeRecord: make(chan struct{}, 1), closeChan: make(chan struct{}), } - qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, cs.accept0RTT, enable0RTT) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, cs.marshalPeerParamsForSessionState, cs.handlePeerParamsFromSessionState, cs.accept0RTT, enable0RTT) cs.tlsConf = qtlsConf return cs, cs.clientHelloWrittenChan } @@ -436,7 +438,22 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) { if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { h.runner.OnError(qerr.Error(qerr.TransportParameterError, err.Error())) } - h.runner.OnReceivedParams(&tp) + h.peerParams = &tp + h.runner.OnReceivedParams(h.peerParams) +} + +// must be called after receiving the transport parameters +func (h *cryptoSetup) marshalPeerParamsForSessionState() []byte { + return h.peerParams.MarshalForSessionTicket() +} + +func (h *cryptoSetup) handlePeerParamsFromSessionState(data []byte) { + var tp TransportParameters + if err := tp.Unmarshal(data, protocol.PerspectiveServer); err != nil { + h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) + return + } + h.zeroRTTParameters = &tp } // only valid for the server @@ -569,7 +586,13 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { n, err := h.initialStream.Write(p) if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient { h.clientHelloWritten = true - close(h.clientHelloWrittenChan) + if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { + h.logger.Debugf("Doing 0-RTT.") + h.clientHelloWrittenChan <- h.zeroRTTParameters + } else { + h.logger.Debugf("Not doing 0-RTT. Has Sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil) + h.clientHelloWrittenChan <- nil + } } else { // We need additional signaling to properly detect HelloRetryRequests. // For servers: when the ServerHello is written. diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index d58f61c3..55d5bc7b 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -417,7 +417,7 @@ var _ = Describe("Crypto Setup TLS", func() { }() var ch chunk Eventually(cChunkChan).Should(Receive(&ch)) - Eventually(chChan).Should(BeClosed()) + Eventually(chChan).Should(Receive(BeNil())) // make sure the whole ClientHello was written Expect(len(ch.data)).To(BeNumerically(">=", 4)) Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) @@ -671,10 +671,53 @@ var _ = Describe("Crypto Setup TLS", func() { csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) - client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, true) - Expect(clientErr).ToNot(HaveOccurred()) - Expect(serverErr).ToNot(HaveOccurred()) - Eventually(receivedSessionTicket).Should(BeClosed()) + + cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() + cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnHandshakeComplete() + client, clientHelloChan := NewCryptoSetupClient( + cInitialStream, + cHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + nil, + &TransportParameters{}, + cRunner, + clientConf, + true, + &congestion.RTTStats{}, + utils.DefaultLogger.WithPrefix("client"), + ) + + sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() + sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnHandshakeComplete() + server = NewCryptoSetupServer( + sInitialStream, + sHandshakeStream, + ioutil.Discard, + protocol.ConnectionID{}, + nil, + &TransportParameters{}, + sRunner, + serverConf, + true, + &congestion.RTTStats{}, + utils.DefaultLogger.WithPrefix("server"), + ) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + handshake(client, cChunkChan, server, sChunkChan) + close(done) + }() + Eventually(done).Should(BeClosed()) + + Expect(clientHelloChan).To(Receive(Not(BeNil()))) + Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue()) opener, err := server.Get0RTTOpener() diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index 2528bf17..ab6d3294 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -31,6 +31,8 @@ func tlsConfigToQtlsConfig( c *tls.Config, recordLayer qtls.RecordLayer, extHandler tlsExtensionHandler, + getDataForSessionState func() []byte, + setDataFromSessionState func([]byte), accept0RTT func([]byte) bool, enable0RTT bool, ) *qtls.Config { @@ -59,16 +61,12 @@ func tlsConfigToQtlsConfig( if tlsConf == nil { return nil, nil } - return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, accept0RTT, enable0RTT), nil + return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, getDataForSessionState, setDataFromSessionState, accept0RTT, enable0RTT), nil } } var csc qtls.ClientSessionCache if c.ClientSessionCache != nil { - csc = newClientSessionCache( - c.ClientSessionCache, - func() []byte { return nil }, - func([]byte) {}, - ) + csc = newClientSessionCache(c.ClientSessionCache, getDataForSessionState, setDataFromSessionState) } conf := &qtls.Config{ Rand: c.Rand, diff --git a/internal/handshake/qtls_test.go b/internal/handshake/qtls_test.go index 55ea1acf..6eb0ad1d 100644 --- a/internal/handshake/qtls_test.go +++ b/internal/handshake/qtls_test.go @@ -28,19 +28,19 @@ func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not im var _ = Describe("qtls.Config generation", func() { It("sets MinVersion and MaxVersion", func() { tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12} - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false) Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13)) }) It("works when called with a nil config", func() { - qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, false) + qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, nil, false) Expect(qtlsConf).ToNot(BeNil()) }) It("sets the setter and getter function for TLS extensions", func() { extHandler := &mockExtensionHandler{} - qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, nil, false) + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, nil, nil, nil, false) Expect(extHandler.get).To(BeFalse()) qtlsConf.GetExtensions(10) Expect(extHandler.get).To(BeTrue()) @@ -51,31 +51,31 @@ var _ = Describe("qtls.Config generation", func() { It("sets the Accept0RTT callback", func() { accept0RTT := func([]byte) bool { return true } - qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, accept0RTT, false) + qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, accept0RTT, false) Expect(qtlsConf.Accept0RTT).ToNot(BeNil()) Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue()) }) It("enables 0-RTT", func() { - qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, false) + qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, nil, false) Expect(qtlsConf.Enable0RTT).To(BeFalse()) Expect(qtlsConf.MaxEarlyData).To(BeZero()) - qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, true) + qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, nil, true) Expect(qtlsConf.Enable0RTT).To(BeTrue()) Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff))) }) It("initializes such that the session ticket key remains constant", func() { tlsConf := &tls.Config{} - qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) - qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) + qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false) + qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false) Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey)) }) Context("GetConfigForClient callback", func() { It("doesn't set it if absent", func() { - qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, false) + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, nil, nil, false) Expect(qtlsConf.GetConfigForClient).To(BeNil()) }) @@ -86,7 +86,7 @@ var _ = Describe("qtls.Config generation", func() { }, } extHandler := &mockExtensionHandler{} - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, nil, nil, nil, false) Expect(qtlsConf.GetConfigForClient).ToNot(BeNil()) confForClient, err := qtlsConf.GetConfigForClient(nil) Expect(err).ToNot(HaveOccurred()) @@ -106,7 +106,7 @@ var _ = Describe("qtls.Config generation", func() { return nil, testErr }, } - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false) _, err := qtlsConf.GetConfigForClient(nil) Expect(err).To(MatchError(testErr)) }) @@ -117,35 +117,49 @@ var _ = Describe("qtls.Config generation", func() { return nil, nil }, } - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false) Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil()) }) }) Context("ClientSessionCache", func() { It("doesn't set if absent", func() { - qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, false) + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, nil, nil, false) Expect(qtlsConf.ClientSessionCache).To(BeNil()) }) It("sets it, and puts and gets session states", func() { csc := NewMockClientSessionCache(mockCtrl) tlsConf := &tls.Config{ClientSessionCache: csc} - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) + var appData []byte + qtlsConf := tlsConfigToQtlsConfig( + tlsConf, + nil, + &mockExtensionHandler{}, + func() []byte { return []byte("foobar") }, + func(p []byte) { appData = p }, + nil, + false, + ) Expect(qtlsConf.ClientSessionCache).ToNot(BeNil()) + + var state *tls.ClientSessionState // put something - csc.EXPECT().Put("foobar", gomock.Any()) - qtlsConf.ClientSessionCache.Put("foobar", &qtls.ClientSessionState{}) + csc.EXPECT().Put("localhost", gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + }) + qtlsConf.ClientSessionCache.Put("localhost", &qtls.ClientSessionState{}) // get something - csc.EXPECT().Get("raboof").Return(nil, true) - _, ok := qtlsConf.ClientSessionCache.Get("raboof") + csc.EXPECT().Get("localhost").Return(state, true) + _, ok := qtlsConf.ClientSessionCache.Get("localhost") Expect(ok).To(BeTrue()) + Expect(appData).To(Equal([]byte("foobar"))) }) It("puts a nil session state", func() { csc := NewMockClientSessionCache(mockCtrl) tlsConf := &tls.Config{ClientSessionCache: csc} - qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false) // put something csc.EXPECT().Put("foobar", nil) qtlsConf.ClientSessionCache.Put("foobar", nil) diff --git a/session.go b/session.go index 8f58432f..0c40f1cf 100644 --- a/session.go +++ b/session.go @@ -156,7 +156,7 @@ type session struct { undecryptablePackets []*receivedPacket - clientHelloWritten <-chan struct{} + clientHelloWritten <-chan *handshake.TransportParameters earlySessionReadyChan chan struct{} handshakeCompleteChan chan struct{} // is closed when the handshake completes handshakeComplete bool @@ -463,8 +463,11 @@ func (s *session) run() error { if s.perspective == protocol.PerspectiveClient { select { - case <-s.clientHelloWritten: + case zeroRTTParams := <-s.clientHelloWritten: s.scheduleSending() + if zeroRTTParams != nil { + s.processTransportParameters(zeroRTTParams) + } case closeErr := <-s.closeChan: // put the close error back into the channel, so that the run loop can receive it s.closeChan <- closeErr @@ -1099,7 +1102,7 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete return } - s.logger.Debugf("Received Transport Parameters: %s", params) + s.logger.Debugf("Processed Transport Parameters: %s", params) s.peerParams = params // Our local idle timeout will always be > 0. s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) @@ -1124,7 +1127,9 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete } // On the server side, the early session is ready as soon as we processed // the client's transport parameters. - close(s.earlySessionReadyChan) + if s.perspective == protocol.PerspectiveServer { + close(s.earlySessionReadyChan) + } } func (s *session) sendPackets() error {