use a synchronous API for the crypto setup (#3939)

This commit is contained in:
Marten Seemann 2023-07-21 10:00:42 -07:00 committed by GitHub
parent 2c0e7e02b0
commit 469a6153b6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 696 additions and 1032 deletions

View file

@ -57,6 +57,8 @@ type cryptoStreamHandler interface {
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
GetSessionTicket() ([]byte, error)
NextEvent() handshake.Event
DiscardInitialKeys()
io.Closer
ConnectionState() handshake.ConnectionState
}
@ -96,18 +98,6 @@ type connRunner interface {
RemoveResetToken(protocol.StatelessResetToken)
}
type handshakeRunner struct {
onReceivedParams func(*wire.TransportParameters)
onReceivedReadKeys func()
dropKeys func(protocol.EncryptionLevel)
onHandshakeComplete func()
}
func (r *handshakeRunner) OnReceivedParams(tp *wire.TransportParameters) { r.onReceivedParams(tp) }
func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) }
func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() }
func (r *handshakeRunner) OnReceivedReadKeys() { r.onReceivedReadKeys() }
type closeError struct {
err error
remote bool
@ -165,6 +155,8 @@ type connection struct {
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
initialStream cryptoStream
handshakeStream cryptoStream
oneRTTStream cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler
@ -183,12 +175,10 @@ type connection struct {
undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []receivedPacket
clientHelloWritten <-chan *wire.TransportParameters
earlyConnReadyChan chan struct{}
handshakeCompleteChan chan struct{} // is closed when the handshake completes
sentFirstPacket bool
handshakeComplete bool
handshakeConfirmed bool
earlyConnReadyChan chan struct{}
sentFirstPacket bool
handshakeComplete bool
handshakeConfirmed bool
receivedRetry bool
versionNegotiated bool
@ -248,17 +238,16 @@ var newConnection = func(
v protocol.VersionNumber,
) quicConn {
s := &connection{
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(),
perspective: protocol.PerspectiveServer,
handshakeCompleteChan: make(chan struct{}),
tracer: tracer,
logger: logger,
version: v,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(),
perspective: protocol.PerspectiveServer,
tracer: tracer,
logger: logger,
version: v,
}
if origDestConnID.Len() > 0 {
s.logID = origDestConnID.String()
@ -294,8 +283,6 @@ var newConnection = func(
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -327,20 +314,8 @@ var newConnection = func(
s.tracer.SentTransportParameters(params)
}
cs := handshake.NewCryptoSetupServer(
initialStream,
handshakeStream,
s.oneRTTStream,
clientDestConnID,
params,
&handshakeRunner{
onReceivedParams: s.handleTransportParameters,
dropKeys: s.dropEncryptionLevel,
onReceivedReadKeys: s.receivedReadKeys,
onHandshakeComplete: func() {
runner.Retire(clientDestConnID)
close(s.handshakeCompleteChan)
},
},
tlsConf,
conf.Allow0RTT,
s.rttStats,
@ -349,9 +324,9 @@ var newConnection = func(
s.version,
)
s.cryptoStreamHandler = cs
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, s.oneRTTStream)
return s
}
@ -373,18 +348,17 @@ var newClientConnection = func(
v protocol.VersionNumber,
) quicConn {
s := &connection{
conn: conn,
config: conf,
origDestConnID: destConnID,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
perspective: protocol.PerspectiveClient,
handshakeCompleteChan: make(chan struct{}),
logID: destConnID.String(),
logger: logger,
tracer: tracer,
versionNegotiated: hasNegotiatedVersion,
version: v,
conn: conn,
config: conf,
origDestConnID: destConnID,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
perspective: protocol.PerspectiveClient,
logID: destConnID.String(),
logger: logger,
tracer: tracer,
versionNegotiated: hasNegotiatedVersion,
version: v,
}
s.connIDManager = newConnIDManager(
destConnID,
@ -415,8 +389,6 @@ var newClientConnection = func(
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -445,18 +417,9 @@ var newClientConnection = func(
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
}
cs, clientHelloWritten := handshake.NewCryptoSetupClient(
initialStream,
handshakeStream,
oneRTTStream,
cs := handshake.NewCryptoSetupClient(
destConnID,
params,
&handshakeRunner{
onReceivedParams: s.handleTransportParameters,
dropKeys: s.dropEncryptionLevel,
onReceivedReadKeys: s.receivedReadKeys,
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
},
tlsConf,
enable0RTT,
s.rttStats,
@ -464,11 +427,10 @@ var newClientConnection = func(
logger,
s.version,
)
s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
if len(tlsConf.ServerName) > 0 {
s.tokenStoreKey = tlsConf.ServerName
} else {
@ -483,6 +445,8 @@ var newClientConnection = func(
}
func (s *connection) preSetup() {
s.initialStream = newCryptoStream()
s.handshakeStream = newCryptoStream()
s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue()
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams)
@ -535,6 +499,9 @@ func (s *connection) run() error {
if err := s.cryptoStreamHandler.StartHandshake(); err != nil {
return err
}
if err := s.handleHandshakeEvents(); err != nil {
return err
}
go func() {
if err := s.sendQueue.Run(); err != nil {
s.destroyImpl(err)
@ -542,17 +509,7 @@ func (s *connection) run() error {
}()
if s.perspective == protocol.PerspectiveClient {
select {
case zeroRTTParams := <-s.clientHelloWritten:
s.scheduleSending()
if zeroRTTParams != nil {
s.restoreTransportParameters(zeroRTTParams)
close(s.earlyConnReadyChan)
}
case closeErr := <-s.closeChan:
// put the close error back into the channel, so that the run loop can receive it
s.closeChan <- closeErr
}
s.scheduleSending() // so the ClientHello actually gets sent
}
var sendQueueAvailable <-chan struct{}
@ -563,8 +520,6 @@ runLoop:
select {
case closeErr = <-s.closeChan:
break runLoop
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
default:
}
@ -635,8 +590,6 @@ runLoop:
if !wasProcessed {
continue
}
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
}
}
@ -762,9 +715,8 @@ func (s *connection) idleTimeoutStartTime() time.Time {
return utils.MaxTime(s.lastPacketReceivedTime, s.firstAckElicitingPacketAfterIdleSentTime)
}
func (s *connection) handleHandshakeComplete() {
func (s *connection) handleHandshakeComplete() error {
s.handshakeComplete = true
s.handshakeCompleteChan = nil // prevent this case from ever being selected again
defer s.handshakeCtxCancel()
// Once the handshake completes, we have derived 1-RTT keys.
// There's no point in queueing undecryptable packets for later decryption any more.
@ -775,14 +727,16 @@ func (s *connection) handleHandshakeComplete() {
if s.perspective == protocol.PerspectiveClient {
s.applyTransportParameters()
return
return nil
}
s.handleHandshakeConfirmed()
if err := s.handleHandshakeConfirmed(); err != nil {
return err
}
ticket, err := s.cryptoStreamHandler.GetSessionTicket()
if err != nil {
s.closeLocal(err)
return err
}
if ticket != nil { // may be nil if session tickets are disabled via tls.Config.SessionTicketsDisabled
s.oneRTTStream.Write(ticket)
@ -792,13 +746,18 @@ func (s *connection) handleHandshakeComplete() {
}
token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr())
if err != nil {
s.closeLocal(err)
return err
}
s.queueControlFrame(&wire.NewTokenFrame{Token: token})
s.queueControlFrame(&wire.HandshakeDoneFrame{})
return nil
}
func (s *connection) handleHandshakeConfirmed() {
func (s *connection) handleHandshakeConfirmed() error {
if err := s.dropEncryptionLevel(protocol.EncryptionHandshake); err != nil {
return err
}
s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
s.cryptoStreamHandler.SetHandshakeConfirmed()
@ -810,6 +769,7 @@ func (s *connection) handleHandshakeConfirmed() {
}
s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize))
}
return nil
}
func (s *connection) handlePacketImpl(rp receivedPacket) bool {
@ -1211,6 +1171,14 @@ func (s *connection) handleUnpackedLongHeaderPacket(
}
}
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake {
// On the server side, Initial keys are dropped as soon as the first Handshake packet is received.
// See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
return err
}
}
s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false
@ -1376,13 +1344,41 @@ func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame
}
func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
if err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil {
return err
}
return s.handleHandshakeEvents()
}
func (s *connection) receivedReadKeys() {
// Queue all packets for decryption that have been undecryptable so far.
s.undecryptablePacketsToProcess = s.undecryptablePackets
s.undecryptablePackets = nil
func (s *connection) handleHandshakeEvents() error {
for {
ev := s.cryptoStreamHandler.NextEvent()
var err error
switch ev.Kind {
case handshake.EventNoEvent:
return nil
case handshake.EventHandshakeComplete:
err = s.handleHandshakeComplete()
case handshake.EventReceivedTransportParameters:
err = s.handleTransportParameters(ev.TransportParameters)
case handshake.EventRestoredTransportParameters:
s.restoreTransportParameters(ev.TransportParameters)
close(s.earlyConnReadyChan)
case handshake.EventReceivedReadKeys:
// Queue all packets for decryption that have been undecryptable so far.
s.undecryptablePacketsToProcess = s.undecryptablePackets
s.undecryptablePackets = nil
case handshake.EventDiscard0RTTKeys:
err = s.dropEncryptionLevel(protocol.Encryption0RTT)
case handshake.EventWriteInitialData:
_, err = s.initialStream.Write(ev.Data)
case handshake.EventWriteHandshakeData:
_, err = s.handshakeStream.Write(ev.Data)
}
if err != nil {
return err
}
}
}
func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error {
@ -1491,7 +1487,9 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr
return nil
}
if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed {
s.handleHandshakeConfirmed()
if err := s.handleHandshakeConfirmed(); err != nil {
return err
}
}
return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked())
}
@ -1623,25 +1621,24 @@ func (s *connection) handleCloseError(closeErr *closeError) {
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket)
}
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error {
if s.tracer != nil {
s.tracer.DroppedEncryptionLevel(encLevel)
}
s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel)
if err := s.cryptoStreamManager.Drop(encLevel); err != nil {
s.closeLocal(err)
return
}
if encLevel == protocol.Encryption0RTT {
//nolint:exhaustive // only Initial and 0-RTT need special treatment
switch encLevel {
case protocol.EncryptionInitial:
s.cryptoStreamHandler.DiscardInitialKeys()
case protocol.Encryption0RTT:
s.streamsMap.ResetFor0RTT()
if err := s.connFlowController.Reset(); err != nil {
s.closeLocal(err)
}
if err := s.framer.Handle0RTTRejection(); err != nil {
s.closeLocal(err)
return err
}
return s.framer.Handle0RTTRejection()
}
return s.cryptoStreamManager.Drop(encLevel)
}
// is called for the client, when restoring transport parameters saved for 0-RTT
@ -1659,13 +1656,12 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters
s.connStateMutex.Unlock()
}
func (s *connection) handleTransportParameters(params *wire.TransportParameters) {
func (s *connection) handleTransportParameters(params *wire.TransportParameters) error {
if err := s.checkTransportParameters(params); err != nil {
s.closeLocal(&qerr.TransportError{
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
})
return
}
}
s.peerParams = params
// On the client side we have to wait for handshake completion.
@ -1680,6 +1676,7 @@ func (s *connection) handleTransportParameters(params *wire.TransportParameters)
s.connStateMutex.Lock()
s.connState.SupportsDatagrams = s.supportsDatagrams()
s.connStateMutex.Unlock()
return nil
}
func (s *connection) checkTransportParameters(params *wire.TransportParameters) error {
@ -1826,7 +1823,9 @@ func (s *connection) sendPackets(now time.Time) error {
return err
}
s.sentFirstPacket = true
s.sendPackedCoalescedPacket(packet, now)
if err := s.sendPackedCoalescedPacket(packet, now); err != nil {
return err
}
sendMode := s.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
s.resetPacingDeadline()
@ -1946,8 +1945,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if packet == nil {
return nil
}
s.sendPackedCoalescedPacket(packet, time.Now())
return nil
return s.sendPackedCoalescedPacket(packet, time.Now())
}
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
@ -1991,8 +1989,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel)
}
s.sendPackedCoalescedPacket(packet, now)
return nil
return s.sendPackedCoalescedPacket(packet, now)
}
// appendPacket appends a new packet to the given packetBuffer.
@ -2022,7 +2019,7 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti
s.connIDManager.SentPacket()
}
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) {
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error {
s.logCoalescedPacket(packet)
for _, p := range packet.longHdrPackets {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
@ -2033,6 +2030,13 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
largestAcked = p.ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false)
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake {
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
// See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil {
return err
}
}
}
if p := packet.shortHdrPacket; p != nil {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
@ -2046,6 +2050,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
}
s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer, packet.buffer.Len())
return nil
}
func (s *connection) sendConnectionClose(e error) ([]byte, error) {

View file

@ -358,6 +358,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
Expect(conn.run()).To(MatchError(expectedErr))
}()
Expect(conn.handleFrame(&wire.ConnectionCloseFrame{
@ -386,6 +387,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
Expect(conn.run()).To(MatchError(testErr))
}()
ccf := &wire.ConnectionCloseFrame{
@ -434,6 +436,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
runErr <- conn.run()
}()
Eventually(areConnsRunning).Should(BeTrue())
@ -815,6 +818,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
expectReplaceWithClosed()
@ -857,6 +861,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
@ -892,6 +897,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
@ -917,6 +923,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
err := conn.run()
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
@ -941,6 +948,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
runErr <- conn.run()
}()
expectReplaceWithClosed()
@ -965,6 +973,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
err := conn.run()
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
@ -1054,6 +1063,7 @@ var _ = Describe("Connection", func() {
BeforeEach(func() {
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
})
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, receivedPacket) {
hdr := &wire.ExtendedHeader{
Header: wire.Header{
@ -1082,6 +1092,8 @@ var _ = Describe("Connection", func() {
hdr: &wire.ExtendedHeader{Header: wire.Header{}},
}, nil
})
cryptoSetup.EXPECT().DiscardInitialKeys()
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial)
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
Expect(conn.handlePacketImpl(packet)).To(BeTrue())
})
@ -1111,6 +1123,8 @@ var _ = Describe("Connection", func() {
},
}, nil
})
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial).AnyTimes()
cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes()
gomock.InOrder(
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()),
@ -1134,6 +1148,8 @@ var _ = Describe("Connection", func() {
}, nil
}),
)
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial).AnyTimes()
cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes()
gomock.InOrder(
tracer.EXPECT().BufferedPacket(gomock.Any(), protocol.ByteCount(len(packet1.data))),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()),
@ -1158,6 +1174,8 @@ var _ = Describe("Connection", func() {
}, nil
})
_, packet2 := getPacketWithLength(wrongConnID, 123)
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial).AnyTimes()
cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes()
// don't EXPECT any more calls to unpacker.UnpackLongHeader()
gomock.InOrder(
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()),
@ -1201,6 +1219,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
close(connDone)
}()
@ -1419,6 +1438,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1443,6 +1463,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1467,6 +1488,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1483,6 +1505,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1500,6 +1523,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1518,6 +1542,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1544,6 +1569,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1566,6 +1592,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1584,6 +1611,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1606,6 +1634,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
@ -1637,6 +1666,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
available := make(chan struct{}, 1)
@ -1668,6 +1698,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending() // no packet will get sent
@ -1691,6 +1722,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
conn.scheduleSending()
@ -1738,6 +1770,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
// don't EXPECT any calls to mconn.Write()
@ -1772,6 +1805,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
Eventually(written).Should(BeClosed())
@ -1836,6 +1870,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
@ -1859,18 +1894,21 @@ var _ = Describe("Connection", func() {
finishHandshake := make(chan struct{})
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
conn.sentPacketHandler = sph
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).AnyTimes()
sph.EXPECT().DropPackets(protocol.EncryptionHandshake)
sph.EXPECT().SetHandshakeConfirmed()
connRunner.EXPECT().Retire(clientDestConnID)
go func() {
defer GinkgoRecover()
<-finishHandshake
cryptoSetup.EXPECT().StartHandshake()
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket()
close(conn.handshakeCompleteChan)
conn.run()
}()
handshakeCtx := conn.HandshakeComplete()
@ -1889,18 +1927,21 @@ var _ = Describe("Connection", func() {
Eventually(conn.Context().Done()).Should(BeClosed())
})
It("sends a connection ticket when the handshake completes", func() {
It("sends a session ticket when the handshake completes", func() {
const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2
packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes()
finishHandshake := make(chan struct{})
connRunner.EXPECT().Retire(clientDestConnID)
conn.sentPacketHandler.DropPackets(protocol.EncryptionInitial)
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake)
go func() {
defer GinkgoRecover()
<-finishHandshake
cryptoSetup.EXPECT().StartHandshake()
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
close(conn.handshakeCompleteChan)
conn.run()
}()
@ -1945,6 +1986,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake()
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
handshakeCtx := conn.HandshakeComplete()
@ -1975,13 +2017,16 @@ var _ = Describe("Connection", func() {
return shortHeaderPacket{}, nil
})
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake)
sph.EXPECT().DropPackets(protocol.EncryptionHandshake)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake()
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
close(conn.handshakeCompleteChan)
conn.run()
}()
Eventually(done).Should(BeClosed())
@ -2001,6 +2046,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
Expect(conn.run()).To(Succeed())
close(done)
}()
@ -2026,6 +2072,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
err := conn.run()
Expect(err).To(MatchError(expectedErr))
close(done)
@ -2076,6 +2123,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
}
@ -2178,6 +2226,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
err := conn.run()
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
@ -2203,6 +2252,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
err := conn.run()
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
@ -2236,6 +2286,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
@ -2263,6 +2314,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
err := conn.run()
nerr, ok := err.(net.Error)
@ -2275,6 +2327,7 @@ var _ = Describe("Connection", func() {
})
It("closes the connection due to the idle timeout after handshake", func() {
conn.sentPacketHandler.DropPackets(protocol.EncryptionInitial)
packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes()
gomock.InOrder(
connRunner.EXPECT().Retire(clientDestConnID),
@ -2282,6 +2335,7 @@ var _ = Describe("Connection", func() {
)
cryptoSetup.EXPECT().Close()
gomock.InOrder(
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake),
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
Expect(e).To(MatchError(&IdleTimeoutError{}))
}),
@ -2292,9 +2346,10 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1)
close(conn.handshakeCompleteChan)
err := conn.run()
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
@ -2312,6 +2367,7 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
Consistently(conn.Context().Done()).ShouldNot(BeClosed())
@ -2326,10 +2382,10 @@ var _ = Describe("Connection", func() {
Eventually(conn.Context().Done()).Should(BeClosed())
})
It("time out earliest after 3 times the PTO", func() {
It("times out earliest after 3 times the PTO", func() {
packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes()
connRunner.EXPECT().Retire(clientDestConnID)
connRunner.EXPECT().Remove(gomock.Any())
connRunner.EXPECT().Retire(gomock.Any()).AnyTimes()
connRunner.EXPECT().Remove(gomock.Any()).Times(2)
cryptoSetup.EXPECT().Close()
closeTimeChan := make(chan time.Time)
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
@ -2343,9 +2399,9 @@ var _ = Describe("Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1)
close(conn.handshakeCompleteChan)
conn.run()
close(done)
}()
@ -2448,15 +2504,12 @@ var _ = Describe("Client Connection", func() {
b, err := hdr.Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred())
return receivedPacket{
data: append(b, data...),
buffer: getPacketBuffer(),
rcvTime: time.Now(),
data: append(b, data...),
buffer: getPacketBuffer(),
}
}
expectReplaceWithClosed := func() {
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any())
}
BeforeEach(func() {
quicConf = populateConfig(&Config{})
tlsConf = nil
@ -2512,11 +2565,8 @@ var _ = Describe("Client Connection", func() {
}, nil
})
conn.unpacker = unpacker
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
conn.run()
}()
done := make(chan struct{})
packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) { close(done) })
newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7})
p := getPacket(&wire.ExtendedHeader{
Header: wire.Header{
@ -2530,15 +2580,23 @@ var _ = Describe("Client Connection", func() {
}, []byte("foobar"))
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), p.Size(), []logging.Frame{})
Expect(conn.handlePacketImpl(p)).To(BeTrue())
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
conn.run()
}()
Eventually(done).Should(BeClosed())
// make sure the go routine returns
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
expectReplaceWithClosed()
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any())
mconn.EXPECT().Write(gomock.Any(), gomock.Any()).MaxTimes(1)
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
Eventually(conn.Context().Done()).Should(BeClosed())
time.Sleep(200 * time.Millisecond)
})
It("continues accepting Long Header packets after using a new connection ID", func() {
@ -2572,6 +2630,8 @@ var _ = Describe("Client Connection", func() {
conn.peerParams = &wire.TransportParameters{}
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
conn.sentPacketHandler = sph
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake)
sph.EXPECT().DropPackets(protocol.EncryptionHandshake)
sph.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().SetHandshakeConfirmed()
Expect(conn.handleHandshakeDoneFrame()).To(Succeed())
@ -2582,7 +2642,9 @@ var _ = Describe("Client Connection", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
conn.sentPacketHandler = sph
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 3}}}
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake)
sph.EXPECT().ReceivedAck(ack, protocol.Encryption1RTT, gomock.Any()).Return(true, nil)
sph.EXPECT().DropPackets(protocol.EncryptionHandshake)
sph.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().SetLargest1RTTAcked(protocol.PacketNumber(3))
cryptoSetup.EXPECT().SetHandshakeConfirmed()
@ -2598,6 +2660,7 @@ var _ = Describe("Client Connection", func() {
close(running)
conn.closeLocal(errors.New("early error"))
})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().Close()
connRunner.EXPECT().Remove(gomock.Any())
go func() {
@ -2633,8 +2696,9 @@ var _ = Describe("Client Connection", func() {
versions,
)
return receivedPacket{
data: b,
buffer: getPacketBuffer(),
rcvTime: time.Now(),
data: b,
buffer: getPacketBuffer(),
}
}
@ -2645,9 +2709,14 @@ var _ = Describe("Client Connection", func() {
sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4)
conn.config.Versions = []protocol.VersionNumber{1234, 4321}
errChan := make(chan error, 1)
start := make(chan struct{})
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().DoAndReturn(func() handshake.Event {
<-start
return handshake.Event{Kind: handshake.EventNoEvent}
})
errChan <- conn.run()
}()
connRunner.EXPECT().Remove(srcConnID)
@ -2659,6 +2728,7 @@ var _ = Describe("Client Connection", func() {
})
cryptoSetup.EXPECT().Close()
Expect(conn.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse())
close(start)
var err error
Eventually(errChan).Should(Receive(&err))
Expect(err).To(HaveOccurred())
@ -2673,9 +2743,11 @@ var _ = Describe("Client Connection", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
errChan <- conn.run()
}()
connRunner.EXPECT().Remove(srcConnID).MaxTimes(1)
packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
gomock.InOrder(
tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()),
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
@ -2771,23 +2843,38 @@ var _ = Describe("Client Connection", func() {
Context("transport parameters", func() {
var (
closed bool
errChan chan error
closed bool
errChan chan error
paramsChan chan *wire.TransportParameters
)
JustBeforeEach(func() {
errChan = make(chan error, 1)
paramsChan = make(chan *wire.TransportParameters, 1)
closed = false
packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().StartHandshake().MaxTimes(1)
// This is not 100% what would happen in reality.
// The run loop calls NextEvent once when it starts up (to send out the ClientHello),
// and then again every time a CRYPTO frame is handled.
// Injecting a CRYPTO frame is not straightforward though,
// so we inject the transport parameters on the first call to NextEvent.
params := <-paramsChan
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{
Kind: handshake.EventReceivedTransportParameters,
TransportParameters: params,
})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}).MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}).MaxTimes(1)
errChan <- conn.run()
close(errChan)
}()
})
expectClose := func(applicationClose bool) {
if !closed {
expectClose := func(applicationClose, errored bool) {
if !closed && !errored {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any())
if applicationClose {
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1)
@ -2822,9 +2909,10 @@ var _ = Describe("Client Connection", func() {
},
}
packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).MaxTimes(1)
tracer.EXPECT().ReceivedTransportParameters(params)
conn.handleTransportParameters(params)
conn.handleHandshakeComplete()
processed := make(chan struct{})
tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) })
paramsChan <- params
Eventually(processed).Should(BeClosed())
// make sure the connection ID is not retired
cf, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount, protocol.Version1)
Expect(cf).To(BeEmpty())
@ -2832,7 +2920,7 @@ var _ = Describe("Client Connection", func() {
Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4})))
// shut down
connRunner.EXPECT().RemoveResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})
expectClose(true)
expectClose(true, false)
})
It("uses the minimum of the peers' idle timeouts", func() {
@ -2842,11 +2930,15 @@ var _ = Describe("Client Connection", func() {
InitialSourceConnectionID: destConnID,
MaxIdleTimeout: 18 * time.Second,
}
tracer.EXPECT().ReceivedTransportParameters(params)
conn.handleTransportParameters(params)
conn.handleHandshakeComplete()
processed := make(chan struct{})
tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) })
paramsChan <- params
Eventually(processed).Should(BeClosed())
// close first
expectClose(true, false)
conn.shutdown()
// then check. Avoids race condition when accessing idleTimeout
Expect(conn.idleTimeout).To(Equal(18 * time.Second))
expectClose(true)
})
It("errors if the transport parameters contain a wrong initial_source_connection_id", func() {
@ -2856,9 +2948,11 @@ var _ = Describe("Client Connection", func() {
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
expectClose(false)
tracer.EXPECT().ReceivedTransportParameters(params)
conn.handleTransportParameters(params)
expectClose(false, true)
processed := make(chan struct{})
tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) })
paramsChan <- params
Eventually(processed).Should(BeClosed())
Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: "expected initial_source_connection_id to equal deadbeef, is decafbad",
@ -2873,9 +2967,11 @@ var _ = Describe("Client Connection", func() {
InitialSourceConnectionID: destConnID,
StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
expectClose(false)
tracer.EXPECT().ReceivedTransportParameters(params)
conn.handleTransportParameters(params)
expectClose(false, true)
processed := make(chan struct{})
tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) })
paramsChan <- params
Eventually(processed).Should(BeClosed())
Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: "missing retry_source_connection_id",
@ -2892,9 +2988,11 @@ var _ = Describe("Client Connection", func() {
RetrySourceConnectionID: &rcid2,
StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
expectClose(false)
tracer.EXPECT().ReceivedTransportParameters(params)
conn.handleTransportParameters(params)
expectClose(false, true)
processed := make(chan struct{})
tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) })
paramsChan <- params
Eventually(processed).Should(BeClosed())
Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: "expected retry_source_connection_id to equal deadbeef, is deadc0de",
@ -2909,9 +3007,11 @@ var _ = Describe("Client Connection", func() {
RetrySourceConnectionID: &rcid,
StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
expectClose(false)
tracer.EXPECT().ReceivedTransportParameters(params)
conn.handleTransportParameters(params)
expectClose(false, true)
processed := make(chan struct{})
tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) })
paramsChan <- params
Eventually(processed).Should(BeClosed())
Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: "received retry_source_connection_id, although no Retry was performed",
@ -2925,9 +3025,11 @@ var _ = Describe("Client Connection", func() {
InitialSourceConnectionID: conn.handshakeDestConnID,
StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
expectClose(false)
tracer.EXPECT().ReceivedTransportParameters(params)
conn.handleTransportParameters(params)
expectClose(false, true)
processed := make(chan struct{})
tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) })
paramsChan <- params
Eventually(processed).Should(BeClosed())
Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: "expected original_destination_connection_id to equal deadbeef, is decafbad",

