Merge pull request #3132 from lucas-clemente/amplification-limit-fixes

various amplification limit fixes
This commit is contained in:
Marten Seemann 2021-04-02 18:32:38 +07:00 committed by GitHub
commit 4b10e67bf1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 31 deletions

View file

@ -196,12 +196,17 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
}
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) {
wasAmplificationLimit := h.isAmplificationLimited()
h.bytesReceived += n
if wasAmplificationLimit && !h.isAmplificationLimited() {
h.setLossDetectionTimer()
}
}
func (h *sentPacketHandler) ReceivedPacket(encLevel protocol.EncryptionLevel) {
if h.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionHandshake {
func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) {
if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated {
h.peerAddressValidated = true
h.setLossDetectionTimer()
}
}
@ -485,7 +490,8 @@ func (h *sentPacketHandler) hasOutstandingPackets() bool {
func (h *sentPacketHandler) setLossDetectionTimer() {
oldAlarm := h.alarm // only needed in case tracing is enabled
if lossTime, encLevel := h.getLossTimeAndSpace(); !lossTime.IsZero() {
lossTime, encLevel := h.getLossTimeAndSpace()
if !lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = lossTime
if h.tracer != nil && h.alarm != oldAlarm {
@ -494,12 +500,26 @@ func (h *sentPacketHandler) setLossDetectionTimer() {
return
}
// Cancel the alarm if amplification limited.
if h.isAmplificationLimited() {
h.alarm = time.Time{}
if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. Amplification limited.")
if h.tracer != nil {
h.tracer.LossTimerCanceled()
}
}
return
}
// Cancel the alarm if no packets are outstanding
if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation {
h.alarm = time.Time{}
h.logger.Debugf("Canceling loss detection timer. No packets in flight.")
if h.tracer != nil && !oldAlarm.IsZero() {
h.tracer.LossTimerCanceled()
if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. No packets in flight.")
if h.tracer != nil {
h.tracer.LossTimerCanceled()
}
}
return
}

View file

@ -540,29 +540,6 @@ var _ = Describe("SentPacketHandler", func() {
handler.SendMode()
})
It("limits the window to 3x the bytes received, to avoid amplification attacks", func() {
handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address
handler.ReceivedBytes(200)
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).Times(2)
handler.SentPacket(&Packet{
PacketNumber: 1,
Length: 599,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
cong.EXPECT().CanSend(protocol.ByteCount(599)).Return(true)
Expect(handler.SendMode()).To(Equal(SendAny))
handler.SentPacket(&Packet{
PacketNumber: 2,
Length: 1,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
Expect(handler.SendMode()).To(Equal(SendNone))
})
It("allows sending of ACKs when congestion limited", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
cong.EXPECT().CanSend(gomock.Any()).Return(true)
@ -635,6 +612,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("implements exponential backoff", func() {
handler.peerAddressValidated = true
handler.SetHandshakeConfirmed()
sendTime := time.Now().Add(-time.Hour)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime}))
@ -828,7 +806,61 @@ var _ = Describe("SentPacketHandler", func() {
})
})
Context("amplification limit", func() {
Context("amplification limit, for the server", func() {
It("limits the window to 3x the bytes received, to avoid amplification attacks", func() {
handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address
handler.ReceivedBytes(200)
handler.SentPacket(&Packet{
PacketNumber: 1,
Length: 599,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
Expect(handler.SendMode()).To(Equal(SendAny))
handler.SentPacket(&Packet{
PacketNumber: 2,
Length: 1,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
Expect(handler.SendMode()).To(Equal(SendNone))
})
It("cancels the loss detection timer when it is amplification limited, and resets it when becoming unblocked", func() {
handler.ReceivedBytes(300)
handler.SentPacket(&Packet{
PacketNumber: 1,
Length: 900,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
// Amplification limited. We don't need to set a timer now.
Expect(handler.GetLossDetectionTimeout()).To(BeZero())
// Unblock the server. Now we should fire up the timer.
handler.ReceivedBytes(1)
Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero())
})
It("resets the loss detection timer when the client's address is validated", func() {
handler.ReceivedBytes(300)
handler.SentPacket(&Packet{
PacketNumber: 1,
Length: 900,
EncryptionLevel: protocol.EncryptionHandshake,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
// Amplification limited. We don't need to set a timer now.
Expect(handler.GetLossDetectionTimeout()).To(BeZero())
handler.ReceivedPacket(protocol.EncryptionHandshake)
Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero())
})
})
Context("amplification limit, for the client", func() {
BeforeEach(func() {
perspective = protocol.PerspectiveClient
})

View file

@ -610,7 +610,6 @@ runLoop:
// nothing to see here.
case <-sendQueueAvailable:
case firstPacket := <-s.receivedPackets:
s.sentPacketHandler.ReceivedBytes(firstPacket.Size())
wasProcessed := s.handlePacketImpl(firstPacket)
// Don't set timers and send packets if the packet made us close the session.
select {
@ -834,6 +833,8 @@ func (s *session) handleHandshakeConfirmed() {
}
func (s *session) handlePacketImpl(rp *receivedPacket) bool {
s.sentPacketHandler.ReceivedBytes(rp.Size())
if wire.IsVersionNegotiationPacket(rp.data) {
s.handleVersionNegotiationPacket(rp)
return false

View file

@ -2594,6 +2594,7 @@ var _ = Describe("Client Session", func() {
It("closes and returns the right error", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
sph.EXPECT().ReceivedBytes(gomock.Any())
sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4)
sess.config.Versions = []protocol.VersionNumber{1234, 4321}
errChan := make(chan error, 1)
@ -2692,6 +2693,7 @@ var _ = Describe("Client Session", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
sph.EXPECT().ResetForRetry()
sph.EXPECT().ReceivedBytes(gomock.Any())
cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})
packer.EXPECT().SetToken([]byte("foobar"))
tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) {
@ -2981,6 +2983,7 @@ var _ = Describe("Client Session", func() {
It("ignores Initial packets which use original source id, after accepting a Retry", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2)
sph.EXPECT().ResetForRetry()
newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID)