diff --git a/session.go b/session.go index 16e189b4..c776f9fb 100644 --- a/session.go +++ b/session.go @@ -688,6 +688,8 @@ func (s *session) handleHandshakeComplete() { s.connIDGenerator.SetHandshakeComplete() if s.perspective == protocol.PerspectiveServer { + s.handshakeConfirmed = true + s.sentPacketHandler.SetHandshakeConfirmed() ticket, err := s.cryptoStreamHandler.GetSessionTicket() if err != nil { s.closeLocal(err) @@ -1238,6 +1240,8 @@ func (s *session) handleHandshakeDoneFrame() error { if s.perspective == protocol.PerspectiveServer { return qerr.NewError(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame") } + s.handshakeConfirmed = true + s.sentPacketHandler.SetHandshakeConfirmed() s.cryptoStreamHandler.DropHandshakeKeys() return nil } @@ -1347,10 +1351,6 @@ func (s *session) handleCloseError(closeErr closeError) { } func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { - if encLevel == protocol.EncryptionHandshake { - s.handshakeConfirmed = true - s.sentPacketHandler.SetHandshakeConfirmed() - } s.sentPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel) if s.tracer != nil { diff --git a/session_test.go b/session_test.go index 2e382991..bf119243 100644 --- a/session_test.go +++ b/session_test.go @@ -1634,6 +1634,7 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode().AnyTimes() + sph.EXPECT().SetHandshakeConfirmed() sessionRunner.EXPECT().Retire(clientDestConnID) go func() { defer GinkgoRecover() @@ -1729,14 +1730,17 @@ var _ = Describe("Session", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().HasPacingBudget().Return(true) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SetHandshakeConfirmed() + sph.EXPECT().SentPacket(gomock.Any()) + mconn.EXPECT().Write(gomock.Any()) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sess.sentPacketHandler = sph done := make(chan struct{}) sessionRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackCoalescedPacket(gomock.Any()).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) { + packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) @@ -1748,7 +1752,7 @@ var _ = Describe("Session", func() { buffer: getPacketBuffer(), }, nil }) - packer.EXPECT().PackCoalescedPacket(gomock.Any()).AnyTimes() + packer.EXPECT().PackPacket().AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -2270,6 +2274,9 @@ var _ = Describe("Client Session", func() { }) It("handles HANDSHAKE_DONE frames", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sess.sentPacketHandler = sph + sph.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().DropHandshakeKeys() Expect(sess.handleHandshakeDoneFrame()).To(Succeed()) })