mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
The RTTStats are used by the logging package. In order to instrument the congestion package, the RTTStats can't be part of that package any more (to avoid an import loop).
836 lines
28 KiB
Go
836 lines
28 KiB
Go
package handshake
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"errors"
|
|
"math/big"
|
|
"time"
|
|
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
"github.com/marten-seemann/qtls"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
type chunk struct {
|
|
data []byte
|
|
encLevel protocol.EncryptionLevel
|
|
}
|
|
|
|
type stream struct {
|
|
encLevel protocol.EncryptionLevel
|
|
chunkChan chan<- chunk
|
|
}
|
|
|
|
func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream {
|
|
return &stream{
|
|
chunkChan: chunkChan,
|
|
encLevel: encLevel,
|
|
}
|
|
}
|
|
|
|
func (s *stream) Write(b []byte) (int, error) {
|
|
data := make([]byte, len(b))
|
|
copy(data, b)
|
|
select {
|
|
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
|
|
default:
|
|
panic("chunkChan too small")
|
|
}
|
|
return len(b), nil
|
|
}
|
|
|
|
var _ = Describe("Crypto Setup TLS", func() {
|
|
var clientConf, serverConf *tls.Config
|
|
|
|
// unparam incorrectly complains that the first argument is never used.
|
|
//nolint:unparam
|
|
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
|
chunkChan := make(chan chunk, 100)
|
|
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
|
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
|
|
return chunkChan, initialStream, handshakeStream
|
|
}
|
|
|
|
BeforeEach(func() {
|
|
serverConf = testdata.GetTLSConfig()
|
|
serverConf.NextProtos = []string{"crypto-setup"}
|
|
clientConf = &tls.Config{
|
|
ServerName: "localhost",
|
|
RootCAs: testdata.GetRootCA(),
|
|
NextProtos: []string{"crypto-setup"},
|
|
}
|
|
})
|
|
|
|
It("creates a qtls.Config", func() {
|
|
tlsConf := &tls.Config{
|
|
ServerName: "quic.clemente.io",
|
|
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return nil, errors.New("GetCertificate")
|
|
},
|
|
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
return nil, errors.New("GetClientCertificate")
|
|
},
|
|
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return &tls.Config{ServerName: ch.ServerName}, nil
|
|
},
|
|
}
|
|
var token protocol.StatelessResetToken
|
|
server := NewCryptoSetupServer(
|
|
&bytes.Buffer{},
|
|
&bytes.Buffer{},
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{StatelessResetToken: &token},
|
|
NewMockHandshakeRunner(mockCtrl),
|
|
tlsConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
qtlsConf := server.(*cryptoSetup).tlsConf
|
|
Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName))
|
|
_, getCertificateErr := qtlsConf.GetCertificate(nil)
|
|
Expect(getCertificateErr).To(MatchError("GetCertificate"))
|
|
_, getClientCertificateErr := qtlsConf.GetClientCertificate(nil)
|
|
Expect(getClientCertificateErr).To(MatchError("GetClientCertificate"))
|
|
cconf, err := qtlsConf.GetConfigForClient(&qtls.ClientHelloInfo{ServerName: "foo.bar"})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(cconf.ServerName).To(Equal("foo.bar"))
|
|
Expect(cconf.AlternativeRecordLayer).ToNot(BeNil())
|
|
Expect(cconf.GetExtensions).ToNot(BeNil())
|
|
Expect(cconf.ReceivedExtensions).ToNot(BeNil())
|
|
})
|
|
|
|
It("returns Handshake() when an error occurs in qtls", func() {
|
|
sErrChan := make(chan error, 1)
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
var token protocol.StatelessResetToken
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{StatelessResetToken: &token},
|
|
runner,
|
|
testdata.GetTLSConfig(),
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
Expect(sErrChan).To(Receive(MatchError("CRYPTO_ERROR: local error: tls: unexpected message")))
|
|
close(done)
|
|
}()
|
|
|
|
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
|
handledMessage := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.HandleMessage(fakeCH, protocol.EncryptionInitial)
|
|
close(handledMessage)
|
|
}()
|
|
Eventually(handledMessage).Should(BeClosed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("errors when a message is received at the wrong encryption level", func() {
|
|
sErrChan := make(chan error, 1)
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
|
var token protocol.StatelessResetToken
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{StatelessResetToken: &token},
|
|
runner,
|
|
testdata.GetTLSConfig(),
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
close(done)
|
|
}()
|
|
|
|
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
|
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
|
|
var err error
|
|
Expect(sErrChan).To(Receive(&err))
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
|
|
Expect(err.Error()).To(ContainSubstring("expected handshake message ClientHello to have encryption level Initial, has Handshake"))
|
|
|
|
// make the go routine return
|
|
Expect(server.Close()).To(Succeed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("returns Handshake() when handling a message fails", func() {
|
|
sErrChan := make(chan error, 1)
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
|
var token protocol.StatelessResetToken
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{StatelessResetToken: &token},
|
|
runner,
|
|
serverConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
var err error
|
|
Expect(sErrChan).To(Receive(&err))
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
|
|
close(done)
|
|
}()
|
|
|
|
fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...)
|
|
server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("returns Handshake() when it is closed", func() {
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
var token protocol.StatelessResetToken
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{StatelessResetToken: &token},
|
|
NewMockHandshakeRunner(mockCtrl),
|
|
serverConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
close(done)
|
|
}()
|
|
Expect(server.Close()).To(Succeed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
Context("doing the handshake", func() {
|
|
generateCert := func() tls.Certificate {
|
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{},
|
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
|
BasicConstraintsValid: true,
|
|
}
|
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
return tls.Certificate{
|
|
PrivateKey: priv,
|
|
Certificate: [][]byte{certDER},
|
|
}
|
|
}
|
|
|
|
newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
|
|
rttStats := &utils.RTTStats{}
|
|
rttStats.UpdateRTT(rtt, 0, time.Now())
|
|
ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
|
|
return rttStats
|
|
}
|
|
|
|
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
|
|
server CryptoSetup, sChunkChan <-chan chunk) {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
for {
|
|
select {
|
|
case c := <-cChunkChan:
|
|
server.HandleMessage(c.data, c.encLevel)
|
|
case c := <-sChunkChan:
|
|
client.HandleMessage(c.data, c.encLevel)
|
|
case <-done: // handshake complete
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
defer close(done)
|
|
server.RunHandshake()
|
|
ticket, err := server.GetSessionTicket()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
if ticket != nil {
|
|
client.HandleMessage(ticket, protocol.Encryption1RTT)
|
|
}
|
|
}()
|
|
|
|
client.RunHandshake()
|
|
Eventually(done).Should(BeClosed())
|
|
}
|
|
|
|
handshakeWithTLSConf := func(
|
|
clientConf, serverConf *tls.Config,
|
|
clientRTTStats, serverRTTStats *utils.RTTStats,
|
|
clientTransportParameters, serverTransportParameters *wire.TransportParameters,
|
|
enable0RTT bool,
|
|
) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
|
|
var cHandshakeComplete bool
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cErrChan := make(chan error, 1)
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
|
|
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
|
|
cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
|
|
client, clientHelloWrittenChan := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
clientTransportParameters,
|
|
cRunner,
|
|
clientConf,
|
|
enable0RTT,
|
|
clientRTTStats,
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
var sHandshakeComplete bool
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
sErrChan := make(chan error, 1)
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
|
|
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
|
|
if serverTransportParameters.StatelessResetToken == nil {
|
|
var token protocol.StatelessResetToken
|
|
serverTransportParameters.StatelessResetToken = &token
|
|
}
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
serverTransportParameters,
|
|
sRunner,
|
|
serverConf,
|
|
enable0RTT,
|
|
serverRTTStats,
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
var cErr, sErr error
|
|
select {
|
|
case sErr = <-sErrChan:
|
|
default:
|
|
Expect(sHandshakeComplete).To(BeTrue())
|
|
}
|
|
select {
|
|
case cErr = <-cErrChan:
|
|
default:
|
|
Expect(cHandshakeComplete).To(BeTrue())
|
|
}
|
|
return clientHelloWrittenChan, client, cErr, server, sErr
|
|
}
|
|
|
|
It("handshakes", func() {
|
|
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{}, &wire.TransportParameters{},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("performs a HelloRetryRequst", func() {
|
|
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
|
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{}, &wire.TransportParameters{},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("handshakes with client auth", func() {
|
|
clientConf.Certificates = []tls.Certificate{generateCert()}
|
|
serverConf.ClientAuth = qtls.RequireAnyClientCert
|
|
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{}, &wire.TransportParameters{},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("signals when it has written the ClientHello", func() {
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
client, chChan := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{},
|
|
runner,
|
|
&tls.Config{InsecureSkipVerify: true},
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
client.RunHandshake()
|
|
close(done)
|
|
}()
|
|
var ch chunk
|
|
Eventually(cChunkChan).Should(Receive(&ch))
|
|
Eventually(chChan).Should(Receive(BeNil()))
|
|
// make sure the whole ClientHello was written
|
|
Expect(len(ch.data)).To(BeNumerically(">=", 4))
|
|
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
|
|
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
|
|
Expect(len(ch.data) - 4).To(Equal(length))
|
|
|
|
// make the go routine return
|
|
Expect(client.Close()).To(Succeed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("receives transport parameters", func() {
|
|
var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cTransportParameters := &wire.TransportParameters{MaxIdleTimeout: 0x42 * time.Second}
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp })
|
|
cRunner.EXPECT().OnHandshakeComplete()
|
|
client, _ := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
cTransportParameters,
|
|
cRunner,
|
|
clientConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
var token protocol.StatelessResetToken
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp })
|
|
sRunner.EXPECT().OnHandshakeComplete()
|
|
sTransportParameters := &wire.TransportParameters{
|
|
MaxIdleTimeout: 0x1337 * time.Second,
|
|
StatelessResetToken: &token,
|
|
}
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
sTransportParameters,
|
|
sRunner,
|
|
serverConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
close(done)
|
|
}()
|
|
Eventually(done).Should(BeClosed())
|
|
Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout))
|
|
Expect(sTransportParametersRcvd).ToNot(BeNil())
|
|
Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout))
|
|
})
|
|
|
|
Context("with session tickets", func() {
|
|
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
cRunner.EXPECT().OnHandshakeComplete()
|
|
client, _ := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{},
|
|
cRunner,
|
|
clientConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
sRunner.EXPECT().OnHandshakeComplete()
|
|
var token protocol.StatelessResetToken
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{StatelessResetToken: &token},
|
|
sRunner,
|
|
serverConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
close(done)
|
|
}()
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
// inject an invalid session ticket
|
|
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
|
|
Expect(qerr.Error()).To(ContainSubstring("expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake"))
|
|
})
|
|
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
|
client.HandleMessage(b, protocol.EncryptionHandshake)
|
|
})
|
|
|
|
It("errors when handling the NewSessionTicket fails", func() {
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
cRunner.EXPECT().OnHandshakeComplete()
|
|
client, _ := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{},
|
|
cRunner,
|
|
clientConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
sRunner.EXPECT().OnHandshakeComplete()
|
|
var token protocol.StatelessResetToken
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
nil,
|
|
&wire.TransportParameters{StatelessResetToken: &token},
|
|
sRunner,
|
|
serverConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
close(done)
|
|
}()
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
// inject an invalid session ticket
|
|
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
})
|
|
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
|
client.HandleMessage(b, protocol.Encryption1RTT)
|
|
})
|
|
|
|
It("uses session resumption", func() {
|
|
csc := NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
|
|
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
|
|
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientOrigRTTStats, &utils.RTTStats{},
|
|
&wire.TransportParameters{}, &wire.TransportParameters{},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
|
|
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
clientRTTStats := &utils.RTTStats{}
|
|
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientRTTStats, &utils.RTTStats{},
|
|
&wire.TransportParameters{}, &wire.TransportParameters{},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
|
|
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
|
|
})
|
|
|
|
It("doesn't use session resumption if the server disabled it", func() {
|
|
csc := NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
_, client, clientErr, server, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{}, &wire.TransportParameters{},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
|
|
serverConf.SessionTicketsDisabled = true
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
_, client, clientErr, server, serverErr = handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{}, &wire.TransportParameters{},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
})
|
|
|
|
It("uses 0-RTT", func() {
|
|
csc := NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
|
|
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
|
|
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
|
|
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
|
|
const initialMaxData protocol.ByteCount = 1337
|
|
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientOrigRTTStats, serverOrigRTTStats,
|
|
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
|
|
true,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
|
|
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
csc.EXPECT().Put(gomock.Any(), nil)
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
|
|
clientRTTStats := &utils.RTTStats{}
|
|
serverRTTStats := &utils.RTTStats{}
|
|
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientRTTStats, serverRTTStats,
|
|
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
|
|
true,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
|
|
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
|
|
|
|
var tp *wire.TransportParameters
|
|
Expect(clientHelloWrittenChan).To(Receive(&tp))
|
|
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
|
|
|
|
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(server.ConnectionState().Used0RTT).To(BeTrue())
|
|
Expect(client.ConnectionState().Used0RTT).To(BeTrue())
|
|
})
|
|
|
|
It("rejects 0-RTT, whent the transport parameters changed", func() {
|
|
csc := NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
|
|
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
|
|
const initialMaxData protocol.ByteCount = 1337
|
|
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientOrigRTTStats, &utils.RTTStats{},
|
|
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
|
|
true,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
|
|
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
csc.EXPECT().Put(gomock.Any(), nil)
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
|
|
clientRTTStats := &utils.RTTStats{}
|
|
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientRTTStats, &utils.RTTStats{},
|
|
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData + 1},
|
|
true,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
|
|
|
|
var tp *wire.TransportParameters
|
|
Expect(clientHelloWrittenChan).To(Receive(&tp))
|
|
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
|
|
|
|
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(server.ConnectionState().Used0RTT).To(BeFalse())
|
|
Expect(client.ConnectionState().Used0RTT).To(BeFalse())
|
|
})
|
|
})
|
|
})
|
|
})
|