retransmit frames instead of packets

This commit is contained in:
Marten Seemann 2019-08-28 16:28:23 +07:00
parent 515c21dd1d
commit 1d392b4276
16 changed files with 705 additions and 962 deletions

View file

@ -0,0 +1,8 @@
package ackhandler
import "github.com/lucas-clemente/quic-go/internal/wire"
type Frame struct {
wire.Frame // nil if the frame has already been acknowledged in another packet
OnLost func(wire.Frame)
}

View file

@ -11,13 +11,12 @@ import (
// A Packet is a packet
type Packet struct {
PacketNumber protocol.PacketNumber
Frames []wire.Frame
Frames []Frame
LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
canBeRetransmitted bool
includedInBytesInFlight bool
}
@ -43,8 +42,7 @@ type SentPacketHandler interface {
// only to be called once the handshake is complete
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() *Packet
DequeueProbePacket() (*Packet, error)
QueueProbePacket() bool /* was a packet queued */
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber

View file

@ -10,8 +10,6 @@ const (
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendRetransmission means that retransmissions should be sent
SendRetransmission
// SendPTO means that a probe packet should be sent
SendPTO
// SendAny means that any packet should be sent
@ -24,8 +22,6 @@ func (s SendMode) String() string {
return "none"
case SendAck:
return "ack"
case SendRetransmission:
return "retransmission"
case SendPTO:
return "pto"
case SendAny:

View file

@ -11,7 +11,6 @@ var _ = Describe("Send Mode", func() {
Expect(SendAny.String()).To(Equal("any"))
Expect(SendAck.String()).To(Equal("ack"))
Expect(SendPTO.String()).To(Equal("pto"))
Expect(SendRetransmission.String()).To(Equal("retransmission"))
Expect(SendMode(123).String()).To(Equal("invalid send mode: 123"))
})
})

View file

@ -1,7 +1,6 @@
package ackhandler
import (
"errors"
"fmt"
"math"
"time"
@ -56,8 +55,6 @@ type sentPacketHandler struct {
// Only applies to the application-data packet number space.
lowestNotConfirmedAcked protocol.PacketNumber
retransmissionQueue []*Packet
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithmWithDebugInfos
@ -112,14 +109,6 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
}
return true, nil
})
// remove packets from the retransmission queue
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel != encLevel {
queue = append(queue, packet)
}
}
h.retransmissionQueue = queue
// drop the packet history
switch encLevel {
case protocol.EncryptionInitial:
@ -170,7 +159,6 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-elicit
h.lastSentAckElicitingPacketTime = packet.SendTime
packet.includedInBytesInFlight = true
h.bytesInFlight += packet.Length
packet.canBeRetransmitted = true
if h.numProbesToSend > 0 {
h.numProbesToSend--
}
@ -335,12 +323,14 @@ func (h *sentPacketHandler) setLossDetectionTimer() {
// Cancel the alarm if no packets are outstanding
if !h.hasOutstandingPackets() {
h.logger.Debugf("setLossDetectionTimer: canceling. Bytes in flight: %d", h.bytesInFlight)
h.alarm = time.Time{}
return
}
// PTO alarm
h.alarm = h.lastSentAckElicitingPacketTime.Add(h.rttStats.PTO() << h.ptoCount)
h.logger.Debugf("setLossDetectionTimer: setting to", h.alarm)
}
func (h *sentPacketHandler) detectLostPackets(
@ -388,27 +378,23 @@ func (h *sentPacketHandler) detectLostPackets(
}
for _, p := range lostPackets {
h.queueFramesForRetransmission(p)
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
if p.canBeRetransmitted {
// queue the packet for retransmission, and report the loss to the congestion controller
if err := h.queuePacketForRetransmission(p, pnSpace); err != nil {
return err
}
}
pnSpace.history.Remove(p.PacketNumber)
if h.traceCallback != nil {
// TODO: trace frames
h.traceCallback(quictrace.Event{
Time: now,
EventType: quictrace.PacketLost,
EncryptionLevel: p.EncryptionLevel,
PacketNumber: p.PacketNumber,
PacketSize: p.Length,
Frames: p.Frames,
TransportState: h.GetStats(),
//Frames: p.Frames,
TransportState: h.GetStats(),
})
}
}
@ -454,63 +440,16 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
pnSpace := h.getPacketNumberSpace(p.EncryptionLevel)
// This happens if a packet and its retransmissions is acked in the same ACK.
// As soon as we process the first one, this will remove all the retransmissions,
// so we won't find the retransmitted packet number later.
if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil {
return nil
}
// this also applies to packets that have been retransmitted as probe packets
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
if err := pnSpace.history.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
return pnSpace.history.Remove(p.PacketNumber)
}
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 {
return nil
}
packet := h.retransmissionQueue[0]
// Shift the slice and don't retain anything that isn't needed.
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet
}
func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
if len(h.retransmissionQueue) > 0 {
return h.DequeuePacketForRetransmission(), nil
}
var pnSpace *packetNumberSpace
var p *Packet
if h.initialPackets != nil {
pnSpace = h.initialPackets
p = h.initialPackets.history.FirstOutstanding()
}
if p == nil && h.handshakePackets != nil {
pnSpace = h.handshakePackets
p = h.handshakePackets.history.FirstOutstanding()
}
if p == nil {
pnSpace = h.oneRTTPackets
p = h.oneRTTPackets.history.FirstOutstanding()
}
if p == nil {
return nil, errors.New("cannot dequeue a probe packet. No outstanding packets")
}
if err := h.queuePacketForRetransmission(p, pnSpace); err != nil {
return nil, err
}
return h.DequeuePacketForRetransmission(), nil
}
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel)
@ -530,7 +469,7 @@ func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) p
}
func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := len(h.retransmissionQueue) + h.oneRTTPackets.history.Len()
numTrackedPackets := h.oneRTTPackets.history.Len()
if h.initialPackets != nil {
numTrackedPackets += h.initialPackets.history.Len()
}
@ -558,10 +497,6 @@ func (h *sentPacketHandler) SendMode() SendMode {
}
return SendAck
}
// Send retransmissions first, if there are any.
if len(h.retransmissionQueue) > 0 {
return SendRetransmission
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
@ -587,30 +522,45 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet, pnSpace *packetNumberSpace) error {
if !p.canBeRetransmitted {
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
func (h *sentPacketHandler) QueueProbePacket() bool {
var p *Packet
if h.initialPackets != nil {
p = h.initialPackets.history.FirstOutstanding()
}
if err := pnSpace.history.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
if p == nil && h.handshakePackets != nil {
p = h.handshakePackets.history.FirstOutstanding()
}
if p == nil {
p = h.oneRTTPackets.history.FirstOutstanding()
}
if p == nil {
return false
}
h.queueFramesForRetransmission(p)
// TODO: don't remove the packet here
// Keep track of acknowledged frames instead.
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
if err := h.getPacketNumberSpace(p.EncryptionLevel).history.Remove(p.PacketNumber); err != nil {
// should never happen. We just got this packet from the history a lines above.
panic(err)
}
return true
}
func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) {
for _, f := range p.Frames {
f.OnLost(f.Frame)
}
h.retransmissionQueue = append(h.retransmissionQueue, p)
return nil
}
func (h *sentPacketHandler) ResetForRetry() error {
h.bytesInFlight = 0
var packets []*Packet
h.initialPackets.history.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted {
packets = append(packets, p)
}
h.queueFramesForRetransmission(p)
return true, nil
})
for _, p := range packets {
h.logger.Debugf("Queueing packet %#x for retransmission.", p.PacketNumber)
h.retransmissionQueue = append(h.retransmissionQueue, p)
}
h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop())
h.setLossDetectionTimer()
return nil

View file

@ -14,40 +14,15 @@ import (
. "github.com/onsi/gomega"
)
func ackElicitingPacket(p *Packet) *Packet {
if p.EncryptionLevel == protocol.EncryptionUnspecified {
p.EncryptionLevel = protocol.Encryption1RTT
}
if p.Length == 0 {
p.Length = 1
}
if p.SendTime.IsZero() {
p.SendTime = time.Now()
}
p.Frames = []wire.Frame{&wire.PingFrame{}}
return p
}
func nonAckElicitingPacket(p *Packet) *Packet {
p = ackElicitingPacket(p)
p.Frames = nil
p.LargestAcked = 1
return p
}
func cryptoPacket(p *Packet) *Packet {
p = ackElicitingPacket(p)
p.EncryptionLevel = protocol.EncryptionInitial
return p
}
var _ = Describe("SentPacketHandler", func() {
var (
handler *sentPacketHandler
streamFrame wire.StreamFrame
lostPackets []protocol.PacketNumber
)
BeforeEach(func() {
lostPackets = nil
rttStats := &congestion.RTTStats{}
handler = NewSentPacketHandler(42, rttStats, nil, utils.DefaultLogger).(*sentPacketHandler)
streamFrame = wire.StreamFrame{
@ -63,6 +38,35 @@ var _ = Describe("SentPacketHandler", func() {
return nil
}
ackElicitingPacket := func(p *Packet) *Packet {
if p.EncryptionLevel == protocol.EncryptionUnspecified {
p.EncryptionLevel = protocol.Encryption1RTT
}
if p.Length == 0 {
p.Length = 1
}
if p.SendTime.IsZero() {
p.SendTime = time.Now()
}
p.Frames = []Frame{
{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, p.PacketNumber) }},
}
return p
}
nonAckElicitingPacket := func(p *Packet) *Packet {
p = ackElicitingPacket(p)
p.Frames = nil
p.LargestAcked = 1
return p
}
cryptoPacket := func(p *Packet) *Packet {
p = ackElicitingPacket(p)
p.EncryptionLevel = protocol.EncryptionInitial
return p
}
expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) {
pnSpace := handler.getPacketNumberSpace(encLevel)
ExpectWithOffset(1, pnSpace.history.Len()).To(Equal(len(expected)))
@ -161,21 +165,21 @@ var _ = Describe("SentPacketHandler", func() {
})
})
Context("acks and nacks the right packets", func() {
Context("acks the right packets", func() {
expectInPacketHistoryOrLost := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) {
pnSpace := handler.getPacketNumberSpace(encLevel)
ExpectWithOffset(1, pnSpace.history.Len()+len(handler.retransmissionQueue)).To(Equal(len(expected)))
ExpectWithOffset(1, pnSpace.history.Len()+len(lostPackets)).To(Equal(len(expected)))
expectedLoop:
for _, p := range expected {
if _, ok := pnSpace.history.packetMap[p]; ok {
continue
}
for _, lost := range handler.retransmissionQueue {
if lost.PacketNumber == p {
for _, lostP := range lostPackets {
if lostP == p {
continue expectedLoop
}
}
Fail(fmt.Sprintf("Packet %d neither in packet history nor declared lost.", p))
Fail(fmt.Sprintf("Packet %d not in packet history.", p))
}
}
@ -320,9 +324,26 @@ var _ = Describe("SentPacketHandler", func() {
Context("determining which ACKs we have received an ACK for", func() {
BeforeEach(func() {
morePackets := []*Packet{
{PacketNumber: 13, LargestAcked: 100, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT},
{PacketNumber: 14, LargestAcked: 200, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT},
{PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT},
{
PacketNumber: 13,
LargestAcked: 100,
Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}},
Length: 1,
EncryptionLevel: protocol.Encryption1RTT,
},
{
PacketNumber: 14,
LargestAcked: 200,
Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}},
Length: 1,
EncryptionLevel: protocol.Encryption1RTT,
},
{
PacketNumber: 15,
Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}},
Length: 1,
EncryptionLevel: protocol.Encryption1RTT,
},
}
for _, packet := range morePackets {
handler.SentPacket(packet)
@ -355,11 +376,6 @@ var _ = Describe("SentPacketHandler", func() {
})
})
It("does not dequeue a packet if no ACK has been received", func() {
handler.SentPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.Encryption1RTT, SendTime: time.Now().Add(-time.Hour)})
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
})
Context("congestion", func() {
var cong *mocks.MockSendAlgorithmWithDebugInfos
@ -380,7 +396,7 @@ var _ = Describe("SentPacketHandler", func() {
handler.SentPacket(&Packet{
PacketNumber: 1,
Length: 42,
Frames: []wire.Frame{&wire.PingFrame{}},
Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) {}}},
EncryptionLevel: protocol.Encryption1RTT,
})
})
@ -474,32 +490,10 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.SendMode()).To(Equal(SendAck))
})
It("doesn't allow retransmissions if congestion limited", func() {
handler.retransmissionQueue = []*Packet{{PacketNumber: 3}}
cong.EXPECT().CanSend(gomock.Any()).Return(false)
Expect(handler.SendMode()).To(Equal(SendAck))
})
It("allows sending retransmissions", func() {
cong.EXPECT().CanSend(gomock.Any()).Return(true)
handler.retransmissionQueue = []*Packet{{PacketNumber: 3}}
Expect(handler.SendMode()).To(Equal(SendRetransmission))
})
It("allows retransmissions, if we're keeping track of between MaxOutstandingSentPackets and MaxTrackedSentPackets packets", func() {
cong.EXPECT().CanSend(gomock.Any()).Return(true)
Expect(protocol.MaxOutstandingSentPackets).To(BeNumerically("<", protocol.MaxTrackedSentPackets))
handler.retransmissionQueue = make([]*Packet, protocol.MaxOutstandingSentPackets+10)
Expect(handler.SendMode()).To(Equal(SendRetransmission))
handler.retransmissionQueue = make([]*Packet, protocol.MaxTrackedSentPackets)
Expect(handler.SendMode()).To(Equal(SendNone))
})
It("allows RTOs, even when congestion limited", func() {
// note that we don't EXPECT a call to GetCongestionWindow
// that means retransmissions are sent without considering the congestion window
handler.numProbesToSend = 1
handler.retransmissionQueue = []*Packet{{PacketNumber: 3}}
Expect(handler.SendMode()).To(Equal(SendPTO))
})
@ -550,6 +544,19 @@ var _ = Describe("SentPacketHandler", func() {
})
Context("probe packets", func() {
It("queues a probe packet", func() {
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11}))
queued := handler.QueueProbePacket()
Expect(queued).To(BeTrue())
Expect(lostPackets).To(Equal([]protocol.PacketNumber{10}))
})
It("says when it can't queue a probe packet", func() {
queued := handler.QueueProbePacket()
Expect(queued).To(BeFalse())
})
It("implements exponential backoff", func() {
sendTime := time.Now().Add(-time.Hour)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime}))
@ -563,9 +570,16 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(4 * timeout))
})
It("sets the TPO send mode until two packets is sent", func() {
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
handler.OnLossDetectionTimeout()
It("sets the PTO send mode until two packets is sent", func() {
var lostPackets []protocol.PacketNumber
handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: 1,
SendTime: time.Now().Add(-time.Hour),
Frames: []Frame{
{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }},
},
}))
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode()).To(Equal(SendPTO))
Expect(handler.ShouldSendNumPackets()).To(Equal(2))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2}))
@ -576,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() {
It("only counts ack-eliciting packets as probe packets", func() {
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
handler.OnLossDetectionTimeout()
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode()).To(Equal(SendPTO))
Expect(handler.ShouldSendNumPackets()).To(Equal(2))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2}))
@ -596,104 +610,49 @@ var _ = Describe("SentPacketHandler", func() {
updateRTT(time.Hour)
Expect(handler.oneRTTPackets.lossTime.IsZero()).To(BeTrue())
handler.OnLossDetectionTimeout() // TLP
handler.OnLossDetectionTimeout() // TLP
handler.OnLossDetectionTimeout() // RTO
p, err := handler.DequeueProbePacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(1)))
p, err = handler.DequeueProbePacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(2)))
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP
Expect(handler.ptoCount).To(BeEquivalentTo(1))
Expect(handler.SendMode()).To(Equal(SendPTO))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3}))
Expect(handler.SendMode()).To(Equal(SendPTO))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4}))
Expect(handler.ptoCount).To(BeEquivalentTo(3))
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // PTO
Expect(handler.ptoCount).To(BeEquivalentTo(2))
Expect(handler.SendMode()).To(Equal(SendPTO))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5}))
Expect(handler.SendMode()).To(Equal(SendPTO))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 6}))
Expect(handler.SendMode()).To(Equal(SendAny))
})
It("gets two probe packets if RTO expires, for crypto packets", func() {
It("gets two probe packets if PTO expires, for crypto packets", func() {
handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 1}))
handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 2}))
updateRTT(time.Hour)
Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue())
handler.OnLossDetectionTimeout() // TLP
handler.OnLossDetectionTimeout() // TLP
handler.OnLossDetectionTimeout() // RTO
p, err := handler.DequeueProbePacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(1)))
p, err = handler.DequeueProbePacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(2)))
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode()).To(Equal(SendPTO))
handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 3}))
Expect(handler.SendMode()).To(Equal(SendPTO))
handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 3}))
Expect(handler.ptoCount).To(BeEquivalentTo(3))
})
It("doesn't delete packets transmitted as PTO from the history", func() {
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: time.Now().Add(-time.Hour)}))
handler.rttStats.UpdateRTT(time.Second, 0, time.Now())
handler.OnLossDetectionTimeout() // TLP
handler.OnLossDetectionTimeout() // TLP
handler.OnLossDetectionTimeout() // RTO
_, err := handler.DequeueProbePacket()
Expect(err).ToNot(HaveOccurred())
_, err = handler.DequeueProbePacket()
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{1, 2}, protocol.Encryption1RTT)
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
// Send a probe packet and receive an ACK for it.
// This verifies the RTO.
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 3}}}
Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())).To(Succeed())
Expect(err).ToNot(HaveOccurred())
Expect(handler.oneRTTPackets.history.Len()).To(BeZero())
Expect(handler.bytesInFlight).To(BeZero())
Expect(handler.retransmissionQueue).To(BeEmpty()) // 1 and 2 were already sent as probe packets
Expect(handler.SendMode()).To(Equal(SendAny))
})
It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() {
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
handler.rttStats.UpdateRTT(time.Second, 0, time.Now())
handler.OnLossDetectionTimeout()
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode()).To(Equal(SendPTO))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}
Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())).To(Succeed())
Expect(handler.SendMode()).To(Equal(SendAny))
})
It("gets packets sent before the probe packet for retransmission", func() {
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: time.Now().Add(-time.Hour)}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3, SendTime: time.Now().Add(-time.Hour)}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4, SendTime: time.Now().Add(-time.Hour)}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)}))
handler.OnLossDetectionTimeout() // TLP
handler.OnLossDetectionTimeout() // TLP
handler.OnLossDetectionTimeout() // RTO
_, err := handler.DequeueProbePacket()
Expect(err).ToNot(HaveOccurred())
_, err = handler.DequeueProbePacket()
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{1, 2, 3, 4, 5}, protocol.Encryption1RTT)
// Send a probe packet and receive an ACK for it.
// This verifies the RTO.
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 6}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 6, Largest: 6}}}
err = handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.oneRTTPackets.history.Len()).To(BeZero())
Expect(handler.bytesInFlight).To(BeZero())
Expect(handler.retransmissionQueue).To(HaveLen(3)) // packets 3, 4, 5
})
It("handles ACKs for the original packet", func() {
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)}))
handler.rttStats.UpdateRTT(time.Second, 0, time.Now())
@ -712,11 +671,7 @@ var _ = Describe("SentPacketHandler", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 6, Largest: 6}}}
Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())).To(Succeed())
expectInPacketHistory([]protocol.PacketNumber{4, 5}, protocol.Encryption1RTT)
for _, p := range []protocol.PacketNumber{1, 2, 3} {
lost := handler.DequeuePacketForRetransmission()
Expect(lost).ToNot(BeNil())
Expect(lost.PacketNumber).To(Equal(p))
}
Expect(lostPackets).To(Equal([]protocol.PacketNumber{1, 2, 3}))
})
})
@ -729,8 +684,6 @@ var _ = Describe("SentPacketHandler", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, now)).To(Succeed())
Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil())
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
// no need to set an alarm, since packet 1 was already declared lost
Expect(handler.oneRTTPackets.lossTime.IsZero()).To(BeTrue())
Expect(handler.bytesInFlight).To(BeZero())
@ -750,11 +703,6 @@ var _ = Describe("SentPacketHandler", func() {
// Packet 1 should be considered lost (1+1/8) RTTs after it was sent.
Expect(handler.oneRTTPackets.lossTime.IsZero()).To(BeFalse())
Expect(handler.oneRTTPackets.lossTime.Sub(getPacket(1, protocol.Encryption1RTT).SendTime)).To(Equal(time.Second * 9 / 8))
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil())
// make sure this is not an RTO: only packet 1 is retransmissted
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
})
It("sets the early retransmit alarm for crypto packets", func() {
@ -771,22 +719,15 @@ var _ = Describe("SentPacketHandler", func() {
// Packet 1 should be considered lost (1+1/8) RTTs after it was sent.
Expect(handler.initialPackets.lossTime.IsZero()).To(BeFalse())
Expect(handler.initialPackets.lossTime.Sub(getPacket(1, protocol.EncryptionInitial).SendTime)).To(Equal(time.Second * 9 / 8))
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil())
// make sure this is not an RTO: only packet 1 is retransmissted
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
})
})
Context("crypto packets", func() {
It("rejects an ACK that acks packets with a higher encryption level", func() {
handler.SentPacket(&Packet{
handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: 13,
EncryptionLevel: protocol.Encryption1RTT,
Frames: []wire.Frame{&streamFrame},
Length: 1,
})
}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionHandshake, time.Now())
Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet"))
@ -794,43 +735,43 @@ var _ = Describe("SentPacketHandler", func() {
It("deletes Initial packets", func() {
for i := protocol.PacketNumber(0); i < 6; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionInitial})
handler.SentPacket(p)
handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: i,
EncryptionLevel: protocol.EncryptionInitial,
}))
}
for i := protocol.PacketNumber(0); i < 10; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake})
handler.SentPacket(p)
handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: i,
EncryptionLevel: protocol.EncryptionHandshake,
}))
}
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16)))
handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionInitial), handler.getPacketNumberSpace(protocol.EncryptionInitial))
lostPacket := getPacket(3, protocol.EncryptionHandshake)
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.DropPackets(protocol.EncryptionInitial)
Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
Expect(handler.initialPackets).To(BeNil())
Expect(handler.handshakePackets.history.Len()).ToNot(BeZero())
packet := handler.DequeuePacketForRetransmission()
Expect(packet).To(Equal(lostPacket))
})
It("deletes Handshake packets", func() {
for i := protocol.PacketNumber(0); i < 6; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake})
handler.SentPacket(p)
handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: i,
EncryptionLevel: protocol.EncryptionHandshake,
}))
}
for i := protocol.PacketNumber(0); i < 10; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.Encryption1RTT})
handler.SentPacket(p)
handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: i,
EncryptionLevel: protocol.Encryption1RTT,
}))
}
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16)))
handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionInitial))
lostPacket := getPacket(3, protocol.Encryption1RTT)
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.DropPackets(protocol.EncryptionHandshake)
Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
Expect(handler.handshakePackets).To(BeNil())
packet := handler.DequeuePacketForRetransmission()
Expect(packet).To(Equal(lostPacket))
})
})
@ -858,25 +799,16 @@ var _ = Describe("SentPacketHandler", func() {
Context("resetting for retry", func() {
It("queues outstanding packets for retransmission and cancels alarms", func() {
packet := &Packet{
PacketNumber: 42,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []wire.Frame{&wire.CryptoFrame{Data: []byte("foobar")}},
Length: 100,
}
handler.SentPacket(packet)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 42, EncryptionLevel: protocol.EncryptionInitial}))
Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero())
Expect(handler.bytesInFlight).ToNot(BeZero())
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
Expect(handler.SendMode()).To(Equal(SendAny))
// now receive a Retry
Expect(handler.ResetForRetry()).To(Succeed())
Expect(lostPackets).To(Equal([]protocol.PacketNumber{42}))
Expect(handler.bytesInFlight).To(BeZero())
Expect(handler.GetLossDetectionTimeout()).To(BeZero())
Expect(handler.SendMode()).To(Equal(SendRetransmission))
p := handler.DequeuePacketForRetransmission()
Expect(p.PacketNumber).To(Equal(packet.PacketNumber))
Expect(p.Frames).To(Equal(packet.Frames))
Expect(handler.SendMode()).To(Equal(SendAny))
})
})
})

View file

@ -9,10 +9,6 @@ import (
type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
numOutstandingPackets int
firstOutstanding *PacketElement
}
func newSentPacketHistory() *sentPacketHistory {
@ -23,19 +19,8 @@ func newSentPacketHistory() *sentPacketHistory {
}
func (h *sentPacketHistory) SentPacket(p *Packet) {
h.sentPacketImpl(p)
}
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
if h.firstOutstanding == nil {
h.firstOutstanding = el
}
if p.canBeRetransmitted {
h.numOutstandingPackets++
}
return el
}
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
@ -63,40 +48,10 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err
// It must not be modified (e.g. retransmitted).
// Use DequeueFirstPacketForRetransmission() to retransmit it.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
if h.firstOutstanding == nil {
if !h.HasOutstandingPackets() {
return nil
}
return &h.firstOutstanding.Value
}
// QueuePacketForRetransmission marks a packet for retransmission.
// A packet can only be queued once.
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
el, ok := h.packetMap[pn]
if !ok {
return fmt.Errorf("sent packet history: packet %d not found", pn)
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
el.Value.canBeRetransmitted = false
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
return nil
}
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
// This is necessary every time the first outstanding packet is deleted or retransmitted.
func (h *sentPacketHistory) readjustFirstOutstanding() {
el := h.firstOutstanding.Next()
for el != nil && !el.Value.canBeRetransmitted {
el = el.Next()
}
h.firstOutstanding = el
return &h.packetList.Front().Value
}
func (h *sentPacketHistory) Len() int {
@ -108,20 +63,11 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
}
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstandingPackets > 0
return h.packetList.Len() > 0
}

View file

