add unit tests for session resumption

This commit is contained in:
Marten Seemann 2019-08-03 16:32:01 +07:00
parent 511acf1371
commit 382094ba68

View file

@ -55,11 +55,12 @@ func (s *stream) Write(b []byte) (int, error) {
var _ = Describe("Crypto Setup TLS", func() {
var clientConf, serverConf *tls.Config
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */, *stream /* 1-RTT */) {
chunkChan := make(chan chunk, 100)
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
return chunkChan, initialStream, handshakeStream
oneRTTStream := newStream(chunkChan, protocol.Encryption1RTT)
return chunkChan, initialStream, handshakeStream, oneRTTStream
}
BeforeEach(func() {
@ -116,7 +117,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream := initStreams()
_, sInitialStream, sHandshakeStream, _ := initStreams()
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
@ -152,7 +153,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("errors when a message is received at the wrong encryption level", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
_, sInitialStream, sHandshakeStream, _ := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server, err := NewCryptoSetupServer(
@ -192,7 +193,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when handling a message fails", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
_, sInitialStream, sHandshakeStream, _ := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server, err := NewCryptoSetupServer(
@ -228,7 +229,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream := initStreams()
_, sInitialStream, sHandshakeStream, _ := initStreams()
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
@ -254,6 +255,8 @@ var _ = Describe("Crypto Setup TLS", func() {
})
Context("doing the handshake", func() {
var testDone chan struct{}
generateCert := func() tls.Certificate {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
Expect(err).ToNot(HaveOccurred())
@ -273,6 +276,14 @@ var _ = Describe("Crypto Setup TLS", func() {
}
}
BeforeEach(func() {
testDone = make(chan struct{})
})
AfterEach(func() {
close(testDone)
})
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
server CryptoSetup, sChunkChan <-chan chunk) {
done := make(chan struct{})
@ -284,7 +295,7 @@ var _ = Describe("Crypto Setup TLS", func() {
server.HandleMessage(c.data, c.encLevel)
case c := <-sChunkChan:
client.HandleMessage(c.data, c.encLevel)
case <-done: // handshake complete
case <-testDone: // handshake complete
return
}
}
@ -300,9 +311,9 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(done).Should(BeClosed())
}
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
var cHandshakeComplete bool
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cChunkChan, cInitialStream, cHandshakeStream, cOneRTTStream := initStreams()
cErrChan := make(chan error, 1)
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
@ -311,7 +322,7 @@ var _ = Describe("Crypto Setup TLS", func() {
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
cOneRTTStream,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -323,7 +334,7 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(err).ToNot(HaveOccurred())
var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams()
sErrChan := make(chan error, 1)
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
@ -333,7 +344,7 @@ var _ = Describe("Crypto Setup TLS", func() {
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
sOneRTTStream,
protocol.ConnectionID{},
nil,
&TransportParameters{StatelessResetToken: &token},
@ -356,18 +367,18 @@ var _ = Describe("Crypto Setup TLS", func() {
default:
Expect(cHandshakeComplete).To(BeTrue())
}
return cErr, sErr
return client, cErr, server, sErr
}
It("handshakes", func() {
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("performs a HelloRetryRequst", func() {
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
@ -375,14 +386,14 @@ var _ = Describe("Crypto Setup TLS", func() {
It("handshakes with client auth", func() {
clientConf.Certificates = []tls.Certificate{generateCert()}
serverConf.ClientAuth = qtls.RequireAnyClientCert
clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf)
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("signals when it has written the ClientHello", func() {
runner := NewMockHandshakeRunner(mockCtrl)
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
client, chChan, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
@ -419,7 +430,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("receives transport parameters", func() {
var cTransportParametersRcvd, sTransportParametersRcvd []byte
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b })
@ -438,7 +449,7 @@ var _ = Describe("Crypto Setup TLS", func() {
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
var token [16]byte
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b })
@ -478,116 +489,172 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(srvTP.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout))
})
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
cRunner,
clientConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
Context("with session tickets", func() {
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
cRunner,
clientConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
sRunner,
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
sRunner,
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
// inject an invalid session ticket
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
qerr := err.(*qerr.QuicError)
Expect(qerr.IsCryptoError()).To(BeTrue())
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
Expect(qerr.Error()).To(ContainSubstring("expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake"))
// inject an invalid session ticket
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
qerr := err.(*qerr.QuicError)
Expect(qerr.IsCryptoError()).To(BeTrue())
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
Expect(qerr.Error()).To(ContainSubstring("expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake"))
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.EncryptionHandshake)
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.EncryptionHandshake)
})
It("errors when handling the NewSessionTicket fails", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
cRunner,
clientConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
It("errors when handling the NewSessionTicket fails", func() {
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
cRunner,
clientConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
sRunner,
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
sRunner,
serverConf,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
// inject an invalid session ticket
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
qerr := err.(*qerr.QuicError)
Expect(qerr.IsCryptoError()).To(BeTrue())
// inject an invalid session ticket
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
qerr := err.(*qerr.QuicError)
Expect(qerr.IsCryptoError()).To(BeTrue())
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.Encryption1RTT)
})
It("uses session resumption", func() {
csc := NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
})
It("doesn't use session resumption if the server disabled it", func() {
csc := NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
serverConf.SessionTicketsDisabled = true
csc.EXPECT().Get(gomock.Any()).Return(state, true)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.Encryption1RTT)
})
})
})