View file

@ -3,12 +3,14 @@ package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) error
NextEvent() handshake.Event
}
type cryptoStreamManager struct {
@ -74,8 +76,6 @@ func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
return m.initialStream.Finish()
case protocol.EncryptionHandshake:
return m.handshakeStream.Finish()
case protocol.Encryption0RTT:
return nil
default:
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
}

View file

@ -87,8 +87,4 @@ var _ = Describe("Crypto Stream Manager", func() {
handshakeStream.EXPECT().Finish()
Expect(csm.Drop(protocol.EncryptionHandshake)).To(Succeed())
})
It("no-ops when dropping 0-RTT", func() {
Expect(csm.Drop(protocol.Encryption0RTT)).To(Succeed())
})
})

View file

@ -13,70 +13,12 @@ import (
"github.com/quic-go/quic-go/internal/wire"
)
type chunk struct {
data []byte
encLevel protocol.EncryptionLevel
}
type stream struct {
chunkChan chan<- chunk
encLevel protocol.EncryptionLevel
}
func (s *stream) Write(b []byte) (int, error) {
data := append([]byte{}, b...)
select {
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
default:
panic("chunkChan too small")
}
return len(b), nil
}
func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 10)
initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial}
handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake}
return chunkChan, initialStream, handshakeStream
}
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnReceivedReadKeys()
DropKeys(protocol.EncryptionLevel)
}
type runner struct {
handshakeComplete chan<- struct{}
}
var _ handshakeRunner = &runner{}
func newRunner(handshakeComplete chan<- struct{}) *runner {
return &runner{handshakeComplete: handshakeComplete}
}
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
func (r *runner) OnReceivedReadKeys() {}
func (r *runner) OnHandshakeComplete() {
close(r.handshakeComplete)
}
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
const alpn = "fuzz"
func main() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
var client, server handshake.CryptoSetup
clientHandshakeCompleted := make(chan struct{})
client, _ = handshake.NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
client := handshake.NewCryptoSetupClient(
protocol.ConnectionID{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
newRunner(clientHandshakeCompleted),
&tls.Config{
MinVersion: tls.VersionTLS13,
ServerName: "localhost",
@ -91,17 +33,11 @@ func main() {
protocol.Version1,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
config := testdata.GetTLSConfig()
config.NextProtos = []string{alpn}
serverHandshakeCompleted := make(chan struct{})
server = handshake.NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
server := handshake.NewCryptoSetupServer(
protocol.ConnectionID{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
newRunner(serverHandshakeCompleted),
config,
false,
utils.NewRTTStats(),
@ -118,29 +54,55 @@ func main() {
log.Fatal(err)
}
done := make(chan struct{})
go func() {
<-serverHandshakeCompleted
<-clientHandshakeCompleted
close(done)
}()
var clientHandshakeComplete, serverHandshakeComplete bool
var messages [][]byte
messageLoop:
for {
select {
case c := <-cChunkChan:
messages = append(messages, c.data)
if err := server.HandleMessage(c.data, c.encLevel); err != nil {
log.Fatal(err)
clientLoop:
for {
ev := client.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
break clientLoop
case handshake.EventWriteInitialData:
messages = append(messages, ev.Data)
if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
log.Fatal(err)
}
case handshake.EventWriteHandshakeData:
messages = append(messages, ev.Data)
if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
log.Fatal(err)
}
case handshake.EventHandshakeComplete:
clientHandshakeComplete = true
}
case c := <-sChunkChan:
messages = append(messages, c.data)
if err := client.HandleMessage(c.data, c.encLevel); err != nil {
log.Fatal(err)
}
serverLoop:
for {
ev := server.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
break serverLoop
case handshake.EventWriteInitialData:
messages = append(messages, ev.Data)
if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
log.Fatal(err)
}
case handshake.EventWriteHandshakeData:
messages = append(messages, ev.Data)
if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
log.Fatal(err)
}
case handshake.EventHandshakeComplete:
serverHandshakeComplete = true
}
case <-done:
break messageLoop
}
if serverHandshakeComplete && clientHandshakeComplete {
break
}
}

View file

@ -126,57 +126,6 @@ func getClientAuth(rand uint8) tls.ClientAuthType {
}
}
type chunk struct {
data []byte
encLevel protocol.EncryptionLevel
}
type stream struct {
chunkChan chan<- chunk
encLevel protocol.EncryptionLevel
}
func (s *stream) Write(b []byte) (int, error) {
data := append([]byte{}, b...)
select {
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
default:
panic("chunkChan too small")
}
return len(b), nil
}
func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 10)
initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial}
handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake}
return chunkChan, initialStream, handshakeStream
}
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnReceivedReadKeys()
DropKeys(protocol.EncryptionLevel)
}
type runner struct {
handshakeComplete chan<- struct{}
}
var _ handshakeRunner = &runner{}
func newRunner(handshakeComplete chan<- struct{}) *runner {
return &runner{handshakeComplete: handshakeComplete}
}
func (r *runner) OnReceivedParams(*wire.TransportParameters) {}
func (r *runner) OnReceivedReadKeys() {}
func (r *runner) OnHandshakeComplete() {
close(r.handshakeComplete)
}
func (r *runner) DropKeys(protocol.EncryptionLevel) {}
const (
alpn = "fuzzing"
alpnWrong = "wrong"
@ -193,28 +142,6 @@ func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
}
}
func maxEncLevel(cs handshake.CryptoSetup, encLevel protocol.EncryptionLevel) protocol.EncryptionLevel {
//nolint:exhaustive
switch encLevel {
case protocol.EncryptionInitial:
return protocol.EncryptionInitial
case protocol.EncryptionHandshake:
// Handshake opener not available. We can't possibly read a Handshake handshake message.
if opener, err := cs.GetHandshakeOpener(); err != nil || opener == nil {
return protocol.EncryptionInitial
}
return protocol.EncryptionHandshake
case protocol.Encryption1RTT:
// 1-RTT opener not available. We can't possibly read a post-handshake message.
if opener, err := cs.Get1RTTOpener(); err != nil || opener == nil {
return maxEncLevel(cs, protocol.EncryptionHandshake)
}
return protocol.Encryption1RTT
default:
panic("unexpected encryption level")
}
}
func getTransportParameters(seed uint8) *wire.TransportParameters {
const maxVarInt = math.MaxUint64 / 4
r := mrand.New(mrand.NewSource(int64(seed)))
@ -357,16 +284,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
messageToReplace := messageConfig % 32
messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
var client, server handshake.CryptoSetup
clientHandshakeCompleted := make(chan struct{})
client, _ = handshake.NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
if len(data) == 0 {
return -1
}
client := handshake.NewCryptoSetupClient(
protocol.ConnectionID{},
clientTP,
newRunner(clientHandshakeCompleted),
clientConf,
enable0RTTClient,
utils.NewRTTStats(),
@ -374,16 +298,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
utils.DefaultLogger.WithPrefix("client"),
protocol.Version1,
)
if err := client.StartHandshake(); err != nil {
log.Fatal(err)
}
serverHandshakeCompleted := make(chan struct{})
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
server = handshake.NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
server := handshake.NewCryptoSetupServer(
protocol.ConnectionID{},
serverTP,
newRunner(serverHandshakeCompleted),
serverConf,
enable0RTTServer,
utils.NewRTTStats(),
@ -391,57 +312,69 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
if len(data) == 0 {
return -1
}
if err := client.StartHandshake(); err != nil {
log.Fatal(err)
}
if err := server.StartHandshake(); err != nil {
log.Fatal(err)
}
done := make(chan struct{})
go func() {
<-serverHandshakeCompleted
<-clientHandshakeCompleted
close(done)
}()
messageLoop:
var clientHandshakeComplete, serverHandshakeComplete bool
for {
select {
case c := <-cChunkChan:
b := c.data
encLevel := c.encLevel
if len(b) > 0 && b[0] == messageToReplace {
fmt.Printf("replacing %s message to the server with %s\n", messageType(b[0]), messageType(data[0]))
b = data
encLevel = maxEncLevel(server, messageToReplaceEncLevel)
clientLoop:
for {
var processedEvent bool
ev := client.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
if !processedEvent && !clientHandshakeComplete { // handshake stuck
return 1
}
break clientLoop
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
msg := ev.Data
if msg[0] == messageToReplace {
fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
msg = data
}
if err := server.HandleMessage(msg, messageToReplaceEncLevel); err != nil {
return 1
}
case handshake.EventHandshakeComplete:
clientHandshakeComplete = true
}
if err := server.HandleMessage(b, encLevel); err != nil {
break messageLoop
processedEvent = true
}
serverLoop:
for {
var processedEvent bool
ev := server.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
if !processedEvent && !serverHandshakeComplete { // handshake stuck
return 1
}
break serverLoop
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
msg := ev.Data
if msg[0] == messageToReplace {
fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
msg = data
}
if err := client.HandleMessage(msg, messageToReplaceEncLevel); err != nil {
return 1
}
case handshake.EventHandshakeComplete:
serverHandshakeComplete = true
}
case c := <-sChunkChan:
b := c.data
encLevel := c.encLevel
if len(b) > 0 && b[0] == messageToReplace {
fmt.Printf("replacing %s message to the client with %s\n", messageType(b[0]), messageType(data[0]))
b = data
encLevel = maxEncLevel(client, messageToReplaceEncLevel)
}
if err := client.HandleMessage(b, encLevel); err != nil {
break messageLoop
}
case <-done: // test done
break messageLoop
processedEvent = true
}
if serverHandshakeComplete && clientHandshakeComplete {
break
}
}
<-done
_ = client.ConnectionState()
_ = server.ConnectionState()

View file

@ -47,6 +47,11 @@ func (h *receivedPacketHandler) ReceivedPacket(
case protocol.EncryptionInitial:
return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
case protocol.EncryptionHandshake:
// The Handshake packet number space might already have been dropped as a result
// of processing the CRYPTO frame that was contained in this packet.
if h.handshakePackets == nil {
return nil
}
return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
case protocol.Encryption0RTT:
if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {

View file

@ -136,16 +136,6 @@ func newSentPacketHandler(
}
}
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial {
// This function is called when the crypto setup seals a Handshake packet.
// If this Handshake packet is coalesced behind an Initial packet, we would drop the Initial packet number space
// before SentPacket() was called for that Initial packet.
return
}
h.dropPackets(encLevel)
}
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight {
@ -156,7 +146,7 @@ func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
}
}
func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
// The server won't await address validation after the handshake is confirmed.
// This applies even if we didn't receive an ACK for a Handshake packet.
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake {
@ -165,6 +155,10 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
// remove outstanding packets from bytes_in_flight
if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake {
pnSpace := h.getPacketNumberSpace(encLevel)
// We might already have dropped this packet number space.
if pnSpace == nil {
return
}
pnSpace.history.Iterate(func(p *packet) (bool, error) {
h.removeFromBytesInFlight(p)
return true, nil
@ -238,10 +232,6 @@ func (h *sentPacketHandler) SentPacket(
isPathMTUProbePacket bool,
) {
h.bytesSent += size
// For the client, drop the Initial packet number space when the first Handshake packet is sent.
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake && h.initialPackets != nil {
h.dropPackets(protocol.EncryptionInitial)
}
pnSpace := h.getPacketNumberSpace(encLevel)
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
@ -884,6 +874,12 @@ func (h *sentPacketHandler) ResetForRetry() error {
}
func (h *sentPacketHandler) SetHandshakeConfirmed() {
if h.initialPackets != nil {
panic("didn't drop initial correctly")
}
if h.handshakePackets != nil {
panic("didn't drop handshake correctly")
}
h.handshakeConfirmed = true
// We don't send PTOs for application data packets before the handshake completes.
// Make sure the timer is armed now, if necessary.

View file

@ -130,6 +130,13 @@ var _ = Describe("SentPacketHandler", func() {
ExpectWithOffset(1, handler.rttStats.SmoothedRTT()).To(Equal(rtt))
}
// setHandshakeConfirmed drops both Initial and Handshake packets and then confirms the handshake
setHandshakeConfirmed := func() {
handler.DropPackets(protocol.EncryptionInitial)
handler.DropPackets(protocol.EncryptionHandshake)
handler.SetHandshakeConfirmed()
}
Context("registering sent packets", func() {
It("accepts two consecutive packets", func() {
sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, EncryptionLevel: protocol.EncryptionHandshake}))
@ -705,7 +712,7 @@ var _ = Describe("SentPacketHandler", func() {
It("implements exponential backoff", func() {
handler.peerAddressValidated = true
handler.SetHandshakeConfirmed()
setHandshakeConfirmed()
sendTime := time.Now().Add(-time.Hour)
sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: sendTime}))
timeout := handler.GetLossDetectionTimeout().Sub(sendTime)
@ -729,7 +736,7 @@ var _ = Describe("SentPacketHandler", func() {
It("reset the PTO count when receiving an ACK", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
now := time.Now()
handler.SetHandshakeConfirmed()
setHandshakeConfirmed()
sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)}))
sentPacket(ackElicitingPacket(&packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)}))
handler.appDataPackets.pns.(*skippingPacketNumberGenerator).next = 3
@ -770,7 +777,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.ptoCount).To(BeEquivalentTo(1))
Expect(handler.SendMode(time.Now())).To(Equal(SendPTOHandshake))
Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeHandshake.Add(handler.rttStats.PTO(false) << 1)))
handler.SetHandshakeConfirmed()
setHandshakeConfirmed()
handler.DropPackets(protocol.EncryptionHandshake)
// PTO timer based on the 1-RTT packet
Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeAppData.Add(handler.rttStats.PTO(true)))) // no backoff. PTO count = 0
@ -780,7 +787,7 @@ var _ = Describe("SentPacketHandler", func() {
It("allows two 1-RTT PTOs", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeConfirmed()
setHandshakeConfirmed()
var lostPackets []protocol.PacketNumber
sentPacket(ackElicitingPacket(&packet{
PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT),
@ -802,7 +809,7 @@ var _ = Describe("SentPacketHandler", func() {
It("only counts ack-eliciting packets as probe packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeConfirmed()
setHandshakeConfirmed()
sentPacket(ackElicitingPacket(&packet{
PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT),
SendTime: time.Now().Add(-time.Hour),
@ -821,7 +828,7 @@ var _ = Describe("SentPacketHandler", func() {
It("gets two probe packets if PTO expires", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeConfirmed()
setHandshakeConfirmed()
sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)}))
sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)}))
@ -869,7 +876,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.GetLossDetectionTimeout()).To(BeZero())
Expect(handler.SendMode(time.Now())).To(Equal(SendAny))
handler.SetHandshakeConfirmed()
setHandshakeConfirmed()
Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero())
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData))
@ -877,7 +884,7 @@ var _ = Describe("SentPacketHandler", func() {
It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeConfirmed()
setHandshakeConfirmed()
pn := handler.PopPacketNumber(protocol.Encryption1RTT)
sentPacket(ackElicitingPacket(&packet{PacketNumber: pn, SendTime: time.Now().Add(-time.Hour)}))
updateRTT(time.Second)
@ -902,7 +909,7 @@ var _ = Describe("SentPacketHandler", func() {
It("doesn't set the PTO timer for Path MTU probe packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeConfirmed()
setHandshakeConfirmed()
updateRTT(time.Second)
sentPacket(ackElicitingPacket(&packet{PacketNumber: 5, SendTime: time.Now(), IsPathMTUProbePacket: true}))
Expect(handler.GetLossDetectionTimeout()).To(BeZero())
@ -1021,6 +1028,7 @@ var _ = Describe("SentPacketHandler", func() {
// Now receive an ACK for a Handshake packet.
// This tells the client that the server completed address validation.
sentPacket(handshakePacket(&packet{PacketNumber: 1}))
handler.DropPackets(protocol.EncryptionInitial) // sending a Handshake packet drops the Initial packet number space
_, err = handler.ReceivedAck(
&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}},
protocol.EncryptionHandshake,
@ -1040,7 +1048,8 @@ var _ = Describe("SentPacketHandler", func() {
)
Expect(err).ToNot(HaveOccurred())
sentPacket(handshakePacketNonAckEliciting(&packet{PacketNumber: 1})) // also drops Initial packets
sentPacket(handshakePacketNonAckEliciting(&packet{PacketNumber: 1}))
handler.DropPackets(protocol.EncryptionInitial) // sending a Handshake packet drops the Initial packet number space
Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero())
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode(time.Now())).To(Equal(SendPTOHandshake))
@ -1075,7 +1084,7 @@ var _ = Describe("SentPacketHandler", func() {
)
Expect(err).ToNot(HaveOccurred())
sentPacket(handshakePacketNonAckEliciting(&packet{PacketNumber: 1, SendTime: time.Now()}))
Expect(handler.initialPackets).To(BeNil())
handler.DropPackets(protocol.EncryptionInitial) // sending a Handshake packet drops the Initial packet number space
pto := handler.rttStats.PTO(false)
Expect(pto).ToNot(BeZero())
@ -1235,39 +1244,6 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.handshakePackets.history.Len()).ToNot(BeZero())
})
Context("deleting Initials", func() {
BeforeEach(func() { perspective = protocol.PerspectiveClient })
It("deletes Initials, as a client", func() {
for i := 0; i < 6; i++ {
sentPacket(ackElicitingPacket(&packet{
PacketNumber: handler.PopPacketNumber(protocol.EncryptionInitial),
EncryptionLevel: protocol.EncryptionInitial,
Length: 1,
}))
}
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6)))
handler.DropPackets(protocol.EncryptionInitial)
// DropPackets should be ignored for clients and the Initial packet number space.
// It has to be possible to send another Initial packets after this function was called.
sentPacket(ackElicitingPacket(&packet{
PacketNumber: handler.PopPacketNumber(protocol.EncryptionInitial),
EncryptionLevel: protocol.EncryptionInitial,
Length: 1,
}))
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7)))
// Sending a Handshake packet triggers dropping of Initials.
sentPacket(ackElicitingPacket(&packet{
PacketNumber: handler.PopPacketNumber(protocol.EncryptionHandshake),
EncryptionLevel: protocol.EncryptionHandshake,
}))
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1)))
Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission
Expect(handler.initialPackets).To(BeNil())
Expect(handler.handshakePackets.history.Len()).ToNot(BeZero())
})
})
It("deletes Handshake packets", func() {
for i := protocol.PacketNumber(0); i < 6; i++ {
sentPacket(ackElicitingPacket(&packet{

View file

@ -92,69 +92,3 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
}
type handshakeSealer struct {
LongHeaderSealer
dropInitialKeys func()
dropped bool
}
func newHandshakeSealer(
aead cipher.AEAD,
headerProtector headerProtector,
dropInitialKeys func(),
perspective protocol.Perspective,
) LongHeaderSealer {
sealer := newLongHeaderSealer(aead, headerProtector)
// The client drops Initial keys when sending the first Handshake packet.
if perspective == protocol.PerspectiveServer {
return sealer
}
return &handshakeSealer{
LongHeaderSealer: sealer,
dropInitialKeys: dropInitialKeys,
}
}
func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
data := s.LongHeaderSealer.Seal(dst, src, pn, ad)
if !s.dropped {
s.dropInitialKeys()
s.dropped = true
}
return data
}
type handshakeOpener struct {
LongHeaderOpener
dropInitialKeys func()
dropped bool
}
func newHandshakeOpener(
aead cipher.AEAD,
headerProtector headerProtector,
dropInitialKeys func(),
perspective protocol.Perspective,
) LongHeaderOpener {
opener := newLongHeaderOpener(aead, headerProtector)
// The server drops Initial keys when first successfully processing a Handshake packet.
if perspective == protocol.PerspectiveClient {
return opener
}
return &handshakeOpener{
LongHeaderOpener: opener,
dropInitialKeys: dropInitialKeys,
}
}
func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad)
if err == nil && !o.dropped {
o.dropInitialKeys()
o.dropped = true
}
return dec, err
}

View file

@ -133,72 +133,5 @@ var _ = Describe("Long Header AEAD", func() {
})
}
})
Describe("Long Header AEAD", func() {
var (
dropped chan struct{} // use a chan because closing it twice will panic
aead cipher.AEAD
hp headerProtector
)
dropCb := func() { close(dropped) }
msg := []byte("Lorem ipsum dolor sit amet.")
ad := []byte("Donec in velit neque.")
BeforeEach(func() {
dropped = make(chan struct{})
key := make([]byte, 16)
hpKey := make([]byte, 16)
rand.Read(key)
rand.Read(hpKey)
block, err := aes.NewCipher(key)
Expect(err).ToNot(HaveOccurred())
aead, err = cipher.NewGCM(block)
Expect(err).ToNot(HaveOccurred())
hp = newHeaderProtector(cipherSuites[0], hpKey, true, protocol.Version1)
})
Context("for the server", func() {
It("drops keys when first successfully processing a Handshake packet", func() {
serverOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveServer)
// first try to open an invalid message
_, err := serverOpener.Open(nil, []byte("invalid"), 0, []byte("invalid"))
Expect(err).To(HaveOccurred())
Expect(dropped).ToNot(BeClosed())
// then open a valid message
enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 10, ad)
_, err = serverOpener.Open(nil, enc, 10, ad)
Expect(err).ToNot(HaveOccurred())
Expect(dropped).To(BeClosed())
// now open the same message again to make sure the callback is only called once
_, err = serverOpener.Open(nil, enc, 10, ad)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't drop keys when sealing a Handshake packet", func() {
serverSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveServer)
serverSealer.Seal(nil, msg, 1, ad)
Expect(dropped).ToNot(BeClosed())
})
})
Context("for the client", func() {
It("drops keys when first sealing a Handshake packet", func() {
clientSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveClient)
// seal the first message
clientSealer.Seal(nil, msg, 1, ad)
Expect(dropped).To(BeClosed())
// seal another message to make sure the callback is only called once
clientSealer.Seal(nil, msg, 2, ad)
})
It("doesn't drop keys when processing a Handshake packet", func() {
enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 42, ad)
clientOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveClient)
_, err := clientOpener.Open(nil, enc, 42, ad)
Expect(err).ToNot(HaveOccurred())
Expect(dropped).ToNot(BeClosed())
})
})
})
}
})

