From 582edae63d992b30e9f97646c92bc09cd2f2d898 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 24 Dec 2022 12:28:27 +1300 Subject: [PATCH] refactor retransmissionQueue to remove version param from constructor --- connection.go | 2 +- connection_test.go | 4 +-- packet_packer.go | 6 ++-- packet_packer_test.go | 4 +-- retransmission_queue.go | 22 +++++++-------- retransmission_queue_test.go | 54 +++++++++++++++++------------------- 6 files changed, 44 insertions(+), 48 deletions(-) diff --git a/connection.go b/connection.go index 679990d5..df72bf3e 100644 --- a/connection.go +++ b/connection.go @@ -507,7 +507,7 @@ var newClientConnection = func( func (s *connection) preSetup() { s.sendQueue = newSendQueue(s.conn) - s.retransmissionQueue = newRetransmissionQueue(s.version) + s.retransmissionQueue = newRetransmissionQueue() s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version) s.rttStats = &utils.RTTStats{} s.connFlowController = flowcontrol.NewConnectionFlowController( diff --git a/connection_test.go b/connection_test.go index d2e65ca4..8289fcdc 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1289,7 +1289,7 @@ var _ = Describe("Connection", func() { Context(fmt.Sprintf("sending %s probe packets", encLevel), func() { var sendMode ackhandler.SendMode - var getFrame func(protocol.ByteCount) wire.Frame + var getFrame func(protocol.ByteCount, protocol.VersionNumber) wire.Frame BeforeEach(func() { //nolint:exhaustive @@ -1356,7 +1356,7 @@ var _ = Describe("Connection", func() { Eventually(sent).Should(BeClosed()) // We're using a mock packet packer in this test. // We therefore need to test separately that the PING was actually queued. - Expect(getFrame(1000)).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(getFrame(1000, protocol.Version1)).To(BeAssignableToTypeOf(&wire.PingFrame{})) }) }) } diff --git a/packet_packer.go b/packet_packer.go index 0a11bac4..50d573cc 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -569,9 +569,9 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s switch encLevel { case protocol.EncryptionInitial: - f = p.retransmissionQueue.GetInitialFrame(maxPacketSize) + f = p.retransmissionQueue.GetInitialFrame(maxPacketSize, p.version) case protocol.EncryptionHandshake: - f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize) + f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, p.version) } if f == nil { break @@ -683,7 +683,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc if remainingLen < protocol.MinStreamFrameSize { break } - f := p.retransmissionQueue.GetAppDataFrame(remainingLen) + f := p.retransmissionQueue.GetAppDataFrame(remainingLen, p.version) if f == nil { break } diff --git a/packet_packer_test.go b/packet_packer_test.go index 7f07a0c7..539bfe69 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -85,7 +85,7 @@ var _ = Describe("Packet packer", func() { BeforeEach(func() { rand.Seed(GinkgoRandomSeed()) - retransmissionQueue = newRetransmissionQueue(version) + retransmissionQueue = newRetransmissionQueue() mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() initialStream = NewMockCryptoStream(mockCtrl) @@ -1593,7 +1593,7 @@ var _ = Describe("Converting to ackhandler.Packet", func() { {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { pingLost = true }}, }, } - p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue(protocol.VersionTLS)) + p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue()) Expect(p.Frames).To(HaveLen(2)) Expect(p.Frames[0].OnLost).ToNot(BeNil()) p.Frames[1].OnLost(nil) diff --git a/retransmission_queue.go b/retransmission_queue.go index 0cfbbc4d..b2fe84c6 100644 --- a/retransmission_queue.go +++ b/retransmission_queue.go @@ -15,12 +15,10 @@ type retransmissionQueue struct { handshakeCryptoData []*wire.CryptoFrame appData []wire.Frame - - version protocol.VersionNumber } -func newRetransmissionQueue(ver protocol.VersionNumber) *retransmissionQueue { - return &retransmissionQueue{version: ver} +func newRetransmissionQueue() *retransmissionQueue { + return &retransmissionQueue{} } func (q *retransmissionQueue) AddInitial(f wire.Frame) { @@ -58,10 +56,10 @@ func (q *retransmissionQueue) AddAppData(f wire.Frame) { q.appData = append(q.appData, f) } -func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Frame { +func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { if len(q.initialCryptoData) > 0 { f := q.initialCryptoData[0] - newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) + newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) if newFrame == nil && !needsSplit { // the whole frame fits q.initialCryptoData = q.initialCryptoData[1:] return f @@ -74,17 +72,17 @@ func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Fr return nil } f := q.initial[0] - if f.Length(q.version) > maxLen { + if f.Length(v) > maxLen { return nil } q.initial = q.initial[1:] return f } -func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire.Frame { +func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { if len(q.handshakeCryptoData) > 0 { f := q.handshakeCryptoData[0] - newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) + newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) if newFrame == nil && !needsSplit { // the whole frame fits q.handshakeCryptoData = q.handshakeCryptoData[1:] return f @@ -97,19 +95,19 @@ func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire. return nil } f := q.handshake[0] - if f.Length(q.version) > maxLen { + if f.Length(v) > maxLen { return nil } q.handshake = q.handshake[1:] return f } -func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount) wire.Frame { +func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { if len(q.appData) == 0 { return nil } f := q.appData[0] - if f.Length(q.version) > maxLen { + if f.Length(v) > maxLen { return nil } q.appData = q.appData[1:] diff --git a/retransmission_queue_test.go b/retransmission_queue_test.go index aa909034..5583cca1 100644 --- a/retransmission_queue_test.go +++ b/retransmission_queue_test.go @@ -9,26 +9,24 @@ import ( ) var _ = Describe("Retransmission queue", func() { - const version = protocol.VersionTLS - var q *retransmissionQueue BeforeEach(func() { - q = newRetransmissionQueue(version) + q = newRetransmissionQueue() }) Context("Initial data", func() { It("doesn't dequeue anything when it's empty", func() { Expect(q.HasInitialData()).To(BeFalse()) - Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) + Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil()) }) It("queues and retrieves a control frame", func() { f := &wire.MaxDataFrame{MaximumData: 0x42} q.AddInitial(f) Expect(q.HasInitialData()).To(BeTrue()) - Expect(q.GetInitialFrame(f.Length(version) - 1)).To(BeNil()) - Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) + Expect(q.GetInitialFrame(f.Length(protocol.Version1)-1, protocol.Version1)).To(BeNil()) + Expect(q.GetInitialFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f)) Expect(q.HasInitialData()).To(BeFalse()) }) @@ -36,7 +34,7 @@ var _ = Describe("Retransmission queue", func() { f := &wire.CryptoFrame{Data: []byte("foobar")} q.AddInitial(f) Expect(q.HasInitialData()).To(BeTrue()) - Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) + Expect(q.GetInitialFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f)) Expect(q.HasInitialData()).To(BeFalse()) }) @@ -47,13 +45,13 @@ var _ = Describe("Retransmission queue", func() { } q.AddInitial(f) Expect(q.HasInitialData()).To(BeTrue()) - f1 := q.GetInitialFrame(f.Length(version) - 3) + f1 := q.GetInitialFrame(f.Length(protocol.Version1)-3, protocol.Version1) Expect(f1).ToNot(BeNil()) Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo"))) Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100))) Expect(q.HasInitialData()).To(BeTrue()) - f2 := q.GetInitialFrame(protocol.MaxByteCount) + f2 := q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1) Expect(f2).ToNot(BeNil()) Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar"))) @@ -65,11 +63,11 @@ var _ = Describe("Retransmission queue", func() { f := &wire.CryptoFrame{Data: []byte("foobar")} q.AddInitial(f) q.AddInitial(&wire.PingFrame{}) - f1 := q.GetInitialFrame(2) // too small for a CRYPTO frame + f1 := q.GetInitialFrame(2, protocol.Version1) // too small for a CRYPTO frame Expect(f1).ToNot(BeNil()) Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{})) Expect(q.HasInitialData()).To(BeTrue()) - f2 := q.GetInitialFrame(protocol.MaxByteCount) + f2 := q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1) Expect(f2).To(Equal(f)) }) @@ -79,8 +77,8 @@ var _ = Describe("Retransmission queue", func() { q.AddInitial(f) q.AddInitial(cf) Expect(q.HasInitialData()).To(BeTrue()) - Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(f)) - Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(cf)) + Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(f)) + Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(cf)) Expect(q.HasInitialData()).To(BeFalse()) }) @@ -89,22 +87,22 @@ var _ = Describe("Retransmission queue", func() { q.AddInitial(&wire.MaxDataFrame{MaximumData: 0x42}) q.DropPackets(protocol.EncryptionInitial) Expect(q.HasInitialData()).To(BeFalse()) - Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) + Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil()) }) }) Context("Handshake data", func() { It("doesn't dequeue anything when it's empty", func() { Expect(q.HasHandshakeData()).To(BeFalse()) - Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil()) }) It("queues and retrieves a control frame", func() { f := &wire.MaxDataFrame{MaximumData: 0x42} q.AddHandshake(f) Expect(q.HasHandshakeData()).To(BeTrue()) - Expect(q.GetHandshakeFrame(f.Length(version) - 1)).To(BeNil()) - Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) + Expect(q.GetHandshakeFrame(f.Length(protocol.Version1)-1, protocol.Version1)).To(BeNil()) + Expect(q.GetHandshakeFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f)) Expect(q.HasHandshakeData()).To(BeFalse()) }) @@ -112,7 +110,7 @@ var _ = Describe("Retransmission queue", func() { f := &wire.CryptoFrame{Data: []byte("foobar")} q.AddHandshake(f) Expect(q.HasHandshakeData()).To(BeTrue()) - Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) + Expect(q.GetHandshakeFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f)) Expect(q.HasHandshakeData()).To(BeFalse()) }) @@ -123,13 +121,13 @@ var _ = Describe("Retransmission queue", func() { } q.AddHandshake(f) Expect(q.HasHandshakeData()).To(BeTrue()) - f1 := q.GetHandshakeFrame(f.Length(version) - 3) + f1 := q.GetHandshakeFrame(f.Length(protocol.Version1)-3, protocol.Version1) Expect(f1).ToNot(BeNil()) Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo"))) Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100))) Expect(q.HasHandshakeData()).To(BeTrue()) - f2 := q.GetHandshakeFrame(protocol.MaxByteCount) + f2 := q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1) Expect(f2).ToNot(BeNil()) Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar"))) @@ -141,11 +139,11 @@ var _ = Describe("Retransmission queue", func() { f := &wire.CryptoFrame{Data: []byte("foobar")} q.AddHandshake(f) q.AddHandshake(&wire.PingFrame{}) - f1 := q.GetHandshakeFrame(2) // too small for a CRYPTO frame + f1 := q.GetHandshakeFrame(2, protocol.Version1) // too small for a CRYPTO frame Expect(f1).ToNot(BeNil()) Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{})) Expect(q.HasHandshakeData()).To(BeTrue()) - f2 := q.GetHandshakeFrame(protocol.MaxByteCount) + f2 := q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1) Expect(f2).To(Equal(f)) }) @@ -155,8 +153,8 @@ var _ = Describe("Retransmission queue", func() { q.AddHandshake(f) q.AddHandshake(cf) Expect(q.HasHandshakeData()).To(BeTrue()) - Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(f)) - Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(cf)) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(f)) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(cf)) Expect(q.HasHandshakeData()).To(BeFalse()) }) @@ -165,13 +163,13 @@ var _ = Describe("Retransmission queue", func() { q.AddHandshake(&wire.MaxDataFrame{MaximumData: 0x42}) q.DropPackets(protocol.EncryptionHandshake) Expect(q.HasHandshakeData()).To(BeFalse()) - Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil()) }) }) Context("Application data", func() { It("doesn't dequeue anything when it's empty", func() { - Expect(q.GetAppDataFrame(protocol.MaxByteCount)).To(BeNil()) + Expect(q.GetAppDataFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil()) }) It("queues and retrieves a control frame", func() { @@ -179,8 +177,8 @@ var _ = Describe("Retransmission queue", func() { Expect(q.HasAppData()).To(BeFalse()) q.AddAppData(f) Expect(q.HasAppData()).To(BeTrue()) - Expect(q.GetAppDataFrame(f.Length(version) - 1)).To(BeNil()) - Expect(q.GetAppDataFrame(f.Length(version))).To(Equal(f)) + Expect(q.GetAppDataFrame(f.Length(protocol.Version1)-1, protocol.Version1)).To(BeNil()) + Expect(q.GetAppDataFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f)) Expect(q.HasAppData()).To(BeFalse()) }) })