use the new qtls interface for (re)storing app data with a session state

Application data is now retrieved and restored via two callbacks on the
qtls.Config. This allows us the get rid of the rather complex wrapping
of the qtls.ClientSessionCache. Furthermore, it makes sure that we only
restore the application data when qtls decides to actually use the
ticket.
This commit is contained in:
Marten Seemann 2020-06-26 22:50:21 +07:00
parent f926945ae5
commit 07d4fd0991
13 changed files with 226 additions and 380 deletions

View file

@ -287,6 +287,13 @@ var _ = Describe("Crypto Setup TLS", func() {
}
}
newRTTStatsWithRTT := func(rtt time.Duration) *congestion.RTTStats {
rttStats := &congestion.RTTStats{}
rttStats.UpdateRTT(rtt, 0, time.Now())
ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
return rttStats
}
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
server CryptoSetup, sChunkChan <-chan chunk) {
done := make(chan struct{})
@ -319,7 +326,12 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(done).Should(BeClosed())
}
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
handshakeWithTLSConf := func(
clientConf, serverConf *tls.Config,
clientRTTStats, serverRTTStats *congestion.RTTStats,
clientTransportParameters, serverTransportParameters *wire.TransportParameters,
enable0RTT bool,
) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
var cHandshakeComplete bool
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cErrChan := make(chan error, 1)
@ -327,17 +339,18 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
client, _ := NewCryptoSetupClient(
cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
client, clientHelloWrittenChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
clientTransportParameters,
cRunner,
clientConf,
enable0RTT,
&congestion.RTTStats{},
clientRTTStats,
nil,
utils.DefaultLogger.WithPrefix("client"),
)
@ -349,18 +362,21 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
var token [16]byte
if serverTransportParameters.StatelessResetToken == nil {
var token [16]byte
serverTransportParameters.StatelessResetToken = &token
}
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
serverTransportParameters,
sRunner,
serverConf,
enable0RTT,
&congestion.RTTStats{},
serverRTTStats,
nil,
utils.DefaultLogger.WithPrefix("server"),
)
@ -377,18 +393,28 @@ var _ = Describe("Crypto Setup TLS", func() {
default:
Expect(cHandshakeComplete).To(BeTrue())
}
return client, cErr, server, sErr
return clientHelloWrittenChan, client, cErr, server, sErr
}
It("handshakes", func() {
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("performs a HelloRetryRequst", func() {
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
@ -396,7 +422,12 @@ var _ = Describe("Crypto Setup TLS", func() {
It("handshakes with client auth", func() {
clientConf.Certificates = []tls.Certificate{generateCert()}
serverConf.ClientAuth = qtls.RequireAnyClientCert
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
@ -626,21 +657,37 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false)
clientRTTStats := &congestion.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
})
It("doesn't use session resumption if the server disabled it", func() {
@ -653,7 +700,12 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
_, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
@ -662,7 +714,12 @@ var _ = Describe("Crypto Setup TLS", func() {
serverConf.SessionTicketsDisabled = true
csc.EXPECT().Get(gomock.Any()).Return(state, true)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false)
_, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
@ -680,68 +737,100 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, true)
const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, serverOrigRTTStats,
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, clientHelloChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
cRunner,
clientConf,
clientRTTStats := &congestion.RTTStats{}
serverRTTStats := &congestion.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, serverRTTStats,
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
&congestion.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
var token [16]byte
server = NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
true,
&congestion.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(clientHelloChan).To(Receive(Not(BeNil())))
var tp *wire.TransportParameters
Expect(clientHelloWrittenChan).To(Receive(&tp))
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(server.ConnectionState().Used0RTT).To(BeTrue())
Expect(client.ConnectionState().Used0RTT).To(BeTrue())
})
It("rejects 0-RTT, whent the transport parameters changed", func() {
csc := NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &congestion.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData + 1},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
var tp *wire.TransportParameters
Expect(clientHelloWrittenChan).To(Receive(&tp))
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(server.ConnectionState().Used0RTT).To(BeFalse())
Expect(client.ConnectionState().Used0RTT).To(BeFalse())
})
})
})