View file

@ -6,7 +6,6 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
@ -30,16 +29,15 @@ type cryptoSetup struct {
tlsConf *tls.Config
conn *qtls.QUICConn
events []Event
version protocol.VersionNumber
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
runner handshakeRunner
zeroRTTParameters *wire.TransportParameters
zeroRTTParametersChan chan<- *wire.TransportParameters
allow0RTT bool
zeroRTTParameters *wire.TransportParameters
allow0RTT bool
rttStats *utils.RTTStats
@ -55,17 +53,14 @@ type cryptoSetup struct {
zeroRTTOpener LongHeaderOpener // only set for the server
zeroRTTSealer LongHeaderSealer // only set for the client
initialStream io.Writer
initialOpener LongHeaderOpener
initialSealer LongHeaderSealer
handshakeStream io.Writer
handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer
used0RTT atomic.Bool
oneRTTStream io.Writer
aead *updatableAEAD
has1RTTSealer bool
has1RTTOpener bool
@ -75,24 +70,18 @@ var _ CryptoSetup = &cryptoSetup{}
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
initialStream, handshakeStream, oneRTTStream io.Writer,
connID protocol.ConnectionID,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
cs, clientHelloWritten := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
) CryptoSetup {
cs := newCryptoSetup(
connID,
tp,
runner,
rttStats,
tracer,
logger,
@ -109,15 +98,13 @@ func NewCryptoSetupClient(
cs.conn = qtls.QUICClient(quicConf)
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
return cs, clientHelloWritten
return cs
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
initialStream, handshakeStream, oneRTTStream io.Writer,
connID protocol.ConnectionID,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
allow0RTT bool,
rttStats *utils.RTTStats,
@ -125,13 +112,9 @@ func NewCryptoSetupServer(
logger utils.Logger,
version protocol.VersionNumber,
) CryptoSetup {
cs, _ := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
cs := newCryptoSetup(
connID,
tp,
runner,
rttStats,
tracer,
logger,
@ -150,38 +133,31 @@ func NewCryptoSetupServer(
}
func newCryptoSetup(
initialStream, handshakeStream, oneRTTStream io.Writer,
connID protocol.ConnectionID,
tp *wire.TransportParameters,
runner handshakeRunner,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
perspective protocol.Perspective,
version protocol.VersionNumber,
) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
) *cryptoSetup {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
if tracer != nil {
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
zeroRTTParametersChan := make(chan *wire.TransportParameters, 1)
return &cryptoSetup{
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
runner: runner,
ourParams: tp,
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
zeroRTTParametersChan: zeroRTTParametersChan,
version: version,
}, zeroRTTParametersChan
initialSealer: initialSealer,
initialOpener: initialOpener,
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
events: make([]Event, 0, 16),
ourParams: tp,
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
version: version,
}
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
@ -216,10 +192,9 @@ func (h *cryptoSetup) StartHandshake() error {
if h.perspective == protocol.PerspectiveClient {
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.zeroRTTParametersChan <- h.zeroRTTParameters
h.events = append(h.events, Event{Kind: EventRestoredTransportParameters, TransportParameters: h.zeroRTTParameters})
} else {
h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
h.zeroRTTParametersChan <- nil
}
}
return nil
@ -275,7 +250,8 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
h.rejected0RTT()
return false, nil
case qtls.QUICWriteData:
return false, h.WriteRecord(ev.Level, ev.Data)
h.WriteRecord(ev.Level, ev.Data)
return false, nil
case qtls.QUICHandshakeDone:
h.handshakeComplete()
return false, nil
@ -284,13 +260,22 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
}
}
func (h *cryptoSetup) NextEvent() Event {
if len(h.events) == 0 {
return Event{Kind: EventNoEvent}
}
ev := h.events[0]
h.events = h.events[1:]
return ev
}
func (h *cryptoSetup) handleTransportParameters(data []byte) error {
var tp wire.TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
return err
}
h.peerParams = &tp
h.runner.OnReceivedParams(h.peerParams)
h.events = append(h.events, Event{Kind: EventReceivedTransportParameters, TransportParameters: h.peerParams})
return nil
}
@ -392,7 +377,7 @@ func (h *cryptoSetup) rejected0RTT() {
h.mutex.Unlock()
if had0RTTKeys {
h.runner.DropKeys(protocol.Encryption0RTT)
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
}
}
@ -414,11 +399,9 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case qtls.QUICEncryptionLevelHandshake:
h.handshakeOpener = newHandshakeOpener(
h.handshakeOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
h.dropInitialKeys,
h.perspective,
)
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
@ -433,7 +416,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr
panic("unexpected read encryption level")
}
h.mutex.Unlock()
h.runner.OnReceivedReadKeys()
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
}
@ -462,11 +445,9 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
// don't set used0RTT here. 0-RTT might still get rejected.
return
case qtls.QUICEncryptionLevelHandshake:
h.handshakeSealer = newHandshakeSealer(
h.handshakeSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
h.dropInitialKeys,
h.perspective,
)
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
@ -496,40 +477,34 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
}
// WriteRecord is called when TLS writes data
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) error {
h.mutex.Lock()
defer h.mutex.Unlock()
var str io.Writer
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) {
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
switch encLevel {
case qtls.QUICEncryptionLevelInitial:
// assume that the first WriteRecord call contains the ClientHello
str = h.initialStream
h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p})
case qtls.QUICEncryptionLevelHandshake:
str = h.handshakeStream
h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p})
case qtls.QUICEncryptionLevelApplication:
str = h.oneRTTStream
panic("unexpected write")
default:
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
}
_, err := str.Write(p)
return err
}
// used a callback in the handshakeSealer and handshakeOpener
func (h *cryptoSetup) dropInitialKeys() {
func (h *cryptoSetup) DiscardInitialKeys() {
h.mutex.Lock()
dropped := h.initialOpener != nil
h.initialOpener = nil
h.initialSealer = nil
h.mutex.Unlock()
h.runner.DropKeys(protocol.EncryptionInitial)
h.logger.Debugf("Dropping Initial keys.")
if dropped {
h.logger.Debugf("Dropping Initial keys.")
}
}
func (h *cryptoSetup) handshakeComplete() {
h.handshakeCompleteTime = time.Now()
h.runner.OnHandshakeComplete()
h.events = append(h.events, Event{Kind: EventHandshakeComplete})
}
func (h *cryptoSetup) SetHandshakeConfirmed() {
@ -544,7 +519,6 @@ func (h *cryptoSetup) SetHandshakeConfirmed() {
}
h.mutex.Unlock()
if dropped {
h.runner.DropKeys(protocol.EncryptionHandshake)
h.logger.Debugf("Dropping Handshake keys.")
}
}

View file

@ -27,46 +27,9 @@ const (
typeNewSessionTicket = 4
)
type chunk struct {
data []byte
encLevel protocol.EncryptionLevel
}
type stream struct {
encLevel protocol.EncryptionLevel
chunkChan chan<- chunk
}
func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream {
return &stream{
chunkChan: chunkChan,
encLevel: encLevel,
}
}
func (s *stream) Write(b []byte) (int, error) {
data := make([]byte, len(b))
copy(data, b)
select {
case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
default:
panic("chunkChan too small")
}
return len(b), nil
}
var _ = Describe("Crypto Setup TLS", func() {
var clientConf, serverConf *tls.Config
// unparam incorrectly complains that the first argument is never used.
//nolint:unparam
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 100)
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
return chunkChan, initialStream, handshakeStream
}
BeforeEach(func() {
serverConf = testdata.GetTLSConfig()
serverConf.NextProtos = []string{"crypto-setup"}
@ -78,17 +41,12 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("handles qtls errors occurring before during ClientHello generation", func() {
_, sInitialStream, sHandshakeStream := initStreams()
tlsConf := testdata.GetTLSConfig()
tlsConf.InsecureSkipVerify = true
tlsConf.NextProtos = []string{""}
cl, _ := NewCryptoSetupClient(
sInitialStream,
sHandshakeStream,
nil,
cl := NewCryptoSetupClient(
protocol.ConnectionID{},
&wire.TransportParameters{},
NewMockHandshakeRunner(mockCtrl),
tlsConf,
false,
&utils.RTTStats{},
@ -104,16 +62,10 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("errors when a message is received at the wrong encryption level", func() {
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
&utils.RTTStats{},
@ -158,32 +110,73 @@ var _ = Describe("Crypto Setup TLS", func() {
return rttStats
}
handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) {
// The clientEvents and serverEvents contain all events that were not processed by the function,
// i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete.
handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) {
Expect(client.StartHandshake()).To(Succeed())
Expect(server.StartHandshake()).To(Succeed())
for {
select {
case c := <-cChunkChan:
Expect(server.HandleMessage(c.data, c.encLevel)).To(Succeed())
continue
default:
}
select {
case c := <-sChunkChan:
Expect(client.HandleMessage(c.data, c.encLevel)).To(Succeed())
continue
default:
}
// no more messages to send from client and server. Handshake complete?
break
}
var clientHandshakeComplete, serverHandshakeComplete bool
ticket, err := server.GetSessionTicket()
Expect(err).ToNot(HaveOccurred())
if ticket != nil {
Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
for {
clientLoop:
for {
ev := client.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case EventNoEvent:
break clientLoop
case EventWriteInitialData:
if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
serverErr = err
return
}
case EventWriteHandshakeData:
if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
serverErr = err
return
}
case EventHandshakeComplete:
clientHandshakeComplete = true
default:
clientEvents = append(clientEvents, ev)
}
}
serverLoop:
for {
ev := server.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case EventNoEvent:
break serverLoop
case EventWriteInitialData:
if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
clientErr = err
return
}
case EventWriteHandshakeData:
if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
clientErr = err
return
}
case EventHandshakeComplete:
serverHandshakeComplete = true
ticket, err := server.GetSessionTicket()
Expect(err).ToNot(HaveOccurred())
if ticket != nil {
Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
}
default:
serverEvents = append(serverEvents, ev)
}
}
if clientHandshakeComplete && serverHandshakeComplete {
break
}
}
return
}
handshakeWithTLSConf := func(
@ -191,22 +184,12 @@ var _ = Describe("Crypto Setup TLS", func() {
clientRTTStats, serverRTTStats *utils.RTTStats,
clientTransportParameters, serverTransportParameters *wire.TransportParameters,
enable0RTT bool,
) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
var cHandshakeComplete bool
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cErrChan := make(chan error, 1)
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
client, clientHelloWrittenChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */
CryptoSetup /* server */, []Event /* more server events */, error, /* server error */
) {
client := NewCryptoSetupClient(
protocol.ConnectionID{},
clientTransportParameters,
cRunner,
clientConf,
enable0RTT,
clientRTTStats,
@ -215,24 +198,13 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1)
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
if serverTransportParameters.StatelessResetToken == nil {
var token protocol.StatelessResetToken
serverTransportParameters.StatelessResetToken = &token
}
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
serverTransportParameters,
sRunner,
serverConf,
enable0RTT,
serverRTTStats,
@ -240,24 +212,12 @@ var _ = Describe("Crypto Setup TLS", func() {
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
handshake(client, cChunkChan, server, sChunkChan)
var cErr, sErr error
select {
case sErr = <-sErrChan:
default:
Expect(sHandshakeComplete).To(BeTrue())
}
select {
case cErr = <-cErrChan:
default:
Expect(cHandshakeComplete).To(BeTrue())
}
return clientHelloWrittenChan, client, cErr, server, sErr
cEvents, cErr, sEvents, sErr := handshake(client, server)
return client, cEvents, cErr, server, sEvents, sErr
}
It("handshakes", func() {
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -269,7 +229,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("performs a HelloRetryRequst", func() {
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -282,7 +242,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("handshakes with client auth", func() {
clientConf.Certificates = []tls.Certificate{generateCert()}
serverConf.ClientAuth = tls.RequireAnyClientCert
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -292,50 +252,11 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(serverErr).ToNot(HaveOccurred())
})
It("signals when it has written the ClientHello", func() {
runner := NewMockHandshakeRunner(mockCtrl)
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, chChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
&wire.TransportParameters{},
runner,
&tls.Config{InsecureSkipVerify: true},
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.Version1,
)
Expect(client.StartHandshake()).To(Succeed())
var ch chunk
Eventually(cChunkChan).Should(Receive(&ch))
Eventually(chChan).Should(Receive(BeNil()))
// make sure the whole ClientHello was written
Expect(len(ch.data)).To(BeNumerically(">=", 4))
Expect(ch.data[0]).To(BeEquivalentTo(typeClientHello))
length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
Expect(len(ch.data) - 4).To(Equal(length))
})
It("receives transport parameters", func() {
var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 0x42 * time.Second}
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp })
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second}
client := NewCryptoSetupClient(
protocol.ConnectionID{},
cTransportParameters,
cRunner,
clientConf,
false,
&utils.RTTStats{},
@ -344,24 +265,15 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
var token protocol.StatelessResetToken
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp })
sRunner.EXPECT().OnHandshakeComplete()
sTransportParameters := &wire.TransportParameters{
MaxIdleTimeout: 0x1337 * time.Second,
MaxIdleTimeout: 1337 * time.Second,
StatelessResetToken: &token,
ActiveConnectionIDLimit: 2,
}
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
sTransportParameters,
sRunner,
serverConf,
false,
&utils.RTTStats{},
@ -370,68 +282,38 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout))
Expect(sTransportParametersRcvd).ToNot(BeNil())
Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout))
clientEvents, cErr, serverEvents, sErr := handshake(client, server)
Expect(cErr).ToNot(HaveOccurred())
Expect(sErr).ToNot(HaveOccurred())
var clientReceivedTransportParameters *wire.TransportParameters
for _, ev := range clientEvents {
if ev.Kind == EventReceivedTransportParameters {
clientReceivedTransportParameters = ev.TransportParameters
}
}
Expect(clientReceivedTransportParameters).ToNot(BeNil())
Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second))
var serverReceivedTransportParameters *wire.TransportParameters
for _, ev := range serverEvents {
if ev.Kind == EventReceivedTransportParameters {
serverReceivedTransportParameters = ev.TransportParameters
}
}
Expect(serverReceivedTransportParameters).ToNot(BeNil())
Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second))
})
Context("with session tickets", func() {
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
cRunner,
clientConf,
client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.Version1,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
sRunner.EXPECT().OnHandshakeComplete()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
// inject an invalid session ticket
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
@ -441,54 +323,14 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("errors when handling the NewSessionTicket fails", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnReceivedReadKeys().Times(2)
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
nil,
protocol.ConnectionID{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
cRunner,
clientConf,
client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.Version1,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnReceivedReadKeys().Times(2)
sRunner.EXPECT().OnHandshakeComplete()
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
nil,
protocol.ConnectionID{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
// inject an invalid session ticket
b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
@ -509,7 +351,7 @@ var _ = Describe("Crypto Setup TLS", func() {
clientConf.ClientSessionCache = csc
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -520,12 +362,11 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -537,7 +378,6 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
})
It("doesn't use session resumption if the server disabled it", func() {
@ -550,7 +390,7 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
_, client, clientErr, server, serverErr := handshakeWithTLSConf(
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -564,7 +404,7 @@ var _ = Describe("Crypto Setup TLS", func() {
serverConf.SessionTicketsDisabled = true
csc.EXPECT().Get(gomock.Any()).Return(state, true)
_, client, clientErr, server, serverErr = handshakeWithTLSConf(
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -592,7 +432,7 @@ var _ = Describe("Crypto Setup TLS", func() {
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, serverOrigRTTStats,
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -604,14 +444,13 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}
serverRTTStats := &utils.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, serverRTTStats,
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -624,9 +463,30 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
var tp *wire.TransportParameters
Expect(clientHelloWrittenChan).To(Receive(&tp))
var clientReceived0RTTKeys bool
for _, ev := range clientEvents {
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case EventRestoredTransportParameters:
tp = ev.TransportParameters
case EventReceivedReadKeys:
clientReceived0RTTKeys = true
}
}
Expect(clientReceived0RTTKeys).To(BeTrue())
Expect(tp).ToNot(BeNil())
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
var serverReceived0RTTKeys bool
for _, ev := range serverEvents {
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case EventReceivedReadKeys:
serverReceived0RTTKeys = true
}
}
Expect(serverReceived0RTTKeys).To(BeTrue())
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(server.ConnectionState().Used0RTT).To(BeTrue())
@ -646,7 +506,7 @@ var _ = Describe("Crypto Setup TLS", func() {
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -658,13 +518,12 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &utils.RTTStats{},
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
@ -676,7 +535,18 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
var tp *wire.TransportParameters
Expect(clientHelloWrittenChan).To(Receive(&tp))
var clientReceived0RTTKeys bool
for _, ev := range clientEvents {
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case EventRestoredTransportParameters:
tp = ev.TransportParameters
case EventReceivedReadKeys:
clientReceived0RTTKeys = true
}
}
Expect(clientReceived0RTTKeys).To(BeTrue())
Expect(tp).ToNot(BeNil())
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
Expect(server.ConnectionState().DidResume).To(BeTrue())

View file

@ -53,18 +53,42 @@ type ShortHeaderSealer interface {
KeyPhase() protocol.KeyPhaseBit
}
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnReceivedReadKeys()
DropKeys(protocol.EncryptionLevel)
}
type ConnectionState struct {
tls.ConnectionState
Used0RTT bool
}
// EventKind is the kind of handshake event.
type EventKind uint8
const (
// EventNoEvent signals that there are no new handshake events
EventNoEvent EventKind = iota + 1
// EventWriteInitialData contains new CRYPTO data to send at the Initial encryption level
EventWriteInitialData
// EventWriteHandshakeData contains new CRYPTO data to send at the Handshake encryption level
EventWriteHandshakeData
// EventReceivedReadKeys signals that new decryption keys are available.
// It doesn't say which encryption level those keys are for.
EventReceivedReadKeys
// EventDiscard0RTTKeys signals that the Handshake keys were discarded.
EventDiscard0RTTKeys
// EventReceivedTransportParameters contains the transport parameters sent by the peer.
EventReceivedTransportParameters
// EventRestoredTransportParameters contains the transport parameters restored from the session ticket.
// It is only used for the client.
EventRestoredTransportParameters
// EventHandshakeComplete signals that the TLS handshake was completed.
EventHandshakeComplete
)
// Event is a handshake event.
type Event struct {
Kind EventKind
Data []byte
TransportParameters *wire.TransportParameters
}
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
StartHandshake() error
@ -73,7 +97,10 @@ type CryptoSetup interface {
GetSessionTicket() ([]byte, error)
HandleMessage([]byte, protocol.EncryptionLevel) error
NextEvent() Event
SetLargest1RTTAcked(protocol.PacketNumber) error
DiscardInitialKeys()
SetHandshakeConfirmed()
ConnectionState() ConnectionState

View file

@ -1,84 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/quic-go/quic-go/internal/handshake (interfaces: HandshakeRunner)
// Package handshake is a generated GoMock package.
package handshake
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/quic-go/quic-go/internal/protocol"
wire "github.com/quic-go/quic-go/internal/wire"
)
// MockHandshakeRunner is a mock of HandshakeRunner interface.
type MockHandshakeRunner struct {
ctrl *gomock.Controller
recorder *MockHandshakeRunnerMockRecorder
}
// MockHandshakeRunnerMockRecorder is the mock recorder for MockHandshakeRunner.
type MockHandshakeRunnerMockRecorder struct {
mock *MockHandshakeRunner
}
// NewMockHandshakeRunner creates a new mock instance.
func NewMockHandshakeRunner(ctrl *gomock.Controller) *MockHandshakeRunner {
mock := &MockHandshakeRunner{ctrl: ctrl}
mock.recorder = &MockHandshakeRunnerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockHandshakeRunner) EXPECT() *MockHandshakeRunnerMockRecorder {
return m.recorder
}
// DropKeys mocks base method.
func (m *MockHandshakeRunner) DropKeys(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DropKeys", arg0)
}
// DropKeys indicates an expected call of DropKeys.
func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0)
}
// OnHandshakeComplete mocks base method.
func (m *MockHandshakeRunner) OnHandshakeComplete() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnHandshakeComplete")
}
// OnHandshakeComplete indicates an expected call of OnHandshakeComplete.
func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockHandshakeRunner)(nil).OnHandshakeComplete))
}
// OnReceivedParams mocks base method.
func (m *MockHandshakeRunner) OnReceivedParams(arg0 *wire.TransportParameters) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnReceivedParams", arg0)
}
// OnReceivedParams indicates an expected call of OnReceivedParams.
func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0)
}
// OnReceivedReadKeys mocks base method.
func (m *MockHandshakeRunner) OnReceivedReadKeys() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnReceivedReadKeys")
}
// OnReceivedReadKeys indicates an expected call of OnReceivedReadKeys.
func (mr *MockHandshakeRunnerMockRecorder) OnReceivedReadKeys() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedReadKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedReadKeys))
}

