From 62b3f22a77d1472c9865249acf3ee6cc4f7daa3c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 5 Aug 2023 18:25:04 -0400 Subject: [PATCH] fix handling of ACK frames serialized after CRYPTO frames (#4018) --- connection.go | 31 ++++++++++++++++++++----- connection_test.go | 57 ++++++++++------------------------------------ 2 files changed, 37 insertions(+), 51 deletions(-) diff --git a/connection.go b/connection.go index 601e5c77..955f8f8a 100644 --- a/connection.go +++ b/connection.go @@ -724,20 +724,22 @@ func (s *connection) idleTimeoutStartTime() time.Time { } func (s *connection) handleHandshakeComplete() error { - s.handshakeComplete = true defer s.handshakeCtxCancel() // Once the handshake completes, we have derived 1-RTT keys. - // There's no point in queueing undecryptable packets for later decryption any more. + // There's no point in queueing undecryptable packets for later decryption anymore. s.undecryptablePackets = nil s.connIDManager.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete() + // The server applies transport parameters right away, but the client side has to wait for handshake completion. + // During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets. if s.perspective == protocol.PerspectiveClient { s.applyTransportParameters() return nil } + // All these only apply to the server side. if err := s.handleHandshakeConfirmed(); err != nil { return err } @@ -1235,6 +1237,7 @@ func (s *connection) handleFrames( if log != nil { frames = make([]logging.Frame, 0, 4) } + handshakeWasComplete := s.handshakeComplete var handleErr error for len(data) > 0 { l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version) @@ -1271,6 +1274,17 @@ func (s *connection) handleFrames( return false, handleErr } } + + // Handle completion of the handshake after processing all the frames. + // This ensures that we correctly handle the following case on the server side: + // We receive a Handshake packet that contains the CRYPTO frame that allows us to complete the handshake, + // and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame. + if !handshakeWasComplete && s.handshakeComplete { + if err := s.handleHandshakeComplete(); err != nil { + return false, err + } + } + return } @@ -1366,7 +1380,9 @@ func (s *connection) handleHandshakeEvents() error { case handshake.EventNoEvent: return nil case handshake.EventHandshakeComplete: - err = s.handleHandshakeComplete() + // Don't call handleHandshakeComplete yet. + // It's advantageous to process ACK frames that might be serialized after the CRYPTO frame first. + s.handshakeComplete = true case handshake.EventReceivedTransportParameters: err = s.handleTransportParameters(ev.TransportParameters) case handshake.EventRestoredTransportParameters: @@ -1494,6 +1510,9 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr if !acked1RTTPacket { return nil } + // On the client side: If the packet acknowledged a 1-RTT packet, this confirms the handshake. + // This is only possible if the ACK was sent in a 1-RTT packet. + // This is an optimization over simply waiting for a HANDSHAKE_DONE frame, see section 4.1.2 of RFC 9001. if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed { if err := s.handleHandshakeConfirmed(); err != nil { return err @@ -1665,6 +1684,9 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters } func (s *connection) handleTransportParameters(params *wire.TransportParameters) error { + if s.tracer != nil { + s.tracer.ReceivedTransportParameters(params) + } if err := s.checkTransportParameters(params); err != nil { return &qerr.TransportError{ ErrorCode: qerr.TransportParameterError, @@ -1699,9 +1721,6 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters) if s.logger.Debug() { s.logger.Debugf("Processed Transport Parameters: %s", params) } - if s.tracer != nil { - s.tracer.ReceivedTransportParameters(params) - } // check the initial_source_connection_id if params.InitialSourceConnectionID != s.handshakeDestConnID { diff --git a/connection_test.go b/connection_test.go index b372085c..95f1b346 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1892,7 +1892,6 @@ var _ = Describe("Connection", func() { It("cancels the HandshakeComplete context when the handshake completes", func() { packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() - finishHandshake := make(chan struct{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) @@ -1902,53 +1901,26 @@ var _ = Describe("Connection", func() { sph.EXPECT().DropPackets(protocol.EncryptionHandshake) sph.EXPECT().SetHandshakeConfirmed() connRunner.EXPECT().Retire(clientDestConnID) - go func() { - defer GinkgoRecover() - <-finishHandshake - cryptoSetup.EXPECT().StartHandshake() - cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}) - cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) - cryptoSetup.EXPECT().SetHandshakeConfirmed() - cryptoSetup.EXPECT().GetSessionTicket() - conn.run() - }() + cryptoSetup.EXPECT().SetHandshakeConfirmed() + cryptoSetup.EXPECT().GetSessionTicket() handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) - close(finishHandshake) + Expect(conn.handleHandshakeComplete()).To(Succeed()) Eventually(handshakeCtx).Should(BeClosed()) - // make sure the go routine returns - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) }) It("sends a session ticket when the handshake completes", func() { const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes() - finishHandshake := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) conn.sentPacketHandler.DropPackets(protocol.EncryptionInitial) tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) - go func() { - defer GinkgoRecover() - <-finishHandshake - cryptoSetup.EXPECT().StartHandshake() - cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}) - cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) - cryptoSetup.EXPECT().SetHandshakeConfirmed() - cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) - conn.run() - }() + cryptoSetup.EXPECT().SetHandshakeConfirmed() + cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) - close(finishHandshake) + Expect(conn.handleHandshakeComplete()).To(Succeed()) var frames []ackhandler.Frame Eventually(func() []ackhandler.Frame { frames, _ = conn.framer.AppendControlFrames(nil, protocol.MaxByteCount, protocol.Version1) @@ -1964,16 +1936,6 @@ var _ = Describe("Connection", func() { } } Expect(size).To(BeEquivalentTo(s)) - // make sure the go routine returns - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(conn.Context().Done()).Should(BeClosed()) }) It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { @@ -2028,6 +1990,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() mconn.EXPECT().Write(gomock.Any(), gomock.Any()) + Expect(conn.handleHandshakeComplete()).To(Succeed()) conn.run() }() Eventually(done).Should(BeClosed()) @@ -2351,6 +2314,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) + Expect(conn.handleHandshakeComplete()).To(Succeed()) err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -2868,7 +2832,10 @@ var _ = Describe("Client Connection", func() { TransportParameters: params, }) cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}).MaxTimes(1) - cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}).MaxTimes(1) + cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}).MaxTimes(1).Do(func() { + defer GinkgoRecover() + Expect(conn.handleHandshakeComplete()).To(Succeed()) + }) errChan <- conn.run() close(errChan) }()