mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-01 19:27:35 +03:00
628 lines
22 KiB
Go
628 lines
22 KiB
Go
package handshake
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"math/big"
|
|
"net"
|
|
"time"
|
|
|
|
tls "github.com/refraction-networking/utls"
|
|
|
|
mocktls "github.com/refraction-networking/uquic/internal/mocks/tls"
|
|
"github.com/refraction-networking/uquic/internal/protocol"
|
|
"github.com/refraction-networking/uquic/internal/qerr"
|
|
"github.com/refraction-networking/uquic/internal/testdata"
|
|
"github.com/refraction-networking/uquic/internal/utils"
|
|
"github.com/refraction-networking/uquic/internal/wire"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
const (
|
|
typeClientHello = 1
|
|
typeNewSessionTicket = 4
|
|
)
|
|
|
|
var _ = Describe("Crypto Setup TLS", func() {
|
|
generateCert := func() tls.Certificate {
|
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{},
|
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
|
BasicConstraintsValid: true,
|
|
}
|
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
return tls.Certificate{
|
|
PrivateKey: priv,
|
|
Certificate: [][]byte{certDER},
|
|
}
|
|
}
|
|
|
|
var clientConf, serverConf *tls.Config
|
|
|
|
BeforeEach(func() {
|
|
serverConf = testdata.GetTLSConfig()
|
|
serverConf.NextProtos = []string{"crypto-setup"}
|
|
clientConf = &tls.Config{
|
|
ServerName: "localhost",
|
|
RootCAs: testdata.GetRootCA(),
|
|
NextProtos: []string{"crypto-setup"},
|
|
}
|
|
})
|
|
|
|
It("handles qtls errors occurring before during ClientHello generation", func() {
|
|
tlsConf := testdata.GetTLSConfig()
|
|
tlsConf.InsecureSkipVerify = true
|
|
tlsConf.NextProtos = []string{""}
|
|
cl := NewCryptoSetupClient(
|
|
protocol.ConnectionID{},
|
|
&wire.TransportParameters{},
|
|
tlsConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
protocol.Version1,
|
|
)
|
|
|
|
Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
|
|
ErrorCode: qerr.InternalError,
|
|
ErrorMessage: "tls: invalid NextProtos value",
|
|
}))
|
|
})
|
|
|
|
It("errors when a message is received at the wrong encryption level", func() {
|
|
var token protocol.StatelessResetToken
|
|
server := NewCryptoSetupServer(
|
|
protocol.ConnectionID{},
|
|
&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
|
|
&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
|
|
&wire.TransportParameters{StatelessResetToken: &token},
|
|
testdata.GetTLSConfig(),
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
protocol.Version1,
|
|
)
|
|
|
|
Expect(server.StartHandshake()).To(Succeed())
|
|
|
|
fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
|
|
// wrong encryption level
|
|
err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
|
|
})
|
|
|
|
Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
|
|
var (
|
|
local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
|
|
remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
|
)
|
|
|
|
It("wraps GetCertificate", func() {
|
|
var localAddr, remoteAddr net.Addr
|
|
tlsConf := &tls.Config{
|
|
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
localAddr = info.Conn.LocalAddr()
|
|
remoteAddr = info.Conn.RemoteAddr()
|
|
cert := generateCert()
|
|
return &cert, nil
|
|
},
|
|
}
|
|
addConnToClientHelloInfo(tlsConf, local, remote)
|
|
_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(localAddr).To(Equal(local))
|
|
Expect(remoteAddr).To(Equal(remote))
|
|
})
|
|
|
|
It("wraps GetConfigForClient", func() {
|
|
var localAddr, remoteAddr net.Addr
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
localAddr = info.Conn.LocalAddr()
|
|
remoteAddr = info.Conn.RemoteAddr()
|
|
return &tls.Config{}, nil
|
|
},
|
|
}
|
|
addConnToClientHelloInfo(tlsConf, local, remote)
|
|
_, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(localAddr).To(Equal(local))
|
|
Expect(remoteAddr).To(Equal(remote))
|
|
})
|
|
|
|
It("wraps GetConfigForClient, recursively", func() {
|
|
var localAddr, remoteAddr net.Addr
|
|
tlsConf := &tls.Config{}
|
|
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
conf := tlsConf.Clone()
|
|
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
localAddr = info.Conn.LocalAddr()
|
|
remoteAddr = info.Conn.RemoteAddr()
|
|
cert := generateCert()
|
|
return &cert, nil
|
|
}
|
|
return conf, nil
|
|
}
|
|
addConnToClientHelloInfo(tlsConf, local, remote)
|
|
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(localAddr).To(Equal(local))
|
|
Expect(remoteAddr).To(Equal(remote))
|
|
})
|
|
})
|
|
|
|
Context("doing the handshake", func() {
|
|
newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
|
|
rttStats := &utils.RTTStats{}
|
|
rttStats.UpdateRTT(rtt, 0, time.Now())
|
|
ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
|
|
return rttStats
|
|
}
|
|
|
|
// The clientEvents and serverEvents contain all events that were not processed by the function,
|
|
// i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete.
|
|
handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) {
|
|
Expect(client.StartHandshake()).To(Succeed())
|
|
Expect(server.StartHandshake()).To(Succeed())
|
|
|
|
var clientHandshakeComplete, serverHandshakeComplete bool
|
|
|
|
for {
|
|
clientLoop:
|
|
for {
|
|
ev := client.NextEvent()
|
|
//nolint:exhaustive // only need to process a few events
|
|
switch ev.Kind {
|
|
case EventNoEvent:
|
|
break clientLoop
|
|
case EventWriteInitialData:
|
|
if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
|
|
serverErr = err
|
|
return
|
|
}
|
|
case EventWriteHandshakeData:
|
|
if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
|
|
serverErr = err
|
|
return
|
|
}
|
|
case EventHandshakeComplete:
|
|
clientHandshakeComplete = true
|
|
default:
|
|
clientEvents = append(clientEvents, ev)
|
|
}
|
|
}
|
|
|
|
serverLoop:
|
|
for {
|
|
ev := server.NextEvent()
|
|
//nolint:exhaustive // only need to process a few events
|
|
switch ev.Kind {
|
|
case EventNoEvent:
|
|
break serverLoop
|
|
case EventWriteInitialData:
|
|
if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
|
|
clientErr = err
|
|
return
|
|
}
|
|
case EventWriteHandshakeData:
|
|
if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
|
|
clientErr = err
|
|
return
|
|
}
|
|
case EventHandshakeComplete:
|
|
serverHandshakeComplete = true
|
|
ticket, err := server.GetSessionTicket()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
if ticket != nil {
|
|
Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
|
|
}
|
|
default:
|
|
serverEvents = append(serverEvents, ev)
|
|
}
|
|
}
|
|
|
|
if clientHandshakeComplete && serverHandshakeComplete {
|
|
break
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
handshakeWithTLSConf := func(
|
|
clientConf, serverConf *tls.Config,
|
|
clientRTTStats, serverRTTStats *utils.RTTStats,
|
|
clientTransportParameters, serverTransportParameters *wire.TransportParameters,
|
|
enable0RTT bool,
|
|
) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */
|
|
CryptoSetup /* server */, []Event /* more server events */, error, /* server error */
|
|
) {
|
|
client := NewCryptoSetupClient(
|
|
protocol.ConnectionID{},
|
|
clientTransportParameters,
|
|
clientConf,
|
|
enable0RTT,
|
|
clientRTTStats,
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
protocol.Version1,
|
|
)
|
|
|
|
if serverTransportParameters.StatelessResetToken == nil {
|
|
var token protocol.StatelessResetToken
|
|
serverTransportParameters.StatelessResetToken = &token
|
|
}
|
|
server := NewCryptoSetupServer(
|
|
protocol.ConnectionID{},
|
|
&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
|
|
&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
|
|
serverTransportParameters,
|
|
serverConf,
|
|
enable0RTT,
|
|
serverRTTStats,
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
protocol.Version1,
|
|
)
|
|
cEvents, cErr, sEvents, sErr := handshake(client, server)
|
|
return client, cEvents, cErr, server, sEvents, sErr
|
|
}
|
|
|
|
It("handshakes", func() {
|
|
_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("performs a HelloRetryRequst", func() {
|
|
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
|
|
_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("handshakes with client auth", func() {
|
|
clientConf.Certificates = []tls.Certificate{generateCert()}
|
|
serverConf.ClientAuth = tls.RequireAnyClientCert
|
|
_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("receives transport parameters", func() {
|
|
cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second}
|
|
client := NewCryptoSetupClient(
|
|
protocol.ConnectionID{},
|
|
cTransportParameters,
|
|
clientConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("client"),
|
|
protocol.Version1,
|
|
)
|
|
|
|
var token protocol.StatelessResetToken
|
|
sTransportParameters := &wire.TransportParameters{
|
|
MaxIdleTimeout: 1337 * time.Second,
|
|
StatelessResetToken: &token,
|
|
ActiveConnectionIDLimit: 2,
|
|
}
|
|
server := NewCryptoSetupServer(
|
|
protocol.ConnectionID{},
|
|
&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
|
|
&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
|
|
sTransportParameters,
|
|
serverConf,
|
|
false,
|
|
&utils.RTTStats{},
|
|
nil,
|
|
utils.DefaultLogger.WithPrefix("server"),
|
|
protocol.Version1,
|
|
)
|
|
|
|
clientEvents, cErr, serverEvents, sErr := handshake(client, server)
|
|
Expect(cErr).ToNot(HaveOccurred())
|
|
Expect(sErr).ToNot(HaveOccurred())
|
|
var clientReceivedTransportParameters *wire.TransportParameters
|
|
for _, ev := range clientEvents {
|
|
if ev.Kind == EventReceivedTransportParameters {
|
|
clientReceivedTransportParameters = ev.TransportParameters
|
|
}
|
|
}
|
|
Expect(clientReceivedTransportParameters).ToNot(BeNil())
|
|
Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second))
|
|
|
|
var serverReceivedTransportParameters *wire.TransportParameters
|
|
for _, ev := range serverEvents {
|
|
if ev.Kind == EventReceivedTransportParameters {
|
|
serverReceivedTransportParameters = ev.TransportParameters
|
|
}
|
|
}
|
|
Expect(serverReceivedTransportParameters).ToNot(BeNil())
|
|
Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second))
|
|
})
|
|
|
|
Context("with session tickets", func() {
|
|
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
|
|
client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
|
|
// inject an invalid session ticket
|
|
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
|
err := client.HandleMessage(b, protocol.EncryptionHandshake)
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
|
|
})
|
|
|
|
It("errors when handling the NewSessionTicket fails", func() {
|
|
client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
|
|
// inject an invalid session ticket
|
|
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
|
|
err := client.HandleMessage(b, protocol.Encryption1RTT)
|
|
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
|
|
Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
|
|
})
|
|
|
|
It("uses session resumption", func() {
|
|
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
|
|
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
|
|
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientOrigRTTStats, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
clientRTTStats := &utils.RTTStats{}
|
|
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientRTTStats, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
|
|
})
|
|
|
|
It("doesn't use session resumption if the server disabled it", func() {
|
|
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
|
|
serverConf.SessionTicketsDisabled = true
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
&utils.RTTStats{}, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
false,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
})
|
|
|
|
It("uses 0-RTT", func() {
|
|
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
|
|
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
|
|
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
|
|
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
|
|
const initialMaxData protocol.ByteCount = 1337
|
|
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientOrigRTTStats, serverOrigRTTStats,
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
|
|
true,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
|
|
clientRTTStats := &utils.RTTStats{}
|
|
serverRTTStats := &utils.RTTStats{}
|
|
client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientRTTStats, serverRTTStats,
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
|
|
true,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
|
|
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
|
|
|
|
var tp *wire.TransportParameters
|
|
var clientReceived0RTTKeys bool
|
|
for _, ev := range clientEvents {
|
|
//nolint:exhaustive // only need to process a few events
|
|
switch ev.Kind {
|
|
case EventRestoredTransportParameters:
|
|
tp = ev.TransportParameters
|
|
case EventReceivedReadKeys:
|
|
clientReceived0RTTKeys = true
|
|
}
|
|
}
|
|
Expect(clientReceived0RTTKeys).To(BeTrue())
|
|
Expect(tp).ToNot(BeNil())
|
|
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
|
|
|
|
var serverReceived0RTTKeys bool
|
|
for _, ev := range serverEvents {
|
|
//nolint:exhaustive // only need to process a few events
|
|
switch ev.Kind {
|
|
case EventReceivedReadKeys:
|
|
serverReceived0RTTKeys = true
|
|
}
|
|
}
|
|
Expect(serverReceived0RTTKeys).To(BeTrue())
|
|
|
|
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(server.ConnectionState().Used0RTT).To(BeTrue())
|
|
Expect(client.ConnectionState().Used0RTT).To(BeTrue())
|
|
})
|
|
|
|
It("rejects 0-RTT, when the transport parameters changed", func() {
|
|
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
|
var state *tls.ClientSessionState
|
|
receivedSessionTicket := make(chan struct{})
|
|
csc.EXPECT().Get(gomock.Any())
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
|
state = css
|
|
close(receivedSessionTicket)
|
|
})
|
|
clientConf.ClientSessionCache = csc
|
|
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
|
|
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
|
|
const initialMaxData protocol.ByteCount = 1337
|
|
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientOrigRTTStats, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
|
|
true,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Eventually(receivedSessionTicket).Should(BeClosed())
|
|
Expect(server.ConnectionState().DidResume).To(BeFalse())
|
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
|
|
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
|
|
clientRTTStats := &utils.RTTStats{}
|
|
client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf(
|
|
clientConf, serverConf,
|
|
clientRTTStats, &utils.RTTStats{},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1},
|
|
true,
|
|
)
|
|
Expect(clientErr).ToNot(HaveOccurred())
|
|
Expect(serverErr).ToNot(HaveOccurred())
|
|
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
|
|
|
|
var tp *wire.TransportParameters
|
|
var clientReceived0RTTKeys bool
|
|
for _, ev := range clientEvents {
|
|
//nolint:exhaustive // only need to process a few events
|
|
switch ev.Kind {
|
|
case EventRestoredTransportParameters:
|
|
tp = ev.TransportParameters
|
|
case EventReceivedReadKeys:
|
|
clientReceived0RTTKeys = true
|
|
}
|
|
}
|
|
Expect(clientReceived0RTTKeys).To(BeTrue())
|
|
Expect(tp).ToNot(BeNil())
|
|
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
|
|
|
|
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
|
Expect(server.ConnectionState().Used0RTT).To(BeFalse())
|
|
Expect(client.ConnectionState().Used0RTT).To(BeFalse())
|
|
})
|
|
})
|
|
})
|
|
})
|