From 469a6153b6798780bf90e8be37f197d0fcb18af2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 21 Jul 2023 10:00:42 -0700 Subject: [PATCH] use a synchronous API for the crypto setup (#3939) --- connection.go | 245 +++++----- connection_test.go | 204 ++++++--- crypto_stream_manager.go | 4 +- crypto_stream_manager_test.go | 4 - fuzzing/handshake/cmd/corpus.go | 132 ++---- fuzzing/handshake/fuzz.go | 191 +++----- .../ackhandler/received_packet_handler.go | 5 + internal/ackhandler/sent_packet_handler.go | 26 +- .../ackhandler/sent_packet_handler_test.go | 64 +-- internal/handshake/aead.go | 66 --- internal/handshake/aead_test.go | 67 --- internal/handshake/crypto_setup.go | 120 ++--- internal/handshake/crypto_setup_test.go | 428 ++++++------------ internal/handshake/interface.go | 41 +- .../handshake/mock_handshake_runner_test.go | 84 ---- internal/handshake/mockgen.go | 6 - internal/mocks/crypto_setup.go | 26 ++ mock_crypto_data_handler_test.go | 15 + 18 files changed, 696 insertions(+), 1032 deletions(-) delete mode 100644 internal/handshake/mock_handshake_runner_test.go delete mode 100644 internal/handshake/mockgen.go diff --git a/connection.go b/connection.go index 4823cd46..bb4d691e 100644 --- a/connection.go +++ b/connection.go @@ -57,6 +57,8 @@ type cryptoStreamHandler interface { SetLargest1RTTAcked(protocol.PacketNumber) error SetHandshakeConfirmed() GetSessionTicket() ([]byte, error) + NextEvent() handshake.Event + DiscardInitialKeys() io.Closer ConnectionState() handshake.ConnectionState } @@ -96,18 +98,6 @@ type connRunner interface { RemoveResetToken(protocol.StatelessResetToken) } -type handshakeRunner struct { - onReceivedParams func(*wire.TransportParameters) - onReceivedReadKeys func() - dropKeys func(protocol.EncryptionLevel) - onHandshakeComplete func() -} - -func (r *handshakeRunner) OnReceivedParams(tp *wire.TransportParameters) { r.onReceivedParams(tp) } -func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) } -func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() } -func (r *handshakeRunner) OnReceivedReadKeys() { r.onReceivedReadKeys() } - type closeError struct { err error remote bool @@ -165,6 +155,8 @@ type connection struct { packer packer mtuDiscoverer mtuDiscoverer // initialized when the handshake completes + initialStream cryptoStream + handshakeStream cryptoStream oneRTTStream cryptoStream // only set for the server cryptoStreamHandler cryptoStreamHandler @@ -183,12 +175,10 @@ type connection struct { undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level undecryptablePacketsToProcess []receivedPacket - clientHelloWritten <-chan *wire.TransportParameters - earlyConnReadyChan chan struct{} - handshakeCompleteChan chan struct{} // is closed when the handshake completes - sentFirstPacket bool - handshakeComplete bool - handshakeConfirmed bool + earlyConnReadyChan chan struct{} + sentFirstPacket bool + handshakeComplete bool + handshakeConfirmed bool receivedRetry bool versionNegotiated bool @@ -248,17 +238,16 @@ var newConnection = func( v protocol.VersionNumber, ) quicConn { s := &connection{ - conn: conn, - config: conf, - handshakeDestConnID: destConnID, - srcConnIDLen: srcConnID.Len(), - tokenGenerator: tokenGenerator, - oneRTTStream: newCryptoStream(), - perspective: protocol.PerspectiveServer, - handshakeCompleteChan: make(chan struct{}), - tracer: tracer, - logger: logger, - version: v, + conn: conn, + config: conf, + handshakeDestConnID: destConnID, + srcConnIDLen: srcConnID.Len(), + tokenGenerator: tokenGenerator, + oneRTTStream: newCryptoStream(), + perspective: protocol.PerspectiveServer, + tracer: tracer, + logger: logger, + version: v, } if origDestConnID.Len() > 0 { s.logID = origDestConnID.String() @@ -294,8 +283,6 @@ var newConnection = func( s.logger, ) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - initialStream := newCryptoStream() - handshakeStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -327,20 +314,8 @@ var newConnection = func( s.tracer.SentTransportParameters(params) } cs := handshake.NewCryptoSetupServer( - initialStream, - handshakeStream, - s.oneRTTStream, clientDestConnID, params, - &handshakeRunner{ - onReceivedParams: s.handleTransportParameters, - dropKeys: s.dropEncryptionLevel, - onReceivedReadKeys: s.receivedReadKeys, - onHandshakeComplete: func() { - runner.Retire(clientDestConnID) - close(s.handshakeCompleteChan) - }, - }, tlsConf, conf.Allow0RTT, s.rttStats, @@ -349,9 +324,9 @@ var newConnection = func( s.version, ) s.cryptoStreamHandler = cs - s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) + s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) - s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) + s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, s.oneRTTStream) return s } @@ -373,18 +348,17 @@ var newClientConnection = func( v protocol.VersionNumber, ) quicConn { s := &connection{ - conn: conn, - config: conf, - origDestConnID: destConnID, - handshakeDestConnID: destConnID, - srcConnIDLen: srcConnID.Len(), - perspective: protocol.PerspectiveClient, - handshakeCompleteChan: make(chan struct{}), - logID: destConnID.String(), - logger: logger, - tracer: tracer, - versionNegotiated: hasNegotiatedVersion, - version: v, + conn: conn, + config: conf, + origDestConnID: destConnID, + handshakeDestConnID: destConnID, + srcConnIDLen: srcConnID.Len(), + perspective: protocol.PerspectiveClient, + logID: destConnID.String(), + logger: logger, + tracer: tracer, + versionNegotiated: hasNegotiatedVersion, + version: v, } s.connIDManager = newConnIDManager( destConnID, @@ -415,8 +389,6 @@ var newClientConnection = func( s.logger, ) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - initialStream := newCryptoStream() - handshakeStream := newCryptoStream() oneRTTStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -445,18 +417,9 @@ var newClientConnection = func( if s.tracer != nil { s.tracer.SentTransportParameters(params) } - cs, clientHelloWritten := handshake.NewCryptoSetupClient( - initialStream, - handshakeStream, - oneRTTStream, + cs := handshake.NewCryptoSetupClient( destConnID, params, - &handshakeRunner{ - onReceivedParams: s.handleTransportParameters, - dropKeys: s.dropEncryptionLevel, - onReceivedReadKeys: s.receivedReadKeys, - onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, - }, tlsConf, enable0RTT, s.rttStats, @@ -464,11 +427,10 @@ var newClientConnection = func( logger, s.version, ) - s.clientHelloWritten = clientHelloWritten s.cryptoStreamHandler = cs - s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) + s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) - s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) + s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) if len(tlsConf.ServerName) > 0 { s.tokenStoreKey = tlsConf.ServerName } else { @@ -483,6 +445,8 @@ var newClientConnection = func( } func (s *connection) preSetup() { + s.initialStream = newCryptoStream() + s.handshakeStream = newCryptoStream() s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue() s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) @@ -535,6 +499,9 @@ func (s *connection) run() error { if err := s.cryptoStreamHandler.StartHandshake(); err != nil { return err } + if err := s.handleHandshakeEvents(); err != nil { + return err + } go func() { if err := s.sendQueue.Run(); err != nil { s.destroyImpl(err) @@ -542,17 +509,7 @@ func (s *connection) run() error { }() if s.perspective == protocol.PerspectiveClient { - select { - case zeroRTTParams := <-s.clientHelloWritten: - s.scheduleSending() - if zeroRTTParams != nil { - s.restoreTransportParameters(zeroRTTParams) - close(s.earlyConnReadyChan) - } - case closeErr := <-s.closeChan: - // put the close error back into the channel, so that the run loop can receive it - s.closeChan <- closeErr - } + s.scheduleSending() // so the ClientHello actually gets sent } var sendQueueAvailable <-chan struct{} @@ -563,8 +520,6 @@ runLoop: select { case closeErr = <-s.closeChan: break runLoop - case <-s.handshakeCompleteChan: - s.handleHandshakeComplete() default: } @@ -635,8 +590,6 @@ runLoop: if !wasProcessed { continue } - case <-s.handshakeCompleteChan: - s.handleHandshakeComplete() } } @@ -762,9 +715,8 @@ func (s *connection) idleTimeoutStartTime() time.Time { return utils.MaxTime(s.lastPacketReceivedTime, s.firstAckElicitingPacketAfterIdleSentTime) } -func (s *connection) handleHandshakeComplete() { +func (s *connection) handleHandshakeComplete() error { s.handshakeComplete = true - s.handshakeCompleteChan = nil // prevent this case from ever being selected again defer s.handshakeCtxCancel() // Once the handshake completes, we have derived 1-RTT keys. // There's no point in queueing undecryptable packets for later decryption any more. @@ -775,14 +727,16 @@ func (s *connection) handleHandshakeComplete() { if s.perspective == protocol.PerspectiveClient { s.applyTransportParameters() - return + return nil } - s.handleHandshakeConfirmed() + if err := s.handleHandshakeConfirmed(); err != nil { + return err + } ticket, err := s.cryptoStreamHandler.GetSessionTicket() if err != nil { - s.closeLocal(err) + return err } if ticket != nil { // may be nil if session tickets are disabled via tls.Config.SessionTicketsDisabled s.oneRTTStream.Write(ticket) @@ -792,13 +746,18 @@ func (s *connection) handleHandshakeComplete() { } token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) if err != nil { - s.closeLocal(err) + return err } s.queueControlFrame(&wire.NewTokenFrame{Token: token}) s.queueControlFrame(&wire.HandshakeDoneFrame{}) + return nil } -func (s *connection) handleHandshakeConfirmed() { +func (s *connection) handleHandshakeConfirmed() error { + if err := s.dropEncryptionLevel(protocol.EncryptionHandshake); err != nil { + return err + } + s.handshakeConfirmed = true s.sentPacketHandler.SetHandshakeConfirmed() s.cryptoStreamHandler.SetHandshakeConfirmed() @@ -810,6 +769,7 @@ func (s *connection) handleHandshakeConfirmed() { } s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize)) } + return nil } func (s *connection) handlePacketImpl(rp receivedPacket) bool { @@ -1211,6 +1171,14 @@ func (s *connection) handleUnpackedLongHeaderPacket( } } + if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake { + // On the server side, Initial keys are dropped as soon as the first Handshake packet is received. + // See Section 4.9.1 of RFC 9001. + if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { + return err + } + } + s.lastPacketReceivedTime = rcvTime s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.keepAlivePingSent = false @@ -1376,13 +1344,41 @@ func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame } func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { - return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) + if err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil { + return err + } + return s.handleHandshakeEvents() } -func (s *connection) receivedReadKeys() { - // Queue all packets for decryption that have been undecryptable so far. - s.undecryptablePacketsToProcess = s.undecryptablePackets - s.undecryptablePackets = nil +func (s *connection) handleHandshakeEvents() error { + for { + ev := s.cryptoStreamHandler.NextEvent() + var err error + switch ev.Kind { + case handshake.EventNoEvent: + return nil + case handshake.EventHandshakeComplete: + err = s.handleHandshakeComplete() + case handshake.EventReceivedTransportParameters: + err = s.handleTransportParameters(ev.TransportParameters) + case handshake.EventRestoredTransportParameters: + s.restoreTransportParameters(ev.TransportParameters) + close(s.earlyConnReadyChan) + case handshake.EventReceivedReadKeys: + // Queue all packets for decryption that have been undecryptable so far. + s.undecryptablePacketsToProcess = s.undecryptablePackets + s.undecryptablePackets = nil + case handshake.EventDiscard0RTTKeys: + err = s.dropEncryptionLevel(protocol.Encryption0RTT) + case handshake.EventWriteInitialData: + _, err = s.initialStream.Write(ev.Data) + case handshake.EventWriteHandshakeData: + _, err = s.handshakeStream.Write(ev.Data) + } + if err != nil { + return err + } + } } func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error { @@ -1491,7 +1487,9 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr return nil } if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed { - s.handleHandshakeConfirmed() + if err := s.handleHandshakeConfirmed(); err != nil { + return err + } } return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } @@ -1623,25 +1621,24 @@ func (s *connection) handleCloseError(closeErr *closeError) { s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket) } -func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { +func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { if s.tracer != nil { s.tracer.DroppedEncryptionLevel(encLevel) } s.sentPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel) - if err := s.cryptoStreamManager.Drop(encLevel); err != nil { - s.closeLocal(err) - return - } - if encLevel == protocol.Encryption0RTT { + //nolint:exhaustive // only Initial and 0-RTT need special treatment + switch encLevel { + case protocol.EncryptionInitial: + s.cryptoStreamHandler.DiscardInitialKeys() + case protocol.Encryption0RTT: s.streamsMap.ResetFor0RTT() if err := s.connFlowController.Reset(); err != nil { - s.closeLocal(err) - } - if err := s.framer.Handle0RTTRejection(); err != nil { - s.closeLocal(err) + return err } + return s.framer.Handle0RTTRejection() } + return s.cryptoStreamManager.Drop(encLevel) } // is called for the client, when restoring transport parameters saved for 0-RTT @@ -1659,13 +1656,12 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters s.connStateMutex.Unlock() } -func (s *connection) handleTransportParameters(params *wire.TransportParameters) { +func (s *connection) handleTransportParameters(params *wire.TransportParameters) error { if err := s.checkTransportParameters(params); err != nil { - s.closeLocal(&qerr.TransportError{ + return &qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: err.Error(), - }) - return + } } s.peerParams = params // On the client side we have to wait for handshake completion. @@ -1680,6 +1676,7 @@ func (s *connection) handleTransportParameters(params *wire.TransportParameters) s.connStateMutex.Lock() s.connState.SupportsDatagrams = s.supportsDatagrams() s.connStateMutex.Unlock() + return nil } func (s *connection) checkTransportParameters(params *wire.TransportParameters) error { @@ -1826,7 +1823,9 @@ func (s *connection) sendPackets(now time.Time) error { return err } s.sentFirstPacket = true - s.sendPackedCoalescedPacket(packet, now) + if err := s.sendPackedCoalescedPacket(packet, now); err != nil { + return err + } sendMode := s.sentPacketHandler.SendMode(now) if sendMode == ackhandler.SendPacingLimited { s.resetPacingDeadline() @@ -1946,8 +1945,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if packet == nil { return nil } - s.sendPackedCoalescedPacket(packet, time.Now()) - return nil + return s.sendPackedCoalescedPacket(packet, time.Now()) } p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) @@ -1991,8 +1989,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } - s.sendPackedCoalescedPacket(packet, now) - return nil + return s.sendPackedCoalescedPacket(packet, now) } // appendPacket appends a new packet to the given packetBuffer. @@ -2022,7 +2019,7 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti s.connIDManager.SentPacket() } -func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) { +func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error { s.logCoalescedPacket(packet) for _, p := range packet.longHdrPackets { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { @@ -2033,6 +2030,13 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time largestAcked = p.ack.LargestAcked() } s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false) + if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { + // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. + // See Section 4.9.1 of RFC 9001. + if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { + return err + } + } } if p := packet.shortHdrPacket; p != nil { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { @@ -2046,6 +2050,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time } s.connIDManager.SentPacket() s.sendQueue.Send(packet.buffer, packet.buffer.Len()) + return nil } func (s *connection) sendConnectionClose(e error) ([]byte, error) { diff --git a/connection_test.go b/connection_test.go index 9eb60251..20e9872b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -358,6 +358,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) Expect(conn.run()).To(MatchError(expectedErr)) }() Expect(conn.handleFrame(&wire.ConnectionCloseFrame{ @@ -386,6 +387,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) Expect(conn.run()).To(MatchError(testErr)) }() ccf := &wire.ConnectionCloseFrame{ @@ -434,6 +436,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) runErr <- conn.run() }() Eventually(areConnsRunning).Should(BeTrue()) @@ -815,6 +818,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() expectReplaceWithClosed() @@ -857,6 +861,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -892,6 +897,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -917,6 +923,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) err := conn.run() Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) @@ -941,6 +948,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) runErr <- conn.run() }() expectReplaceWithClosed() @@ -965,6 +973,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) err := conn.run() Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) @@ -1054,6 +1063,7 @@ var _ = Describe("Connection", func() { BeforeEach(func() { tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) }) + getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, receivedPacket) { hdr := &wire.ExtendedHeader{ Header: wire.Header{ @@ -1082,6 +1092,8 @@ var _ = Describe("Connection", func() { hdr: &wire.ExtendedHeader{Header: wire.Header{}}, }, nil }) + cryptoSetup.EXPECT().DiscardInitialKeys() + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) @@ -1111,6 +1123,8 @@ var _ = Describe("Connection", func() { }, }, nil }) + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial).AnyTimes() + cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes() gomock.InOrder( tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), @@ -1134,6 +1148,8 @@ var _ = Describe("Connection", func() { }, nil }), ) + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial).AnyTimes() + cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes() gomock.InOrder( tracer.EXPECT().BufferedPacket(gomock.Any(), protocol.ByteCount(len(packet1.data))), tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), @@ -1158,6 +1174,8 @@ var _ = Describe("Connection", func() { }, nil }) _, packet2 := getPacketWithLength(wrongConnID, 123) + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial).AnyTimes() + cryptoSetup.EXPECT().DiscardInitialKeys().AnyTimes() // don't EXPECT any more calls to unpacker.UnpackLongHeader() gomock.InOrder( tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any()), @@ -1201,6 +1219,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() close(connDone) }() @@ -1419,6 +1438,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1443,6 +1463,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1467,6 +1488,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1483,6 +1505,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1500,6 +1523,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1518,6 +1542,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1544,6 +1569,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1566,6 +1592,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1584,6 +1611,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1606,6 +1634,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() @@ -1637,6 +1666,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() available := make(chan struct{}, 1) @@ -1668,6 +1698,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() // no packet will get sent @@ -1691,6 +1722,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() conn.scheduleSending() @@ -1738,6 +1770,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() // don't EXPECT any calls to mconn.Write() @@ -1772,6 +1805,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() Eventually(written).Should(BeClosed()) @@ -1836,6 +1870,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() @@ -1859,18 +1894,21 @@ var _ = Describe("Connection", func() { finishHandshake := make(chan struct{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).AnyTimes() + sph.EXPECT().DropPackets(protocol.EncryptionHandshake) sph.EXPECT().SetHandshakeConfirmed() connRunner.EXPECT().Retire(clientDestConnID) go func() { defer GinkgoRecover() <-finishHandshake cryptoSetup.EXPECT().StartHandshake() + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() - close(conn.handshakeCompleteChan) conn.run() }() handshakeCtx := conn.HandshakeComplete() @@ -1889,18 +1927,21 @@ var _ = Describe("Connection", func() { Eventually(conn.Context().Done()).Should(BeClosed()) }) - It("sends a connection ticket when the handshake completes", func() { + It("sends a session ticket when the handshake completes", func() { const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() finishHandshake := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) + conn.sentPacketHandler.DropPackets(protocol.EncryptionInitial) + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) go func() { defer GinkgoRecover() <-finishHandshake cryptoSetup.EXPECT().StartHandshake() + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) - close(conn.handshakeCompleteChan) conn.run() }() @@ -1945,6 +1986,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake() + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() handshakeCtx := conn.HandshakeComplete() @@ -1975,13 +2017,16 @@ var _ = Describe("Connection", func() { return shortHeaderPacket{}, nil }) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) + sph.EXPECT().DropPackets(protocol.EncryptionHandshake) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake() + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() mconn.EXPECT().Write(gomock.Any(), gomock.Any()) - close(conn.handshakeCompleteChan) conn.run() }() Eventually(done).Should(BeClosed()) @@ -2001,6 +2046,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) Expect(conn.run()).To(Succeed()) close(done) }() @@ -2026,6 +2072,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) err := conn.run() Expect(err).To(MatchError(expectedErr)) close(done) @@ -2076,6 +2123,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() } @@ -2178,6 +2226,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -2203,6 +2252,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -2236,6 +2286,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -2263,6 +2314,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) err := conn.run() nerr, ok := err.(net.Error) @@ -2275,6 +2327,7 @@ var _ = Describe("Connection", func() { }) It("closes the connection due to the idle timeout after handshake", func() { + conn.sentPacketHandler.DropPackets(protocol.EncryptionInitial) packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() gomock.InOrder( connRunner.EXPECT().Retire(clientDestConnID), @@ -2282,6 +2335,7 @@ var _ = Describe("Connection", func() { ) cryptoSetup.EXPECT().Close() gomock.InOrder( + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake), tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { Expect(e).To(MatchError(&IdleTimeoutError{})) }), @@ -2292,9 +2346,10 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) - close(conn.handshakeCompleteChan) err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -2312,6 +2367,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -2326,10 +2382,10 @@ var _ = Describe("Connection", func() { Eventually(conn.Context().Done()).Should(BeClosed()) }) - It("time out earliest after 3 times the PTO", func() { + It("times out earliest after 3 times the PTO", func() { packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() - connRunner.EXPECT().Retire(clientDestConnID) - connRunner.EXPECT().Remove(gomock.Any()) + connRunner.EXPECT().Retire(gomock.Any()).AnyTimes() + connRunner.EXPECT().Remove(gomock.Any()).Times(2) cryptoSetup.EXPECT().Close() closeTimeChan := make(chan time.Time) tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { @@ -2343,9 +2399,9 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) - close(conn.handshakeCompleteChan) conn.run() close(done) }() @@ -2448,15 +2504,12 @@ var _ = Describe("Client Connection", func() { b, err := hdr.Append(nil, conn.version) Expect(err).ToNot(HaveOccurred()) return receivedPacket{ - data: append(b, data...), - buffer: getPacketBuffer(), + rcvTime: time.Now(), + data: append(b, data...), + buffer: getPacketBuffer(), } } - expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) - } - BeforeEach(func() { quicConf = populateConfig(&Config{}) tlsConf = nil @@ -2512,11 +2565,8 @@ var _ = Describe("Client Connection", func() { }, nil }) conn.unpacker = unpacker - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) - conn.run() - }() + done := make(chan struct{}) + packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) { close(done) }) newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7}) p := getPacket(&wire.ExtendedHeader{ Header: wire.Header{ @@ -2530,15 +2580,23 @@ var _ = Describe("Client Connection", func() { }, []byte("foobar")) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), p.Size(), []logging.Frame{}) Expect(conn.handlePacketImpl(p)).To(BeTrue()) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) + conn.run() + }() + Eventually(done).Should(BeClosed()) // make sure the go routine returns packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) + mconn.EXPECT().Write(gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.shutdown() Eventually(conn.Context().Done()).Should(BeClosed()) + time.Sleep(200 * time.Millisecond) }) It("continues accepting Long Header packets after using a new connection ID", func() { @@ -2572,6 +2630,8 @@ var _ = Describe("Client Connection", func() { conn.peerParams = &wire.TransportParameters{} sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) + sph.EXPECT().DropPackets(protocol.EncryptionHandshake) sph.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().SetHandshakeConfirmed() Expect(conn.handleHandshakeDoneFrame()).To(Succeed()) @@ -2582,7 +2642,9 @@ var _ = Describe("Client Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 3}}} + tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) sph.EXPECT().ReceivedAck(ack, protocol.Encryption1RTT, gomock.Any()).Return(true, nil) + sph.EXPECT().DropPackets(protocol.EncryptionHandshake) sph.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().SetLargest1RTTAcked(protocol.PacketNumber(3)) cryptoSetup.EXPECT().SetHandshakeConfirmed() @@ -2598,6 +2660,7 @@ var _ = Describe("Client Connection", func() { close(running) conn.closeLocal(errors.New("early error")) }) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().Close() connRunner.EXPECT().Remove(gomock.Any()) go func() { @@ -2633,8 +2696,9 @@ var _ = Describe("Client Connection", func() { versions, ) return receivedPacket{ - data: b, - buffer: getPacketBuffer(), + rcvTime: time.Now(), + data: b, + buffer: getPacketBuffer(), } } @@ -2645,9 +2709,14 @@ var _ = Describe("Client Connection", func() { sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4) conn.config.Versions = []protocol.VersionNumber{1234, 4321} errChan := make(chan error, 1) + start := make(chan struct{}) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().DoAndReturn(func() handshake.Event { + <-start + return handshake.Event{Kind: handshake.EventNoEvent} + }) errChan <- conn.run() }() connRunner.EXPECT().Remove(srcConnID) @@ -2659,6 +2728,7 @@ var _ = Describe("Client Connection", func() { }) cryptoSetup.EXPECT().Close() Expect(conn.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse()) + close(start) var err error Eventually(errChan).Should(Receive(&err)) Expect(err).To(HaveOccurred()) @@ -2673,9 +2743,11 @@ var _ = Describe("Client Connection", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) errChan <- conn.run() }() connRunner.EXPECT().Remove(srcConnID).MaxTimes(1) + packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) gomock.InOrder( tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { @@ -2771,23 +2843,38 @@ var _ = Describe("Client Connection", func() { Context("transport parameters", func() { var ( - closed bool - errChan chan error + closed bool + errChan chan error + paramsChan chan *wire.TransportParameters ) JustBeforeEach(func() { errChan = make(chan error, 1) + paramsChan = make(chan *wire.TransportParameters, 1) closed = false + packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) + // This is not 100% what would happen in reality. + // The run loop calls NextEvent once when it starts up (to send out the ClientHello), + // and then again every time a CRYPTO frame is handled. + // Injecting a CRYPTO frame is not straightforward though, + // so we inject the transport parameters on the first call to NextEvent. + params := <-paramsChan + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{ + Kind: handshake.EventReceivedTransportParameters, + TransportParameters: params, + }) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}).MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}).MaxTimes(1) errChan <- conn.run() close(errChan) }() }) - expectClose := func(applicationClose bool) { - if !closed { + expectClose := func(applicationClose, errored bool) { + if !closed && !errored { connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) if applicationClose { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) @@ -2822,9 +2909,10 @@ var _ = Describe("Client Connection", func() { }, } packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).MaxTimes(1) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - conn.handleHandshakeComplete() + processed := make(chan struct{}) + tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) }) + paramsChan <- params + Eventually(processed).Should(BeClosed()) // make sure the connection ID is not retired cf, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount, protocol.Version1) Expect(cf).To(BeEmpty()) @@ -2832,7 +2920,7 @@ var _ = Describe("Client Connection", func() { Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) // shut down connRunner.EXPECT().RemoveResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) - expectClose(true) + expectClose(true, false) }) It("uses the minimum of the peers' idle timeouts", func() { @@ -2842,11 +2930,15 @@ var _ = Describe("Client Connection", func() { InitialSourceConnectionID: destConnID, MaxIdleTimeout: 18 * time.Second, } - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) - conn.handleHandshakeComplete() + processed := make(chan struct{}) + tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) }) + paramsChan <- params + Eventually(processed).Should(BeClosed()) + // close first + expectClose(true, false) + conn.shutdown() + // then check. Avoids race condition when accessing idleTimeout Expect(conn.idleTimeout).To(Equal(18 * time.Second)) - expectClose(true) }) It("errors if the transport parameters contain a wrong initial_source_connection_id", func() { @@ -2856,9 +2948,11 @@ var _ = Describe("Client Connection", func() { InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) + expectClose(false, true) + processed := make(chan struct{}) + tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) }) + paramsChan <- params + Eventually(processed).Should(BeClosed()) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "expected initial_source_connection_id to equal deadbeef, is decafbad", @@ -2873,9 +2967,11 @@ var _ = Describe("Client Connection", func() { InitialSourceConnectionID: destConnID, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) + expectClose(false, true) + processed := make(chan struct{}) + tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) }) + paramsChan <- params + Eventually(processed).Should(BeClosed()) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "missing retry_source_connection_id", @@ -2892,9 +2988,11 @@ var _ = Describe("Client Connection", func() { RetrySourceConnectionID: &rcid2, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) + expectClose(false, true) + processed := make(chan struct{}) + tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) }) + paramsChan <- params + Eventually(processed).Should(BeClosed()) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "expected retry_source_connection_id to equal deadbeef, is deadc0de", @@ -2909,9 +3007,11 @@ var _ = Describe("Client Connection", func() { RetrySourceConnectionID: &rcid, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) + expectClose(false, true) + processed := make(chan struct{}) + tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) }) + paramsChan <- params + Eventually(processed).Should(BeClosed()) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "received retry_source_connection_id, although no Retry was performed", @@ -2925,9 +3025,11 @@ var _ = Describe("Client Connection", func() { InitialSourceConnectionID: conn.handshakeDestConnID, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose(false) - tracer.EXPECT().ReceivedTransportParameters(params) - conn.handleTransportParameters(params) + expectClose(false, true) + processed := make(chan struct{}) + tracer.EXPECT().ReceivedTransportParameters(params).Do(func(*wire.TransportParameters) { close(processed) }) + paramsChan <- params + Eventually(processed).Should(BeClosed()) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "expected original_destination_connection_id to equal deadbeef, is decafbad", diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 8961965d..c48e238a 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -3,12 +3,14 @@ package quic import ( "fmt" + "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) type cryptoDataHandler interface { HandleMessage([]byte, protocol.EncryptionLevel) error + NextEvent() handshake.Event } type cryptoStreamManager struct { @@ -74,8 +76,6 @@ func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error { return m.initialStream.Finish() case protocol.EncryptionHandshake: return m.handshakeStream.Finish() - case protocol.Encryption0RTT: - return nil default: panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel)) } diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index f06a8b65..daffffe6 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -87,8 +87,4 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().Finish() Expect(csm.Drop(protocol.EncryptionHandshake)).To(Succeed()) }) - - It("no-ops when dropping 0-RTT", func() { - Expect(csm.Drop(protocol.Encryption0RTT)).To(Succeed()) - }) }) diff --git a/fuzzing/handshake/cmd/corpus.go b/fuzzing/handshake/cmd/corpus.go index 3963fc1d..a7e0196f 100644 --- a/fuzzing/handshake/cmd/corpus.go +++ b/fuzzing/handshake/cmd/corpus.go @@ -13,70 +13,12 @@ import ( "github.com/quic-go/quic-go/internal/wire" ) -type chunk struct { - data []byte - encLevel protocol.EncryptionLevel -} - -type stream struct { - chunkChan chan<- chunk - encLevel protocol.EncryptionLevel -} - -func (s *stream) Write(b []byte) (int, error) { - data := append([]byte{}, b...) - select { - case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: - default: - panic("chunkChan too small") - } - return len(b), nil -} - -func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) { - chunkChan := make(chan chunk, 10) - initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial} - handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake} - return chunkChan, initialStream, handshakeStream -} - -type handshakeRunner interface { - OnReceivedParams(*wire.TransportParameters) - OnHandshakeComplete() - OnReceivedReadKeys() - DropKeys(protocol.EncryptionLevel) -} - -type runner struct { - handshakeComplete chan<- struct{} -} - -var _ handshakeRunner = &runner{} - -func newRunner(handshakeComplete chan<- struct{}) *runner { - return &runner{handshakeComplete: handshakeComplete} -} - -func (r *runner) OnReceivedParams(*wire.TransportParameters) {} -func (r *runner) OnReceivedReadKeys() {} -func (r *runner) OnHandshakeComplete() { - close(r.handshakeComplete) -} -func (r *runner) DropKeys(protocol.EncryptionLevel) {} - const alpn = "fuzz" func main() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - var client, server handshake.CryptoSetup - clientHandshakeCompleted := make(chan struct{}) - client, _ = handshake.NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - nil, + client := handshake.NewCryptoSetupClient( protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, - newRunner(clientHandshakeCompleted), &tls.Config{ MinVersion: tls.VersionTLS13, ServerName: "localhost", @@ -91,17 +33,11 @@ func main() { protocol.Version1, ) - sChunkChan, sInitialStream, sHandshakeStream := initStreams() config := testdata.GetTLSConfig() config.NextProtos = []string{alpn} - serverHandshakeCompleted := make(chan struct{}) - server = handshake.NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - nil, + server := handshake.NewCryptoSetupServer( protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, - newRunner(serverHandshakeCompleted), config, false, utils.NewRTTStats(), @@ -118,29 +54,55 @@ func main() { log.Fatal(err) } - done := make(chan struct{}) - go func() { - <-serverHandshakeCompleted - <-clientHandshakeCompleted - close(done) - }() - + var clientHandshakeComplete, serverHandshakeComplete bool var messages [][]byte -messageLoop: for { - select { - case c := <-cChunkChan: - messages = append(messages, c.data) - if err := server.HandleMessage(c.data, c.encLevel); err != nil { - log.Fatal(err) + clientLoop: + for { + ev := client.NextEvent() + //nolint:exhaustive // only need to process a few events + switch ev.Kind { + case handshake.EventNoEvent: + break clientLoop + case handshake.EventWriteInitialData: + messages = append(messages, ev.Data) + if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { + log.Fatal(err) + } + case handshake.EventWriteHandshakeData: + messages = append(messages, ev.Data) + if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { + log.Fatal(err) + } + case handshake.EventHandshakeComplete: + clientHandshakeComplete = true } - case c := <-sChunkChan: - messages = append(messages, c.data) - if err := client.HandleMessage(c.data, c.encLevel); err != nil { - log.Fatal(err) + } + + serverLoop: + for { + ev := server.NextEvent() + //nolint:exhaustive // only need to process a few events + switch ev.Kind { + case handshake.EventNoEvent: + break serverLoop + case handshake.EventWriteInitialData: + messages = append(messages, ev.Data) + if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { + log.Fatal(err) + } + case handshake.EventWriteHandshakeData: + messages = append(messages, ev.Data) + if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { + log.Fatal(err) + } + case handshake.EventHandshakeComplete: + serverHandshakeComplete = true } - case <-done: - break messageLoop + } + + if serverHandshakeComplete && clientHandshakeComplete { + break } } diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index 5c8fcb1b..c89f6a9f 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -126,57 +126,6 @@ func getClientAuth(rand uint8) tls.ClientAuthType { } } -type chunk struct { - data []byte - encLevel protocol.EncryptionLevel -} - -type stream struct { - chunkChan chan<- chunk - encLevel protocol.EncryptionLevel -} - -func (s *stream) Write(b []byte) (int, error) { - data := append([]byte{}, b...) - select { - case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: - default: - panic("chunkChan too small") - } - return len(b), nil -} - -func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) { - chunkChan := make(chan chunk, 10) - initialStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionInitial} - handshakeStream := &stream{chunkChan: chunkChan, encLevel: protocol.EncryptionHandshake} - return chunkChan, initialStream, handshakeStream -} - -type handshakeRunner interface { - OnReceivedParams(*wire.TransportParameters) - OnHandshakeComplete() - OnReceivedReadKeys() - DropKeys(protocol.EncryptionLevel) -} - -type runner struct { - handshakeComplete chan<- struct{} -} - -var _ handshakeRunner = &runner{} - -func newRunner(handshakeComplete chan<- struct{}) *runner { - return &runner{handshakeComplete: handshakeComplete} -} - -func (r *runner) OnReceivedParams(*wire.TransportParameters) {} -func (r *runner) OnReceivedReadKeys() {} -func (r *runner) OnHandshakeComplete() { - close(r.handshakeComplete) -} -func (r *runner) DropKeys(protocol.EncryptionLevel) {} - const ( alpn = "fuzzing" alpnWrong = "wrong" @@ -193,28 +142,6 @@ func toEncryptionLevel(n uint8) protocol.EncryptionLevel { } } -func maxEncLevel(cs handshake.CryptoSetup, encLevel protocol.EncryptionLevel) protocol.EncryptionLevel { - //nolint:exhaustive - switch encLevel { - case protocol.EncryptionInitial: - return protocol.EncryptionInitial - case protocol.EncryptionHandshake: - // Handshake opener not available. We can't possibly read a Handshake handshake message. - if opener, err := cs.GetHandshakeOpener(); err != nil || opener == nil { - return protocol.EncryptionInitial - } - return protocol.EncryptionHandshake - case protocol.Encryption1RTT: - // 1-RTT opener not available. We can't possibly read a post-handshake message. - if opener, err := cs.Get1RTTOpener(); err != nil || opener == nil { - return maxEncLevel(cs, protocol.EncryptionHandshake) - } - return protocol.Encryption1RTT - default: - panic("unexpected encryption level") - } -} - func getTransportParameters(seed uint8) *wire.TransportParameters { const maxVarInt = math.MaxUint64 / 4 r := mrand.New(mrand.NewSource(int64(seed))) @@ -357,16 +284,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. messageToReplace := messageConfig % 32 messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6) - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - var client, server handshake.CryptoSetup - clientHandshakeCompleted := make(chan struct{}) - client, _ = handshake.NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - nil, + if len(data) == 0 { + return -1 + } + + client := handshake.NewCryptoSetupClient( protocol.ConnectionID{}, clientTP, - newRunner(clientHandshakeCompleted), clientConf, enable0RTTClient, utils.NewRTTStats(), @@ -374,16 +298,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. utils.DefaultLogger.WithPrefix("client"), protocol.Version1, ) + if err := client.StartHandshake(); err != nil { + log.Fatal(err) + } - serverHandshakeCompleted := make(chan struct{}) - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - server = handshake.NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - nil, + server := handshake.NewCryptoSetupServer( protocol.ConnectionID{}, serverTP, - newRunner(serverHandshakeCompleted), serverConf, enable0RTTServer, utils.NewRTTStats(), @@ -391,57 +312,69 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. utils.DefaultLogger.WithPrefix("server"), protocol.Version1, ) - - if len(data) == 0 { - return -1 - } - - if err := client.StartHandshake(); err != nil { - log.Fatal(err) - } - if err := server.StartHandshake(); err != nil { log.Fatal(err) } - done := make(chan struct{}) - go func() { - <-serverHandshakeCompleted - <-clientHandshakeCompleted - close(done) - }() - -messageLoop: + var clientHandshakeComplete, serverHandshakeComplete bool for { - select { - case c := <-cChunkChan: - b := c.data - encLevel := c.encLevel - if len(b) > 0 && b[0] == messageToReplace { - fmt.Printf("replacing %s message to the server with %s\n", messageType(b[0]), messageType(data[0])) - b = data - encLevel = maxEncLevel(server, messageToReplaceEncLevel) + clientLoop: + for { + var processedEvent bool + ev := client.NextEvent() + //nolint:exhaustive // only need to process a few events + switch ev.Kind { + case handshake.EventNoEvent: + if !processedEvent && !clientHandshakeComplete { // handshake stuck + return 1 + } + break clientLoop + case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: + msg := ev.Data + if msg[0] == messageToReplace { + fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) + msg = data + } + if err := server.HandleMessage(msg, messageToReplaceEncLevel); err != nil { + return 1 + } + case handshake.EventHandshakeComplete: + clientHandshakeComplete = true } - if err := server.HandleMessage(b, encLevel); err != nil { - break messageLoop + processedEvent = true + } + + serverLoop: + for { + var processedEvent bool + ev := server.NextEvent() + //nolint:exhaustive // only need to process a few events + switch ev.Kind { + case handshake.EventNoEvent: + if !processedEvent && !serverHandshakeComplete { // handshake stuck + return 1 + } + break serverLoop + case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData: + msg := ev.Data + if msg[0] == messageToReplace { + fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel) + msg = data + } + if err := client.HandleMessage(msg, messageToReplaceEncLevel); err != nil { + return 1 + } + case handshake.EventHandshakeComplete: + serverHandshakeComplete = true } - case c := <-sChunkChan: - b := c.data - encLevel := c.encLevel - if len(b) > 0 && b[0] == messageToReplace { - fmt.Printf("replacing %s message to the client with %s\n", messageType(b[0]), messageType(data[0])) - b = data - encLevel = maxEncLevel(client, messageToReplaceEncLevel) - } - if err := client.HandleMessage(b, encLevel); err != nil { - break messageLoop - } - case <-done: // test done - break messageLoop + processedEvent = true + } + + if serverHandshakeComplete && clientHandshakeComplete { + break } } - <-done _ = client.ConnectionState() _ = server.ConnectionState() diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index 3675694f..e11ee1ee 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -47,6 +47,11 @@ func (h *receivedPacketHandler) ReceivedPacket( case protocol.EncryptionInitial: return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) case protocol.EncryptionHandshake: + // The Handshake packet number space might already have been dropped as a result + // of processing the CRYPTO frame that was contained in this packet. + if h.handshakePackets == nil { + return nil + } return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) case protocol.Encryption0RTT: if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 07978910..c955da6e 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -136,16 +136,6 @@ func newSentPacketHandler( } } -func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { - if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial { - // This function is called when the crypto setup seals a Handshake packet. - // If this Handshake packet is coalesced behind an Initial packet, we would drop the Initial packet number space - // before SentPacket() was called for that Initial packet. - return - } - h.dropPackets(encLevel) -} - func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { if p.includedInBytesInFlight { if p.Length > h.bytesInFlight { @@ -156,7 +146,7 @@ func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { } } -func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { +func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { // The server won't await address validation after the handshake is confirmed. // This applies even if we didn't receive an ACK for a Handshake packet. if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake { @@ -165,6 +155,10 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { // remove outstanding packets from bytes_in_flight if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { pnSpace := h.getPacketNumberSpace(encLevel) + // We might already have dropped this packet number space. + if pnSpace == nil { + return + } pnSpace.history.Iterate(func(p *packet) (bool, error) { h.removeFromBytesInFlight(p) return true, nil @@ -238,10 +232,6 @@ func (h *sentPacketHandler) SentPacket( isPathMTUProbePacket bool, ) { h.bytesSent += size - // For the client, drop the Initial packet number space when the first Handshake packet is sent. - if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake && h.initialPackets != nil { - h.dropPackets(protocol.EncryptionInitial) - } pnSpace := h.getPacketNumberSpace(encLevel) if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { @@ -884,6 +874,12 @@ func (h *sentPacketHandler) ResetForRetry() error { } func (h *sentPacketHandler) SetHandshakeConfirmed() { + if h.initialPackets != nil { + panic("didn't drop initial correctly") + } + if h.handshakePackets != nil { + panic("didn't drop handshake correctly") + } h.handshakeConfirmed = true // We don't send PTOs for application data packets before the handshake completes. // Make sure the timer is armed now, if necessary. diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 2c94fdda..6c603c87 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -130,6 +130,13 @@ var _ = Describe("SentPacketHandler", func() { ExpectWithOffset(1, handler.rttStats.SmoothedRTT()).To(Equal(rtt)) } + // setHandshakeConfirmed drops both Initial and Handshake packets and then confirms the handshake + setHandshakeConfirmed := func() { + handler.DropPackets(protocol.EncryptionInitial) + handler.DropPackets(protocol.EncryptionHandshake) + handler.SetHandshakeConfirmed() + } + Context("registering sent packets", func() { It("accepts two consecutive packets", func() { sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, EncryptionLevel: protocol.EncryptionHandshake})) @@ -705,7 +712,7 @@ var _ = Describe("SentPacketHandler", func() { It("implements exponential backoff", func() { handler.peerAddressValidated = true - handler.SetHandshakeConfirmed() + setHandshakeConfirmed() sendTime := time.Now().Add(-time.Hour) sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: sendTime})) timeout := handler.GetLossDetectionTimeout().Sub(sendTime) @@ -729,7 +736,7 @@ var _ = Describe("SentPacketHandler", func() { It("reset the PTO count when receiving an ACK", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) now := time.Now() - handler.SetHandshakeConfirmed() + setHandshakeConfirmed() sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) sentPacket(ackElicitingPacket(&packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) handler.appDataPackets.pns.(*skippingPacketNumberGenerator).next = 3 @@ -770,7 +777,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.ptoCount).To(BeEquivalentTo(1)) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOHandshake)) Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeHandshake.Add(handler.rttStats.PTO(false) << 1))) - handler.SetHandshakeConfirmed() + setHandshakeConfirmed() handler.DropPackets(protocol.EncryptionHandshake) // PTO timer based on the 1-RTT packet Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeAppData.Add(handler.rttStats.PTO(true)))) // no backoff. PTO count = 0 @@ -780,7 +787,7 @@ var _ = Describe("SentPacketHandler", func() { It("allows two 1-RTT PTOs", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() + setHandshakeConfirmed() var lostPackets []protocol.PacketNumber sentPacket(ackElicitingPacket(&packet{ PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT), @@ -802,7 +809,7 @@ var _ = Describe("SentPacketHandler", func() { It("only counts ack-eliciting packets as probe packets", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() + setHandshakeConfirmed() sentPacket(ackElicitingPacket(&packet{ PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT), SendTime: time.Now().Add(-time.Hour), @@ -821,7 +828,7 @@ var _ = Describe("SentPacketHandler", func() { It("gets two probe packets if PTO expires", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() + setHandshakeConfirmed() sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) @@ -869,7 +876,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) - handler.SetHandshakeConfirmed() + setHandshakeConfirmed() Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) @@ -877,7 +884,7 @@ var _ = Describe("SentPacketHandler", func() { It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() + setHandshakeConfirmed() pn := handler.PopPacketNumber(protocol.Encryption1RTT) sentPacket(ackElicitingPacket(&packet{PacketNumber: pn, SendTime: time.Now().Add(-time.Hour)})) updateRTT(time.Second) @@ -902,7 +909,7 @@ var _ = Describe("SentPacketHandler", func() { It("doesn't set the PTO timer for Path MTU probe packets", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() + setHandshakeConfirmed() updateRTT(time.Second) sentPacket(ackElicitingPacket(&packet{PacketNumber: 5, SendTime: time.Now(), IsPathMTUProbePacket: true})) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) @@ -1021,6 +1028,7 @@ var _ = Describe("SentPacketHandler", func() { // Now receive an ACK for a Handshake packet. // This tells the client that the server completed address validation. sentPacket(handshakePacket(&packet{PacketNumber: 1})) + handler.DropPackets(protocol.EncryptionInitial) // sending a Handshake packet drops the Initial packet number space _, err = handler.ReceivedAck( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionHandshake, @@ -1040,7 +1048,8 @@ var _ = Describe("SentPacketHandler", func() { ) Expect(err).ToNot(HaveOccurred()) - sentPacket(handshakePacketNonAckEliciting(&packet{PacketNumber: 1})) // also drops Initial packets + sentPacket(handshakePacketNonAckEliciting(&packet{PacketNumber: 1})) + handler.DropPackets(protocol.EncryptionInitial) // sending a Handshake packet drops the Initial packet number space Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOHandshake)) @@ -1075,7 +1084,7 @@ var _ = Describe("SentPacketHandler", func() { ) Expect(err).ToNot(HaveOccurred()) sentPacket(handshakePacketNonAckEliciting(&packet{PacketNumber: 1, SendTime: time.Now()})) - Expect(handler.initialPackets).To(BeNil()) + handler.DropPackets(protocol.EncryptionInitial) // sending a Handshake packet drops the Initial packet number space pto := handler.rttStats.PTO(false) Expect(pto).ToNot(BeZero()) @@ -1235,39 +1244,6 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) }) - Context("deleting Initials", func() { - BeforeEach(func() { perspective = protocol.PerspectiveClient }) - - It("deletes Initials, as a client", func() { - for i := 0; i < 6; i++ { - sentPacket(ackElicitingPacket(&packet{ - PacketNumber: handler.PopPacketNumber(protocol.EncryptionInitial), - EncryptionLevel: protocol.EncryptionInitial, - Length: 1, - })) - } - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) - handler.DropPackets(protocol.EncryptionInitial) - // DropPackets should be ignored for clients and the Initial packet number space. - // It has to be possible to send another Initial packets after this function was called. - sentPacket(ackElicitingPacket(&packet{ - PacketNumber: handler.PopPacketNumber(protocol.EncryptionInitial), - EncryptionLevel: protocol.EncryptionInitial, - Length: 1, - })) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7))) - // Sending a Handshake packet triggers dropping of Initials. - sentPacket(ackElicitingPacket(&packet{ - PacketNumber: handler.PopPacketNumber(protocol.EncryptionHandshake), - EncryptionLevel: protocol.EncryptionHandshake, - })) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(1))) - Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission - Expect(handler.initialPackets).To(BeNil()) - Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) - }) - }) - It("deletes Handshake packets", func() { for i := protocol.PacketNumber(0); i < 6; i++ { sentPacket(ackElicitingPacket(&packet{ diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index ccda43ca..6aa89fb3 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -92,69 +92,3 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad [] func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) } - -type handshakeSealer struct { - LongHeaderSealer - - dropInitialKeys func() - dropped bool -} - -func newHandshakeSealer( - aead cipher.AEAD, - headerProtector headerProtector, - dropInitialKeys func(), - perspective protocol.Perspective, -) LongHeaderSealer { - sealer := newLongHeaderSealer(aead, headerProtector) - // The client drops Initial keys when sending the first Handshake packet. - if perspective == protocol.PerspectiveServer { - return sealer - } - return &handshakeSealer{ - LongHeaderSealer: sealer, - dropInitialKeys: dropInitialKeys, - } -} - -func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { - data := s.LongHeaderSealer.Seal(dst, src, pn, ad) - if !s.dropped { - s.dropInitialKeys() - s.dropped = true - } - return data -} - -type handshakeOpener struct { - LongHeaderOpener - - dropInitialKeys func() - dropped bool -} - -func newHandshakeOpener( - aead cipher.AEAD, - headerProtector headerProtector, - dropInitialKeys func(), - perspective protocol.Perspective, -) LongHeaderOpener { - opener := newLongHeaderOpener(aead, headerProtector) - // The server drops Initial keys when first successfully processing a Handshake packet. - if perspective == protocol.PerspectiveClient { - return opener - } - return &handshakeOpener{ - LongHeaderOpener: opener, - dropInitialKeys: dropInitialKeys, - } -} - -func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad) - if err == nil && !o.dropped { - o.dropInitialKeys() - o.dropped = true - } - return dec, err -} diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 6dee25d3..85fe28d8 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -133,72 +133,5 @@ var _ = Describe("Long Header AEAD", func() { }) } }) - - Describe("Long Header AEAD", func() { - var ( - dropped chan struct{} // use a chan because closing it twice will panic - aead cipher.AEAD - hp headerProtector - ) - dropCb := func() { close(dropped) } - msg := []byte("Lorem ipsum dolor sit amet.") - ad := []byte("Donec in velit neque.") - - BeforeEach(func() { - dropped = make(chan struct{}) - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err = cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) - hp = newHeaderProtector(cipherSuites[0], hpKey, true, protocol.Version1) - }) - - Context("for the server", func() { - It("drops keys when first successfully processing a Handshake packet", func() { - serverOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveServer) - // first try to open an invalid message - _, err := serverOpener.Open(nil, []byte("invalid"), 0, []byte("invalid")) - Expect(err).To(HaveOccurred()) - Expect(dropped).ToNot(BeClosed()) - // then open a valid message - enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 10, ad) - _, err = serverOpener.Open(nil, enc, 10, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(dropped).To(BeClosed()) - // now open the same message again to make sure the callback is only called once - _, err = serverOpener.Open(nil, enc, 10, ad) - Expect(err).ToNot(HaveOccurred()) - }) - - It("doesn't drop keys when sealing a Handshake packet", func() { - serverSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveServer) - serverSealer.Seal(nil, msg, 1, ad) - Expect(dropped).ToNot(BeClosed()) - }) - }) - - Context("for the client", func() { - It("drops keys when first sealing a Handshake packet", func() { - clientSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveClient) - // seal the first message - clientSealer.Seal(nil, msg, 1, ad) - Expect(dropped).To(BeClosed()) - // seal another message to make sure the callback is only called once - clientSealer.Seal(nil, msg, 2, ad) - }) - - It("doesn't drop keys when processing a Handshake packet", func() { - enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 42, ad) - clientOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveClient) - _, err := clientOpener.Open(nil, enc, 42, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(dropped).ToNot(BeClosed()) - }) - }) - }) } }) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 29e66f57..66ed04d6 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "errors" "fmt" - "io" "sync" "sync/atomic" "time" @@ -30,16 +29,15 @@ type cryptoSetup struct { tlsConf *tls.Config conn *qtls.QUICConn + events []Event + version protocol.VersionNumber ourParams *wire.TransportParameters peerParams *wire.TransportParameters - runner handshakeRunner - - zeroRTTParameters *wire.TransportParameters - zeroRTTParametersChan chan<- *wire.TransportParameters - allow0RTT bool + zeroRTTParameters *wire.TransportParameters + allow0RTT bool rttStats *utils.RTTStats @@ -55,17 +53,14 @@ type cryptoSetup struct { zeroRTTOpener LongHeaderOpener // only set for the server zeroRTTSealer LongHeaderSealer // only set for the client - initialStream io.Writer initialOpener LongHeaderOpener initialSealer LongHeaderSealer - handshakeStream io.Writer handshakeOpener LongHeaderOpener handshakeSealer LongHeaderSealer used0RTT atomic.Bool - oneRTTStream io.Writer aead *updatableAEAD has1RTTSealer bool has1RTTOpener bool @@ -75,24 +70,18 @@ var _ CryptoSetup = &cryptoSetup{} // NewCryptoSetupClient creates a new crypto setup for the client func NewCryptoSetupClient( - initialStream, handshakeStream, oneRTTStream io.Writer, connID protocol.ConnectionID, tp *wire.TransportParameters, - runner handshakeRunner, tlsConf *tls.Config, enable0RTT bool, rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber, -) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { - cs, clientHelloWritten := newCryptoSetup( - initialStream, - handshakeStream, - oneRTTStream, +) CryptoSetup { + cs := newCryptoSetup( connID, tp, - runner, rttStats, tracer, logger, @@ -109,15 +98,13 @@ func NewCryptoSetupClient( cs.conn = qtls.QUICClient(quicConf) cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) - return cs, clientHelloWritten + return cs } // NewCryptoSetupServer creates a new crypto setup for the server func NewCryptoSetupServer( - initialStream, handshakeStream, oneRTTStream io.Writer, connID protocol.ConnectionID, tp *wire.TransportParameters, - runner handshakeRunner, tlsConf *tls.Config, allow0RTT bool, rttStats *utils.RTTStats, @@ -125,13 +112,9 @@ func NewCryptoSetupServer( logger utils.Logger, version protocol.VersionNumber, ) CryptoSetup { - cs, _ := newCryptoSetup( - initialStream, - handshakeStream, - oneRTTStream, + cs := newCryptoSetup( connID, tp, - runner, rttStats, tracer, logger, @@ -150,38 +133,31 @@ func NewCryptoSetupServer( } func newCryptoSetup( - initialStream, handshakeStream, oneRTTStream io.Writer, connID protocol.ConnectionID, tp *wire.TransportParameters, - runner handshakeRunner, rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, perspective protocol.Perspective, version protocol.VersionNumber, -) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { +) *cryptoSetup { initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) if tracer != nil { tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } - zeroRTTParametersChan := make(chan *wire.TransportParameters, 1) return &cryptoSetup{ - initialStream: initialStream, - initialSealer: initialSealer, - initialOpener: initialOpener, - handshakeStream: handshakeStream, - oneRTTStream: oneRTTStream, - aead: newUpdatableAEAD(rttStats, tracer, logger, version), - runner: runner, - ourParams: tp, - rttStats: rttStats, - tracer: tracer, - logger: logger, - perspective: perspective, - zeroRTTParametersChan: zeroRTTParametersChan, - version: version, - }, zeroRTTParametersChan + initialSealer: initialSealer, + initialOpener: initialOpener, + aead: newUpdatableAEAD(rttStats, tracer, logger, version), + events: make([]Event, 0, 16), + ourParams: tp, + rttStats: rttStats, + tracer: tracer, + logger: logger, + perspective: perspective, + version: version, + } } func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { @@ -216,10 +192,9 @@ func (h *cryptoSetup) StartHandshake() error { if h.perspective == protocol.PerspectiveClient { if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { h.logger.Debugf("Doing 0-RTT.") - h.zeroRTTParametersChan <- h.zeroRTTParameters + h.events = append(h.events, Event{Kind: EventRestoredTransportParameters, TransportParameters: h.zeroRTTParameters}) } else { h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil) - h.zeroRTTParametersChan <- nil } } return nil @@ -275,7 +250,8 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) { h.rejected0RTT() return false, nil case qtls.QUICWriteData: - return false, h.WriteRecord(ev.Level, ev.Data) + h.WriteRecord(ev.Level, ev.Data) + return false, nil case qtls.QUICHandshakeDone: h.handshakeComplete() return false, nil @@ -284,13 +260,22 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) { } } +func (h *cryptoSetup) NextEvent() Event { + if len(h.events) == 0 { + return Event{Kind: EventNoEvent} + } + ev := h.events[0] + h.events = h.events[1:] + return ev +} + func (h *cryptoSetup) handleTransportParameters(data []byte) error { var tp wire.TransportParameters if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { return err } h.peerParams = &tp - h.runner.OnReceivedParams(h.peerParams) + h.events = append(h.events, Event{Kind: EventReceivedTransportParameters, TransportParameters: h.peerParams}) return nil } @@ -392,7 +377,7 @@ func (h *cryptoSetup) rejected0RTT() { h.mutex.Unlock() if had0RTTKeys { - h.runner.DropKeys(protocol.Encryption0RTT) + h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys}) } } @@ -414,11 +399,9 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } case qtls.QUICEncryptionLevelHandshake: - h.handshakeOpener = newHandshakeOpener( + h.handshakeOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), - h.dropInitialKeys, - h.perspective, ) if h.logger.Debug() { h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) @@ -433,7 +416,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr panic("unexpected read encryption level") } h.mutex.Unlock() - h.runner.OnReceivedReadKeys() + h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) if h.tracer != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } @@ -462,11 +445,9 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t // don't set used0RTT here. 0-RTT might still get rejected. return case qtls.QUICEncryptionLevelHandshake: - h.handshakeSealer = newHandshakeSealer( + h.handshakeSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), - h.dropInitialKeys, - h.perspective, ) if h.logger.Debug() { h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) @@ -496,40 +477,34 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t } // WriteRecord is called when TLS writes data -func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) error { - h.mutex.Lock() - defer h.mutex.Unlock() - - var str io.Writer +func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) { //nolint:exhaustive // handshake records can only be written for Initial and Handshake. switch encLevel { case qtls.QUICEncryptionLevelInitial: - // assume that the first WriteRecord call contains the ClientHello - str = h.initialStream + h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p}) case qtls.QUICEncryptionLevelHandshake: - str = h.handshakeStream + h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p}) case qtls.QUICEncryptionLevelApplication: - str = h.oneRTTStream + panic("unexpected write") default: panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel)) } - _, err := str.Write(p) - return err } -// used a callback in the handshakeSealer and handshakeOpener -func (h *cryptoSetup) dropInitialKeys() { +func (h *cryptoSetup) DiscardInitialKeys() { h.mutex.Lock() + dropped := h.initialOpener != nil h.initialOpener = nil h.initialSealer = nil h.mutex.Unlock() - h.runner.DropKeys(protocol.EncryptionInitial) - h.logger.Debugf("Dropping Initial keys.") + if dropped { + h.logger.Debugf("Dropping Initial keys.") + } } func (h *cryptoSetup) handshakeComplete() { h.handshakeCompleteTime = time.Now() - h.runner.OnHandshakeComplete() + h.events = append(h.events, Event{Kind: EventHandshakeComplete}) } func (h *cryptoSetup) SetHandshakeConfirmed() { @@ -544,7 +519,6 @@ func (h *cryptoSetup) SetHandshakeConfirmed() { } h.mutex.Unlock() if dropped { - h.runner.DropKeys(protocol.EncryptionHandshake) h.logger.Debugf("Dropping Handshake keys.") } } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index f6d8ae3f..e891cb9f 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -27,46 +27,9 @@ const ( typeNewSessionTicket = 4 ) -type chunk struct { - data []byte - encLevel protocol.EncryptionLevel -} - -type stream struct { - encLevel protocol.EncryptionLevel - chunkChan chan<- chunk -} - -func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream { - return &stream{ - chunkChan: chunkChan, - encLevel: encLevel, - } -} - -func (s *stream) Write(b []byte) (int, error) { - data := make([]byte, len(b)) - copy(data, b) - select { - case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: - default: - panic("chunkChan too small") - } - return len(b), nil -} - var _ = Describe("Crypto Setup TLS", func() { var clientConf, serverConf *tls.Config - // unparam incorrectly complains that the first argument is never used. - //nolint:unparam - initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { - chunkChan := make(chan chunk, 100) - initialStream := newStream(chunkChan, protocol.EncryptionInitial) - handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) - return chunkChan, initialStream, handshakeStream - } - BeforeEach(func() { serverConf = testdata.GetTLSConfig() serverConf.NextProtos = []string{"crypto-setup"} @@ -78,17 +41,12 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("handles qtls errors occurring before during ClientHello generation", func() { - _, sInitialStream, sHandshakeStream := initStreams() tlsConf := testdata.GetTLSConfig() tlsConf.InsecureSkipVerify = true tlsConf.NextProtos = []string{""} - cl, _ := NewCryptoSetupClient( - sInitialStream, - sHandshakeStream, - nil, + cl := NewCryptoSetupClient( protocol.ConnectionID{}, &wire.TransportParameters{}, - NewMockHandshakeRunner(mockCtrl), tlsConf, false, &utils.RTTStats{}, @@ -104,16 +62,10 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("errors when a message is received at the wrong encryption level", func() { - _, sInitialStream, sHandshakeStream := initStreams() - runner := NewMockHandshakeRunner(mockCtrl) var token protocol.StatelessResetToken server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - nil, protocol.ConnectionID{}, &wire.TransportParameters{StatelessResetToken: &token}, - runner, testdata.GetTLSConfig(), false, &utils.RTTStats{}, @@ -158,32 +110,73 @@ var _ = Describe("Crypto Setup TLS", func() { return rttStats } - handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) { + // The clientEvents and serverEvents contain all events that were not processed by the function, + // i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete. + handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) { Expect(client.StartHandshake()).To(Succeed()) Expect(server.StartHandshake()).To(Succeed()) - for { - select { - case c := <-cChunkChan: - Expect(server.HandleMessage(c.data, c.encLevel)).To(Succeed()) - continue - default: - } - select { - case c := <-sChunkChan: - Expect(client.HandleMessage(c.data, c.encLevel)).To(Succeed()) - continue - default: - } - // no more messages to send from client and server. Handshake complete? - break - } + var clientHandshakeComplete, serverHandshakeComplete bool - ticket, err := server.GetSessionTicket() - Expect(err).ToNot(HaveOccurred()) - if ticket != nil { - Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed()) + for { + clientLoop: + for { + ev := client.NextEvent() + //nolint:exhaustive // only need to process a few events + switch ev.Kind { + case EventNoEvent: + break clientLoop + case EventWriteInitialData: + if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { + serverErr = err + return + } + case EventWriteHandshakeData: + if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { + serverErr = err + return + } + case EventHandshakeComplete: + clientHandshakeComplete = true + default: + clientEvents = append(clientEvents, ev) + } + } + + serverLoop: + for { + ev := server.NextEvent() + //nolint:exhaustive // only need to process a few events + switch ev.Kind { + case EventNoEvent: + break serverLoop + case EventWriteInitialData: + if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { + clientErr = err + return + } + case EventWriteHandshakeData: + if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { + clientErr = err + return + } + case EventHandshakeComplete: + serverHandshakeComplete = true + ticket, err := server.GetSessionTicket() + Expect(err).ToNot(HaveOccurred()) + if ticket != nil { + Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed()) + } + default: + serverEvents = append(serverEvents, ev) + } + } + + if clientHandshakeComplete && serverHandshakeComplete { + break + } } + return } handshakeWithTLSConf := func( @@ -191,22 +184,12 @@ var _ = Describe("Crypto Setup TLS", func() { clientRTTStats, serverRTTStats *utils.RTTStats, clientTransportParameters, serverTransportParameters *wire.TransportParameters, enable0RTT bool, - ) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { - var cHandshakeComplete bool - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cErrChan := make(chan error, 1) - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise - cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) - cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1) - client, clientHelloWrittenChan := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - nil, + ) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */ + CryptoSetup /* server */, []Event /* more server events */, error, /* server error */ + ) { + client := NewCryptoSetupClient( protocol.ConnectionID{}, clientTransportParameters, - cRunner, clientConf, enable0RTT, clientRTTStats, @@ -215,24 +198,13 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - var sHandshakeComplete bool - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sErrChan := make(chan error, 1) - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise - sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) if serverTransportParameters.StatelessResetToken == nil { var token protocol.StatelessResetToken serverTransportParameters.StatelessResetToken = &token } server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - nil, protocol.ConnectionID{}, serverTransportParameters, - sRunner, serverConf, enable0RTT, serverRTTStats, @@ -240,24 +212,12 @@ var _ = Describe("Crypto Setup TLS", func() { utils.DefaultLogger.WithPrefix("server"), protocol.Version1, ) - - handshake(client, cChunkChan, server, sChunkChan) - var cErr, sErr error - select { - case sErr = <-sErrChan: - default: - Expect(sHandshakeComplete).To(BeTrue()) - } - select { - case cErr = <-cErrChan: - default: - Expect(cHandshakeComplete).To(BeTrue()) - } - return clientHelloWrittenChan, client, cErr, server, sErr + cEvents, cErr, sEvents, sErr := handshake(client, server) + return client, cEvents, cErr, server, sEvents, sErr } It("handshakes", func() { - _, _, clientErr, _, serverErr := handshakeWithTLSConf( + _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, &utils.RTTStats{}, &utils.RTTStats{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -269,7 +229,7 @@ var _ = Describe("Crypto Setup TLS", func() { It("performs a HelloRetryRequst", func() { serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} - _, _, clientErr, _, serverErr := handshakeWithTLSConf( + _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, &utils.RTTStats{}, &utils.RTTStats{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -282,7 +242,7 @@ var _ = Describe("Crypto Setup TLS", func() { It("handshakes with client auth", func() { clientConf.Certificates = []tls.Certificate{generateCert()} serverConf.ClientAuth = tls.RequireAnyClientCert - _, _, clientErr, _, serverErr := handshakeWithTLSConf( + _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, &utils.RTTStats{}, &utils.RTTStats{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -292,50 +252,11 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(serverErr).ToNot(HaveOccurred()) }) - It("signals when it has written the ClientHello", func() { - runner := NewMockHandshakeRunner(mockCtrl) - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - client, chChan := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - nil, - protocol.ConnectionID{}, - &wire.TransportParameters{}, - runner, - &tls.Config{InsecureSkipVerify: true}, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.Version1, - ) - - Expect(client.StartHandshake()).To(Succeed()) - var ch chunk - Eventually(cChunkChan).Should(Receive(&ch)) - Eventually(chChan).Should(Receive(BeNil())) - // make sure the whole ClientHello was written - Expect(len(ch.data)).To(BeNumerically(">=", 4)) - Expect(ch.data[0]).To(BeEquivalentTo(typeClientHello)) - length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) - Expect(len(ch.data) - 4).To(Equal(length)) - }) - It("receives transport parameters", func() { - var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 0x42 * time.Second} - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedReadKeys().Times(2) - cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp }) - cRunner.EXPECT().OnHandshakeComplete() - client, _ := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - nil, + cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second} + client := NewCryptoSetupClient( protocol.ConnectionID{}, cTransportParameters, - cRunner, clientConf, false, &utils.RTTStats{}, @@ -344,24 +265,15 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - sChunkChan, sInitialStream, sHandshakeStream := initStreams() var token protocol.StatelessResetToken - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedReadKeys().Times(2) - sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp }) - sRunner.EXPECT().OnHandshakeComplete() sTransportParameters := &wire.TransportParameters{ - MaxIdleTimeout: 0x1337 * time.Second, + MaxIdleTimeout: 1337 * time.Second, StatelessResetToken: &token, ActiveConnectionIDLimit: 2, } server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - nil, protocol.ConnectionID{}, sTransportParameters, - sRunner, serverConf, false, &utils.RTTStats{}, @@ -370,68 +282,38 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) - Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout)) - Expect(sTransportParametersRcvd).ToNot(BeNil()) - Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout)) + clientEvents, cErr, serverEvents, sErr := handshake(client, server) + Expect(cErr).ToNot(HaveOccurred()) + Expect(sErr).ToNot(HaveOccurred()) + var clientReceivedTransportParameters *wire.TransportParameters + for _, ev := range clientEvents { + if ev.Kind == EventReceivedTransportParameters { + clientReceivedTransportParameters = ev.TransportParameters + } + } + Expect(clientReceivedTransportParameters).ToNot(BeNil()) + Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second)) + + var serverReceivedTransportParameters *wire.TransportParameters + for _, ev := range serverEvents { + if ev.Kind == EventReceivedTransportParameters { + serverReceivedTransportParameters = ev.TransportParameters + } + } + Expect(serverReceivedTransportParameters).ToNot(BeNil()) + Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second)) }) Context("with session tickets", func() { It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnReceivedReadKeys().Times(2) - cRunner.EXPECT().OnHandshakeComplete() - client, _ := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - nil, - protocol.ConnectionID{}, - &wire.TransportParameters{ActiveConnectionIDLimit: 2}, - cRunner, - clientConf, + client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.Version1, ) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnReceivedReadKeys().Times(2) - sRunner.EXPECT().OnHandshakeComplete() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - nil, - protocol.ConnectionID{}, - &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, - sRunner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.Version1, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) // inject an invalid session ticket b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) @@ -441,54 +323,14 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("errors when handling the NewSessionTicket fails", func() { - cChunkChan, cInitialStream, cHandshakeStream := initStreams() - cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnReceivedReadKeys().Times(2) - cRunner.EXPECT().OnHandshakeComplete() - client, _ := NewCryptoSetupClient( - cInitialStream, - cHandshakeStream, - nil, - protocol.ConnectionID{}, - &wire.TransportParameters{ActiveConnectionIDLimit: 2}, - cRunner, - clientConf, + client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( + clientConf, serverConf, + &utils.RTTStats{}, &utils.RTTStats{}, + &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("client"), - protocol.Version1, ) - - sChunkChan, sInitialStream, sHandshakeStream := initStreams() - sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnReceivedReadKeys().Times(2) - sRunner.EXPECT().OnHandshakeComplete() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - nil, - protocol.ConnectionID{}, - &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, - sRunner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.Version1, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handshake(client, cChunkChan, server, sChunkChan) - close(done) - }() - Eventually(done).Should(BeClosed()) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) // inject an invalid session ticket b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) @@ -509,7 +351,7 @@ var _ = Describe("Crypto Setup TLS", func() { clientConf.ClientSessionCache = csc const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) - clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( + client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, clientOrigRTTStats, &utils.RTTStats{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -520,12 +362,11 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(receivedSessionTicket).Should(BeClosed()) Expect(server.ConnectionState().DidResume).To(BeFalse()) Expect(client.ConnectionState().DidResume).To(BeFalse()) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) clientRTTStats := &utils.RTTStats{} - clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( + client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( clientConf, serverConf, clientRTTStats, &utils.RTTStats{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -537,7 +378,6 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue()) Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) }) It("doesn't use session resumption if the server disabled it", func() { @@ -550,7 +390,7 @@ var _ = Describe("Crypto Setup TLS", func() { close(receivedSessionTicket) }) clientConf.ClientSessionCache = csc - _, client, clientErr, server, serverErr := handshakeWithTLSConf( + client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, &utils.RTTStats{}, &utils.RTTStats{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -564,7 +404,7 @@ var _ = Describe("Crypto Setup TLS", func() { serverConf.SessionTicketsDisabled = true csc.EXPECT().Get(gomock.Any()).Return(state, true) - _, client, clientErr, server, serverErr = handshakeWithTLSConf( + client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( clientConf, serverConf, &utils.RTTStats{}, &utils.RTTStats{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -592,7 +432,7 @@ var _ = Describe("Crypto Setup TLS", func() { serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) const initialMaxData protocol.ByteCount = 1337 - clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( + client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, clientOrigRTTStats, serverOrigRTTStats, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -604,14 +444,13 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(receivedSessionTicket).Should(BeClosed()) Expect(server.ConnectionState().DidResume).To(BeFalse()) Expect(client.ConnectionState().DidResume).To(BeFalse()) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) clientRTTStats := &utils.RTTStats{} serverRTTStats := &utils.RTTStats{} - clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( + client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf( clientConf, serverConf, clientRTTStats, serverRTTStats, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -624,9 +463,30 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) var tp *wire.TransportParameters - Expect(clientHelloWrittenChan).To(Receive(&tp)) + var clientReceived0RTTKeys bool + for _, ev := range clientEvents { + //nolint:exhaustive // only need to process a few events + switch ev.Kind { + case EventRestoredTransportParameters: + tp = ev.TransportParameters + case EventReceivedReadKeys: + clientReceived0RTTKeys = true + } + } + Expect(clientReceived0RTTKeys).To(BeTrue()) + Expect(tp).ToNot(BeNil()) Expect(tp.InitialMaxData).To(Equal(initialMaxData)) + var serverReceived0RTTKeys bool + for _, ev := range serverEvents { + //nolint:exhaustive // only need to process a few events + switch ev.Kind { + case EventReceivedReadKeys: + serverReceived0RTTKeys = true + } + } + Expect(serverReceived0RTTKeys).To(BeTrue()) + Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue()) Expect(server.ConnectionState().Used0RTT).To(BeTrue()) @@ -646,7 +506,7 @@ var _ = Describe("Crypto Setup TLS", func() { const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) const initialMaxData protocol.ByteCount = 1337 - clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( + client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, clientOrigRTTStats, &utils.RTTStats{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -658,13 +518,12 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(receivedSessionTicket).Should(BeClosed()) Expect(server.ConnectionState().DidResume).To(BeFalse()) Expect(client.ConnectionState().DidResume).To(BeFalse()) - Expect(clientHelloWrittenChan).To(Receive(BeNil())) csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) clientRTTStats := &utils.RTTStats{} - clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( + client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf( clientConf, serverConf, clientRTTStats, &utils.RTTStats{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, @@ -676,7 +535,18 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) var tp *wire.TransportParameters - Expect(clientHelloWrittenChan).To(Receive(&tp)) + var clientReceived0RTTKeys bool + for _, ev := range clientEvents { + //nolint:exhaustive // only need to process a few events + switch ev.Kind { + case EventRestoredTransportParameters: + tp = ev.TransportParameters + case EventReceivedReadKeys: + clientReceived0RTTKeys = true + } + } + Expect(clientReceived0RTTKeys).To(BeTrue()) + Expect(tp).ToNot(BeNil()) Expect(tp.InitialMaxData).To(Equal(initialMaxData)) Expect(server.ConnectionState().DidResume).To(BeTrue()) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index ab242953..fab224f9 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -53,18 +53,42 @@ type ShortHeaderSealer interface { KeyPhase() protocol.KeyPhaseBit } -type handshakeRunner interface { - OnReceivedParams(*wire.TransportParameters) - OnHandshakeComplete() - OnReceivedReadKeys() - DropKeys(protocol.EncryptionLevel) -} - type ConnectionState struct { tls.ConnectionState Used0RTT bool } +// EventKind is the kind of handshake event. +type EventKind uint8 + +const ( + // EventNoEvent signals that there are no new handshake events + EventNoEvent EventKind = iota + 1 + // EventWriteInitialData contains new CRYPTO data to send at the Initial encryption level + EventWriteInitialData + // EventWriteHandshakeData contains new CRYPTO data to send at the Handshake encryption level + EventWriteHandshakeData + // EventReceivedReadKeys signals that new decryption keys are available. + // It doesn't say which encryption level those keys are for. + EventReceivedReadKeys + // EventDiscard0RTTKeys signals that the Handshake keys were discarded. + EventDiscard0RTTKeys + // EventReceivedTransportParameters contains the transport parameters sent by the peer. + EventReceivedTransportParameters + // EventRestoredTransportParameters contains the transport parameters restored from the session ticket. + // It is only used for the client. + EventRestoredTransportParameters + // EventHandshakeComplete signals that the TLS handshake was completed. + EventHandshakeComplete +) + +// Event is a handshake event. +type Event struct { + Kind EventKind + Data []byte + TransportParameters *wire.TransportParameters +} + // CryptoSetup handles the handshake and protecting / unprotecting packets type CryptoSetup interface { StartHandshake() error @@ -73,7 +97,10 @@ type CryptoSetup interface { GetSessionTicket() ([]byte, error) HandleMessage([]byte, protocol.EncryptionLevel) error + NextEvent() Event + SetLargest1RTTAcked(protocol.PacketNumber) error + DiscardInitialKeys() SetHandshakeConfirmed() ConnectionState() ConnectionState diff --git a/internal/handshake/mock_handshake_runner_test.go b/internal/handshake/mock_handshake_runner_test.go deleted file mode 100644 index 9d3cfef5..00000000 --- a/internal/handshake/mock_handshake_runner_test.go +++ /dev/null @@ -1,84 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go/internal/handshake (interfaces: HandshakeRunner) - -// Package handshake is a generated GoMock package. -package handshake - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/quic-go/quic-go/internal/protocol" - wire "github.com/quic-go/quic-go/internal/wire" -) - -// MockHandshakeRunner is a mock of HandshakeRunner interface. -type MockHandshakeRunner struct { - ctrl *gomock.Controller - recorder *MockHandshakeRunnerMockRecorder -} - -// MockHandshakeRunnerMockRecorder is the mock recorder for MockHandshakeRunner. -type MockHandshakeRunnerMockRecorder struct { - mock *MockHandshakeRunner -} - -// NewMockHandshakeRunner creates a new mock instance. -func NewMockHandshakeRunner(ctrl *gomock.Controller) *MockHandshakeRunner { - mock := &MockHandshakeRunner{ctrl: ctrl} - mock.recorder = &MockHandshakeRunnerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockHandshakeRunner) EXPECT() *MockHandshakeRunnerMockRecorder { - return m.recorder -} - -// DropKeys mocks base method. -func (m *MockHandshakeRunner) DropKeys(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DropKeys", arg0) -} - -// DropKeys indicates an expected call of DropKeys. -func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0) -} - -// OnHandshakeComplete mocks base method. -func (m *MockHandshakeRunner) OnHandshakeComplete() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnHandshakeComplete") -} - -// OnHandshakeComplete indicates an expected call of OnHandshakeComplete. -func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockHandshakeRunner)(nil).OnHandshakeComplete)) -} - -// OnReceivedParams mocks base method. -func (m *MockHandshakeRunner) OnReceivedParams(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnReceivedParams", arg0) -} - -// OnReceivedParams indicates an expected call of OnReceivedParams. -func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0) -} - -// OnReceivedReadKeys mocks base method. -func (m *MockHandshakeRunner) OnReceivedReadKeys() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnReceivedReadKeys") -} - -// OnReceivedReadKeys indicates an expected call of OnReceivedReadKeys. -func (mr *MockHandshakeRunnerMockRecorder) OnReceivedReadKeys() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedReadKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedReadKeys)) -} diff --git a/internal/handshake/mockgen.go b/internal/handshake/mockgen.go deleted file mode 100644 index 68b0988c..00000000 --- a/internal/handshake/mockgen.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build gomock || generate - -package handshake - -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package handshake -destination mock_handshake_runner_test.go github.com/quic-go/quic-go/internal/handshake HandshakeRunner" -type HandshakeRunner = handshakeRunner diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 0c5d528f..1c707b9c 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -75,6 +75,18 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) } +// DiscardInitialKeys mocks base method. +func (m *MockCryptoSetup) DiscardInitialKeys() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DiscardInitialKeys") +} + +// DiscardInitialKeys indicates an expected call of DiscardInitialKeys. +func (mr *MockCryptoSetupMockRecorder) DiscardInitialKeys() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscardInitialKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DiscardInitialKeys)) +} + // Get0RTTOpener mocks base method. func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { m.ctrl.T.Helper() @@ -224,6 +236,20 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) } +// NextEvent mocks base method. +func (m *MockCryptoSetup) NextEvent() handshake.Event { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextEvent") + ret0, _ := ret[0].(handshake.Event) + return ret0 +} + +// NextEvent indicates an expected call of NextEvent. +func (mr *MockCryptoSetupMockRecorder) NextEvent() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoSetup)(nil).NextEvent)) +} + // SetHandshakeConfirmed mocks base method. func (m *MockCryptoSetup) SetHandshakeConfirmed() { m.ctrl.T.Helper() diff --git a/mock_crypto_data_handler_test.go b/mock_crypto_data_handler_test.go index d6852891..d077886c 100644 --- a/mock_crypto_data_handler_test.go +++ b/mock_crypto_data_handler_test.go @@ -8,6 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + handshake "github.com/quic-go/quic-go/internal/handshake" protocol "github.com/quic-go/quic-go/internal/protocol" ) @@ -47,3 +48,17 @@ func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{ mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) } + +// NextEvent mocks base method. +func (m *MockCryptoDataHandler) NextEvent() handshake.Event { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextEvent") + ret0, _ := ret[0].(handshake.Event) + return ret0 +} + +// NextEvent indicates an expected call of NextEvent. +func (mr *MockCryptoDataHandlerMockRecorder) NextEvent() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoDataHandler)(nil).NextEvent)) +}