uquic/internal/handshake/crypto_setup_test.go
2022-10-11 16:38:44 +04:00

864 lines
28 KiB
Go

package handshake
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"time"
mocktls "github.com/lucas-clemente/quic-go/internal/mocks/tls"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
}
type chunk struct {
data []byte
encLevel protocol.EncryptionLevel
}
type stream struct {
encLevel protocol.EncryptionLevel
chunkChan chan<- chunk
}
func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream {
return &stream{
chunkChan: chunkChan,
encLevel: encLevel,
}
}
func (s *stream) Write(b []byte) (int, error) {
data := make([]byte, len(b))
copy(data, b)
select {
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
default:
panic("chunkChan too small")
}
return len(b), nil
}
var _ = Describe("Crypto Setup TLS", func() {
var clientConf, serverConf *tls.Config
// unparam incorrectly complains that the first argument is never used.
//nolint:unparam
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 100)
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
return chunkChan, initialStream, handshakeStream
}
BeforeEach(func() {
serverConf = testdata.GetTLSConfig()
serverConf.NextProtos = []string{"crypto-setup"}
clientConf = &tls.Config{
ServerName: "localhost",
RootCAs: testdata.GetRootCA(),
NextProtos: []string{"crypto-setup"},
}
})
It("returns Handshake() when an error occurs in qtls", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
ErrorMessage: "local error: tls: unexpected message",
})))
close(done)
}()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
handledMessage := make(chan struct{})
go func() {
defer GinkgoRecover()
server.HandleMessage(fakeCH, protocol.EncryptionInitial)
close(handledMessage)
}()
Eventually(handledMessage).Should(BeClosed())
Eventually(done).Should(BeClosed())
})
It("handles qtls errors occurring before during ClientHello generation", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream := initStreams()
tlsConf := testdata.GetTLSConfig()
tlsConf.InsecureSkipVerify = true
tlsConf.NextProtos = []string{""}
cl, _ := NewCryptoSetupClient(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
runner,
tlsConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
cl.RunHandshake()
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: "tls: invalid NextProtos value",
})))
})
It("errors when a message is received at the wrong encryption level", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
close(done)
}()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake",
})))
// make the go routine return
Expect(server.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("returns Handshake() when handling a message fails", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
runner,
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
var err error
Expect(sErrChan).To(Receive(&err))
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
close(done)
}()
fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...)
server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level
Eventually(done).Should(BeClosed())
})
It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.RunHandshake()
close(done)
}()
Expect(server.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
Context("doing the handshake", func() {
generateCert := func() tls.Certificate {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
Expect(err).ToNot(HaveOccurred())
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour), // valid for an hour
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
Expect(err).ToNot(HaveOccurred())
return tls.Certificate{
PrivateKey: priv,
Certificate: [][]byte{certDER},
}
}
newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
rttStats := &utils.RTTStats{}
rttStats.UpdateRTT(rtt, 0, time.Now())
ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
return rttStats
}
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
server CryptoSetup, sChunkChan <-chan chunk,
) {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
for {
select {
case c := <-cChunkChan:
msgType := messageType(c.data[0])
finished := server.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeClientHello {
// If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys.
_, err := server.GetHandshakeOpener()
Expect(finished).To(Equal(err == nil))
} else {
Expect(finished).To(BeFalse())
}
case c := <-sChunkChan:
msgType := messageType(c.data[0])
finished := client.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeServerHello {
Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom)))
} else {
Expect(finished).To(BeFalse())
}
case <-done: // handshake complete
return
}
}
}()
go func() {
defer GinkgoRecover()
defer close(done)
server.RunHandshake()
ticket, err := server.GetSessionTicket()
Expect(err).ToNot(HaveOccurred())
if ticket != nil {
client.HandleMessage(ticket, protocol.Encryption1RTT)
}
}()
client.RunHandshake()
Eventually(done).Should(BeClosed())
}
handshakeWithTLSConf := func(
clientConf, serverConf *tls.Config,
clientRTTStats, serverRTTStats *utils.RTTStats,
clientTransportParameters, serverTransportParameters *wire.TransportParameters,
enable0RTT bool,
) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
var cHandshakeComplete bool
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cErrChan := make(chan error, 1)
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
client, clientHelloWrittenChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
clientTransportParameters,
cRunner,
clientConf,
enable0RTT,
clientRTTStats,
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1)
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
if serverTransportParameters.StatelessResetToken == nil {
var token protocol.StatelessResetToken
serverTransportParameters.StatelessResetToken = &token
}
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
serverTransportParameters,
sRunner,
serverConf,
enable0RTT,
serverRTTStats,
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
handshake(client, cChunkChan, server, sChunkChan)
var cErr, sErr error
select {
case sErr = <-sErrChan:
default:
Expect(sHandshakeComplete).To(BeTrue())
}
select {
case cErr = <-cErrChan:
default:
Expect(cHandshakeComplete).To(BeTrue())
}
return clientHelloWrittenChan, client, cErr, server, sErr
}
It("handshakes", func() {
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("performs a HelloRetryRequst", func() {
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("handshakes with client auth", func() {
clientConf.Certificates = []tls.Certificate{generateCert()}
serverConf.ClientAuth = tls.RequireAnyClientCert
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("signals when it has written the ClientHello", func() {
runner := NewMockHandshakeRunner(mockCtrl)
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, chChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
runner,
&tls.Config{InsecureSkipVerify: true},
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RunHandshake()
close(done)
}()
var ch chunk
Eventually(cChunkChan).Should(Receive(&ch))
Eventually(chChan).Should(Receive(BeNil()))
// make sure the whole ClientHello was written
Expect(len(ch.data)).To(BeNumerically(">=", 4))
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
Expect(len(ch.data) - 4).To(Equal(length))
// make the go routine return
Expect(client.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("receives transport parameters", func() {
var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &wire.TransportParameters{MaxIdleTimeout: 0x42 * time.Second}
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp })
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
cTransportParameters,
cRunner,
clientConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp })
sRunner.EXPECT().OnHandshakeComplete()
sTransportParameters := &wire.TransportParameters{
MaxIdleTimeout: 0x1337 * time.Second,
StatelessResetToken: &token,
}
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
sTransportParameters,
sRunner,
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout))
Expect(sTransportParametersRcvd).ToNot(BeNil())
Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout))
})
Context("with session tickets", func() {
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
cRunner,
clientConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
// inject an invalid session ticket
cRunner.EXPECT().OnError(&qerr.TransportError{
ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage),
ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake",
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.EncryptionHandshake)
})
It("errors when handling the NewSessionTicket fails", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
cRunner,
clientConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
// inject an invalid session ticket
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
})
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
client.HandleMessage(b, protocol.Encryption1RTT)
})
It("uses session resumption", func() {
csc := mocktls.NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
})
It("doesn't use session resumption if the server disabled it", func() {
csc := mocktls.NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
_, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
serverConf.SessionTicketsDisabled = true
csc.EXPECT().Get(gomock.Any()).Return(state, true)
_, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
})
It("uses 0-RTT", func() {
csc := mocktls.NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, serverOrigRTTStats,
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}
serverRTTStats := &utils.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, serverRTTStats,
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
var tp *wire.TransportParameters
Expect(clientHelloWrittenChan).To(Receive(&tp))
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(server.ConnectionState().Used0RTT).To(BeTrue())
Expect(client.ConnectionState().Used0RTT).To(BeTrue())
})
It("rejects 0-RTT, when the transport parameters changed", func() {
csc := mocktls.NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData - 1},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
var tp *wire.TransportParameters
Expect(clientHelloWrittenChan).To(Receive(&tp))
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(server.ConnectionState().Used0RTT).To(BeFalse())
Expect(client.ConnectionState().Used0RTT).To(BeFalse())
})
})
})
})