receive 0-RTT sealers and openers from qtls when using 0-RTT

This commit is contained in:
Marten Seemann 2019-08-04 10:47:37 +07:00
parent a9f4195fd0
commit c5f74595ca
4 changed files with 128 additions and 10 deletions

View file

@ -98,6 +98,9 @@ type cryptoSetup struct {
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
zeroRTTOpener LongHeaderOpener // only set for the server
zeroRTTSealer LongHeaderSealer // only set for the client
initialStream io.Writer
initialOpener LongHeaderOpener
initialSealer LongHeaderSealer
@ -482,6 +485,16 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
if h.perspective == protocol.PerspectiveClient {
panic("Received 0-RTT read key for the client")
}
h.zeroRTTOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
)
h.mutex.Unlock()
return
case qtls.EncryptionHandshake:
h.readEncLevel = protocol.EncryptionHandshake
h.handshakeOpener = newHandshakeOpener(
@ -506,6 +519,16 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
if h.perspective == protocol.PerspectiveServer {
panic("Received 0-RTT write key for the server")
}
h.zeroRTTSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
)
h.mutex.Unlock()
return
case qtls.EncryptionHandshake:
h.writeEncLevel = protocol.EncryptionHandshake
h.handshakeSealer = newHandshakeSealer(
@ -592,6 +615,16 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
return h.initialSealer, nil
}
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTSealer == nil {
return nil, errors.New("CryptoSetup: 0-RTT sealer not available")
}
return h.zeroRTTSealer, nil
}
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
@ -625,6 +658,20 @@ func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
return h.initialOpener, nil
}
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.zeroRTTOpener, nil
}
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()

View file

@ -312,7 +312,7 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(done).Should(BeClosed())
}
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
var cHandshakeComplete bool
cChunkChan, cInitialStream, cHandshakeStream, cOneRTTStream := initStreams()
cErrChan := make(chan error, 1)
@ -329,7 +329,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{},
cRunner,
clientConf,
false,
enable0RTT,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
@ -351,7 +351,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
false,
enable0RTT,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
@ -372,14 +372,14 @@ var _ = Describe("Crypto Setup TLS", func() {
}
It("handshakes", func() {
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf)
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("performs a HelloRetryRequst", func() {
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf)
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
@ -387,7 +387,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("handshakes with client auth", func() {
clientConf.Certificates = []tls.Certificate{generateCert()}
serverConf.ClientAuth = qtls.RequireAnyClientCert
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf)
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
@ -613,7 +613,7 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf)
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
@ -622,7 +622,7 @@ var _ = Describe("Crypto Setup TLS", func() {
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
@ -640,7 +640,7 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf)
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
@ -649,13 +649,52 @@ var _ = Describe("Crypto Setup TLS", func() {
serverConf.SessionTicketsDisabled = true
csc.EXPECT().Get(gomock.Any()).Return(state, true)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
})
It("uses 0-RTT", func() {
csc := NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, true)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, true)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
opener, err := server.Get0RTTOpener()
Expect(err).ToNot(HaveOccurred())
Expect(opener).ToNot(BeNil())
sealer, err := client.Get0RTTSealer()
Expect(err).ToNot(HaveOccurred())
Expect(sealer).ToNot(BeNil())
// use the 0-RTT sealer and opener to encrypt and decrypt a message
plaintext := []byte("Lorem ipsum dolor sit amet")
msg := sealer.Seal(nil, plaintext, 0x1337, []byte("foobar"))
decrypted, err := opener.Open(nil, msg, 0x1337, []byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(decrypted).To(Equal(plaintext))
})
})
})
})

View file

@ -78,9 +78,11 @@ type CryptoSetup interface {
GetInitialOpener() (LongHeaderOpener, error)
GetHandshakeOpener() (LongHeaderOpener, error)
Get0RTTOpener() (LongHeaderOpener, error)
Get1RTTOpener() (ShortHeaderOpener, error)
GetInitialSealer() (LongHeaderSealer, error)
GetHandshakeSealer() (LongHeaderSealer, error)
Get0RTTSealer() (LongHeaderSealer, error)
Get1RTTSealer() (ShortHeaderSealer, error)
}

View file

@ -88,6 +88,36 @@ func (mr *MockCryptoSetupMockRecorder) DropHandshakeKeys() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropHandshakeKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DropHandshakeKeys))
}
// Get0RTTOpener mocks base method
func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get0RTTOpener")
ret0, _ := ret[0].(handshake.LongHeaderOpener)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get0RTTOpener indicates an expected call of Get0RTTOpener
func (mr *MockCryptoSetupMockRecorder) Get0RTTOpener() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTOpener))
}
// Get0RTTSealer mocks base method
func (m *MockCryptoSetup) Get0RTTSealer() (handshake.LongHeaderSealer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get0RTTSealer")
ret0, _ := ret[0].(handshake.LongHeaderSealer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get0RTTSealer indicates an expected call of Get0RTTSealer
func (mr *MockCryptoSetupMockRecorder) Get0RTTSealer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTSealer))
}
// Get1RTTOpener mocks base method
func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) {
m.ctrl.T.Helper()