View file

@ -1,6 +0,0 @@
//go:build gomock || generate
package handshake
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package handshake -destination mock_handshake_runner_test.go github.com/quic-go/quic-go/internal/handshake HandshakeRunner"
type HandshakeRunner = handshakeRunner

View file

@ -75,6 +75,18 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
}
// DiscardInitialKeys mocks base method.
func (m *MockCryptoSetup) DiscardInitialKeys() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DiscardInitialKeys")
}
// DiscardInitialKeys indicates an expected call of DiscardInitialKeys.
func (mr *MockCryptoSetupMockRecorder) DiscardInitialKeys() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscardInitialKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DiscardInitialKeys))
}
// Get0RTTOpener mocks base method.
func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) {
m.ctrl.T.Helper()
@ -224,6 +236,20 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
}
// NextEvent mocks base method.
func (m *MockCryptoSetup) NextEvent() handshake.Event {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextEvent")
ret0, _ := ret[0].(handshake.Event)
return ret0
}
// NextEvent indicates an expected call of NextEvent.
func (mr *MockCryptoSetupMockRecorder) NextEvent() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoSetup)(nil).NextEvent))
}
// SetHandshakeConfirmed mocks base method.
func (m *MockCryptoSetup) SetHandshakeConfirmed() {
m.ctrl.T.Helper()

View file

@ -8,6 +8,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
handshake "github.com/quic-go/quic-go/internal/handshake"
protocol "github.com/quic-go/quic-go/internal/protocol"
)
@ -47,3 +48,17 @@ func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1)
}
// NextEvent mocks base method.
func (m *MockCryptoDataHandler) NextEvent() handshake.Event {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextEvent")
ret0, _ := ret[0].(handshake.Event)
return ret0
}
// NextEvent indicates an expected call of NextEvent.
func (mr *MockCryptoDataHandlerMockRecorder) NextEvent() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoDataHandler)(nil).NextEvent))
}