mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
Previously, RunHandshake() would send the session tickets. Now, the session initiates the sending of the session ticket. This simplifies the setup a bit, and it will make it possible to include the RTT estimate in the session ticket without accessing the RTTStats concurrently.
714 lines
22 KiB
Go
714 lines
22 KiB
Go
package handshake
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"errors"
|
|
"math/big"
|
|
"time"
|
|
|
|
gomock "github.com/golang/mock/gomock"
|
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
"github.com/marten-seemann/qtls"
|
|
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
type chunk struct {
|
|
data []byte
|
|
encLevel protocol.EncryptionLevel
|
|
}
|
|
|
|
type stream struct {
|
|
encLevel protocol.EncryptionLevel
|
|
chunkChan chan<- chunk
|
|
}
|
|
|
|
func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream {
|
|
return &stream{
|
|
chunkChan: chunkChan,
|
|
encLevel: encLevel,
|
|
}
|
|
}
|
|
|
|
func (s *stream) Write(b []byte) (int, error) {
|
|
data := make([]byte, len(b))
|
|
copy(data, b)
|
|
select {
|
|
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
|
|
default:
|
|
panic("chunkChan too small")
|
|
}
|
|
return len(b), nil
|
|
}
|
|
|
|
var _ = Describe("Crypto Setup TLS", func() {
|
|
var clientConf, serverConf *tls.Config
|
|
|
|
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
|
chunkChan := make(chan chunk, 100)
|
|
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
|
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
|
|
return chunkChan, initialStream, handshakeStream
|
|
}
|
|
|
|
BeforeEach(func() {
|
|
serverConf = testdata.GetTLSConfig()
|
|
serverConf.NextProtos = []string{"crypto-setup"}
|
|
clientConf = &tls.Config{
|
|
ServerName: "localhost",
|
|
RootCAs: testdata.GetRootCA(),
|
|
NextProtos: []string{"crypto-setup"},
|
|
}
|
|
})
|
|
|
|
It("creates a qtls.Config", func() {
|
|
tlsConf := &tls.Config{
|
|
ServerName: "quic.clemente.io",
|
|
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return nil, errors.New("GetCertificate")
|
|
},
|
|
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
return nil, errors.New("GetClientCertificate")
|
|
},
|
|
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return &tls.Config{ServerName: ch.ServerName}, nil
|
|
},
|
|
}
|
|
server := NewCryptoSetupServer(
|
|
&bytes.Buffer{},
|
|
&bytes.Buffer{},
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
NewMockHandshakeRunner(mockCtrl),
|
|
tlsConf,
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
qtlsConf := server.(*cryptoSetup).tlsConf
|
|
Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName))
|
|
_, getCertificateErr := qtlsConf.GetCertificate(nil)
|
|
Expect(getCertificateErr).To(MatchError("GetCertificate"))
|
|
_, getClientCertificateErr := qtlsConf.GetClientCertificate(nil)
|
|
Expect(getClientCertificateErr).To(MatchError("GetClientCertificate"))
|
|
cconf, err := qtlsConf.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "foo.bar"})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(cconf.ServerName).To(Equal("foo.bar"))
|
|
Expect(cconf.AlternativeRecordLayer).ToNot(BeNil())
|
|
Expect(cconf.GetExtensions).ToNot(BeNil())
|
|
Expect(cconf.ReceivedExtensions).ToNot(BeNil())
|
|
})
|
|
|
|
It("returns Handshake() when an error occurs in qtls", func() {
|
|
sErrChan := make(chan error, 1)
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
runner,
|
|
testdata.GetTLSConfig(),
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
Expect(sErrChan).To(Receive(MatchError("CRYPTO_ERROR: local error: tls: unexpected message")))
|
|
close(done)
|
|
}()
|
|
|
|
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
|
handledMessage := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.HandleMessage(fakeCH, protocol.EncryptionInitial)
|
|
close(handledMessage)
|
|
}()
|
|
Eventually(handledMessage).Should(BeClosed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("errors when a message is received at the wrong encryption level", func() {
|
|
sErrChan := make(chan error, 1)
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
runner,
|
|
testdata.GetTLSConfig(),
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
close(done)
|
|
}()
|
|
|
|
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
|
|
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
|
|
var err error
|
|
Expect(sErrChan).To(Receive(&err))
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
|
|
Expect(err.Error()).To(ContainSubstring("expected handshake message ClientHello to have encryption level Initial, has Handshake"))
|
|
|
|
// make the go routine return
|
|
Expect(server.Close()).To(Succeed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("returns Handshake() when handling a message fails", func() {
|
|
sErrChan := make(chan error, 1)
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
runner,
|
|
serverConf,
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
var err error
|
|
Expect(sErrChan).To(Receive(&err))
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
|
|
close(done)
|
|
}()
|
|
|
|
fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...)
|
|
server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("returns Handshake() when it is closed", func() {
|
|
_, sInitialStream, sHandshakeStream := initStreams()
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
NewMockHandshakeRunner(mockCtrl),
|
|
serverConf,
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
close(done)
|
|
}()
|
|
Expect(server.Close()).To(Succeed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
Context("doing the handshake", func() {
|
|
var testDone chan struct{}
|
|
|
|
generateCert := func() tls.Certificate {
|
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{},
|
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
|
BasicConstraintsValid: true,
|
|
}
|
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
return tls.Certificate{
|
|
PrivateKey: priv,
|
|
Certificate: [][]byte{certDER},
|
|
}
|
|
}
|
|
|
|
BeforeEach(func() {
|
|
testDone = make(chan struct{})
|
|
})
|
|
|
|
AfterEach(func() {
|
|
close(testDone)
|
|
})
|
|
|
|
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
|
|
server CryptoSetup, sChunkChan <-chan chunk) {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
for {
|
|
select {
|
|
case c := <-cChunkChan:
|
|
server.HandleMessage(c.data, c.encLevel)
|
|
case c := <-sChunkChan:
|
|
client.HandleMessage(c.data, c.encLevel)
|
|
case <-testDone: // handshake complete
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
server.RunHandshake()
|
|
ticket, err := server.GetSessionTicket()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
if ticket != nil {
|
|
client.HandleMessage(ticket, protocol.Encryption1RTT)
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
client.RunHandshake()
|
|
Eventually(done).Should(BeClosed())
|
|
}
|
|
|
|
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
|
|
var cHandshakeComplete bool
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cErrChan := make(chan error, 1)
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
|
|
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
|
|
client, _ := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
cRunner,
|
|
clientConf,
|
|
enable0RTT,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
var sHandshakeComplete bool
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
sErrChan := make(chan error, 1)
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
|
|
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
|
|
var token [16]byte
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{StatelessResetToken: &token},
|
|
sRunner,
|
|
serverConf,
|
|
enable0RTT,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
var cErr, sErr error
|
|
select {
|
|
case sErr = <-sErrChan:
|
|
default:
|
|
Expect(sHandshakeComplete).To(BeTrue())
|
|
}
|
|
select {
|
|
case cErr = <-cErrChan:
|
|
default:
|
|
Expect(cHandshakeComplete).To(BeTrue())
|
|
}
|
|
return client, cErr, server, sErr
|
|
}
|
|
|
|
It("handshakes", func() {
|
|
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("performs a HelloRetryRequst", func() {
|
|
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
|
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("handshakes with client auth", func() {
|
|
clientConf.Certificates = []tls.Certificate{generateCert()}
|
|
serverConf.ClientAuth = qtls.RequireAnyClientCert
|
|
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("signals when it has written the ClientHello", func() {
|
|
runner := NewMockHandshakeRunner(mockCtrl)
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
client, chChan := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
runner,
|
|
&tls.Config{InsecureSkipVerify: true},
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
client.RunHandshake()
|
|
close(done)
|
|
}()
|
|
var ch chunk
|
|
Eventually(cChunkChan).Should(Receive(&ch))
|
|
Eventually(chChan).Should(Receive(BeNil()))
|
|
// make sure the whole ClientHello was written
|
|
Expect(len(ch.data)).To(BeNumerically(">=", 4))
|
|
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
|
|
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
|
|
Expect(len(ch.data) - 4).To(Equal(length))
|
|
|
|
// make the go routine return
|
|
Expect(client.Close()).To(Succeed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("receives transport parameters", func() {
|
|
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second}
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp })
|
|
cRunner.EXPECT().OnHandshakeComplete()
|
|
client, _ := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
cTransportParameters,
|
|
cRunner,
|
|
clientConf,
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
var token [16]byte
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp })
|
|
sRunner.EXPECT().OnHandshakeComplete()
|
|
sTransportParameters := &TransportParameters{
|
|
MaxIdleTimeout: 0x1337 * time.Second,
|
|
StatelessResetToken: &token,
|
|
}
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
sTransportParameters,
|
|
sRunner,
|
|
serverConf,
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
close(done)
|
|
}()
|
|
Eventually(done).Should(BeClosed())
|
|
Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout))
|
|
Expect(sTransportParametersRcvd).ToNot(BeNil())
|
|
Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout))
|
|
})
|
|
|
|
Context("with session tickets", func() {
|
|
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
cRunner.EXPECT().OnHandshakeComplete()
|
|
client, _ := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
cRunner,
|
|
clientConf,
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
sRunner.EXPECT().OnHandshakeComplete()
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
sRunner,
|
|
serverConf,
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
close(done)
|
|
}()
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
// inject an invalid session ticket
|
|
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage)))
|
|
Expect(qerr.Error()).To(ContainSubstring("expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake"))
|
|
})
|
|
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
|
client.HandleMessage(b, protocol.EncryptionHandshake)
|
|
})
|
|
|
|
It("errors when handling the NewSessionTicket fails", func() {
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
cRunner.EXPECT().OnHandshakeComplete()
|
|
client, _ := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
cRunner,
|
|
clientConf,
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
sRunner.EXPECT().OnHandshakeComplete()
|
|
server := NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
sRunner,
|
|
serverConf,
|
|
false,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
close(done)
|
|
}()
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
// inject an invalid session ticket
|
|
cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) {
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
|
qerr := err.(*qerr.QuicError)
|
|
Expect(qerr.IsCryptoError()).To(BeTrue())
|
|
})
|
|
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
|
client.HandleMessage(b, protocol.Encryption1RTT)
|
|
})
|
|
|
|
It("uses session resumption", func() {
|
|
csc := NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
|
})
|
|
|
|
It("doesn't use session resumption if the server disabled it", func() {
|
|
csc := NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
|
|
serverConf.SessionTicketsDisabled = true
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
})
|
|
|
|
It("uses 0-RTT", func() {
|
|
csc := NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, true)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
csc.EXPECT().Put(gomock.Any(), nil)
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
|
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
cRunner.EXPECT().OnHandshakeComplete()
|
|
client, clientHelloChan := NewCryptoSetupClient(
|
|
cInitialStream,
|
|
cHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
cRunner,
|
|
clientConf,
|
|
true,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
)
|
|
|
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
|
sRunner.EXPECT().OnHandshakeComplete()
|
|
server = NewCryptoSetupServer(
|
|
sInitialStream,
|
|
sHandshakeStream,
|
|
protocol.ConnectionID{},
|
|
nil,
|
|
&TransportParameters{},
|
|
sRunner,
|
|
serverConf,
|
|
true,
|
|
&congestion.RTTStats{},
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
handshake(client, cChunkChan, server, sChunkChan)
|
|
close(done)
|
|
}()
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
Expect(clientHelloChan).To(Receive(Not(BeNil())))
|
|
|
|
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
|
})
|
|
})
|
|
})
|
|
})
|