@ -15,13 +15,14 @@ var _ = Describe("SentPacketHistory", func() {
ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers)))
ExpectWithOffset(1, hist.packetList.Len()).To(Equal(len(packetNumbers)))
i := 0
hist.Iterate(func(p *Packet) (bool, error) {
err := hist.Iterate(func(p *Packet) (bool, error) {
pn := packetNumbers[i]
ExpectWithOffset(1, p.PacketNumber).To(Equal(pn))
ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn))
i++
return true, nil
})
Expect(err).ToNot(HaveOccurred())
}
BeforeEach(func() {
@ -53,45 +54,6 @@ var _ = Describe("SentPacketHistory", func() {
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2)))
})
It("gets the second packet if the first one is retransmitted", func() {
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
front := hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// Queue the first packet for retransmission.
// The first outstanding packet should now be 3.
err := hist.MarkCannotBeRetransmitted(1)
Expect(err).ToNot(HaveOccurred())
front = hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(3)))
})
It("gets the third packet if the first two are retransmitted", func() {
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
front := hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// Queue the second packet for retransmission.
// The first outstanding packet should still be 3.
err := hist.MarkCannotBeRetransmitted(3)
Expect(err).ToNot(HaveOccurred())
front = hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// Queue the first packet for retransmission.
// The first outstanding packet should still be 4.
err = hist.MarkCannotBeRetransmitted(1)
Expect(err).ToNot(HaveOccurred())
front = hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(4)))
})
})
It("gets a packet by packet number", func() {
@ -164,60 +126,32 @@ var _ = Describe("SentPacketHistory", func() {
Context("outstanding packets", func() {
It("says if it has outstanding packets", func() {
Expect(hist.HasOutstandingPackets()).To(BeFalse())
hist.SentPacket(&Packet{
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
})
hist.SentPacket(&Packet{EncryptionLevel: protocol.Encryption1RTT})
Expect(hist.HasOutstandingPackets()).To(BeTrue())
})
It("doesn't consider non-ack-eliciting packets as outstanding", func() {
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionInitial,
})
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
It("accounts for deleted packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
PacketNumber: 10,
EncryptionLevel: protocol.Encryption1RTT,
})
Expect(hist.HasOutstandingPackets()).To(BeTrue())
err := hist.Remove(10)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
It("doesn't count packets marked as non-ack-eliciting", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingPackets()).To(BeTrue())
err := hist.MarkCannotBeRetransmitted(10)
Expect(err).ToNot(HaveOccurred())
Expect(hist.Remove(10)).To(Succeed())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
It("counts the number of packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
PacketNumber: 10,
EncryptionLevel: protocol.Encryption1RTT,
})
hist.SentPacket(&Packet{
PacketNumber: 11,
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
PacketNumber: 11,
EncryptionLevel: protocol.Encryption1RTT,
})
err := hist.Remove(11)
Expect(err).ToNot(HaveOccurred())
Expect(hist.Remove(11)).To(Succeed())
Expect(hist.HasOutstandingPackets()).To(BeTrue())
err = hist.Remove(10)
Expect(err).ToNot(HaveOccurred())
Expect(hist.Remove(10)).To(Succeed())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
})

View file

@ -38,35 +38,6 @@ func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder {
return m.recorder
}
// DequeuePacketForRetransmission mocks base method
func (m *MockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DequeuePacketForRetransmission")
ret0, _ := ret[0].(*ackhandler.Packet)
return ret0
}
// DequeuePacketForRetransmission indicates an expected call of DequeuePacketForRetransmission
func (mr *MockSentPacketHandlerMockRecorder) DequeuePacketForRetransmission() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeuePacketForRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeuePacketForRetransmission))
}
// DequeueProbePacket mocks base method
func (m *MockSentPacketHandler) DequeueProbePacket() (*ackhandler.Packet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DequeueProbePacket")
ret0, _ := ret[0].(*ackhandler.Packet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DequeueProbePacket indicates an expected call of DequeueProbePacket
func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket))
}
// DropPackets mocks base method
func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
@ -164,6 +135,20 @@ func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0)
}
// QueueProbePacket mocks base method
func (m *MockSentPacketHandler) QueueProbePacket() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "QueueProbePacket")
ret0, _ := ret[0].(bool)
return ret0
}
// QueueProbePacket indicates an expected call of QueueProbePacket
func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket))
}
// ReceivedAck mocks base method
func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.PacketNumber, arg2 protocol.EncryptionLevel, arg3 time.Time) error {
m.ctrl.T.Helper()

View file

@ -8,7 +8,6 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
@ -106,21 +105,6 @@ func (mr *MockPackerMockRecorder) PackPacket() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket))
}
// PackRetransmission mocks base method
func (m *MockPacker) PackRetransmission(arg0 *ackhandler.Packet) ([]*packedPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PackRetransmission", arg0)
ret0, _ := ret[0].([]*packedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PackRetransmission indicates an expected call of PackRetransmission
func (mr *MockPackerMockRecorder) PackRetransmission(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackRetransmission", reflect.TypeOf((*MockPacker)(nil).PackRetransmission), arg0)
}
// SetToken mocks base method
func (m *MockPacker) SetToken(arg0 []byte) {
m.ctrl.T.Helper()

View file

@ -17,7 +17,6 @@ import (
type packer interface {
PackPacket() (*packedPacket, error)
MaybePackAckPacket() (*packedPacket, error)
PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error)
PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error)
HandleTransportParameters(*handshake.TransportParameters)
@ -62,17 +61,31 @@ func (p *packedPacket) IsAckEliciting() bool {
return ackhandler.HasAckElicitingFrames(p.frames)
}
func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
func (p *packedPacket) ToAckHandlerPacket(q *retransmissionQueue) *ackhandler.Packet {
largestAcked := protocol.InvalidPacketNumber
if p.ack != nil {
largestAcked = p.ack.LargestAcked()
}
encLevel := p.EncryptionLevel()
frames := make([]ackhandler.Frame, len(p.frames))
for i, f := range p.frames {
frame := f
frames[i].Frame = frame
switch encLevel {
case protocol.EncryptionInitial:
frames[i].OnLost = q.AddInitial
case protocol.EncryptionHandshake:
frames[i].OnLost = q.AddHandshake
case protocol.Encryption1RTT:
frames[i].OnLost = q.AddAppData
}
}
return &ackhandler.Packet{
PacketNumber: p.header.PacketNumber,
LargestAcked: largestAcked,
Frames: p.frames,
Frames: frames,
Length: protocol.ByteCount(len(p.raw)),
EncryptionLevel: p.EncryptionLevel(),
EncryptionLevel: encLevel,
SendTime: time.Now(),
}
}
@ -130,9 +143,10 @@ type packetPacker struct {
token []byte
pnManager packetNumberManager
framer frameSource
acks ackFrameSource
pnManager packetNumberManager
framer frameSource
acks ackFrameSource
retransmissionQueue *retransmissionQueue
maxPacketSize protocol.ByteCount
numNonAckElicitingAcks int
@ -146,6 +160,7 @@ func newPacketPacker(
initialStream cryptoStream,
handshakeStream cryptoStream,
packetNumberManager packetNumberManager,
retransmissionQueue *retransmissionQueue,
remoteAddr net.Addr, // only used for determining the max packet size
cryptoSetup sealingManager,
framer frameSource,
@ -154,17 +169,18 @@ func newPacketPacker(
version protocol.VersionNumber,
) *packetPacker {
return &packetPacker{
cryptoSetup: cryptoSetup,
destConnID: destConnID,
srcConnID: srcConnID,
initialStream: initialStream,
handshakeStream: handshakeStream,
perspective: perspective,
version: version,
framer: framer,
acks: acks,
pnManager: packetNumberManager,
maxPacketSize: getMaxPacketSize(remoteAddr),
cryptoSetup: cryptoSetup,
destConnID: destConnID,
srcConnID: srcConnID,
initialStream: initialStream,
handshakeStream: handshakeStream,
retransmissionQueue: retransmissionQueue,
perspective: perspective,
version: version,
framer: framer,
acks: acks,
pnManager: packetNumberManager,
maxPacketSize: getMaxPacketSize(remoteAddr),
}
}
@ -237,80 +253,6 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
return p.writeAndSealPacket(hdr, payload, encLevel, sealer)
}
// PackRetransmission packs a retransmission
// For packets sent after completion of the handshake, it might happen that 2 packets have to be sent.
// This can happen e.g. when a longer packet number is used in the header.
func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) {
var controlFrames []wire.Frame
var streamFrames []*wire.StreamFrame
for _, f := range packet.Frames {
// CRYPTO frames are treated as control frames here.
// Since we're making sure that the header can never be larger for a retransmission,
// we never have to split CRYPTO frames.
if sf, ok := f.(*wire.StreamFrame); ok {
sf.DataLenPresent = true
streamFrames = append(streamFrames, sf)
} else {
controlFrames = append(controlFrames, f)
}
}
var packets []*packedPacket
for len(controlFrames) > 0 || len(streamFrames) > 0 {
var frames []wire.Frame
var length protocol.ByteCount
sealer, hdr, err := p.getSealerAndHeader(packet.EncryptionLevel)
if err != nil {
return nil, err
}
hdrLen := hdr.GetLength(p.version)
maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - hdrLen
for len(controlFrames) > 0 {
frame := controlFrames[0]
frameLen := frame.Length(p.version)
if length+frameLen > maxSize {
break
}
length += frameLen
frames = append(frames, frame)
controlFrames = controlFrames[1:]
}
for len(streamFrames) > 0 && length+protocol.MinStreamFrameSize < maxSize {
frame := streamFrames[0]
frame.DataLenPresent = false
frameToAdd := frame
sf, needsSplit := frame.MaybeSplitOffFrame(maxSize-length, p.version)
if needsSplit {
if sf == nil { // size too small to create a new STREAM frame
continue
}
frameToAdd = sf
} else {
streamFrames = streamFrames[1:]
}
frame.DataLenPresent = true
length += frameToAdd.Length(p.version)
frames = append(frames, frameToAdd)
}
if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
sfLen := sf.Length(p.version)
sf.DataLenPresent = false
length += sf.Length(p.version) - sfLen
}
p, err := p.writeAndSealPacket(hdr, payload{frames: frames, length: length}, packet.EncryptionLevel, sealer)
if err != nil {
return nil, err
}
packets = append(packets, p)
}
return packets, nil
}
// PackPacket packs a new packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket() (*packedPacket, error) {
@ -371,9 +313,10 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
}
hasData := p.initialStream.HasData()
hasRetransmission := p.retransmissionQueue.HasInitialData()
ack := p.acks.GetAckFrame(protocol.EncryptionInitial)
var sealer handshake.LongHeaderSealer
if hasData || ack != nil {
if hasData || hasRetransmission || ack != nil {
s = p.initialStream
encLevel = protocol.EncryptionInitial
sealer = initialSealer
@ -382,8 +325,9 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
}
} else {
hasData = p.handshakeStream.HasData()
hasRetransmission = p.retransmissionQueue.HasHandshakeData()
ack = p.acks.GetAckFrame(protocol.EncryptionHandshake)
if hasData || ack != nil {
if hasData || hasRetransmission || ack != nil {
s = p.handshakeStream
encLevel = protocol.EncryptionHandshake
sealer = handshakeSealer
@ -403,7 +347,24 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
}
hdr := p.getLongHeader(encLevel)
hdrLen := hdr.GetLength(p.version)
if hasData {
if hasRetransmission {
for {
var f wire.Frame
switch encLevel {
case protocol.EncryptionInitial:
remainingLen := protocol.MinInitialPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length
f = p.retransmissionQueue.GetInitialFrame(remainingLen)
case protocol.EncryptionHandshake:
remainingLen := p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length
f = p.retransmissionQueue.GetHandshakeFrame(remainingLen)
}
if f == nil {
break
}
payload.frames = append(payload.frames, f)
payload.length += f.Length(p.version)
}
} else if hasData {
cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length)
payload.frames = []wire.Frame{cf}
payload.length += cf.Length(p.version)
@ -419,14 +380,25 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (paylo
payload.length += ack.Length(p.version)
}
frames, lengthAdded := p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length)
for {
remainingLen := maxFrameSize - payload.length
if remainingLen < protocol.MinStreamFrameSize {
break
}
f := p.retransmissionQueue.GetAppDataFrame(remainingLen)
if f == nil {
break
}
payload.frames = append(payload.frames, f)
payload.length += f.Length(p.version)
}
var lengthAdded protocol.ByteCount
payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length)
payload.length += lengthAdded
frames, lengthAdded = p.framer.AppendStreamFrames(frames, maxFrameSize-payload.length)
if len(frames) > 0 {
payload.frames = append(payload.frames, frames...)
payload.length += lengthAdded
}
payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length)
payload.length += lengthAdded
return payload, nil
}
@ -478,6 +450,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex
switch encLevel {
case protocol.EncryptionInitial:
hdr.Type = protocol.PacketTypeInitial
hdr.Token = p.token
case protocol.EncryptionHandshake:
hdr.Type = protocol.PacketTypeHandshake
}
@ -507,7 +480,6 @@ func (p *packetPacker) writeAndSealPacket(
if encLevel != protocol.Encryption1RTT {
if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
header.Token = p.token
headerLen := header.GetLength(p.version)
header.Length = pnLen + protocol.MinInitialPacketSize - headerLen
paddingLen = protocol.ByteCount(protocol.MinInitialPacketSize-sealer.Overhead()) - headerLen - payload.length
@ -550,6 +522,7 @@ func (p *packetPacker) writeAndSealPacketWithPadding(
}
if payloadSize := protocol.ByteCount(buffer.Len()-payloadOffset) - paddingLen; payloadSize != payload.length {
fmt.Printf("%#v\n", payload)
return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize)
}
if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {

View file

@ -8,7 +8,6 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/ackhandler"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks"
mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler"
@ -20,21 +19,24 @@ import (
var _ = Describe("Packet packer", func() {
const maxPacketSize protocol.ByteCount = 1357
const version = protocol.VersionTLS
var (
packer *packetPacker
framer *MockFrameSource
ackFramer *MockAckFrameSource
initialStream *MockCryptoStream
handshakeStream *MockCryptoStream
sealingManager *MockSealingManager
pnManager *mockackhandler.MockSentPacketHandler
packer *packetPacker
retransmissionQueue *retransmissionQueue
framer *MockFrameSource
ackFramer *MockAckFrameSource
initialStream *MockCryptoStream
handshakeStream *MockCryptoStream
sealingManager *MockSealingManager
pnManager *mockackhandler.MockSentPacketHandler
)
checkLength := func(data []byte) {
hdr, _, _, err := wire.ParsePacket(data, 0)
Expect(err).ToNot(HaveOccurred())
r := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(r, protocol.VersionWhatever)
extHdr, err := hdr.ParseExtended(r, version)
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen)))
}
@ -61,7 +63,7 @@ var _ = Describe("Packet packer", func() {
BeforeEach(func() {
rand.Seed(GinkgoRandomSeed())
version := protocol.VersionTLS
retransmissionQueue = newRetransmissionQueue(version)
mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
initialStream = NewMockCryptoStream(mockCtrl)
@ -77,6 +79,7 @@ var _ = Describe("Packet packer", func() {
initialStream,
handshakeStream,
pnManager,
retransmissionQueue,
&net.TCPAddr{},
sealingManager,
framer,
@ -408,6 +411,60 @@ var _ = Describe("Packet packer", func() {
Expect(r.Len()).To(BeZero())
})
It("packs multiple small STREAM frames into single packet", func() {
f1 := &wire.StreamFrame{
StreamID: 5,
Data: []byte("frame 1"),
DataLenPresent: true,
}
f2 := &wire.StreamFrame{
StreamID: 5,
Data: []byte("frame 2"),
DataLenPresent: true,
}
f3 := &wire.StreamFrame{
StreamID: 3,
Data: []byte("frame 3"),
DataLenPresent: true,
}
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames()
expectAppendStreamFrames(f1, f2, f3)
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(3))
Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1")))
Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2")))
Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3")))
})
It("adds retransmissions", func() {
f1 := &wire.StreamFrame{Data: []byte("frame 1")}
cf := &wire.MaxDataFrame{ByteOffset: 0x42}
retransmissionQueue.AddAppData(f1)
retransmissionQueue.AddAppData(cf)
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames()
f2 := &wire.StreamFrame{Data: []byte("frame 2")}
expectAppendStreamFrames(f2)
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(3))
Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{}))
Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1")))
Expect(p.frames[1]).To(Equal(cf))
Expect(p.frames[2]).To(BeAssignableToTypeOf(&wire.StreamFrame{}))
Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2")))
})
Context("making ACK packets ack-eliciting", func() {
sendMaxNumNonAckElicitingAcks := func() {
for i := 0; i < protocol.MaxNonAckElicitingAcks; i++ {
@ -491,230 +548,6 @@ var _ = Describe("Packet packer", func() {
})
})
Context("STREAM frame handling", func() {
It("does not split a STREAM frame with maximum size", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
expectAppendControlFrames()
sf := &wire.StreamFrame{
Offset: 1,
StreamID: 5,
}
framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) {
sf.Data = bytes.Repeat([]byte{'f'}, int(maxSize-sf.Length(packer.version)))
return []wire.Frame{sf}, sf.Length(packer.version)
})
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.raw).To(HaveLen(int(maxPacketSize)))
Expect(p.frames[0].(*wire.StreamFrame).Data).To(HaveLen(len(sf.Data)))
Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse())
})
It("packs multiple small STREAM frames into single packet", func() {
f1 := &wire.StreamFrame{
StreamID: 5,
Data: []byte("frame 1"),
DataLenPresent: true,
}
f2 := &wire.StreamFrame{
StreamID: 5,
Data: []byte("frame 2"),
DataLenPresent: true,
}
f3 := &wire.StreamFrame{
StreamID: 3,
Data: []byte("frame 3"),
DataLenPresent: true,
}
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
expectAppendControlFrames()
expectAppendStreamFrames(f1, f2, f3)
p, err := packer.PackPacket()
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(3))
Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1")))
Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2")))
Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3")))
})
})
Context("retransmissions", func() {
It("retransmits a small packet", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
frames := []wire.Frame{
&wire.MaxDataFrame{ByteOffset: 0x1234},
&wire.StreamFrame{StreamID: 42, Data: []byte("foobar")},
}
packets, err := packer.PackRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.Encryption1RTT,
Frames: frames,
})
Expect(err).ToNot(HaveOccurred())
Expect(packets).To(HaveLen(1))
p := packets[0]
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
Expect(p.frames).To(Equal(frames))
})
It("packs two packets for retransmission if the original packet contained many control frames", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2)
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2)
var frames []wire.Frame
var totalLen protocol.ByteCount
// pack a bunch of control frames, such that the packet is way bigger than a single packet
for i := 0; totalLen < maxPacketSize*3/2; i++ {
f := &wire.MaxStreamDataFrame{
StreamID: protocol.StreamID(i),
ByteOffset: protocol.ByteCount(i),
}
frames = append(frames, f)
totalLen += f.Length(packer.version)
}
packets, err := packer.PackRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.Encryption1RTT,
Frames: frames,
})
Expect(err).ToNot(HaveOccurred())
Expect(packets).To(HaveLen(2))
Expect(len(packets[0].frames) + len(packets[1].frames)).To(Equal(len(frames)))
Expect(packets[1].frames).To(Equal(frames[len(packets[0].frames):]))
// check that the first packet was filled up as far as possible:
// if the first frame (after the STOP_WAITING) was packed into the first packet, it would have overflown the MaxPacketSize
Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize))
})
It("splits a STREAM frame that doesn't fit", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2)
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2)
packets, err := packer.PackRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.Encryption1RTT,
Frames: []wire.Frame{&wire.StreamFrame{
StreamID: 42,
Offset: 1337,
Data: bytes.Repeat([]byte{'a'}, int(maxPacketSize)*3/2),
}},
})
Expect(err).ToNot(HaveOccurred())
Expect(packets).To(HaveLen(2))
Expect(packets[0].frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{}))
Expect(packets[1].frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{}))
sf1 := packets[0].frames[0].(*wire.StreamFrame)
sf2 := packets[1].frames[0].(*wire.StreamFrame)
Expect(sf1.StreamID).To(Equal(protocol.StreamID(42)))
Expect(sf1.Offset).To(Equal(protocol.ByteCount(1337)))
Expect(sf1.DataLenPresent).To(BeFalse())
Expect(sf2.StreamID).To(Equal(protocol.StreamID(42)))
Expect(sf2.Offset).To(Equal(protocol.ByteCount(1337) + sf1.DataLen()))
Expect(sf2.DataLenPresent).To(BeFalse())
Expect(sf1.DataLen() + sf2.DataLen()).To(Equal(maxPacketSize * 3 / 2))
Expect(packets[0].raw).To(HaveLen(int(maxPacketSize)))
})
It("splits STREAM frames, if necessary", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).AnyTimes()
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).AnyTimes()
for i := 0; i < 100; i++ {
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).MaxTimes(2)
sf1 := &wire.StreamFrame{
StreamID: 42,
Offset: 1337,
Data: bytes.Repeat([]byte{'a'}, 1+int(rand.Int31n(int32(maxPacketSize*4/5)))),
}
sf2 := &wire.StreamFrame{
StreamID: 2,
Offset: 42,
Data: bytes.Repeat([]byte{'b'}, 1+int(rand.Int31n(int32(maxPacketSize*4/5)))),
}
expectedDataLen := sf1.DataLen() + sf2.DataLen()
frames := []wire.Frame{sf1, sf2}
packets, err := packer.PackRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.Encryption1RTT,
Frames: frames,
})
Expect(err).ToNot(HaveOccurred())
if len(packets) > 1 {
Expect(packets[0].raw).To(HaveLen(int(maxPacketSize)))
}
var dataLen protocol.ByteCount
for _, p := range packets {
for _, f := range p.frames {
dataLen += f.(*wire.StreamFrame).DataLen()
}
}
Expect(dataLen).To(Equal(expectedDataLen))
}
})
It("packs two packets for retransmission if the original packet contained many STREAM frames", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2)
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2)
var frames []wire.Frame
var totalLen protocol.ByteCount
// pack a bunch of control frames, such that the packet is way bigger than a single packet
for i := 0; totalLen < maxPacketSize*3/2; i++ {
f := &wire.StreamFrame{
StreamID: protocol.StreamID(i),
Data: []byte("foobar"),
DataLenPresent: true,
}
frames = append(frames, f)
totalLen += f.Length(packer.version)
}
packets, err := packer.PackRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.Encryption1RTT,
Frames: frames,
})
Expect(err).ToNot(HaveOccurred())
Expect(packets).To(HaveLen(2))
Expect(len(packets[0].frames) + len(packets[1].frames)).To(Equal(len(frames))) // all frames
Expect(packets[1].frames).To(Equal(frames[len(packets[0].frames):]))
// check that the first packet was filled up as far as possible:
// if the first frame was packed into the first packet, it would have overflown the MaxPacketSize
Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize-protocol.MinStreamFrameSize))
})
It("correctly sets the DataLenPresent on STREAM frames", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
frames := []wire.Frame{
&wire.StreamFrame{StreamID: 4, Data: []byte("foobar"), DataLenPresent: true},
&wire.StreamFrame{StreamID: 5, Data: []byte("barfoo")},
}
packets, err := packer.PackRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.Encryption1RTT,
Frames: frames,
})
Expect(err).ToNot(HaveOccurred())
Expect(packets).To(HaveLen(1))
p := packets[0]
Expect(p.frames).To(HaveLen(2))
Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{}))
Expect(p.frames[1]).To(BeAssignableToTypeOf(&wire.StreamFrame{}))
sf1 := p.frames[0].(*wire.StreamFrame)
sf2 := p.frames[1].(*wire.StreamFrame)
Expect(sf1.StreamID).To(Equal(protocol.StreamID(4)))
Expect(sf1.DataLenPresent).To(BeTrue())
Expect(sf2.StreamID).To(Equal(protocol.StreamID(5)))
Expect(sf2.DataLenPresent).To(BeFalse())
})
})
Context("max packet size", func() {
It("sets the maximum packet size", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
@ -805,8 +638,25 @@ var _ = Describe("Packet packer", func() {
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
expectedPacketLen := packer.maxPacketSize
Expect(p.raw).To(HaveLen(int(expectedPacketLen)))
Expect(p.raw).To(HaveLen(int(packer.maxPacketSize)))
Expect(p.header.IsLongHeader).To(BeTrue())
checkLength(p.raw)
})
It("adds retransmissions", func() {
f := &wire.CryptoFrame{Data: []byte("Initial")}
retransmissionQueue.AddInitial(f)
retransmissionQueue.AddHandshake(&wire.CryptoFrame{Data: []byte("Handshake")})
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil)
sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer"))
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
initialStream.EXPECT().HasData()
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
Expect(p.frames).To(Equal([]wire.Frame{f}))
Expect(p.header.IsLongHeader).To(BeTrue())
checkLength(p.raw)
})
@ -923,63 +773,25 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(packet).ToNot(BeNil())
})
Context("retransmitions", func() {
cf := &wire.CryptoFrame{Data: []byte("foo")}
It("packs a retransmission with the right encryption level", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil)
packet := &ackhandler.Packet{
EncryptionLevel: protocol.EncryptionInitial,
Frames: []wire.Frame{cf},
}
p, err := packer.PackRetransmission(packet)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(HaveLen(1))
Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial))
Expect(p[0].frames).To(Equal([]wire.Frame{cf}))
Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
})
It("packs a retransmission for an Initial packet", func() {
token := []byte("initial token")
packer.SetToken(token)
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil)
packer.perspective = protocol.PerspectiveClient
packet := &ackhandler.Packet{
EncryptionLevel: protocol.EncryptionInitial,
Frames: []wire.Frame{cf},
}
packets, err := packer.PackRetransmission(packet)
Expect(err).ToNot(HaveOccurred())
Expect(packets).To(HaveLen(1))
p := packets[0]
Expect(p.frames).To(Equal([]wire.Frame{cf}))
Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial))
Expect(p.header.Token).To(Equal(token))
Expect(p.raw).To(HaveLen(protocol.MinInitialPacketSize))
})
})
})
})
})
var _ = Describe("Converting to AckHandler packets", func() {
It("convert a packet", func() {
f1 := &wire.MaxDataFrame{}
f2 := &wire.PingFrame{}
packet := &packedPacket{
header: &wire.ExtendedHeader{Header: wire.Header{}},
frames: []wire.Frame{&wire.MaxDataFrame{}, &wire.PingFrame{}},
frames: []wire.Frame{f1, f2},
ack: &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 80}}},
raw: []byte("foobar"),
}
p := packet.ToAckHandlerPacket()
p := packet.ToAckHandlerPacket(nil)
Expect(p.Length).To(Equal(protocol.ByteCount(6)))
Expect(p.Frames).To(Equal(packet.frames))
Expect(p.Frames).To(HaveLen(2))
Expect(p.Frames[0].Frame).To(Equal(f1))
Expect(p.Frames[1].Frame).To(Equal(f2))
Expect(p.LargestAcked).To(Equal(protocol.PacketNumber(100)))
Expect(p.SendTime).To(BeTemporally("~", time.Now(), 50*time.Millisecond))
})
@ -990,7 +802,7 @@ var _ = Describe("Converting to AckHandler packets", func() {
frames: []wire.Frame{&wire.MaxDataFrame{}, &wire.PingFrame{}},
raw: []byte("foobar"),
}
p := packet.ToAckHandlerPacket()
p := packet.ToAckHandlerPacket(nil)
Expect(p.LargestAcked).To(Equal(protocol.InvalidPacketNumber))
})
})

