diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 373d677e..db739fb4 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -28,7 +28,7 @@ type SentPacketHandler interface { ReceivedBytes(protocol.ByteCount) DropPackets(protocol.EncryptionLevel) ResetForRetry() error - SetHandshakeComplete() + SetHandshakeConfirmed() // The SendMode determines if and what kind of packets can be sent. SendMode() SendMode diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 2332dcf7..203fcdee 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -58,7 +58,7 @@ type sentPacketHandler struct { // Always true for the client. peerAddressValidated bool - handshakeComplete bool + handshakeConfirmed bool // lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived // example: we send an ACK for packets 90-100 with packet number 20 @@ -444,7 +444,7 @@ func (h *sentPacketHandler) getPTOTimeAndSpace() (time.Time, protocol.Encryption encLevel = protocol.EncryptionHandshake } } - if h.handshakeComplete && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { + if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { pto = t @@ -468,7 +468,7 @@ func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { func (h *sentPacketHandler) hasOutstandingPackets() bool { // We only send application data probe packets once the handshake completes, // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. - return (h.handshakeComplete && h.appDataPackets.history.HasOutstandingPackets()) || + return (h.handshakeConfirmed && h.appDataPackets.history.HasOutstandingPackets()) || h.hasOutstandingCryptoPackets() } @@ -802,8 +802,8 @@ func (h *sentPacketHandler) ResetForRetry() error { return nil } -func (h *sentPacketHandler) SetHandshakeComplete() { - h.handshakeComplete = true +func (h *sentPacketHandler) SetHandshakeConfirmed() { + h.handshakeConfirmed = true // We don't send PTOs for application data packets before the handshake completes. // Make sure the timer is armed now, if necessary. h.setLossDetectionTimer() diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index fc0cafdd..21054d4a 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -604,7 +604,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("implements exponential backoff", func() { - handler.SetHandshakeComplete() + handler.SetHandshakeConfirmed() sendTime := time.Now().Add(-time.Hour) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) timeout := handler.GetLossDetectionTimeout().Sub(sendTime) @@ -620,7 +620,7 @@ var _ = Describe("SentPacketHandler", func() { It("reset the PTO count when receiving an ACK", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) now := time.Now() - handler.SetHandshakeComplete() + handler.SetHandshakeConfirmed() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) @@ -659,7 +659,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.ptoCount).To(BeEquivalentTo(1)) Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeHandshake.Add(handler.rttStats.PTO(false) << 1))) - handler.SetHandshakeComplete() + handler.SetHandshakeConfirmed() handler.DropPackets(protocol.EncryptionHandshake) // PTO timer based on the 1-RTT packet Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeAppData.Add(handler.rttStats.PTO(true)))) // no backoff. PTO count = 0 @@ -669,7 +669,7 @@ var _ = Describe("SentPacketHandler", func() { It("allows two 1-RTT PTOs", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeComplete() + handler.SetHandshakeConfirmed() var lostPackets []protocol.PacketNumber handler.SentPacket(ackElicitingPacket(&Packet{ PacketNumber: 1, @@ -688,7 +688,7 @@ var _ = Describe("SentPacketHandler", func() { It("only counts ack-eliciting packets as probe packets", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeComplete() + handler.SetHandshakeConfirmed() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTOAppData)) @@ -704,7 +704,7 @@ var _ = Describe("SentPacketHandler", func() { It("gets two probe packets if PTO expires", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeComplete() + handler.SetHandshakeConfirmed() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) @@ -752,7 +752,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP Expect(handler.GetLossDetectionTimeout()).To(BeZero()) Expect(handler.SendMode()).To(Equal(SendAny)) - handler.SetHandshakeComplete() + handler.SetHandshakeConfirmed() Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTOAppData)) @@ -760,7 +760,7 @@ var _ = Describe("SentPacketHandler", func() { It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.SetHandshakeComplete() + handler.SetHandshakeConfirmed() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) @@ -902,7 +902,7 @@ var _ = Describe("SentPacketHandler", func() { It("sets the early retransmit alarm", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) - handler.handshakeComplete = true + handler.handshakeConfirmed = true now := time.Now() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-2 * time.Second)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-2 * time.Second)})) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 355695f2..74928606 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -229,16 +229,16 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) } -// SetHandshakeComplete mocks base method -func (m *MockSentPacketHandler) SetHandshakeComplete() { +// SetHandshakeConfirmed mocks base method +func (m *MockSentPacketHandler) SetHandshakeConfirmed() { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHandshakeComplete") + m.ctrl.Call(m, "SetHandshakeConfirmed") } -// SetHandshakeComplete indicates an expected call of SetHandshakeComplete -func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call { +// SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed +func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed)) } // TimeUntilSend mocks base method diff --git a/session.go b/session.go index d25a684b..f0819bcb 100644 --- a/session.go +++ b/session.go @@ -683,7 +683,6 @@ func (s *session) handleHandshakeComplete() { s.handshakeCtxCancel() s.connIDGenerator.SetHandshakeComplete() - s.sentPacketHandler.SetHandshakeComplete() if s.perspective == protocol.PerspectiveServer { ticket, err := s.cryptoStreamHandler.GetSessionTicket() @@ -1331,6 +1330,7 @@ func (s *session) handleCloseError(closeErr closeError) { func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { if encLevel == protocol.EncryptionHandshake { s.handshakeConfirmed = true + s.sentPacketHandler.SetHandshakeConfirmed() } s.sentPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel) diff --git a/session_test.go b/session_test.go index bbadd556..31281217 100644 --- a/session_test.go +++ b/session_test.go @@ -1597,13 +1597,11 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("cancels the HandshakeComplete context and informs the SentPacketHandler when the handshake completes", func() { + It("cancels the HandshakeComplete context when the handshake completes", func() { packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() finishHandshake := make(chan struct{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph - sphNotified := make(chan struct{}) - sph.EXPECT().SetHandshakeComplete().Do(func() { close(sphNotified) }) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode().AnyTimes() @@ -1621,7 +1619,6 @@ var _ = Describe("Session", func() { Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) close(finishHandshake) Eventually(handshakeCtx.Done()).Should(BeClosed()) - Eventually(sphNotified).Should(BeClosed()) // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() @@ -1704,7 +1701,6 @@ var _ = Describe("Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount) - sph.EXPECT().SetHandshakeComplete() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().HasPacingBudget().Return(true)