diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 48dc8f34..a41167fb 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -98,6 +98,9 @@ type cryptoSetup struct { readEncLevel protocol.EncryptionLevel writeEncLevel protocol.EncryptionLevel + zeroRTTOpener LongHeaderOpener // only set for the server + zeroRTTSealer LongHeaderSealer // only set for the client + initialStream io.Writer initialOpener LongHeaderOpener initialSealer LongHeaderSealer @@ -482,6 +485,16 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { h.mutex.Lock() switch encLevel { + case qtls.Encryption0RTT: + if h.perspective == protocol.PerspectiveClient { + panic("Received 0-RTT read key for the client") + } + h.zeroRTTOpener = newLongHeaderOpener( + createAEAD(suite, trafficSecret), + newHeaderProtector(suite, trafficSecret, true), + ) + h.mutex.Unlock() + return case qtls.EncryptionHandshake: h.readEncLevel = protocol.EncryptionHandshake h.handshakeOpener = newHandshakeOpener( @@ -506,6 +519,16 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { h.mutex.Lock() switch encLevel { + case qtls.Encryption0RTT: + if h.perspective == protocol.PerspectiveServer { + panic("Received 0-RTT write key for the server") + } + h.zeroRTTSealer = newLongHeaderSealer( + createAEAD(suite, trafficSecret), + newHeaderProtector(suite, trafficSecret, true), + ) + h.mutex.Unlock() + return case qtls.EncryptionHandshake: h.writeEncLevel = protocol.EncryptionHandshake h.handshakeSealer = newHandshakeSealer( @@ -592,6 +615,16 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { return h.initialSealer, nil } +func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTSealer == nil { + return nil, errors.New("CryptoSetup: 0-RTT sealer not available") + } + return h.zeroRTTSealer, nil +} + func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { h.mutex.Lock() defer h.mutex.Unlock() @@ -625,6 +658,20 @@ func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { return h.initialOpener, nil } +func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTOpener == nil { + if h.initialOpener != nil { + return nil, ErrKeysNotYetAvailable + } + // if the initial opener is also not available, the keys were already dropped + return nil, ErrKeysDropped + } + return h.zeroRTTOpener, nil +} + func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { h.mutex.Lock() defer h.mutex.Unlock() diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 20fa7385..fe4afbe6 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -312,7 +312,7 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) } - handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { + handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { var cHandshakeComplete bool cChunkChan, cInitialStream, cHandshakeStream, cOneRTTStream := initStreams() cErrChan := make(chan error, 1) @@ -329,7 +329,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{}, cRunner, clientConf, - false, + enable0RTT, &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("client"), ) @@ -351,7 +351,7 @@ var _ = Describe("Crypto Setup TLS", func() { &TransportParameters{StatelessResetToken: &token}, sRunner, serverConf, - false, + enable0RTT, &congestion.RTTStats{}, utils.DefaultLogger.WithPrefix("server"), ) @@ -372,14 +372,14 @@ var _ = Describe("Crypto Setup TLS", func() { } It("handshakes", func() { - _, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf) + _, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) }) It("performs a HelloRetryRequst", func() { serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} - _, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf) + _, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) }) @@ -387,7 +387,7 @@ var _ = Describe("Crypto Setup TLS", func() { It("handshakes with client auth", func() { clientConf.Certificates = []tls.Certificate{generateCert()} serverConf.ClientAuth = qtls.RequireAnyClientCert - _, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf) + _, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) }) @@ -613,7 +613,7 @@ var _ = Describe("Crypto Setup TLS", func() { close(receivedSessionTicket) }) clientConf.ClientSessionCache = csc - client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf) + client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) Eventually(receivedSessionTicket).Should(BeClosed()) @@ -622,7 +622,7 @@ var _ = Describe("Crypto Setup TLS", func() { csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) - client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf) + client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) Eventually(receivedSessionTicket).Should(BeClosed()) @@ -640,7 +640,7 @@ var _ = Describe("Crypto Setup TLS", func() { close(receivedSessionTicket) }) clientConf.ClientSessionCache = csc - client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf) + client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) Eventually(receivedSessionTicket).Should(BeClosed()) @@ -649,13 +649,52 @@ var _ = Describe("Crypto Setup TLS", func() { serverConf.SessionTicketsDisabled = true csc.EXPECT().Get(gomock.Any()).Return(state, true) - client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf) + client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false) Expect(clientErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred()) Eventually(receivedSessionTicket).Should(BeClosed()) Expect(server.ConnectionState().DidResume).To(BeFalse()) Expect(client.ConnectionState().DidResume).To(BeFalse()) }) + + It("uses 0-RTT", func() { + csc := NewMockClientSessionCache(mockCtrl) + var state *tls.ClientSessionState + receivedSessionTicket := make(chan struct{}) + csc.EXPECT().Get(gomock.Any()) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { + state = css + close(receivedSessionTicket) + }) + clientConf.ClientSessionCache = csc + client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, true) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeFalse()) + Expect(client.ConnectionState().DidResume).To(BeFalse()) + + csc.EXPECT().Get(gomock.Any()).Return(state, true) + csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) + client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, true) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + Eventually(receivedSessionTicket).Should(BeClosed()) + Expect(server.ConnectionState().DidResume).To(BeTrue()) + Expect(client.ConnectionState().DidResume).To(BeTrue()) + opener, err := server.Get0RTTOpener() + Expect(err).ToNot(HaveOccurred()) + Expect(opener).ToNot(BeNil()) + sealer, err := client.Get0RTTSealer() + Expect(err).ToNot(HaveOccurred()) + Expect(sealer).ToNot(BeNil()) + // use the 0-RTT sealer and opener to encrypt and decrypt a message + plaintext := []byte("Lorem ipsum dolor sit amet") + msg := sealer.Seal(nil, plaintext, 0x1337, []byte("foobar")) + decrypted, err := opener.Open(nil, msg, 0x1337, []byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(plaintext)) + }) }) }) }) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 1159266a..1915a71b 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -78,9 +78,11 @@ type CryptoSetup interface { GetInitialOpener() (LongHeaderOpener, error) GetHandshakeOpener() (LongHeaderOpener, error) + Get0RTTOpener() (LongHeaderOpener, error) Get1RTTOpener() (ShortHeaderOpener, error) GetInitialSealer() (LongHeaderSealer, error) GetHandshakeSealer() (LongHeaderSealer, error) + Get0RTTSealer() (LongHeaderSealer, error) Get1RTTSealer() (ShortHeaderSealer, error) } diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 411bbd04..6565bd96 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -88,6 +88,36 @@ func (mr *MockCryptoSetupMockRecorder) DropHandshakeKeys() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropHandshakeKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DropHandshakeKeys)) } +// Get0RTTOpener mocks base method +func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get0RTTOpener") + ret0, _ := ret[0].(handshake.LongHeaderOpener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get0RTTOpener indicates an expected call of Get0RTTOpener +func (mr *MockCryptoSetupMockRecorder) Get0RTTOpener() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTOpener)) +} + +// Get0RTTSealer mocks base method +func (m *MockCryptoSetup) Get0RTTSealer() (handshake.LongHeaderSealer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get0RTTSealer") + ret0, _ := ret[0].(handshake.LongHeaderSealer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get0RTTSealer indicates an expected call of Get0RTTSealer +func (mr *MockCryptoSetupMockRecorder) Get0RTTSealer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTSealer)) +} + // Get1RTTOpener mocks base method func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { m.ctrl.T.Helper()