132
retransmission_queue.go Normal file
View file

@ -0,0 +1,132 @@
package quic
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type retransmissionQueue struct {
initial []wire.Frame
initialCryptoData []*wire.CryptoFrame
handshake []wire.Frame
handshakeCryptoData []*wire.CryptoFrame
appData []wire.Frame
streamData []*wire.StreamFrame
version protocol.VersionNumber
}
func newRetransmissionQueue(ver protocol.VersionNumber) *retransmissionQueue {
return &retransmissionQueue{version: ver}
}
func (q *retransmissionQueue) AddInitial(f wire.Frame) {
if cf, ok := f.(*wire.CryptoFrame); ok {
q.initialCryptoData = append(q.initialCryptoData, cf)
return
}
q.initial = append(q.initial, f)
}
func (q *retransmissionQueue) AddHandshake(f wire.Frame) {
if cf, ok := f.(*wire.CryptoFrame); ok {
q.handshakeCryptoData = append(q.handshakeCryptoData, cf)
return
}
q.handshake = append(q.handshake, f)
}
func (q *retransmissionQueue) HasInitialData() bool {
return len(q.initialCryptoData) > 0 || len(q.initial) > 0
}
func (q *retransmissionQueue) HasHandshakeData() bool {
return len(q.handshakeCryptoData) > 0 || len(q.handshake) > 0
}
func (q *retransmissionQueue) AddAppData(f wire.Frame) {
if sf, ok := f.(*wire.StreamFrame); ok {
sf.DataLenPresent = true
q.streamData = append(q.streamData, sf)
return
}
q.appData = append(q.appData, f)
}
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Frame {
if len(q.initialCryptoData) > 0 {
if f := q.initialCryptoData[0]; f.Length(q.version) <= maxLen {
q.initialCryptoData = q.initialCryptoData[1:]
return f
}
}
if len(q.initial) == 0 {
return nil
}
f := q.initial[0]
if f.Length(q.version) > maxLen {
return nil
}
q.initial = q.initial[1:]
return f
}
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire.Frame {
if len(q.handshakeCryptoData) > 0 {
if f := q.handshakeCryptoData[0]; f.Length(q.version) <= maxLen {
q.handshakeCryptoData = q.handshakeCryptoData[1:]
return f
}
}
if len(q.handshake) == 0 {
return nil
}
f := q.handshake[0]
if f.Length(q.version) > maxLen {
return nil
}
q.handshake = q.handshake[1:]
return f
}
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount) wire.Frame {
if len(q.streamData) > 0 {
f := q.streamData[0]
if f.Length(q.version) <= maxLen {
q.streamData = q.streamData[1:]
return f
}
if maxLen >= protocol.MinStreamFrameSize {
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version)
if needsSplit && newFrame != nil {
return newFrame
}
}
}
if len(q.appData) == 0 {
return nil
}
f := q.appData[0]
if f.Length(q.version) > maxLen {
return nil
}
q.appData = q.appData[1:]
return f
}
func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) {
switch encLevel {
case protocol.EncryptionInitial:
q.initial = nil
q.initialCryptoData = nil
case protocol.EncryptionHandshake:
q.handshake = nil
q.handshakeCryptoData = nil
default:
panic(fmt.Sprintf("unexpected encryption level: %s", encLevel))
}
}

