refactor retransmissionQueue to remove version param from constructor

This commit is contained in:
Marten Seemann 2022-12-24 12:28:27 +13:00
parent 5c7d120b8f
commit 582edae63d
6 changed files with 44 additions and 48 deletions

View file

@ -507,7 +507,7 @@ var newClientConnection = func(
func (s *connection) preSetup() { func (s *connection) preSetup() {
s.sendQueue = newSendQueue(s.conn) s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue(s.version) s.retransmissionQueue = newRetransmissionQueue()
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version) s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version)
s.rttStats = &utils.RTTStats{} s.rttStats = &utils.RTTStats{}
s.connFlowController = flowcontrol.NewConnectionFlowController( s.connFlowController = flowcontrol.NewConnectionFlowController(

View file

@ -1289,7 +1289,7 @@ var _ = Describe("Connection", func() {
Context(fmt.Sprintf("sending %s probe packets", encLevel), func() { Context(fmt.Sprintf("sending %s probe packets", encLevel), func() {
var sendMode ackhandler.SendMode var sendMode ackhandler.SendMode
var getFrame func(protocol.ByteCount) wire.Frame var getFrame func(protocol.ByteCount, protocol.VersionNumber) wire.Frame
BeforeEach(func() { BeforeEach(func() {
//nolint:exhaustive //nolint:exhaustive
@ -1356,7 +1356,7 @@ var _ = Describe("Connection", func() {
Eventually(sent).Should(BeClosed()) Eventually(sent).Should(BeClosed())
// We're using a mock packet packer in this test. // We're using a mock packet packer in this test.
// We therefore need to test separately that the PING was actually queued. // We therefore need to test separately that the PING was actually queued.
Expect(getFrame(1000)).To(BeAssignableToTypeOf(&wire.PingFrame{})) Expect(getFrame(1000, protocol.Version1)).To(BeAssignableToTypeOf(&wire.PingFrame{}))
}) })
}) })
} }

View file

@ -569,9 +569,9 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
//nolint:exhaustive // 0-RTT packets can't contain any retransmission.s //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
f = p.retransmissionQueue.GetInitialFrame(maxPacketSize) f = p.retransmissionQueue.GetInitialFrame(maxPacketSize, p.version)
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize) f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, p.version)
} }
if f == nil { if f == nil {
break break
@ -683,7 +683,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if remainingLen < protocol.MinStreamFrameSize { if remainingLen < protocol.MinStreamFrameSize {
break break
} }
f := p.retransmissionQueue.GetAppDataFrame(remainingLen) f := p.retransmissionQueue.GetAppDataFrame(remainingLen, p.version)
if f == nil { if f == nil {
break break
} }

View file

@ -85,7 +85,7 @@ var _ = Describe("Packet packer", func() {
BeforeEach(func() { BeforeEach(func() {
rand.Seed(GinkgoRandomSeed()) rand.Seed(GinkgoRandomSeed())
retransmissionQueue = newRetransmissionQueue(version) retransmissionQueue = newRetransmissionQueue()
mockSender := NewMockStreamSender(mockCtrl) mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
initialStream = NewMockCryptoStream(mockCtrl) initialStream = NewMockCryptoStream(mockCtrl)
@ -1593,7 +1593,7 @@ var _ = Describe("Converting to ackhandler.Packet", func() {
{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { pingLost = true }}, {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { pingLost = true }},
}, },
} }
p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue(protocol.VersionTLS)) p := packet.ToAckHandlerPacket(time.Now(), newRetransmissionQueue())
Expect(p.Frames).To(HaveLen(2)) Expect(p.Frames).To(HaveLen(2))
Expect(p.Frames[0].OnLost).ToNot(BeNil()) Expect(p.Frames[0].OnLost).ToNot(BeNil())
p.Frames[1].OnLost(nil) p.Frames[1].OnLost(nil)

View file

@ -15,12 +15,10 @@ type retransmissionQueue struct {
handshakeCryptoData []*wire.CryptoFrame handshakeCryptoData []*wire.CryptoFrame
appData []wire.Frame appData []wire.Frame
version protocol.VersionNumber
} }
func newRetransmissionQueue(ver protocol.VersionNumber) *retransmissionQueue { func newRetransmissionQueue() *retransmissionQueue {
return &retransmissionQueue{version: ver} return &retransmissionQueue{}
} }
func (q *retransmissionQueue) AddInitial(f wire.Frame) { func (q *retransmissionQueue) AddInitial(f wire.Frame) {
@ -58,10 +56,10 @@ func (q *retransmissionQueue) AddAppData(f wire.Frame) {
q.appData = append(q.appData, f) q.appData = append(q.appData, f)
} }
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Frame { func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
if len(q.initialCryptoData) > 0 { if len(q.initialCryptoData) > 0 {
f := q.initialCryptoData[0] f := q.initialCryptoData[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
if newFrame == nil && !needsSplit { // the whole frame fits if newFrame == nil && !needsSplit { // the whole frame fits
q.initialCryptoData = q.initialCryptoData[1:] q.initialCryptoData = q.initialCryptoData[1:]
return f return f
@ -74,17 +72,17 @@ func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Fr
return nil return nil
} }
f := q.initial[0] f := q.initial[0]
if f.Length(q.version) > maxLen { if f.Length(v) > maxLen {
return nil return nil
} }
q.initial = q.initial[1:] q.initial = q.initial[1:]
return f return f
} }
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire.Frame { func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
if len(q.handshakeCryptoData) > 0 { if len(q.handshakeCryptoData) > 0 {
f := q.handshakeCryptoData[0] f := q.handshakeCryptoData[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
if newFrame == nil && !needsSplit { // the whole frame fits if newFrame == nil && !needsSplit { // the whole frame fits
q.handshakeCryptoData = q.handshakeCryptoData[1:] q.handshakeCryptoData = q.handshakeCryptoData[1:]
return f return f
@ -97,19 +95,19 @@ func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire.
return nil return nil
} }
f := q.handshake[0] f := q.handshake[0]
if f.Length(q.version) > maxLen { if f.Length(v) > maxLen {
return nil return nil
} }
q.handshake = q.handshake[1:] q.handshake = q.handshake[1:]
return f return f
} }
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount) wire.Frame { func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame {
if len(q.appData) == 0 { if len(q.appData) == 0 {
return nil return nil
} }
f := q.appData[0] f := q.appData[0]
if f.Length(q.version) > maxLen { if f.Length(v) > maxLen {
return nil return nil
} }
q.appData = q.appData[1:] q.appData = q.appData[1:]

View file

@ -9,26 +9,24 @@ import (
) )
var _ = Describe("Retransmission queue", func() { var _ = Describe("Retransmission queue", func() {
const version = protocol.VersionTLS
var q *retransmissionQueue var q *retransmissionQueue
BeforeEach(func() { BeforeEach(func() {
q = newRetransmissionQueue(version) q = newRetransmissionQueue()
}) })
Context("Initial data", func() { Context("Initial data", func() {
It("doesn't dequeue anything when it's empty", func() { It("doesn't dequeue anything when it's empty", func() {
Expect(q.HasInitialData()).To(BeFalse()) Expect(q.HasInitialData()).To(BeFalse())
Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil())
}) })
It("queues and retrieves a control frame", func() { It("queues and retrieves a control frame", func() {
f := &wire.MaxDataFrame{MaximumData: 0x42} f := &wire.MaxDataFrame{MaximumData: 0x42}
q.AddInitial(f) q.AddInitial(f)
Expect(q.HasInitialData()).To(BeTrue()) Expect(q.HasInitialData()).To(BeTrue())
Expect(q.GetInitialFrame(f.Length(version) - 1)).To(BeNil()) Expect(q.GetInitialFrame(f.Length(protocol.Version1)-1, protocol.Version1)).To(BeNil())
Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) Expect(q.GetInitialFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f))
Expect(q.HasInitialData()).To(BeFalse()) Expect(q.HasInitialData()).To(BeFalse())
}) })
@ -36,7 +34,7 @@ var _ = Describe("Retransmission queue", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")} f := &wire.CryptoFrame{Data: []byte("foobar")}
q.AddInitial(f) q.AddInitial(f)
Expect(q.HasInitialData()).To(BeTrue()) Expect(q.HasInitialData()).To(BeTrue())
Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) Expect(q.GetInitialFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f))
Expect(q.HasInitialData()).To(BeFalse()) Expect(q.HasInitialData()).To(BeFalse())
}) })
@ -47,13 +45,13 @@ var _ = Describe("Retransmission queue", func() {
} }
q.AddInitial(f) q.AddInitial(f)
Expect(q.HasInitialData()).To(BeTrue()) Expect(q.HasInitialData()).To(BeTrue())
f1 := q.GetInitialFrame(f.Length(version) - 3) f1 := q.GetInitialFrame(f.Length(protocol.Version1)-3, protocol.Version1)
Expect(f1).ToNot(BeNil()) Expect(f1).ToNot(BeNil())
Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{}))
Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo"))) Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo")))
Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100))) Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100)))
Expect(q.HasInitialData()).To(BeTrue()) Expect(q.HasInitialData()).To(BeTrue())
f2 := q.GetInitialFrame(protocol.MaxByteCount) f2 := q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)
Expect(f2).ToNot(BeNil()) Expect(f2).ToNot(BeNil())
Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{}))
Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar"))) Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar")))
@ -65,11 +63,11 @@ var _ = Describe("Retransmission queue", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")} f := &wire.CryptoFrame{Data: []byte("foobar")}
q.AddInitial(f) q.AddInitial(f)
q.AddInitial(&wire.PingFrame{}) q.AddInitial(&wire.PingFrame{})
f1 := q.GetInitialFrame(2) // too small for a CRYPTO frame f1 := q.GetInitialFrame(2, protocol.Version1) // too small for a CRYPTO frame
Expect(f1).ToNot(BeNil()) Expect(f1).ToNot(BeNil())
Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{})) Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{}))
Expect(q.HasInitialData()).To(BeTrue()) Expect(q.HasInitialData()).To(BeTrue())
f2 := q.GetInitialFrame(protocol.MaxByteCount) f2 := q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)
Expect(f2).To(Equal(f)) Expect(f2).To(Equal(f))
}) })
@ -79,8 +77,8 @@ var _ = Describe("Retransmission queue", func() {
q.AddInitial(f) q.AddInitial(f)
q.AddInitial(cf) q.AddInitial(cf)
Expect(q.HasInitialData()).To(BeTrue()) Expect(q.HasInitialData()).To(BeTrue())
Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(f)) Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(f))
Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(cf)) Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(cf))
Expect(q.HasInitialData()).To(BeFalse()) Expect(q.HasInitialData()).To(BeFalse())
}) })
@ -89,22 +87,22 @@ var _ = Describe("Retransmission queue", func() {
q.AddInitial(&wire.MaxDataFrame{MaximumData: 0x42}) q.AddInitial(&wire.MaxDataFrame{MaximumData: 0x42})
q.DropPackets(protocol.EncryptionInitial) q.DropPackets(protocol.EncryptionInitial)
Expect(q.HasInitialData()).To(BeFalse()) Expect(q.HasInitialData()).To(BeFalse())
Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) Expect(q.GetInitialFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil())
}) })
}) })
Context("Handshake data", func() { Context("Handshake data", func() {
It("doesn't dequeue anything when it's empty", func() { It("doesn't dequeue anything when it's empty", func() {
Expect(q.HasHandshakeData()).To(BeFalse()) Expect(q.HasHandshakeData()).To(BeFalse())
Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil())
}) })
It("queues and retrieves a control frame", func() { It("queues and retrieves a control frame", func() {
f := &wire.MaxDataFrame{MaximumData: 0x42} f := &wire.MaxDataFrame{MaximumData: 0x42}
q.AddHandshake(f) q.AddHandshake(f)
Expect(q.HasHandshakeData()).To(BeTrue()) Expect(q.HasHandshakeData()).To(BeTrue())
Expect(q.GetHandshakeFrame(f.Length(version) - 1)).To(BeNil()) Expect(q.GetHandshakeFrame(f.Length(protocol.Version1)-1, protocol.Version1)).To(BeNil())
Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) Expect(q.GetHandshakeFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f))
Expect(q.HasHandshakeData()).To(BeFalse()) Expect(q.HasHandshakeData()).To(BeFalse())
}) })
@ -112,7 +110,7 @@ var _ = Describe("Retransmission queue", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")} f := &wire.CryptoFrame{Data: []byte("foobar")}
q.AddHandshake(f) q.AddHandshake(f)
Expect(q.HasHandshakeData()).To(BeTrue()) Expect(q.HasHandshakeData()).To(BeTrue())
Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) Expect(q.GetHandshakeFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f))
Expect(q.HasHandshakeData()).To(BeFalse()) Expect(q.HasHandshakeData()).To(BeFalse())
}) })
@ -123,13 +121,13 @@ var _ = Describe("Retransmission queue", func() {
} }
q.AddHandshake(f) q.AddHandshake(f)
Expect(q.HasHandshakeData()).To(BeTrue()) Expect(q.HasHandshakeData()).To(BeTrue())
f1 := q.GetHandshakeFrame(f.Length(version) - 3) f1 := q.GetHandshakeFrame(f.Length(protocol.Version1)-3, protocol.Version1)
Expect(f1).ToNot(BeNil()) Expect(f1).ToNot(BeNil())
Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) Expect(f1).To(BeAssignableToTypeOf(&wire.CryptoFrame{}))
Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo"))) Expect(f1.(*wire.CryptoFrame).Data).To(Equal([]byte("foo")))
Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100))) Expect(f1.(*wire.CryptoFrame).Offset).To(Equal(protocol.ByteCount(100)))
Expect(q.HasHandshakeData()).To(BeTrue()) Expect(q.HasHandshakeData()).To(BeTrue())
f2 := q.GetHandshakeFrame(protocol.MaxByteCount) f2 := q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)
Expect(f2).ToNot(BeNil()) Expect(f2).ToNot(BeNil())
Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) Expect(f2).To(BeAssignableToTypeOf(&wire.CryptoFrame{}))
Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar"))) Expect(f2.(*wire.CryptoFrame).Data).To(Equal([]byte("bar")))
@ -141,11 +139,11 @@ var _ = Describe("Retransmission queue", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")} f := &wire.CryptoFrame{Data: []byte("foobar")}
q.AddHandshake(f) q.AddHandshake(f)
q.AddHandshake(&wire.PingFrame{}) q.AddHandshake(&wire.PingFrame{})
f1 := q.GetHandshakeFrame(2) // too small for a CRYPTO frame f1 := q.GetHandshakeFrame(2, protocol.Version1) // too small for a CRYPTO frame
Expect(f1).ToNot(BeNil()) Expect(f1).ToNot(BeNil())
Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{})) Expect(f1).To(BeAssignableToTypeOf(&wire.PingFrame{}))
Expect(q.HasHandshakeData()).To(BeTrue()) Expect(q.HasHandshakeData()).To(BeTrue())
f2 := q.GetHandshakeFrame(protocol.MaxByteCount) f2 := q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)
Expect(f2).To(Equal(f)) Expect(f2).To(Equal(f))
}) })
@ -155,8 +153,8 @@ var _ = Describe("Retransmission queue", func() {
q.AddHandshake(f) q.AddHandshake(f)
q.AddHandshake(cf) q.AddHandshake(cf)
Expect(q.HasHandshakeData()).To(BeTrue()) Expect(q.HasHandshakeData()).To(BeTrue())
Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(f)) Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(f))
Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(cf)) Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(Equal(cf))
Expect(q.HasHandshakeData()).To(BeFalse()) Expect(q.HasHandshakeData()).To(BeFalse())
}) })
@ -165,13 +163,13 @@ var _ = Describe("Retransmission queue", func() {
q.AddHandshake(&wire.MaxDataFrame{MaximumData: 0x42}) q.AddHandshake(&wire.MaxDataFrame{MaximumData: 0x42})
q.DropPackets(protocol.EncryptionHandshake) q.DropPackets(protocol.EncryptionHandshake)
Expect(q.HasHandshakeData()).To(BeFalse()) Expect(q.HasHandshakeData()).To(BeFalse())
Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) Expect(q.GetHandshakeFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil())
}) })
}) })
Context("Application data", func() { Context("Application data", func() {
It("doesn't dequeue anything when it's empty", func() { It("doesn't dequeue anything when it's empty", func() {
Expect(q.GetAppDataFrame(protocol.MaxByteCount)).To(BeNil()) Expect(q.GetAppDataFrame(protocol.MaxByteCount, protocol.Version1)).To(BeNil())
}) })
It("queues and retrieves a control frame", func() { It("queues and retrieves a control frame", func() {
@ -179,8 +177,8 @@ var _ = Describe("Retransmission queue", func() {
Expect(q.HasAppData()).To(BeFalse()) Expect(q.HasAppData()).To(BeFalse())
q.AddAppData(f) q.AddAppData(f)
Expect(q.HasAppData()).To(BeTrue()) Expect(q.HasAppData()).To(BeTrue())
Expect(q.GetAppDataFrame(f.Length(version) - 1)).To(BeNil()) Expect(q.GetAppDataFrame(f.Length(protocol.Version1)-1, protocol.Version1)).To(BeNil())
Expect(q.GetAppDataFrame(f.Length(version))).To(Equal(f)) Expect(q.GetAppDataFrame(f.Length(protocol.Version1), protocol.Version1)).To(Equal(f))
Expect(q.HasAppData()).To(BeFalse()) Expect(q.HasAppData()).To(BeFalse())
}) })
}) })