only arm the application data PTO timer after the handshake is confirmed

This commit is contained in:
Marten Seemann 2020-07-27 16:40:21 +07:00
parent a854a4ace9
commit 8db76ab9c2
6 changed files with 23 additions and 27 deletions

View file

@ -28,7 +28,7 @@ type SentPacketHandler interface {
ReceivedBytes(protocol.ByteCount) ReceivedBytes(protocol.ByteCount)
DropPackets(protocol.EncryptionLevel) DropPackets(protocol.EncryptionLevel)
ResetForRetry() error ResetForRetry() error
SetHandshakeComplete() SetHandshakeConfirmed()
// The SendMode determines if and what kind of packets can be sent. // The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode SendMode() SendMode

View file

@ -58,7 +58,7 @@ type sentPacketHandler struct {
// Always true for the client. // Always true for the client.
peerAddressValidated bool peerAddressValidated bool
handshakeComplete bool handshakeConfirmed bool
// lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived // lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20 // example: we send an ACK for packets 90-100 with packet number 20
@ -444,7 +444,7 @@ func (h *sentPacketHandler) getPTOTimeAndSpace() (time.Time, protocol.Encryption
encLevel = protocol.EncryptionHandshake encLevel = protocol.EncryptionHandshake
} }
} }
if h.handshakeComplete && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount)
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t pto = t
@ -468,7 +468,7 @@ func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
func (h *sentPacketHandler) hasOutstandingPackets() bool { func (h *sentPacketHandler) hasOutstandingPackets() bool {
// We only send application data probe packets once the handshake completes, // We only send application data probe packets once the handshake completes,
// because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets.
return (h.handshakeComplete && h.appDataPackets.history.HasOutstandingPackets()) || return (h.handshakeConfirmed && h.appDataPackets.history.HasOutstandingPackets()) ||
h.hasOutstandingCryptoPackets() h.hasOutstandingCryptoPackets()
} }
@ -802,8 +802,8 @@ func (h *sentPacketHandler) ResetForRetry() error {
return nil return nil
} }
func (h *sentPacketHandler) SetHandshakeComplete() { func (h *sentPacketHandler) SetHandshakeConfirmed() {
h.handshakeComplete = true h.handshakeConfirmed = true
// We don't send PTOs for application data packets before the handshake completes. // We don't send PTOs for application data packets before the handshake completes.
// Make sure the timer is armed now, if necessary. // Make sure the timer is armed now, if necessary.
h.setLossDetectionTimer() h.setLossDetectionTimer()

View file

@ -604,7 +604,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("implements exponential backoff", func() { It("implements exponential backoff", func() {
handler.SetHandshakeComplete() handler.SetHandshakeConfirmed()
sendTime := time.Now().Add(-time.Hour) sendTime := time.Now().Add(-time.Hour)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime}))
timeout := handler.GetLossDetectionTimeout().Sub(sendTime) timeout := handler.GetLossDetectionTimeout().Sub(sendTime)
@ -620,7 +620,7 @@ var _ = Describe("SentPacketHandler", func() {
It("reset the PTO count when receiving an ACK", func() { It("reset the PTO count when receiving an ACK", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake) handler.ReceivedPacket(protocol.EncryptionHandshake)
now := time.Now() now := time.Now()
handler.SetHandshakeComplete() handler.SetHandshakeConfirmed()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)}))
Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second))
@ -659,7 +659,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.ptoCount).To(BeEquivalentTo(1)) Expect(handler.ptoCount).To(BeEquivalentTo(1))
Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) Expect(handler.SendMode()).To(Equal(SendPTOHandshake))
Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeHandshake.Add(handler.rttStats.PTO(false) << 1))) Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeHandshake.Add(handler.rttStats.PTO(false) << 1)))
handler.SetHandshakeComplete() handler.SetHandshakeConfirmed()
handler.DropPackets(protocol.EncryptionHandshake) handler.DropPackets(protocol.EncryptionHandshake)
// PTO timer based on the 1-RTT packet // PTO timer based on the 1-RTT packet
Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeAppData.Add(handler.rttStats.PTO(true)))) // no backoff. PTO count = 0 Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeAppData.Add(handler.rttStats.PTO(true)))) // no backoff. PTO count = 0
@ -669,7 +669,7 @@ var _ = Describe("SentPacketHandler", func() {
It("allows two 1-RTT PTOs", func() { It("allows two 1-RTT PTOs", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake) handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete() handler.SetHandshakeConfirmed()
var lostPackets []protocol.PacketNumber var lostPackets []protocol.PacketNumber
handler.SentPacket(ackElicitingPacket(&Packet{ handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: 1, PacketNumber: 1,
@ -688,7 +688,7 @@ var _ = Describe("SentPacketHandler", func() {
It("only counts ack-eliciting packets as probe packets", func() { It("only counts ack-eliciting packets as probe packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake) handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete() handler.SetHandshakeConfirmed()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode()).To(Equal(SendPTOAppData)) Expect(handler.SendMode()).To(Equal(SendPTOAppData))
@ -704,7 +704,7 @@ var _ = Describe("SentPacketHandler", func() {
It("gets two probe packets if PTO expires", func() { It("gets two probe packets if PTO expires", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake) handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete() handler.SetHandshakeConfirmed()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2}))
@ -752,7 +752,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP
Expect(handler.GetLossDetectionTimeout()).To(BeZero()) Expect(handler.GetLossDetectionTimeout()).To(BeZero())
Expect(handler.SendMode()).To(Equal(SendAny)) Expect(handler.SendMode()).To(Equal(SendAny))
handler.SetHandshakeComplete() handler.SetHandshakeConfirmed()
Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero())
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode()).To(Equal(SendPTOAppData)) Expect(handler.SendMode()).To(Equal(SendPTOAppData))
@ -760,7 +760,7 @@ var _ = Describe("SentPacketHandler", func() {
It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake) handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete() handler.SetHandshakeConfirmed()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) handler.rttStats.UpdateRTT(time.Second, 0, time.Now())
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.OnLossDetectionTimeout()).To(Succeed())
@ -902,7 +902,7 @@ var _ = Describe("SentPacketHandler", func() {
It("sets the early retransmit alarm", func() { It("sets the early retransmit alarm", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake) handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.handshakeComplete = true handler.handshakeConfirmed = true
now := time.Now() now := time.Now()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-2 * time.Second)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-2 * time.Second)}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-2 * time.Second)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: now.Add(-2 * time.Second)}))

View file

@ -229,16 +229,16 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0)
} }
// SetHandshakeComplete mocks base method // SetHandshakeConfirmed mocks base method
func (m *MockSentPacketHandler) SetHandshakeComplete() { func (m *MockSentPacketHandler) SetHandshakeConfirmed() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "SetHandshakeComplete") m.ctrl.Call(m, "SetHandshakeConfirmed")
} }
// SetHandshakeComplete indicates an expected call of SetHandshakeComplete // SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed
func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call { func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed))
} }
// TimeUntilSend mocks base method // TimeUntilSend mocks base method

View file

@ -683,7 +683,6 @@ func (s *session) handleHandshakeComplete() {
s.handshakeCtxCancel() s.handshakeCtxCancel()
s.connIDGenerator.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete()
s.sentPacketHandler.SetHandshakeComplete()
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
ticket, err := s.cryptoStreamHandler.GetSessionTicket() ticket, err := s.cryptoStreamHandler.GetSessionTicket()
@ -1331,6 +1330,7 @@ func (s *session) handleCloseError(closeErr closeError) {
func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
if encLevel == protocol.EncryptionHandshake { if encLevel == protocol.EncryptionHandshake {
s.handshakeConfirmed = true s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
} }
s.sentPacketHandler.DropPackets(encLevel) s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel)

View file

@ -1597,13 +1597,11 @@ var _ = Describe("Session", func() {
Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
It("cancels the HandshakeComplete context and informs the SentPacketHandler when the handshake completes", func() { It("cancels the HandshakeComplete context when the handshake completes", func() {
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes()
finishHandshake := make(chan struct{}) finishHandshake := make(chan struct{})
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph sess.sentPacketHandler = sph
sphNotified := make(chan struct{})
sph.EXPECT().SetHandshakeComplete().Do(func() { close(sphNotified) })
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode().AnyTimes() sph.EXPECT().SendMode().AnyTimes()
@ -1621,7 +1619,6 @@ var _ = Describe("Session", func() {
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
close(finishHandshake) close(finishHandshake)
Eventually(handshakeCtx.Done()).Should(BeClosed()) Eventually(handshakeCtx.Done()).Should(BeClosed())
Eventually(sphNotified).Should(BeClosed())
// make sure the go routine returns // make sure the go routine returns
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed() expectReplaceWithClosed()
@ -1704,7 +1701,6 @@ var _ = Describe("Session", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount) sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount)
sph.EXPECT().SetHandshakeComplete()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().HasPacingBudget().Return(true) sph.EXPECT().HasPacingBudget().Return(true)