View file

@ -0,0 +1,158 @@
package quic
import (
"math/rand"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Retransmission queue", func() {
const version = protocol.VersionTLS
var q *retransmissionQueue
BeforeEach(func() {
q = newRetransmissionQueue(version)
})
Context("Initial data", func() {
It("doesn't dequeue anything when it's empty", func() {
Expect(q.HasInitialData()).To(BeFalse())
Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil())
})
It("queues and retrieves a control frame", func() {
f := &wire.MaxDataFrame{ByteOffset: 0x42}
q.AddInitial(f)
Expect(q.HasInitialData()).To(BeTrue())
Expect(q.GetInitialFrame(f.Length(version) - 1)).To(BeNil())
Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f))
Expect(q.HasInitialData()).To(BeFalse())
})
It("queues and retrieves a CRYPTO frame", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")}
q.AddInitial(f)
Expect(q.HasInitialData()).To(BeTrue())
Expect(q.GetInitialFrame(f.Length(version) - 1)).To(BeNil())
Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f))
Expect(q.HasInitialData()).To(BeFalse())
})
It("retrieves both a CRYPTO frame and a control frame", func() {
cf := &wire.MaxDataFrame{ByteOffset: 0x42}
f := &wire.CryptoFrame{Data: []byte("foobar")}
q.AddInitial(f)
q.AddInitial(cf)
Expect(q.HasInitialData()).To(BeTrue())
Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(f))
Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(cf))
Expect(q.HasInitialData()).To(BeFalse())
})
It("drops all Initial frames", func() {
q.AddInitial(&wire.CryptoFrame{Data: []byte("foobar")})
q.AddInitial(&wire.MaxDataFrame{ByteOffset: 0x42})
q.DropPackets(protocol.EncryptionInitial)
Expect(q.HasInitialData()).To(BeFalse())
Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil())
})
})
Context("Handshake data", func() {
It("doesn't dequeue anything when it's empty", func() {
Expect(q.HasHandshakeData()).To(BeFalse())
Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil())
})
It("queues and retrieves a control frame", func() {
f := &wire.MaxDataFrame{ByteOffset: 0x42}
q.AddHandshake(f)
Expect(q.HasHandshakeData()).To(BeTrue())
Expect(q.GetHandshakeFrame(f.Length(version) - 1)).To(BeNil())
Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f))
Expect(q.HasHandshakeData()).To(BeFalse())
})
It("queues and retrieves a CRYPTO frame", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")}
q.AddHandshake(f)
Expect(q.HasHandshakeData()).To(BeTrue())
Expect(q.GetHandshakeFrame(f.Length(version) - 1)).To(BeNil())
Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f))
Expect(q.HasHandshakeData()).To(BeFalse())
})
It("retrieves both a CRYPTO frame and a control frame", func() {
cf := &wire.MaxDataFrame{ByteOffset: 0x42}
f := &wire.CryptoFrame{Data: []byte("foobar")}
q.AddHandshake(f)
q.AddHandshake(cf)
Expect(q.HasHandshakeData()).To(BeTrue())
Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(f))
Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(cf))
Expect(q.HasHandshakeData()).To(BeFalse())
})
It("drops all Initial frames", func() {
q.AddHandshake(&wire.CryptoFrame{Data: []byte("foobar")})
q.AddHandshake(&wire.MaxDataFrame{ByteOffset: 0x42})
q.DropPackets(protocol.EncryptionHandshake)
Expect(q.HasHandshakeData()).To(BeFalse())
Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil())
})
})
Context("Application data", func() {
It("doesn't dequeue anything when it's empty", func() {
Expect(q.GetAppDataFrame(protocol.MaxByteCount)).To(BeNil())
})
It("queues and retrieves a control frame", func() {
f := &wire.MaxDataFrame{ByteOffset: 0x42}
q.AddAppData(f)
Expect(q.GetAppDataFrame(f.Length(version) - 1)).To(BeNil())
Expect(q.GetAppDataFrame(f.Length(version))).To(Equal(f))
})
It("queues and retrieves a STREAM frame", func() {
f := &wire.StreamFrame{Data: []byte("foobar")}
q.AddAppData(f)
Expect(q.GetAppDataFrame(f.Length(version) - 1)).To(BeNil())
Expect(q.GetAppDataFrame(f.Length(version))).To(Equal(f))
})
It("splits STREAM frames larger than MinStreamFrameSize", func() {
data := make([]byte, 1000)
rand.Read(data)
f := &wire.StreamFrame{
Data: data,
FinBit: true,
}
q.AddAppData(f)
Expect(q.GetAppDataFrame(protocol.MinStreamFrameSize - 1)).To(BeNil())
f1 := q.GetAppDataFrame(protocol.MinStreamFrameSize).(*wire.StreamFrame)
Expect(f1).ToNot(BeNil())
Expect(f1.Length(version)).To(Equal(protocol.MinStreamFrameSize))
Expect(f1.FinBit).To(BeFalse())
Expect(f1.Data).To(Equal(data[:f1.DataLen()]))
f2 := q.GetAppDataFrame(protocol.MaxByteCount).(*wire.StreamFrame)
Expect(f2).ToNot(BeNil())
Expect(f2.FinBit).To(BeTrue())
Expect(f1.DataLen() + f2.DataLen()).To(BeEquivalentTo(1000))
Expect(f2.Data).To(Equal(data[f1.DataLen():]))
Expect(q.GetAppDataFrame(protocol.MaxByteCount)).To(BeNil())
})
It("returns a control frame if it doesn't split a STREAM frame", func() {
cf := &wire.MaxDataFrame{ByteOffset: 0x42}
q.AddAppData(&wire.StreamFrame{Data: make([]byte, 1000)})
q.AddAppData(cf)
Expect(q.GetAppDataFrame(protocol.MinStreamFrameSize - 1)).To(Equal(cf))
})
})
})

