From b0ab718c7a96b21325168fe727da6253fd5cebcd Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 5 Jan 2018 14:22:42 +0700 Subject: [PATCH] delete non-forward-secure retransmissions when the handshake completes --- ackhandler/sent_packet_handler.go | 7 +++++ ackhandler/sent_packet_handler_test.go | 42 ++++++++++++++++++-------- session.go | 4 --- session_test.go | 16 ---------- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 102e33cd..ade3ebaf 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -99,6 +99,13 @@ func (h *sentPacketHandler) ShouldSendRetransmittablePacket() bool { } func (h *sentPacketHandler) SetHandshakeComplete() { + var queue []*Packet + for _, packet := range h.retransmissionQueue { + if packet.EncryptionLevel == protocol.EncryptionForwardSecure { + queue = append(queue, packet) + } + } + h.retransmissionQueue = queue h.handshakeComplete = true } diff --git a/ackhandler/sent_packet_handler_test.go b/ackhandler/sent_packet_handler_test.go index e663a2b9..d4fa3f11 100644 --- a/ackhandler/sent_packet_handler_test.go +++ b/ackhandler/sent_packet_handler_test.go @@ -576,13 +576,13 @@ var _ = Describe("SentPacketHandler", func() { BeforeEach(func() { packets = []*Packet{ - {PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1}, - {PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1}, + {PacketNumber: 1, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted}, + {PacketNumber: 2, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted}, + {PacketNumber: 3, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionUnencrypted}, + {PacketNumber: 4, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure}, + {PacketNumber: 5, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionSecure}, + {PacketNumber: 6, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure}, + {PacketNumber: 7, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.EncryptionForwardSecure}, } for _, packet := range packets { handler.SentPacket(packet) @@ -590,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() { // Increase RTT, because the tests would be flaky otherwise handler.rttStats.UpdateRTT(time.Minute, 0, time.Now()) // Ack a single packet so that we have non-RTO timings - handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionUnencrypted, time.Now()) + handler.ReceivedAck(&wire.AckFrame{LargestAcked: 2, LowestAcked: 2}, 1, protocol.EncryptionForwardSecure, time.Now()) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) }) @@ -610,15 +610,33 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) }) - Context("StopWaitings", func() { - It("gets a StopWaitingFrame", func() { + It("deletes non forward-secure packets when the handshake completes", func() { + for i := protocol.PacketNumber(1); i <= 7; i++ { + if i == 2 { // packet 2 was already acked in BeforeEach + continue + } + handler.queuePacketForRetransmission(getPacketElement(i)) + } + Expect(handler.retransmissionQueue).To(HaveLen(6)) + handler.SetHandshakeComplete() + packet := handler.DequeuePacketForRetransmission() + Expect(packet).ToNot(BeNil()) + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(6))) + packet = handler.DequeuePacketForRetransmission() + Expect(packet).ToNot(BeNil()) + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(7))) + Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) + }) + + Context("STOP_WAITINGs", func() { + It("gets a STOP_WAITING frame", func() { ack := wire.AckFrame{LargestAcked: 5, LowestAcked: 5} - err := handler.ReceivedAck(&ack, 2, protocol.EncryptionUnencrypted, time.Now()) + err := handler.ReceivedAck(&ack, 2, protocol.EncryptionForwardSecure, time.Now()) Expect(err).ToNot(HaveOccurred()) Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6})) }) - It("gets a StopWaitingFrame after queueing a retransmission", func() { + It("gets a STOP_WAITING frame after queueing a retransmission", func() { handler.queuePacketForRetransmission(getPacketElement(5)) Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6})) }) diff --git a/session.go b/session.go index 3ec8c9ec..dc32f057 100644 --- a/session.go +++ b/session.go @@ -758,10 +758,6 @@ func (s *session) sendPacket() (bool, error) { // retransmit handshake packets if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { - if s.handshakeComplete { - // don't retransmit handshake packets when the handshake is complete - continue - } utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) if !s.version.UsesIETFFrameFormat() { s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) diff --git a/session_test.go b/session_test.go index ca909a8d..71b61d08 100644 --- a/session_test.go +++ b/session_test.go @@ -1000,22 +1000,6 @@ var _ = Describe("Session", func() { Expect(sent).To(BeTrue()) Expect(mconn.written).To(HaveLen(1)) }) - - It("doesn't retransmit handshake packets when the handshake is complete", func() { - sess.handshakeComplete = true - sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - sph.EXPECT().DequeuePacketForRetransmission().Return( - &ackhandler.Packet{ - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionSecure, - }) - sph.EXPECT().DequeuePacketForRetransmission() - sph.EXPECT().ShouldSendRetransmittablePacket() - sent, err := sess.sendPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeFalse()) - Expect(mconn.written).To(BeEmpty()) - }) }) Context("for packets after the handshake", func() {