From 6c0142cb4c338a99e333307a25a5a051bc62b62b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 29 Mar 2021 08:51:23 +0700 Subject: [PATCH 1/4] notify the sent packet handler about all received bytes --- session.go | 3 ++- session_test.go | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/session.go b/session.go index 7c68f65c..18c7976b 100644 --- a/session.go +++ b/session.go @@ -608,7 +608,6 @@ runLoop: // nothing to see here. case <-sendQueueAvailable: case firstPacket := <-s.receivedPackets: - s.sentPacketHandler.ReceivedBytes(firstPacket.Size()) wasProcessed := s.handlePacketImpl(firstPacket) // Don't set timers and send packets if the packet made us close the session. select { @@ -830,6 +829,8 @@ func (s *session) handleHandshakeComplete() { } func (s *session) handlePacketImpl(rp *receivedPacket) bool { + s.sentPacketHandler.ReceivedBytes(rp.Size()) + if wire.IsVersionNegotiationPacket(rp.data) { s.handleVersionNegotiationPacket(rp) return false diff --git a/session_test.go b/session_test.go index 31681ca7..082ce907 100644 --- a/session_test.go +++ b/session_test.go @@ -2592,6 +2592,7 @@ var _ = Describe("Client Session", func() { It("closes and returns the right error", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph + sph.EXPECT().ReceivedBytes(gomock.Any()) sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4) sess.config.Versions = []protocol.VersionNumber{1234, 4321} errChan := make(chan error, 1) @@ -2690,6 +2691,7 @@ var _ = Describe("Client Session", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph sph.EXPECT().ResetForRetry() + sph.EXPECT().ReceivedBytes(gomock.Any()) cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) packer.EXPECT().SetToken([]byte("foobar")) tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { @@ -2976,6 +2978,7 @@ var _ = Describe("Client Session", func() { It("ignores Initial packets which use original source id, after accepting a Retry", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph + sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) sph.EXPECT().ResetForRetry() newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) From 3fab321ea79da350fc07886d6ea125b5f116749d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 29 Mar 2021 09:32:00 +0700 Subject: [PATCH 2/4] cancel the loss detection timer when amplification limited --- internal/ackhandler/sent_packet_handler.go | 10 +++ .../ackhandler/sent_packet_handler_test.go | 61 +++++++++++-------- 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 91ff05a3..8ceb4691 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -490,6 +490,16 @@ func (h *sentPacketHandler) setLossDetectionTimer() { return } + // Cancel the alarm if amplification limited. + if h.isAmplificationLimited() { + h.alarm = time.Time{} + h.logger.Debugf("Canceling loss detection timer. Amplification limited.") + if h.tracer != nil && !oldAlarm.IsZero() { + h.tracer.LossTimerCanceled() + } + return + } + // Cancel the alarm if no packets are outstanding if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation { h.alarm = time.Time{} diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 380be3c8..6251d988 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -540,29 +540,6 @@ var _ = Describe("SentPacketHandler", func() { handler.SendMode() }) - It("limits the window to 3x the bytes received, to avoid amplification attacks", func() { - handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address - handler.ReceivedBytes(200) - cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).Times(2) - handler.SentPacket(&Packet{ - PacketNumber: 1, - Length: 599, - EncryptionLevel: protocol.EncryptionInitial, - Frames: []Frame{{Frame: &wire.PingFrame{}}}, - SendTime: time.Now(), - }) - cong.EXPECT().CanSend(protocol.ByteCount(599)).Return(true) - Expect(handler.SendMode()).To(Equal(SendAny)) - handler.SentPacket(&Packet{ - PacketNumber: 2, - Length: 1, - EncryptionLevel: protocol.EncryptionInitial, - Frames: []Frame{{Frame: &wire.PingFrame{}}}, - SendTime: time.Now(), - }) - Expect(handler.SendMode()).To(Equal(SendNone)) - }) - It("allows sending of ACKs when congestion limited", func() { handler.ReceivedPacket(protocol.EncryptionHandshake) cong.EXPECT().CanSend(gomock.Any()).Return(true) @@ -635,6 +612,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("implements exponential backoff", func() { + handler.peerAddressValidated = true handler.SetHandshakeConfirmed() sendTime := time.Now().Add(-time.Hour) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) @@ -828,7 +806,42 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("amplification limit", func() { + Context("amplification limit, for the server", func() { + It("limits the window to 3x the bytes received, to avoid amplification attacks", func() { + handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address + handler.ReceivedBytes(200) + handler.SentPacket(&Packet{ + PacketNumber: 1, + Length: 599, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + Expect(handler.SendMode()).To(Equal(SendAny)) + handler.SentPacket(&Packet{ + PacketNumber: 2, + Length: 1, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + Expect(handler.SendMode()).To(Equal(SendNone)) + }) + + It("cancels the loss detection timer when it is amplification limited", func() { + handler.ReceivedBytes(300) + handler.SentPacket(&Packet{ + PacketNumber: 1, + Length: 900, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + }) + }) + + Context("amplification limit, for the client", func() { BeforeEach(func() { perspective = protocol.PerspectiveClient }) From a695bae01968def15c423ab6120c6d1df910dbd2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 29 Mar 2021 09:39:58 +0700 Subject: [PATCH 3/4] restart the loss detection timer when the server becomes unblocked --- internal/ackhandler/sent_packet_handler.go | 4 ++++ internal/ackhandler/sent_packet_handler_test.go | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 8ceb4691..2c31fd89 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -194,7 +194,11 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { } func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) { + wasAmplificationLimit := h.isAmplificationLimited() h.bytesReceived += n + if wasAmplificationLimit && !h.isAmplificationLimited() { + h.setLossDetectionTimer() + } } func (h *sentPacketHandler) ReceivedPacket(encLevel protocol.EncryptionLevel) { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 6251d988..afc51194 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -828,7 +828,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode()).To(Equal(SendNone)) }) - It("cancels the loss detection timer when it is amplification limited", func() { + It("cancels the loss detection timer when it is amplification limited, and resets it when becoming unblocked", func() { handler.ReceivedBytes(300) handler.SentPacket(&Packet{ PacketNumber: 1, @@ -837,7 +837,11 @@ var _ = Describe("SentPacketHandler", func() { Frames: []Frame{{Frame: &wire.PingFrame{}}}, SendTime: time.Now(), }) + // Amplification limited. We don't need to set a timer now. Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + // Unblock the server. Now we should fire up the timer. + handler.ReceivedBytes(1) + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) }) }) From b6634fe1242c38169062224cc10e3be489ecf7ba Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 29 Mar 2021 14:53:35 +0700 Subject: [PATCH 4/4] reset the loss detection timer when the client's address is validated --- internal/ackhandler/sent_packet_handler.go | 24 ++++++++++++------- .../ackhandler/sent_packet_handler_test.go | 15 ++++++++++++ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 2c31fd89..5f25659e 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -201,9 +201,10 @@ func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) { } } -func (h *sentPacketHandler) ReceivedPacket(encLevel protocol.EncryptionLevel) { - if h.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionHandshake { +func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) { + if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated { h.peerAddressValidated = true + h.setLossDetectionTimer() } } @@ -485,7 +486,8 @@ func (h *sentPacketHandler) hasOutstandingPackets() bool { func (h *sentPacketHandler) setLossDetectionTimer() { oldAlarm := h.alarm // only needed in case tracing is enabled - if lossTime, encLevel := h.getLossTimeAndSpace(); !lossTime.IsZero() { + lossTime, encLevel := h.getLossTimeAndSpace() + if !lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = lossTime if h.tracer != nil && h.alarm != oldAlarm { @@ -497,9 +499,11 @@ func (h *sentPacketHandler) setLossDetectionTimer() { // Cancel the alarm if amplification limited. if h.isAmplificationLimited() { h.alarm = time.Time{} - h.logger.Debugf("Canceling loss detection timer. Amplification limited.") - if h.tracer != nil && !oldAlarm.IsZero() { - h.tracer.LossTimerCanceled() + if !oldAlarm.IsZero() { + h.logger.Debugf("Canceling loss detection timer. Amplification limited.") + if h.tracer != nil { + h.tracer.LossTimerCanceled() + } } return } @@ -507,9 +511,11 @@ func (h *sentPacketHandler) setLossDetectionTimer() { // Cancel the alarm if no packets are outstanding if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation { h.alarm = time.Time{} - h.logger.Debugf("Canceling loss detection timer. No packets in flight.") - if h.tracer != nil && !oldAlarm.IsZero() { - h.tracer.LossTimerCanceled() + if !oldAlarm.IsZero() { + h.logger.Debugf("Canceling loss detection timer. No packets in flight.") + if h.tracer != nil { + h.tracer.LossTimerCanceled() + } } return } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index afc51194..5fef0dc1 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -843,6 +843,21 @@ var _ = Describe("SentPacketHandler", func() { handler.ReceivedBytes(1) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) }) + + It("resets the loss detection timer when the client's address is validated", func() { + handler.ReceivedBytes(300) + handler.SentPacket(&Packet{ + PacketNumber: 1, + Length: 900, + EncryptionLevel: protocol.EncryptionHandshake, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + // Amplification limited. We don't need to set a timer now. + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + handler.ReceivedPacket(protocol.EncryptionHandshake) + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + }) }) Context("amplification limit, for the client", func() {