View file

@ -115,6 +115,7 @@ type session struct {
cryptoStreamManager *cryptoStreamManager
sentPacketHandler ackhandler.SentPacketHandler
receivedPacketHandler ackhandler.ReceivedPacketHandler
retransmissionQueue *retransmissionQueue
framer framer
windowUpdateQueue *windowUpdateQueue
connFlowController flowcontrol.ConnectionFlowController
@ -242,6 +243,7 @@ var newSession = func(
initialStream,
handshakeStream,
s.sentPacketHandler,
s.retransmissionQueue,
s.RemoteAddr(),
cs,
s.framer,
@ -328,6 +330,7 @@ var newClientSession = func(
initialStream,
handshakeStream,
s.sentPacketHandler,
s.retransmissionQueue,
s.RemoteAddr(),
cs,
s.framer,
@ -350,6 +353,7 @@ var newClientSession = func(
func (s *session) preSetup() {
s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue(s.version)
s.frameParser = wire.NewFrameParser(s.version)
s.rttStats = &congestion.RTTStats{}
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
@ -1109,16 +1113,6 @@ sendLoop:
return err
}
numPacketsSent++
case ackhandler.SendRetransmission:
sentPacket, err := s.maybeSendRetransmission()
if err != nil {
return err
}
if sentPacket {
numPacketsSent++
// This can happen if a retransmission queued, but it wasn't necessary to send it.
// e.g. when an Initial is queued, but we already received a packet from the server.
}
case ackhandler.SendAny:
sentPacket, err := s.sendPacket()
if err != nil {
@ -1152,47 +1146,29 @@ func (s *session) maybeSendAckOnlyPacket() error {
if packet == nil {
return nil
}
s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket())
s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(s.retransmissionQueue))
s.sendQueue.Send(packet)
return nil
}
// maybeSendRetransmission sends retransmissions for at most one packet.
// It takes care that Initials aren't retransmitted, if a packet from the server was already received.
func (s *session) maybeSendRetransmission() (bool, error) {
retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission()
if retransmitPacket == nil {
return false, nil
}
s.logger.Debugf("Dequeueing retransmission for packet 0x%x (%s)", retransmitPacket.PacketNumber, retransmitPacket.EncryptionLevel)
packets, err := s.packer.PackRetransmission(retransmitPacket)
if err != nil {
return false, err
}
for _, packet := range packets {
s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket())
s.sendPackedPacket(packet)
}
return true, nil
}
func (s *session) sendProbePacket() error {
p, err := s.sentPacketHandler.DequeueProbePacket()
if err != nil {
return err
// Queue probe packets until we actually send out a packet.
for {
if wasQueued := s.sentPacketHandler.QueueProbePacket(); !wasQueued {
break
}
sent, err := s.sendPacket()
if err != nil {
return err
}
if sent {
return nil
}
}
s.logger.Debugf("Sending a retransmission for %#x as a probe packet.", p.PacketNumber)
packets, err := s.packer.PackRetransmission(p)
if err != nil {
return err
}
for _, packet := range packets {
s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket())
s.sendPackedPacket(packet)
}
return nil
// If there is nothing else to queue, make sure we send out something.
s.framer.QueueControlFrame(&wire.PingFrame{})
_, err := s.sendPacket()
return err
}
func (s *session) sendPacket() (bool, error) {
@ -1205,7 +1181,7 @@ func (s *session) sendPacket() (bool, error) {
if err != nil || packet == nil {
return false, err
}
s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket())
s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(s.retransmissionQueue))
s.sendPackedPacket(packet)
return true, nil
}

View file

@ -892,72 +892,14 @@ var _ = Describe("Session", func() {
Expect(frames).To(Equal([]wire.Frame{&wire.DataBlockedFrame{DataLimit: 1337}}))
})
It("sends a retransmission and a regular packet in the same run", func() {
packetToRetransmit := &ackhandler.Packet{PacketNumber: 10}
retransmittedPacket := getPacket(123)
newPacket := getPacket(234)
sess.windowUpdateQueue.callback(&wire.MaxDataFrame{})
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().DequeuePacketForRetransmission().Return(packetToRetransmit)
sph.EXPECT().SendMode().Return(ackhandler.SendRetransmission)
sph.EXPECT().SendMode().Return(ackhandler.SendAny)
sph.EXPECT().ShouldSendNumPackets().Return(2)
sph.EXPECT().TimeUntilSend()
gomock.InOrder(
packer.EXPECT().PackRetransmission(packetToRetransmit).Return([]*packedPacket{retransmittedPacket}, nil),
sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) {
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123)))
}),
packer.EXPECT().PackPacket().Return(newPacket, nil),
sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) {
Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(234)))
}),
)
sess.sentPacketHandler = sph
Expect(sess.sendPackets()).To(Succeed())
Eventually(mconn.written).Should(HaveLen(2))
})
It("sends multiple packets, if the retransmission is split", func() {
packet := &ackhandler.Packet{
PacketNumber: 42,
Frames: []wire.Frame{&wire.StreamFrame{
StreamID: 0x5,
Data: []byte("foobar"),
}},
EncryptionLevel: protocol.Encryption1RTT,
}
retransmissions := []*packedPacket{getPacket(1337), getPacket(1338)}
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().DequeuePacketForRetransmission().Return(packet)
packer.EXPECT().PackRetransmission(packet).Return(retransmissions, nil)
gomock.InOrder(
sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) {
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(1337)))
}),
sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) {
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(1338)))
}),
)
sess.sentPacketHandler = sph
sent, err := sess.maybeSendRetransmission()
Expect(err).NotTo(HaveOccurred())
Expect(sent).To(BeTrue())
Eventually(mconn.written).Should(HaveLen(2))
})
It("sends a probe packet", func() {
packetToRetransmit := &ackhandler.Packet{PacketNumber: 0x42}
retransmittedPacket := getPacket(123)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend()
sph.EXPECT().SendMode().Return(ackhandler.SendPTO)
sph.EXPECT().ShouldSendNumPackets().Return(1)
sph.EXPECT().DequeueProbePacket().Return(packetToRetransmit, nil)
packer.EXPECT().PackRetransmission(packetToRetransmit).Return([]*packedPacket{retransmittedPacket}, nil)
sph.EXPECT().QueueProbePacket()
packer.EXPECT().PackPacket().Return(getPacket(123), nil)
sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) {
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123)))
})
@ -965,6 +907,25 @@ var _ = Describe("Session", func() {
Expect(sess.sendPackets()).To(Succeed())
})
It("sends a PING as a probe packet", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend()
sph.EXPECT().SendMode().Return(ackhandler.SendPTO)
sph.EXPECT().ShouldSendNumPackets().Return(1)
sph.EXPECT().QueueProbePacket().Return(false)
packer.EXPECT().PackPacket().Return(getPacket(123), nil)
sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) {
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123)))
})
sess.sentPacketHandler = sph
Expect(sess.sendPackets()).To(Succeed())
// We're using a mock packet packer in this test.
// We therefore need to test separately that the PING was actually queued.
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
Expect(frames).To(Equal([]wire.Frame{&wire.PingFrame{}}))
})
It("doesn't send when the SentPacketHandler doesn't allow it", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
@ -981,7 +942,6 @@ var _ = Describe("Session", func() {
BeforeEach(func() {
sph = mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().DequeuePacketForRetransmission().AnyTimes()
sess.sentPacketHandler = sph
streamManager.EXPECT().CloseWithError(gomock.Any())
})