diff --git a/internal/ackhandler/frame.go b/internal/ackhandler/frame.go new file mode 100644 index 00000000..83b09ceb --- /dev/null +++ b/internal/ackhandler/frame.go @@ -0,0 +1,8 @@ +package ackhandler + +import "github.com/lucas-clemente/quic-go/internal/wire" + +type Frame struct { + wire.Frame // nil if the frame has already been acknowledged in another packet + OnLost func(wire.Frame) +} diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 2b03c977..c66e5008 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -11,13 +11,12 @@ import ( // A Packet is a packet type Packet struct { PacketNumber protocol.PacketNumber - Frames []wire.Frame + Frames []Frame LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK Length protocol.ByteCount EncryptionLevel protocol.EncryptionLevel SendTime time.Time - canBeRetransmitted bool includedInBytesInFlight bool } @@ -43,8 +42,7 @@ type SentPacketHandler interface { // only to be called once the handshake is complete GetLowestPacketNotConfirmedAcked() protocol.PacketNumber - DequeuePacketForRetransmission() *Packet - DequeueProbePacket() (*Packet, error) + QueueProbePacket() bool /* was a packet queued */ PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber diff --git a/internal/ackhandler/send_mode.go b/internal/ackhandler/send_mode.go index 8cdaa7e6..360ce2e3 100644 --- a/internal/ackhandler/send_mode.go +++ b/internal/ackhandler/send_mode.go @@ -10,8 +10,6 @@ const ( SendNone SendMode = iota // SendAck means an ACK-only packet should be sent SendAck - // SendRetransmission means that retransmissions should be sent - SendRetransmission // SendPTO means that a probe packet should be sent SendPTO // SendAny means that any packet should be sent @@ -24,8 +22,6 @@ func (s SendMode) String() string { return "none" case SendAck: return "ack" - case SendRetransmission: - return "retransmission" case SendPTO: return "pto" case SendAny: diff --git a/internal/ackhandler/send_mode_test.go b/internal/ackhandler/send_mode_test.go index 251daffc..0b846b85 100644 --- a/internal/ackhandler/send_mode_test.go +++ b/internal/ackhandler/send_mode_test.go @@ -11,7 +11,6 @@ var _ = Describe("Send Mode", func() { Expect(SendAny.String()).To(Equal("any")) Expect(SendAck.String()).To(Equal("ack")) Expect(SendPTO.String()).To(Equal("pto")) - Expect(SendRetransmission.String()).To(Equal("retransmission")) Expect(SendMode(123).String()).To(Equal("invalid send mode: 123")) }) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 56600ef8..6a339b4c 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -1,7 +1,6 @@ package ackhandler import ( - "errors" "fmt" "math" "time" @@ -56,8 +55,6 @@ type sentPacketHandler struct { // Only applies to the application-data packet number space. lowestNotConfirmedAcked protocol.PacketNumber - retransmissionQueue []*Packet - bytesInFlight protocol.ByteCount congestion congestion.SendAlgorithmWithDebugInfos @@ -112,14 +109,6 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { } return true, nil }) - // remove packets from the retransmission queue - var queue []*Packet - for _, packet := range h.retransmissionQueue { - if packet.EncryptionLevel != encLevel { - queue = append(queue, packet) - } - } - h.retransmissionQueue = queue // drop the packet history switch encLevel { case protocol.EncryptionInitial: @@ -170,7 +159,6 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-elicit h.lastSentAckElicitingPacketTime = packet.SendTime packet.includedInBytesInFlight = true h.bytesInFlight += packet.Length - packet.canBeRetransmitted = true if h.numProbesToSend > 0 { h.numProbesToSend-- } @@ -335,12 +323,14 @@ func (h *sentPacketHandler) setLossDetectionTimer() { // Cancel the alarm if no packets are outstanding if !h.hasOutstandingPackets() { + h.logger.Debugf("setLossDetectionTimer: canceling. Bytes in flight: %d", h.bytesInFlight) h.alarm = time.Time{} return } // PTO alarm h.alarm = h.lastSentAckElicitingPacketTime.Add(h.rttStats.PTO() << h.ptoCount) + h.logger.Debugf("setLossDetectionTimer: setting to", h.alarm) } func (h *sentPacketHandler) detectLostPackets( @@ -388,27 +378,23 @@ func (h *sentPacketHandler) detectLostPackets( } for _, p := range lostPackets { + h.queueFramesForRetransmission(p) // the bytes in flight need to be reduced no matter if this packet will be retransmitted if p.includedInBytesInFlight { h.bytesInFlight -= p.Length h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) } - if p.canBeRetransmitted { - // queue the packet for retransmission, and report the loss to the congestion controller - if err := h.queuePacketForRetransmission(p, pnSpace); err != nil { - return err - } - } pnSpace.history.Remove(p.PacketNumber) if h.traceCallback != nil { + // TODO: trace frames h.traceCallback(quictrace.Event{ Time: now, EventType: quictrace.PacketLost, EncryptionLevel: p.EncryptionLevel, PacketNumber: p.PacketNumber, PacketSize: p.Length, - Frames: p.Frames, - TransportState: h.GetStats(), + //Frames: p.Frames, + TransportState: h.GetStats(), }) } } @@ -454,63 +440,16 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error { pnSpace := h.getPacketNumberSpace(p.EncryptionLevel) - // This happens if a packet and its retransmissions is acked in the same ACK. - // As soon as we process the first one, this will remove all the retransmissions, - // so we won't find the retransmitted packet number later. if packet := pnSpace.history.GetPacket(p.PacketNumber); packet == nil { return nil } - // this also applies to packets that have been retransmitted as probe packets if p.includedInBytesInFlight { h.bytesInFlight -= p.Length } - if err := pnSpace.history.MarkCannotBeRetransmitted(p.PacketNumber); err != nil { - return err - } return pnSpace.history.Remove(p.PacketNumber) } -func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { - if len(h.retransmissionQueue) == 0 { - return nil - } - packet := h.retransmissionQueue[0] - // Shift the slice and don't retain anything that isn't needed. - copy(h.retransmissionQueue, h.retransmissionQueue[1:]) - h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil - h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1] - return packet -} - -func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) { - if len(h.retransmissionQueue) > 0 { - return h.DequeuePacketForRetransmission(), nil - } - - var pnSpace *packetNumberSpace - var p *Packet - if h.initialPackets != nil { - pnSpace = h.initialPackets - p = h.initialPackets.history.FirstOutstanding() - } - if p == nil && h.handshakePackets != nil { - pnSpace = h.handshakePackets - p = h.handshakePackets.history.FirstOutstanding() - } - if p == nil { - pnSpace = h.oneRTTPackets - p = h.oneRTTPackets.history.FirstOutstanding() - } - if p == nil { - return nil, errors.New("cannot dequeue a probe packet. No outstanding packets") - } - if err := h.queuePacketForRetransmission(p, pnSpace); err != nil { - return nil, err - } - return h.DequeuePacketForRetransmission(), nil -} - func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { pnSpace := h.getPacketNumberSpace(encLevel) @@ -530,7 +469,7 @@ func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) p } func (h *sentPacketHandler) SendMode() SendMode { - numTrackedPackets := len(h.retransmissionQueue) + h.oneRTTPackets.history.Len() + numTrackedPackets := h.oneRTTPackets.history.Len() if h.initialPackets != nil { numTrackedPackets += h.initialPackets.history.Len() } @@ -558,10 +497,6 @@ func (h *sentPacketHandler) SendMode() SendMode { } return SendAck } - // Send retransmissions first, if there are any. - if len(h.retransmissionQueue) > 0 { - return SendRetransmission - } if numTrackedPackets >= protocol.MaxOutstandingSentPackets { if h.logger.Debug() { h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) @@ -587,30 +522,45 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int { return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) } -func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet, pnSpace *packetNumberSpace) error { - if !p.canBeRetransmitted { - return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber) +func (h *sentPacketHandler) QueueProbePacket() bool { + var p *Packet + if h.initialPackets != nil { + p = h.initialPackets.history.FirstOutstanding() } - if err := pnSpace.history.MarkCannotBeRetransmitted(p.PacketNumber); err != nil { - return err + if p == nil && h.handshakePackets != nil { + p = h.handshakePackets.history.FirstOutstanding() + } + if p == nil { + p = h.oneRTTPackets.history.FirstOutstanding() + } + if p == nil { + return false + } + h.queueFramesForRetransmission(p) + // TODO: don't remove the packet here + // Keep track of acknowledged frames instead. + if p.includedInBytesInFlight { + h.bytesInFlight -= p.Length + } + if err := h.getPacketNumberSpace(p.EncryptionLevel).history.Remove(p.PacketNumber); err != nil { + // should never happen. We just got this packet from the history a lines above. + panic(err) + } + return true +} + +func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { + for _, f := range p.Frames { + f.OnLost(f.Frame) } - h.retransmissionQueue = append(h.retransmissionQueue, p) - return nil } func (h *sentPacketHandler) ResetForRetry() error { h.bytesInFlight = 0 - var packets []*Packet h.initialPackets.history.Iterate(func(p *Packet) (bool, error) { - if p.canBeRetransmitted { - packets = append(packets, p) - } + h.queueFramesForRetransmission(p) return true, nil }) - for _, p := range packets { - h.logger.Debugf("Queueing packet %#x for retransmission.", p.PacketNumber) - h.retransmissionQueue = append(h.retransmissionQueue, p) - } h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop()) h.setLossDetectionTimer() return nil diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 2bf944a0..74b81214 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -14,40 +14,15 @@ import ( . "github.com/onsi/gomega" ) -func ackElicitingPacket(p *Packet) *Packet { - if p.EncryptionLevel == protocol.EncryptionUnspecified { - p.EncryptionLevel = protocol.Encryption1RTT - } - if p.Length == 0 { - p.Length = 1 - } - if p.SendTime.IsZero() { - p.SendTime = time.Now() - } - p.Frames = []wire.Frame{&wire.PingFrame{}} - return p -} - -func nonAckElicitingPacket(p *Packet) *Packet { - p = ackElicitingPacket(p) - p.Frames = nil - p.LargestAcked = 1 - return p -} - -func cryptoPacket(p *Packet) *Packet { - p = ackElicitingPacket(p) - p.EncryptionLevel = protocol.EncryptionInitial - return p -} - var _ = Describe("SentPacketHandler", func() { var ( handler *sentPacketHandler streamFrame wire.StreamFrame + lostPackets []protocol.PacketNumber ) BeforeEach(func() { + lostPackets = nil rttStats := &congestion.RTTStats{} handler = NewSentPacketHandler(42, rttStats, nil, utils.DefaultLogger).(*sentPacketHandler) streamFrame = wire.StreamFrame{ @@ -63,6 +38,35 @@ var _ = Describe("SentPacketHandler", func() { return nil } + ackElicitingPacket := func(p *Packet) *Packet { + if p.EncryptionLevel == protocol.EncryptionUnspecified { + p.EncryptionLevel = protocol.Encryption1RTT + } + if p.Length == 0 { + p.Length = 1 + } + if p.SendTime.IsZero() { + p.SendTime = time.Now() + } + p.Frames = []Frame{ + {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, p.PacketNumber) }}, + } + return p + } + + nonAckElicitingPacket := func(p *Packet) *Packet { + p = ackElicitingPacket(p) + p.Frames = nil + p.LargestAcked = 1 + return p + } + + cryptoPacket := func(p *Packet) *Packet { + p = ackElicitingPacket(p) + p.EncryptionLevel = protocol.EncryptionInitial + return p + } + expectInPacketHistory := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { pnSpace := handler.getPacketNumberSpace(encLevel) ExpectWithOffset(1, pnSpace.history.Len()).To(Equal(len(expected))) @@ -161,21 +165,21 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("acks and nacks the right packets", func() { + Context("acks the right packets", func() { expectInPacketHistoryOrLost := func(expected []protocol.PacketNumber, encLevel protocol.EncryptionLevel) { pnSpace := handler.getPacketNumberSpace(encLevel) - ExpectWithOffset(1, pnSpace.history.Len()+len(handler.retransmissionQueue)).To(Equal(len(expected))) + ExpectWithOffset(1, pnSpace.history.Len()+len(lostPackets)).To(Equal(len(expected))) expectedLoop: for _, p := range expected { if _, ok := pnSpace.history.packetMap[p]; ok { continue } - for _, lost := range handler.retransmissionQueue { - if lost.PacketNumber == p { + for _, lostP := range lostPackets { + if lostP == p { continue expectedLoop } } - Fail(fmt.Sprintf("Packet %d neither in packet history nor declared lost.", p)) + Fail(fmt.Sprintf("Packet %d not in packet history.", p)) } } @@ -320,9 +324,26 @@ var _ = Describe("SentPacketHandler", func() { Context("determining which ACKs we have received an ACK for", func() { BeforeEach(func() { morePackets := []*Packet{ - {PacketNumber: 13, LargestAcked: 100, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT}, - {PacketNumber: 14, LargestAcked: 200, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT}, - {PacketNumber: 15, Frames: []wire.Frame{&streamFrame}, Length: 1, EncryptionLevel: protocol.Encryption1RTT}, + { + PacketNumber: 13, + LargestAcked: 100, + Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, + Length: 1, + EncryptionLevel: protocol.Encryption1RTT, + }, + { + PacketNumber: 14, + LargestAcked: 200, + Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, + Length: 1, + EncryptionLevel: protocol.Encryption1RTT, + }, + { + PacketNumber: 15, + Frames: []Frame{{Frame: &streamFrame, OnLost: func(wire.Frame) {}}}, + Length: 1, + EncryptionLevel: protocol.Encryption1RTT, + }, } for _, packet := range morePackets { handler.SentPacket(packet) @@ -355,11 +376,6 @@ var _ = Describe("SentPacketHandler", func() { }) }) - It("does not dequeue a packet if no ACK has been received", func() { - handler.SentPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.Encryption1RTT, SendTime: time.Now().Add(-time.Hour)}) - Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) - }) - Context("congestion", func() { var cong *mocks.MockSendAlgorithmWithDebugInfos @@ -380,7 +396,7 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(&Packet{ PacketNumber: 1, Length: 42, - Frames: []wire.Frame{&wire.PingFrame{}}, + Frames: []Frame{{Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) {}}}, EncryptionLevel: protocol.Encryption1RTT, }) }) @@ -474,32 +490,10 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode()).To(Equal(SendAck)) }) - It("doesn't allow retransmissions if congestion limited", func() { - handler.retransmissionQueue = []*Packet{{PacketNumber: 3}} - cong.EXPECT().CanSend(gomock.Any()).Return(false) - Expect(handler.SendMode()).To(Equal(SendAck)) - }) - - It("allows sending retransmissions", func() { - cong.EXPECT().CanSend(gomock.Any()).Return(true) - handler.retransmissionQueue = []*Packet{{PacketNumber: 3}} - Expect(handler.SendMode()).To(Equal(SendRetransmission)) - }) - - It("allows retransmissions, if we're keeping track of between MaxOutstandingSentPackets and MaxTrackedSentPackets packets", func() { - cong.EXPECT().CanSend(gomock.Any()).Return(true) - Expect(protocol.MaxOutstandingSentPackets).To(BeNumerically("<", protocol.MaxTrackedSentPackets)) - handler.retransmissionQueue = make([]*Packet, protocol.MaxOutstandingSentPackets+10) - Expect(handler.SendMode()).To(Equal(SendRetransmission)) - handler.retransmissionQueue = make([]*Packet, protocol.MaxTrackedSentPackets) - Expect(handler.SendMode()).To(Equal(SendNone)) - }) - It("allows RTOs, even when congestion limited", func() { // note that we don't EXPECT a call to GetCongestionWindow // that means retransmissions are sent without considering the congestion window handler.numProbesToSend = 1 - handler.retransmissionQueue = []*Packet{{PacketNumber: 3}} Expect(handler.SendMode()).To(Equal(SendPTO)) }) @@ -550,6 +544,19 @@ var _ = Describe("SentPacketHandler", func() { }) Context("probe packets", func() { + It("queues a probe packet", func() { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) + queued := handler.QueueProbePacket() + Expect(queued).To(BeTrue()) + Expect(lostPackets).To(Equal([]protocol.PacketNumber{10})) + }) + + It("says when it can't queue a probe packet", func() { + queued := handler.QueueProbePacket() + Expect(queued).To(BeFalse()) + }) + It("implements exponential backoff", func() { sendTime := time.Now().Add(-time.Hour) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) @@ -563,9 +570,16 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(4 * timeout)) }) - It("sets the TPO send mode until two packets is sent", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) - handler.OnLossDetectionTimeout() + It("sets the PTO send mode until two packets is sent", func() { + var lostPackets []protocol.PacketNumber + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: 1, + SendTime: time.Now().Add(-time.Hour), + Frames: []Frame{ + {Frame: &wire.PingFrame{}, OnLost: func(wire.Frame) { lostPackets = append(lostPackets, 1) }}, + }, + })) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTO)) Expect(handler.ShouldSendNumPackets()).To(Equal(2)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) @@ -576,7 +590,7 @@ var _ = Describe("SentPacketHandler", func() { It("only counts ack-eliciting packets as probe packets", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) - handler.OnLossDetectionTimeout() + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTO)) Expect(handler.ShouldSendNumPackets()).To(Equal(2)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) @@ -596,104 +610,49 @@ var _ = Describe("SentPacketHandler", func() { updateRTT(time.Hour) Expect(handler.oneRTTPackets.lossTime.IsZero()).To(BeTrue()) - handler.OnLossDetectionTimeout() // TLP - handler.OnLossDetectionTimeout() // TLP - handler.OnLossDetectionTimeout() // RTO - p, err := handler.DequeueProbePacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(1))) - p, err = handler.DequeueProbePacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP + Expect(handler.ptoCount).To(BeEquivalentTo(1)) + Expect(handler.SendMode()).To(Equal(SendPTO)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) + Expect(handler.SendMode()).To(Equal(SendPTO)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4})) - Expect(handler.ptoCount).To(BeEquivalentTo(3)) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // PTO + Expect(handler.ptoCount).To(BeEquivalentTo(2)) + Expect(handler.SendMode()).To(Equal(SendPTO)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5})) + Expect(handler.SendMode()).To(Equal(SendPTO)) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 6})) + + Expect(handler.SendMode()).To(Equal(SendAny)) }) - It("gets two probe packets if RTO expires, for crypto packets", func() { + It("gets two probe packets if PTO expires, for crypto packets", func() { handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 1})) handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 2})) updateRTT(time.Hour) Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue()) - handler.OnLossDetectionTimeout() // TLP - handler.OnLossDetectionTimeout() // TLP - handler.OnLossDetectionTimeout() // RTO - p, err := handler.DequeueProbePacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(1))) - p, err = handler.DequeueProbePacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTO)) + handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 3})) + Expect(handler.SendMode()).To(Equal(SendPTO)) + handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 3})) - Expect(handler.ptoCount).To(BeEquivalentTo(3)) - }) - - It("doesn't delete packets transmitted as PTO from the history", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: time.Now().Add(-time.Hour)})) - handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) - handler.OnLossDetectionTimeout() // TLP - handler.OnLossDetectionTimeout() // TLP - handler.OnLossDetectionTimeout() // RTO - _, err := handler.DequeueProbePacket() - Expect(err).ToNot(HaveOccurred()) - _, err = handler.DequeueProbePacket() - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistory([]protocol.PacketNumber{1, 2}, protocol.Encryption1RTT) - Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2))) - // Send a probe packet and receive an ACK for it. - // This verifies the RTO. - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 3}}} - Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())).To(Succeed()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.oneRTTPackets.history.Len()).To(BeZero()) - Expect(handler.bytesInFlight).To(BeZero()) - Expect(handler.retransmissionQueue).To(BeEmpty()) // 1 and 2 were already sent as probe packets + Expect(handler.SendMode()).To(Equal(SendAny)) }) It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) - handler.OnLossDetectionTimeout() + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendPTO)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendAny)) }) - It("gets packets sent before the probe packet for retransmission", func() { - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: time.Now().Add(-time.Hour)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3, SendTime: time.Now().Add(-time.Hour)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4, SendTime: time.Now().Add(-time.Hour)})) - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)})) - handler.OnLossDetectionTimeout() // TLP - handler.OnLossDetectionTimeout() // TLP - handler.OnLossDetectionTimeout() // RTO - _, err := handler.DequeueProbePacket() - Expect(err).ToNot(HaveOccurred()) - _, err = handler.DequeueProbePacket() - Expect(err).ToNot(HaveOccurred()) - expectInPacketHistory([]protocol.PacketNumber{1, 2, 3, 4, 5}, protocol.Encryption1RTT) - // Send a probe packet and receive an ACK for it. - // This verifies the RTO. - handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 6})) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 6, Largest: 6}}} - err = handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now()) - Expect(err).ToNot(HaveOccurred()) - Expect(handler.oneRTTPackets.history.Len()).To(BeZero()) - Expect(handler.bytesInFlight).To(BeZero()) - Expect(handler.retransmissionQueue).To(HaveLen(3)) // packets 3, 4, 5 - }) - It("handles ACKs for the original packet", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)})) handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) @@ -712,11 +671,7 @@ var _ = Describe("SentPacketHandler", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 6, Largest: 6}}} Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())).To(Succeed()) expectInPacketHistory([]protocol.PacketNumber{4, 5}, protocol.Encryption1RTT) - for _, p := range []protocol.PacketNumber{1, 2, 3} { - lost := handler.DequeuePacketForRetransmission() - Expect(lost).ToNot(BeNil()) - Expect(lost.PacketNumber).To(Equal(p)) - } + Expect(lostPackets).To(Equal([]protocol.PacketNumber{1, 2, 3})) }) }) @@ -729,8 +684,6 @@ var _ = Describe("SentPacketHandler", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, now)).To(Succeed()) - Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil()) - Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) // no need to set an alarm, since packet 1 was already declared lost Expect(handler.oneRTTPackets.lossTime.IsZero()).To(BeTrue()) Expect(handler.bytesInFlight).To(BeZero()) @@ -750,11 +703,6 @@ var _ = Describe("SentPacketHandler", func() { // Packet 1 should be considered lost (1+1/8) RTTs after it was sent. Expect(handler.oneRTTPackets.lossTime.IsZero()).To(BeFalse()) Expect(handler.oneRTTPackets.lossTime.Sub(getPacket(1, protocol.Encryption1RTT).SendTime)).To(Equal(time.Second * 9 / 8)) - - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil()) - // make sure this is not an RTO: only packet 1 is retransmissted - Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) }) It("sets the early retransmit alarm for crypto packets", func() { @@ -771,22 +719,15 @@ var _ = Describe("SentPacketHandler", func() { // Packet 1 should be considered lost (1+1/8) RTTs after it was sent. Expect(handler.initialPackets.lossTime.IsZero()).To(BeFalse()) Expect(handler.initialPackets.lossTime.Sub(getPacket(1, protocol.EncryptionInitial).SendTime)).To(Equal(time.Second * 9 / 8)) - - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.DequeuePacketForRetransmission()).NotTo(BeNil()) - // make sure this is not an RTO: only packet 1 is retransmissted - Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) }) }) Context("crypto packets", func() { It("rejects an ACK that acks packets with a higher encryption level", func() { - handler.SentPacket(&Packet{ + handler.SentPacket(ackElicitingPacket(&Packet{ PacketNumber: 13, EncryptionLevel: protocol.Encryption1RTT, - Frames: []wire.Frame{&streamFrame}, - Length: 1, - }) + })) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} err := handler.ReceivedAck(ack, 1, protocol.EncryptionHandshake, time.Now()) Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) @@ -794,43 +735,43 @@ var _ = Describe("SentPacketHandler", func() { It("deletes Initial packets", func() { for i := protocol.PacketNumber(0); i < 6; i++ { - p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionInitial}) - handler.SentPacket(p) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.EncryptionInitial, + })) } for i := protocol.PacketNumber(0); i < 10; i++ { - p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake}) - handler.SentPacket(p) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.EncryptionHandshake, + })) } Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) - handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionInitial), handler.getPacketNumberSpace(protocol.EncryptionInitial)) - lostPacket := getPacket(3, protocol.EncryptionHandshake) - handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) handler.DropPackets(protocol.EncryptionInitial) + Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) Expect(handler.initialPackets).To(BeNil()) Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) - packet := handler.DequeuePacketForRetransmission() - Expect(packet).To(Equal(lostPacket)) }) It("deletes Handshake packets", func() { for i := protocol.PacketNumber(0); i < 6; i++ { - p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake}) - handler.SentPacket(p) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.EncryptionHandshake, + })) } for i := protocol.PacketNumber(0); i < 10; i++ { - p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.Encryption1RTT}) - handler.SentPacket(p) + handler.SentPacket(ackElicitingPacket(&Packet{ + PacketNumber: i, + EncryptionLevel: protocol.Encryption1RTT, + })) } Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) - handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionInitial)) - lostPacket := getPacket(3, protocol.Encryption1RTT) - handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) handler.DropPackets(protocol.EncryptionHandshake) + Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) Expect(handler.handshakePackets).To(BeNil()) - packet := handler.DequeuePacketForRetransmission() - Expect(packet).To(Equal(lostPacket)) }) }) @@ -858,25 +799,16 @@ var _ = Describe("SentPacketHandler", func() { Context("resetting for retry", func() { It("queues outstanding packets for retransmission and cancels alarms", func() { - packet := &Packet{ - PacketNumber: 42, - EncryptionLevel: protocol.EncryptionInitial, - Frames: []wire.Frame{&wire.CryptoFrame{Data: []byte("foobar")}}, - Length: 100, - } - handler.SentPacket(packet) + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 42, EncryptionLevel: protocol.EncryptionInitial})) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) Expect(handler.bytesInFlight).ToNot(BeZero()) - Expect(handler.DequeuePacketForRetransmission()).To(BeNil()) Expect(handler.SendMode()).To(Equal(SendAny)) // now receive a Retry Expect(handler.ResetForRetry()).To(Succeed()) + Expect(lostPackets).To(Equal([]protocol.PacketNumber{42})) Expect(handler.bytesInFlight).To(BeZero()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - Expect(handler.SendMode()).To(Equal(SendRetransmission)) - p := handler.DequeuePacketForRetransmission() - Expect(p.PacketNumber).To(Equal(packet.PacketNumber)) - Expect(p.Frames).To(Equal(packet.Frames)) + Expect(handler.SendMode()).To(Equal(SendAny)) }) }) }) diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go index 857fd306..e8b9bd02 100644 --- a/internal/ackhandler/sent_packet_history.go +++ b/internal/ackhandler/sent_packet_history.go @@ -9,10 +9,6 @@ import ( type sentPacketHistory struct { packetList *PacketList packetMap map[protocol.PacketNumber]*PacketElement - - numOutstandingPackets int - - firstOutstanding *PacketElement } func newSentPacketHistory() *sentPacketHistory { @@ -23,19 +19,8 @@ func newSentPacketHistory() *sentPacketHistory { } func (h *sentPacketHistory) SentPacket(p *Packet) { - h.sentPacketImpl(p) -} - -func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement { el := h.packetList.PushBack(*p) h.packetMap[p.PacketNumber] = el - if h.firstOutstanding == nil { - h.firstOutstanding = el - } - if p.canBeRetransmitted { - h.numOutstandingPackets++ - } - return el } func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet { @@ -63,40 +48,10 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err // It must not be modified (e.g. retransmitted). // Use DequeueFirstPacketForRetransmission() to retransmit it. func (h *sentPacketHistory) FirstOutstanding() *Packet { - if h.firstOutstanding == nil { + if !h.HasOutstandingPackets() { return nil } - return &h.firstOutstanding.Value -} - -// QueuePacketForRetransmission marks a packet for retransmission. -// A packet can only be queued once. -func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error { - el, ok := h.packetMap[pn] - if !ok { - return fmt.Errorf("sent packet history: packet %d not found", pn) - } - if el.Value.canBeRetransmitted { - h.numOutstandingPackets-- - if h.numOutstandingPackets < 0 { - panic("numOutstandingHandshakePackets negative") - } - } - el.Value.canBeRetransmitted = false - if el == h.firstOutstanding { - h.readjustFirstOutstanding() - } - return nil -} - -// readjustFirstOutstanding readjusts the pointer to the first outstanding packet. -// This is necessary every time the first outstanding packet is deleted or retransmitted. -func (h *sentPacketHistory) readjustFirstOutstanding() { - el := h.firstOutstanding.Next() - for el != nil && !el.Value.canBeRetransmitted { - el = el.Next() - } - h.firstOutstanding = el + return &h.packetList.Front().Value } func (h *sentPacketHistory) Len() int { @@ -108,20 +63,11 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { if !ok { return fmt.Errorf("packet %d not found in sent packet history", p) } - if el == h.firstOutstanding { - h.readjustFirstOutstanding() - } - if el.Value.canBeRetransmitted { - h.numOutstandingPackets-- - if h.numOutstandingPackets < 0 { - panic("numOutstandingHandshakePackets negative") - } - } h.packetList.Remove(el) delete(h.packetMap, p) return nil } func (h *sentPacketHistory) HasOutstandingPackets() bool { - return h.numOutstandingPackets > 0 + return h.packetList.Len() > 0 } diff --git a/internal/ackhandler/sent_packet_history_test.go b/internal/ackhandler/sent_packet_history_test.go index eea18967..c0e25d08 100644 --- a/internal/ackhandler/sent_packet_history_test.go +++ b/internal/ackhandler/sent_packet_history_test.go @@ -15,13 +15,14 @@ var _ = Describe("SentPacketHistory", func() { ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers))) ExpectWithOffset(1, hist.packetList.Len()).To(Equal(len(packetNumbers))) i := 0 - hist.Iterate(func(p *Packet) (bool, error) { + err := hist.Iterate(func(p *Packet) (bool, error) { pn := packetNumbers[i] ExpectWithOffset(1, p.PacketNumber).To(Equal(pn)) ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn)) i++ return true, nil }) + Expect(err).ToNot(HaveOccurred()) } BeforeEach(func() { @@ -53,45 +54,6 @@ var _ = Describe("SentPacketHistory", func() { Expect(front).ToNot(BeNil()) Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2))) }) - - It("gets the second packet if the first one is retransmitted", func() { - hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true}) - hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true}) - hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true}) - front := hist.FirstOutstanding() - Expect(front).ToNot(BeNil()) - Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1))) - // Queue the first packet for retransmission. - // The first outstanding packet should now be 3. - err := hist.MarkCannotBeRetransmitted(1) - Expect(err).ToNot(HaveOccurred()) - front = hist.FirstOutstanding() - Expect(front).ToNot(BeNil()) - Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(3))) - }) - - It("gets the third packet if the first two are retransmitted", func() { - hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true}) - hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true}) - hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true}) - front := hist.FirstOutstanding() - Expect(front).ToNot(BeNil()) - Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1))) - // Queue the second packet for retransmission. - // The first outstanding packet should still be 3. - err := hist.MarkCannotBeRetransmitted(3) - Expect(err).ToNot(HaveOccurred()) - front = hist.FirstOutstanding() - Expect(front).ToNot(BeNil()) - Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1))) - // Queue the first packet for retransmission. - // The first outstanding packet should still be 4. - err = hist.MarkCannotBeRetransmitted(1) - Expect(err).ToNot(HaveOccurred()) - front = hist.FirstOutstanding() - Expect(front).ToNot(BeNil()) - Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(4))) - }) }) It("gets a packet by packet number", func() { @@ -164,60 +126,32 @@ var _ = Describe("SentPacketHistory", func() { Context("outstanding packets", func() { It("says if it has outstanding packets", func() { Expect(hist.HasOutstandingPackets()).To(BeFalse()) - hist.SentPacket(&Packet{ - EncryptionLevel: protocol.Encryption1RTT, - canBeRetransmitted: true, - }) + hist.SentPacket(&Packet{EncryptionLevel: protocol.Encryption1RTT}) Expect(hist.HasOutstandingPackets()).To(BeTrue()) }) - It("doesn't consider non-ack-eliciting packets as outstanding", func() { - hist.SentPacket(&Packet{ - EncryptionLevel: protocol.EncryptionInitial, - }) - Expect(hist.HasOutstandingPackets()).To(BeFalse()) - }) - It("accounts for deleted packets", func() { hist.SentPacket(&Packet{ - PacketNumber: 10, - EncryptionLevel: protocol.Encryption1RTT, - canBeRetransmitted: true, + PacketNumber: 10, + EncryptionLevel: protocol.Encryption1RTT, }) Expect(hist.HasOutstandingPackets()).To(BeTrue()) - err := hist.Remove(10) - Expect(err).ToNot(HaveOccurred()) - Expect(hist.HasOutstandingPackets()).To(BeFalse()) - }) - - It("doesn't count packets marked as non-ack-eliciting", func() { - hist.SentPacket(&Packet{ - PacketNumber: 10, - EncryptionLevel: protocol.Encryption1RTT, - canBeRetransmitted: true, - }) - Expect(hist.HasOutstandingPackets()).To(BeTrue()) - err := hist.MarkCannotBeRetransmitted(10) - Expect(err).ToNot(HaveOccurred()) + Expect(hist.Remove(10)).To(Succeed()) Expect(hist.HasOutstandingPackets()).To(BeFalse()) }) It("counts the number of packets", func() { hist.SentPacket(&Packet{ - PacketNumber: 10, - EncryptionLevel: protocol.Encryption1RTT, - canBeRetransmitted: true, + PacketNumber: 10, + EncryptionLevel: protocol.Encryption1RTT, }) hist.SentPacket(&Packet{ - PacketNumber: 11, - EncryptionLevel: protocol.Encryption1RTT, - canBeRetransmitted: true, + PacketNumber: 11, + EncryptionLevel: protocol.Encryption1RTT, }) - err := hist.Remove(11) - Expect(err).ToNot(HaveOccurred()) + Expect(hist.Remove(11)).To(Succeed()) Expect(hist.HasOutstandingPackets()).To(BeTrue()) - err = hist.Remove(10) - Expect(err).ToNot(HaveOccurred()) + Expect(hist.Remove(10)).To(Succeed()) Expect(hist.HasOutstandingPackets()).To(BeFalse()) }) }) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 2f231983..7c032987 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -38,35 +38,6 @@ func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { return m.recorder } -// DequeuePacketForRetransmission mocks base method -func (m *MockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DequeuePacketForRetransmission") - ret0, _ := ret[0].(*ackhandler.Packet) - return ret0 -} - -// DequeuePacketForRetransmission indicates an expected call of DequeuePacketForRetransmission -func (mr *MockSentPacketHandlerMockRecorder) DequeuePacketForRetransmission() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeuePacketForRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeuePacketForRetransmission)) -} - -// DequeueProbePacket mocks base method -func (m *MockSentPacketHandler) DequeueProbePacket() (*ackhandler.Packet, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DequeueProbePacket") - ret0, _ := ret[0].(*ackhandler.Packet) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DequeueProbePacket indicates an expected call of DequeueProbePacket -func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket)) -} - // DropPackets mocks base method func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { m.ctrl.T.Helper() @@ -164,6 +135,20 @@ func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) } +// QueueProbePacket mocks base method +func (m *MockSentPacketHandler) QueueProbePacket() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueueProbePacket") + ret0, _ := ret[0].(bool) + return ret0 +} + +// QueueProbePacket indicates an expected call of QueueProbePacket +func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket)) +} + // ReceivedAck mocks base method func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.PacketNumber, arg2 protocol.EncryptionLevel, arg3 time.Time) error { m.ctrl.T.Helper() diff --git a/mock_packer_test.go b/mock_packer_test.go index 805d8df8..b96cfb68 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -8,7 +8,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler" handshake "github.com/lucas-clemente/quic-go/internal/handshake" protocol "github.com/lucas-clemente/quic-go/internal/protocol" wire "github.com/lucas-clemente/quic-go/internal/wire" @@ -106,21 +105,6 @@ func (mr *MockPackerMockRecorder) PackPacket() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket)) } -// PackRetransmission mocks base method -func (m *MockPacker) PackRetransmission(arg0 *ackhandler.Packet) ([]*packedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackRetransmission", arg0) - ret0, _ := ret[0].([]*packedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// PackRetransmission indicates an expected call of PackRetransmission -func (mr *MockPackerMockRecorder) PackRetransmission(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackRetransmission", reflect.TypeOf((*MockPacker)(nil).PackRetransmission), arg0) -} - // SetToken mocks base method func (m *MockPacker) SetToken(arg0 []byte) { m.ctrl.T.Helper() diff --git a/packet_packer.go b/packet_packer.go index e2585d9c..898d42b3 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -17,7 +17,6 @@ import ( type packer interface { PackPacket() (*packedPacket, error) MaybePackAckPacket() (*packedPacket, error) - PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error) HandleTransportParameters(*handshake.TransportParameters) @@ -62,17 +61,31 @@ func (p *packedPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) } -func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { +func (p *packedPacket) ToAckHandlerPacket(q *retransmissionQueue) *ackhandler.Packet { largestAcked := protocol.InvalidPacketNumber if p.ack != nil { largestAcked = p.ack.LargestAcked() } + encLevel := p.EncryptionLevel() + frames := make([]ackhandler.Frame, len(p.frames)) + for i, f := range p.frames { + frame := f + frames[i].Frame = frame + switch encLevel { + case protocol.EncryptionInitial: + frames[i].OnLost = q.AddInitial + case protocol.EncryptionHandshake: + frames[i].OnLost = q.AddHandshake + case protocol.Encryption1RTT: + frames[i].OnLost = q.AddAppData + } + } return &ackhandler.Packet{ PacketNumber: p.header.PacketNumber, LargestAcked: largestAcked, - Frames: p.frames, + Frames: frames, Length: protocol.ByteCount(len(p.raw)), - EncryptionLevel: p.EncryptionLevel(), + EncryptionLevel: encLevel, SendTime: time.Now(), } } @@ -130,9 +143,10 @@ type packetPacker struct { token []byte - pnManager packetNumberManager - framer frameSource - acks ackFrameSource + pnManager packetNumberManager + framer frameSource + acks ackFrameSource + retransmissionQueue *retransmissionQueue maxPacketSize protocol.ByteCount numNonAckElicitingAcks int @@ -146,6 +160,7 @@ func newPacketPacker( initialStream cryptoStream, handshakeStream cryptoStream, packetNumberManager packetNumberManager, + retransmissionQueue *retransmissionQueue, remoteAddr net.Addr, // only used for determining the max packet size cryptoSetup sealingManager, framer frameSource, @@ -154,17 +169,18 @@ func newPacketPacker( version protocol.VersionNumber, ) *packetPacker { return &packetPacker{ - cryptoSetup: cryptoSetup, - destConnID: destConnID, - srcConnID: srcConnID, - initialStream: initialStream, - handshakeStream: handshakeStream, - perspective: perspective, - version: version, - framer: framer, - acks: acks, - pnManager: packetNumberManager, - maxPacketSize: getMaxPacketSize(remoteAddr), + cryptoSetup: cryptoSetup, + destConnID: destConnID, + srcConnID: srcConnID, + initialStream: initialStream, + handshakeStream: handshakeStream, + retransmissionQueue: retransmissionQueue, + perspective: perspective, + version: version, + framer: framer, + acks: acks, + pnManager: packetNumberManager, + maxPacketSize: getMaxPacketSize(remoteAddr), } } @@ -237,80 +253,6 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { return p.writeAndSealPacket(hdr, payload, encLevel, sealer) } -// PackRetransmission packs a retransmission -// For packets sent after completion of the handshake, it might happen that 2 packets have to be sent. -// This can happen e.g. when a longer packet number is used in the header. -func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) { - var controlFrames []wire.Frame - var streamFrames []*wire.StreamFrame - for _, f := range packet.Frames { - // CRYPTO frames are treated as control frames here. - // Since we're making sure that the header can never be larger for a retransmission, - // we never have to split CRYPTO frames. - if sf, ok := f.(*wire.StreamFrame); ok { - sf.DataLenPresent = true - streamFrames = append(streamFrames, sf) - } else { - controlFrames = append(controlFrames, f) - } - } - - var packets []*packedPacket - for len(controlFrames) > 0 || len(streamFrames) > 0 { - var frames []wire.Frame - var length protocol.ByteCount - - sealer, hdr, err := p.getSealerAndHeader(packet.EncryptionLevel) - if err != nil { - return nil, err - } - - hdrLen := hdr.GetLength(p.version) - maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - hdrLen - - for len(controlFrames) > 0 { - frame := controlFrames[0] - frameLen := frame.Length(p.version) - if length+frameLen > maxSize { - break - } - length += frameLen - frames = append(frames, frame) - controlFrames = controlFrames[1:] - } - - for len(streamFrames) > 0 && length+protocol.MinStreamFrameSize < maxSize { - frame := streamFrames[0] - frame.DataLenPresent = false - frameToAdd := frame - - sf, needsSplit := frame.MaybeSplitOffFrame(maxSize-length, p.version) - if needsSplit { - if sf == nil { // size too small to create a new STREAM frame - continue - } - frameToAdd = sf - } else { - streamFrames = streamFrames[1:] - } - frame.DataLenPresent = true - length += frameToAdd.Length(p.version) - frames = append(frames, frameToAdd) - } - if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok { - sfLen := sf.Length(p.version) - sf.DataLenPresent = false - length += sf.Length(p.version) - sfLen - } - p, err := p.writeAndSealPacket(hdr, payload{frames: frames, length: length}, packet.EncryptionLevel, sealer) - if err != nil { - return nil, err - } - packets = append(packets, p) - } - return packets, nil -} - // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { @@ -371,9 +313,10 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { } hasData := p.initialStream.HasData() + hasRetransmission := p.retransmissionQueue.HasInitialData() ack := p.acks.GetAckFrame(protocol.EncryptionInitial) var sealer handshake.LongHeaderSealer - if hasData || ack != nil { + if hasData || hasRetransmission || ack != nil { s = p.initialStream encLevel = protocol.EncryptionInitial sealer = initialSealer @@ -382,8 +325,9 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { } } else { hasData = p.handshakeStream.HasData() + hasRetransmission = p.retransmissionQueue.HasHandshakeData() ack = p.acks.GetAckFrame(protocol.EncryptionHandshake) - if hasData || ack != nil { + if hasData || hasRetransmission || ack != nil { s = p.handshakeStream encLevel = protocol.EncryptionHandshake sealer = handshakeSealer @@ -403,7 +347,24 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { } hdr := p.getLongHeader(encLevel) hdrLen := hdr.GetLength(p.version) - if hasData { + if hasRetransmission { + for { + var f wire.Frame + switch encLevel { + case protocol.EncryptionInitial: + remainingLen := protocol.MinInitialPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length + f = p.retransmissionQueue.GetInitialFrame(remainingLen) + case protocol.EncryptionHandshake: + remainingLen := p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length + f = p.retransmissionQueue.GetHandshakeFrame(remainingLen) + } + if f == nil { + break + } + payload.frames = append(payload.frames, f) + payload.length += f.Length(p.version) + } + } else if hasData { cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length) payload.frames = []wire.Frame{cf} payload.length += cf.Length(p.version) @@ -419,14 +380,25 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (paylo payload.length += ack.Length(p.version) } - frames, lengthAdded := p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) + for { + remainingLen := maxFrameSize - payload.length + if remainingLen < protocol.MinStreamFrameSize { + break + } + f := p.retransmissionQueue.GetAppDataFrame(remainingLen) + if f == nil { + break + } + payload.frames = append(payload.frames, f) + payload.length += f.Length(p.version) + } + + var lengthAdded protocol.ByteCount + payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) payload.length += lengthAdded - frames, lengthAdded = p.framer.AppendStreamFrames(frames, maxFrameSize-payload.length) - if len(frames) > 0 { - payload.frames = append(payload.frames, frames...) - payload.length += lengthAdded - } + payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length) + payload.length += lengthAdded return payload, nil } @@ -478,6 +450,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex switch encLevel { case protocol.EncryptionInitial: hdr.Type = protocol.PacketTypeInitial + hdr.Token = p.token case protocol.EncryptionHandshake: hdr.Type = protocol.PacketTypeHandshake } @@ -507,7 +480,6 @@ func (p *packetPacker) writeAndSealPacket( if encLevel != protocol.Encryption1RTT { if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { - header.Token = p.token headerLen := header.GetLength(p.version) header.Length = pnLen + protocol.MinInitialPacketSize - headerLen paddingLen = protocol.ByteCount(protocol.MinInitialPacketSize-sealer.Overhead()) - headerLen - payload.length @@ -550,6 +522,7 @@ func (p *packetPacker) writeAndSealPacketWithPadding( } if payloadSize := protocol.ByteCount(buffer.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { + fmt.Printf("%#v\n", payload) return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) } if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize { diff --git a/packet_packer_test.go b/packet_packer_test.go index 1c152100..d9309c5f 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" @@ -20,21 +19,24 @@ import ( var _ = Describe("Packet packer", func() { const maxPacketSize protocol.ByteCount = 1357 + const version = protocol.VersionTLS + var ( - packer *packetPacker - framer *MockFrameSource - ackFramer *MockAckFrameSource - initialStream *MockCryptoStream - handshakeStream *MockCryptoStream - sealingManager *MockSealingManager - pnManager *mockackhandler.MockSentPacketHandler + packer *packetPacker + retransmissionQueue *retransmissionQueue + framer *MockFrameSource + ackFramer *MockAckFrameSource + initialStream *MockCryptoStream + handshakeStream *MockCryptoStream + sealingManager *MockSealingManager + pnManager *mockackhandler.MockSentPacketHandler ) checkLength := func(data []byte) { hdr, _, _, err := wire.ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(r, protocol.VersionWhatever) + extHdr, err := hdr.ParseExtended(r, version) Expect(err).ToNot(HaveOccurred()) ExpectWithOffset(1, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen))) } @@ -61,7 +63,7 @@ var _ = Describe("Packet packer", func() { BeforeEach(func() { rand.Seed(GinkgoRandomSeed()) - version := protocol.VersionTLS + retransmissionQueue = newRetransmissionQueue(version) mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() initialStream = NewMockCryptoStream(mockCtrl) @@ -77,6 +79,7 @@ var _ = Describe("Packet packer", func() { initialStream, handshakeStream, pnManager, + retransmissionQueue, &net.TCPAddr{}, sealingManager, framer, @@ -408,6 +411,60 @@ var _ = Describe("Packet packer", func() { Expect(r.Len()).To(BeZero()) }) + It("packs multiple small STREAM frames into single packet", func() { + f1 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 1"), + DataLenPresent: true, + } + f2 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 2"), + DataLenPresent: true, + } + f3 := &wire.StreamFrame{ + StreamID: 3, + Data: []byte("frame 3"), + DataLenPresent: true, + } + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + expectAppendControlFrames() + expectAppendStreamFrames(f1, f2, f3) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(3)) + Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) + Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) + Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) + }) + + It("adds retransmissions", func() { + f1 := &wire.StreamFrame{Data: []byte("frame 1")} + cf := &wire.MaxDataFrame{ByteOffset: 0x42} + retransmissionQueue.AddAppData(f1) + retransmissionQueue.AddAppData(cf) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + expectAppendControlFrames() + f2 := &wire.StreamFrame{Data: []byte("frame 2")} + expectAppendStreamFrames(f2) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(3)) + Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) + Expect(p.frames[1]).To(Equal(cf)) + Expect(p.frames[2]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) + }) + Context("making ACK packets ack-eliciting", func() { sendMaxNumNonAckElicitingAcks := func() { for i := 0; i < protocol.MaxNonAckElicitingAcks; i++ { @@ -491,230 +548,6 @@ var _ = Describe("Packet packer", func() { }) }) - Context("STREAM frame handling", func() { - It("does not split a STREAM frame with maximum size", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - expectAppendControlFrames() - sf := &wire.StreamFrame{ - Offset: 1, - StreamID: 5, - } - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - sf.Data = bytes.Repeat([]byte{'f'}, int(maxSize-sf.Length(packer.version))) - return []wire.Frame{sf}, sf.Length(packer.version) - }) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.raw).To(HaveLen(int(maxPacketSize))) - Expect(p.frames[0].(*wire.StreamFrame).Data).To(HaveLen(len(sf.Data))) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - }) - - It("packs multiple small STREAM frames into single packet", func() { - f1 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("frame 1"), - DataLenPresent: true, - } - f2 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("frame 2"), - DataLenPresent: true, - } - f3 := &wire.StreamFrame{ - StreamID: 3, - Data: []byte("frame 3"), - DataLenPresent: true, - } - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - expectAppendControlFrames() - expectAppendStreamFrames(f1, f2, f3) - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(3)) - Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) - Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) - Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) - }) - }) - - Context("retransmissions", func() { - It("retransmits a small packet", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - frames := []wire.Frame{ - &wire.MaxDataFrame{ByteOffset: 0x1234}, - &wire.StreamFrame{StreamID: 42, Data: []byte("foobar")}, - } - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.Encryption1RTT, - Frames: frames, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(p.frames).To(Equal(frames)) - }) - - It("packs two packets for retransmission if the original packet contained many control frames", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) - var frames []wire.Frame - var totalLen protocol.ByteCount - // pack a bunch of control frames, such that the packet is way bigger than a single packet - for i := 0; totalLen < maxPacketSize*3/2; i++ { - f := &wire.MaxStreamDataFrame{ - StreamID: protocol.StreamID(i), - ByteOffset: protocol.ByteCount(i), - } - frames = append(frames, f) - totalLen += f.Length(packer.version) - } - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.Encryption1RTT, - Frames: frames, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(2)) - Expect(len(packets[0].frames) + len(packets[1].frames)).To(Equal(len(frames))) - Expect(packets[1].frames).To(Equal(frames[len(packets[0].frames):])) - // check that the first packet was filled up as far as possible: - // if the first frame (after the STOP_WAITING) was packed into the first packet, it would have overflown the MaxPacketSize - Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize)) - }) - - It("splits a STREAM frame that doesn't fit", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.Encryption1RTT, - Frames: []wire.Frame{&wire.StreamFrame{ - StreamID: 42, - Offset: 1337, - Data: bytes.Repeat([]byte{'a'}, int(maxPacketSize)*3/2), - }}, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(2)) - Expect(packets[0].frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(packets[1].frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - sf1 := packets[0].frames[0].(*wire.StreamFrame) - sf2 := packets[1].frames[0].(*wire.StreamFrame) - Expect(sf1.StreamID).To(Equal(protocol.StreamID(42))) - Expect(sf1.Offset).To(Equal(protocol.ByteCount(1337))) - Expect(sf1.DataLenPresent).To(BeFalse()) - Expect(sf2.StreamID).To(Equal(protocol.StreamID(42))) - Expect(sf2.Offset).To(Equal(protocol.ByteCount(1337) + sf1.DataLen())) - Expect(sf2.DataLenPresent).To(BeFalse()) - Expect(sf1.DataLen() + sf2.DataLen()).To(Equal(maxPacketSize * 3 / 2)) - Expect(packets[0].raw).To(HaveLen(int(maxPacketSize))) - }) - - It("splits STREAM frames, if necessary", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).AnyTimes() - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).AnyTimes() - for i := 0; i < 100; i++ { - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).MaxTimes(2) - sf1 := &wire.StreamFrame{ - StreamID: 42, - Offset: 1337, - Data: bytes.Repeat([]byte{'a'}, 1+int(rand.Int31n(int32(maxPacketSize*4/5)))), - } - sf2 := &wire.StreamFrame{ - StreamID: 2, - Offset: 42, - Data: bytes.Repeat([]byte{'b'}, 1+int(rand.Int31n(int32(maxPacketSize*4/5)))), - } - expectedDataLen := sf1.DataLen() + sf2.DataLen() - frames := []wire.Frame{sf1, sf2} - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.Encryption1RTT, - Frames: frames, - }) - Expect(err).ToNot(HaveOccurred()) - - if len(packets) > 1 { - Expect(packets[0].raw).To(HaveLen(int(maxPacketSize))) - } - - var dataLen protocol.ByteCount - for _, p := range packets { - for _, f := range p.frames { - dataLen += f.(*wire.StreamFrame).DataLen() - } - } - Expect(dataLen).To(Equal(expectedDataLen)) - } - }) - - It("packs two packets for retransmission if the original packet contained many STREAM frames", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) - var frames []wire.Frame - var totalLen protocol.ByteCount - // pack a bunch of control frames, such that the packet is way bigger than a single packet - for i := 0; totalLen < maxPacketSize*3/2; i++ { - f := &wire.StreamFrame{ - StreamID: protocol.StreamID(i), - Data: []byte("foobar"), - DataLenPresent: true, - } - frames = append(frames, f) - totalLen += f.Length(packer.version) - } - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.Encryption1RTT, - Frames: frames, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(2)) - Expect(len(packets[0].frames) + len(packets[1].frames)).To(Equal(len(frames))) // all frames - Expect(packets[1].frames).To(Equal(frames[len(packets[0].frames):])) - // check that the first packet was filled up as far as possible: - // if the first frame was packed into the first packet, it would have overflown the MaxPacketSize - Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize-protocol.MinStreamFrameSize)) - }) - - It("correctly sets the DataLenPresent on STREAM frames", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - frames := []wire.Frame{ - &wire.StreamFrame{StreamID: 4, Data: []byte("foobar"), DataLenPresent: true}, - &wire.StreamFrame{StreamID: 5, Data: []byte("barfoo")}, - } - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.Encryption1RTT, - Frames: frames, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(p.frames[1]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - sf1 := p.frames[0].(*wire.StreamFrame) - sf2 := p.frames[1].(*wire.StreamFrame) - Expect(sf1.StreamID).To(Equal(protocol.StreamID(4))) - Expect(sf1.DataLenPresent).To(BeTrue()) - Expect(sf2.StreamID).To(Equal(protocol.StreamID(5))) - Expect(sf2.DataLenPresent).To(BeFalse()) - }) - }) - Context("max packet size", func() { It("sets the maximum packet size", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) @@ -805,8 +638,25 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) - expectedPacketLen := packer.maxPacketSize - Expect(p.raw).To(HaveLen(int(expectedPacketLen))) + Expect(p.raw).To(HaveLen(int(packer.maxPacketSize))) + Expect(p.header.IsLongHeader).To(BeTrue()) + checkLength(p.raw) + }) + + It("adds retransmissions", func() { + f := &wire.CryptoFrame{Data: []byte("Initial")} + retransmissionQueue.AddInitial(f) + retransmissionQueue.AddHandshake(&wire.CryptoFrame{Data: []byte("Handshake")}) + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + initialStream.EXPECT().HasData() + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.frames).To(Equal([]wire.Frame{f})) Expect(p.header.IsLongHeader).To(BeTrue()) checkLength(p.raw) }) @@ -923,63 +773,25 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) }) - - Context("retransmitions", func() { - cf := &wire.CryptoFrame{Data: []byte("foo")} - - It("packs a retransmission with the right encryption level", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - packet := &ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionInitial, - Frames: []wire.Frame{cf}, - } - p, err := packer.PackRetransmission(packet) - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(HaveLen(1)) - Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(p[0].frames).To(Equal([]wire.Frame{cf})) - Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - }) - - It("packs a retransmission for an Initial packet", func() { - token := []byte("initial token") - packer.SetToken(token) - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - packer.perspective = protocol.PerspectiveClient - packet := &ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionInitial, - Frames: []wire.Frame{cf}, - } - packets, err := packer.PackRetransmission(packet) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.frames).To(Equal([]wire.Frame{cf})) - Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(p.header.Token).To(Equal(token)) - Expect(p.raw).To(HaveLen(protocol.MinInitialPacketSize)) - }) - }) }) }) }) var _ = Describe("Converting to AckHandler packets", func() { It("convert a packet", func() { + f1 := &wire.MaxDataFrame{} + f2 := &wire.PingFrame{} packet := &packedPacket{ header: &wire.ExtendedHeader{Header: wire.Header{}}, - frames: []wire.Frame{&wire.MaxDataFrame{}, &wire.PingFrame{}}, + frames: []wire.Frame{f1, f2}, ack: &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 80}}}, raw: []byte("foobar"), } - p := packet.ToAckHandlerPacket() + p := packet.ToAckHandlerPacket(nil) Expect(p.Length).To(Equal(protocol.ByteCount(6))) - Expect(p.Frames).To(Equal(packet.frames)) + Expect(p.Frames).To(HaveLen(2)) + Expect(p.Frames[0].Frame).To(Equal(f1)) + Expect(p.Frames[1].Frame).To(Equal(f2)) Expect(p.LargestAcked).To(Equal(protocol.PacketNumber(100))) Expect(p.SendTime).To(BeTemporally("~", time.Now(), 50*time.Millisecond)) }) @@ -990,7 +802,7 @@ var _ = Describe("Converting to AckHandler packets", func() { frames: []wire.Frame{&wire.MaxDataFrame{}, &wire.PingFrame{}}, raw: []byte("foobar"), } - p := packet.ToAckHandlerPacket() + p := packet.ToAckHandlerPacket(nil) Expect(p.LargestAcked).To(Equal(protocol.InvalidPacketNumber)) }) }) diff --git a/retransmission_queue.go b/retransmission_queue.go new file mode 100644 index 00000000..0de60d1f --- /dev/null +++ b/retransmission_queue.go @@ -0,0 +1,132 @@ +package quic + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type retransmissionQueue struct { + initial []wire.Frame + initialCryptoData []*wire.CryptoFrame + + handshake []wire.Frame + handshakeCryptoData []*wire.CryptoFrame + + appData []wire.Frame + streamData []*wire.StreamFrame + + version protocol.VersionNumber +} + +func newRetransmissionQueue(ver protocol.VersionNumber) *retransmissionQueue { + return &retransmissionQueue{version: ver} +} + +func (q *retransmissionQueue) AddInitial(f wire.Frame) { + if cf, ok := f.(*wire.CryptoFrame); ok { + q.initialCryptoData = append(q.initialCryptoData, cf) + return + } + q.initial = append(q.initial, f) +} + +func (q *retransmissionQueue) AddHandshake(f wire.Frame) { + if cf, ok := f.(*wire.CryptoFrame); ok { + q.handshakeCryptoData = append(q.handshakeCryptoData, cf) + return + } + q.handshake = append(q.handshake, f) +} + +func (q *retransmissionQueue) HasInitialData() bool { + return len(q.initialCryptoData) > 0 || len(q.initial) > 0 +} + +func (q *retransmissionQueue) HasHandshakeData() bool { + return len(q.handshakeCryptoData) > 0 || len(q.handshake) > 0 +} + +func (q *retransmissionQueue) AddAppData(f wire.Frame) { + if sf, ok := f.(*wire.StreamFrame); ok { + sf.DataLenPresent = true + q.streamData = append(q.streamData, sf) + return + } + q.appData = append(q.appData, f) +} + +func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Frame { + if len(q.initialCryptoData) > 0 { + if f := q.initialCryptoData[0]; f.Length(q.version) <= maxLen { + q.initialCryptoData = q.initialCryptoData[1:] + return f + } + } + if len(q.initial) == 0 { + return nil + } + f := q.initial[0] + if f.Length(q.version) > maxLen { + return nil + } + q.initial = q.initial[1:] + return f +} + +func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire.Frame { + if len(q.handshakeCryptoData) > 0 { + if f := q.handshakeCryptoData[0]; f.Length(q.version) <= maxLen { + q.handshakeCryptoData = q.handshakeCryptoData[1:] + return f + } + } + if len(q.handshake) == 0 { + return nil + } + f := q.handshake[0] + if f.Length(q.version) > maxLen { + return nil + } + q.handshake = q.handshake[1:] + return f +} + +func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount) wire.Frame { + if len(q.streamData) > 0 { + f := q.streamData[0] + if f.Length(q.version) <= maxLen { + q.streamData = q.streamData[1:] + return f + } + if maxLen >= protocol.MinStreamFrameSize { + newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) + if needsSplit && newFrame != nil { + return newFrame + } + } + } + if len(q.appData) == 0 { + return nil + } + f := q.appData[0] + if f.Length(q.version) > maxLen { + return nil + } + q.appData = q.appData[1:] + return f +} + +func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) { + switch encLevel { + case protocol.EncryptionInitial: + q.initial = nil + q.initialCryptoData = nil + case protocol.EncryptionHandshake: + q.handshake = nil + q.handshakeCryptoData = nil + default: + panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) + } +} diff --git a/retransmission_queue_test.go b/retransmission_queue_test.go new file mode 100644 index 00000000..c6c4b963 --- /dev/null +++ b/retransmission_queue_test.go @@ -0,0 +1,158 @@ +package quic + +import ( + "math/rand" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Retransmission queue", func() { + const version = protocol.VersionTLS + + var q *retransmissionQueue + + BeforeEach(func() { + q = newRetransmissionQueue(version) + }) + + Context("Initial data", func() { + It("doesn't dequeue anything when it's empty", func() { + Expect(q.HasInitialData()).To(BeFalse()) + Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) + }) + + It("queues and retrieves a control frame", func() { + f := &wire.MaxDataFrame{ByteOffset: 0x42} + q.AddInitial(f) + Expect(q.HasInitialData()).To(BeTrue()) + Expect(q.GetInitialFrame(f.Length(version) - 1)).To(BeNil()) + Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasInitialData()).To(BeFalse()) + }) + + It("queues and retrieves a CRYPTO frame", func() { + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddInitial(f) + Expect(q.HasInitialData()).To(BeTrue()) + Expect(q.GetInitialFrame(f.Length(version) - 1)).To(BeNil()) + Expect(q.GetInitialFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasInitialData()).To(BeFalse()) + }) + + It("retrieves both a CRYPTO frame and a control frame", func() { + cf := &wire.MaxDataFrame{ByteOffset: 0x42} + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddInitial(f) + q.AddInitial(cf) + Expect(q.HasInitialData()).To(BeTrue()) + Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(f)) + Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(Equal(cf)) + Expect(q.HasInitialData()).To(BeFalse()) + }) + + It("drops all Initial frames", func() { + q.AddInitial(&wire.CryptoFrame{Data: []byte("foobar")}) + q.AddInitial(&wire.MaxDataFrame{ByteOffset: 0x42}) + q.DropPackets(protocol.EncryptionInitial) + Expect(q.HasInitialData()).To(BeFalse()) + Expect(q.GetInitialFrame(protocol.MaxByteCount)).To(BeNil()) + }) + }) + + Context("Handshake data", func() { + It("doesn't dequeue anything when it's empty", func() { + Expect(q.HasHandshakeData()).To(BeFalse()) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) + }) + + It("queues and retrieves a control frame", func() { + f := &wire.MaxDataFrame{ByteOffset: 0x42} + q.AddHandshake(f) + Expect(q.HasHandshakeData()).To(BeTrue()) + Expect(q.GetHandshakeFrame(f.Length(version) - 1)).To(BeNil()) + Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasHandshakeData()).To(BeFalse()) + }) + + It("queues and retrieves a CRYPTO frame", func() { + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddHandshake(f) + Expect(q.HasHandshakeData()).To(BeTrue()) + Expect(q.GetHandshakeFrame(f.Length(version) - 1)).To(BeNil()) + Expect(q.GetHandshakeFrame(f.Length(version))).To(Equal(f)) + Expect(q.HasHandshakeData()).To(BeFalse()) + }) + + It("retrieves both a CRYPTO frame and a control frame", func() { + cf := &wire.MaxDataFrame{ByteOffset: 0x42} + f := &wire.CryptoFrame{Data: []byte("foobar")} + q.AddHandshake(f) + q.AddHandshake(cf) + Expect(q.HasHandshakeData()).To(BeTrue()) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(f)) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(Equal(cf)) + Expect(q.HasHandshakeData()).To(BeFalse()) + }) + + It("drops all Initial frames", func() { + q.AddHandshake(&wire.CryptoFrame{Data: []byte("foobar")}) + q.AddHandshake(&wire.MaxDataFrame{ByteOffset: 0x42}) + q.DropPackets(protocol.EncryptionHandshake) + Expect(q.HasHandshakeData()).To(BeFalse()) + Expect(q.GetHandshakeFrame(protocol.MaxByteCount)).To(BeNil()) + }) + }) + + Context("Application data", func() { + It("doesn't dequeue anything when it's empty", func() { + Expect(q.GetAppDataFrame(protocol.MaxByteCount)).To(BeNil()) + }) + + It("queues and retrieves a control frame", func() { + f := &wire.MaxDataFrame{ByteOffset: 0x42} + q.AddAppData(f) + Expect(q.GetAppDataFrame(f.Length(version) - 1)).To(BeNil()) + Expect(q.GetAppDataFrame(f.Length(version))).To(Equal(f)) + }) + + It("queues and retrieves a STREAM frame", func() { + f := &wire.StreamFrame{Data: []byte("foobar")} + q.AddAppData(f) + Expect(q.GetAppDataFrame(f.Length(version) - 1)).To(BeNil()) + Expect(q.GetAppDataFrame(f.Length(version))).To(Equal(f)) + }) + + It("splits STREAM frames larger than MinStreamFrameSize", func() { + data := make([]byte, 1000) + rand.Read(data) + f := &wire.StreamFrame{ + Data: data, + FinBit: true, + } + q.AddAppData(f) + Expect(q.GetAppDataFrame(protocol.MinStreamFrameSize - 1)).To(BeNil()) + f1 := q.GetAppDataFrame(protocol.MinStreamFrameSize).(*wire.StreamFrame) + Expect(f1).ToNot(BeNil()) + Expect(f1.Length(version)).To(Equal(protocol.MinStreamFrameSize)) + Expect(f1.FinBit).To(BeFalse()) + Expect(f1.Data).To(Equal(data[:f1.DataLen()])) + f2 := q.GetAppDataFrame(protocol.MaxByteCount).(*wire.StreamFrame) + Expect(f2).ToNot(BeNil()) + Expect(f2.FinBit).To(BeTrue()) + Expect(f1.DataLen() + f2.DataLen()).To(BeEquivalentTo(1000)) + Expect(f2.Data).To(Equal(data[f1.DataLen():])) + Expect(q.GetAppDataFrame(protocol.MaxByteCount)).To(BeNil()) + }) + + It("returns a control frame if it doesn't split a STREAM frame", func() { + cf := &wire.MaxDataFrame{ByteOffset: 0x42} + q.AddAppData(&wire.StreamFrame{Data: make([]byte, 1000)}) + q.AddAppData(cf) + Expect(q.GetAppDataFrame(protocol.MinStreamFrameSize - 1)).To(Equal(cf)) + }) + }) +}) diff --git a/session.go b/session.go index 6ea39965..cd9db0fb 100644 --- a/session.go +++ b/session.go @@ -115,6 +115,7 @@ type session struct { cryptoStreamManager *cryptoStreamManager sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler + retransmissionQueue *retransmissionQueue framer framer windowUpdateQueue *windowUpdateQueue connFlowController flowcontrol.ConnectionFlowController @@ -242,6 +243,7 @@ var newSession = func( initialStream, handshakeStream, s.sentPacketHandler, + s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, @@ -328,6 +330,7 @@ var newClientSession = func( initialStream, handshakeStream, s.sentPacketHandler, + s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, @@ -350,6 +353,7 @@ var newClientSession = func( func (s *session) preSetup() { s.sendQueue = newSendQueue(s.conn) + s.retransmissionQueue = newRetransmissionQueue(s.version) s.frameParser = wire.NewFrameParser(s.version) s.rttStats = &congestion.RTTStats{} s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) @@ -1109,16 +1113,6 @@ sendLoop: return err } numPacketsSent++ - case ackhandler.SendRetransmission: - sentPacket, err := s.maybeSendRetransmission() - if err != nil { - return err - } - if sentPacket { - numPacketsSent++ - // This can happen if a retransmission queued, but it wasn't necessary to send it. - // e.g. when an Initial is queued, but we already received a packet from the server. - } case ackhandler.SendAny: sentPacket, err := s.sendPacket() if err != nil { @@ -1152,47 +1146,29 @@ func (s *session) maybeSendAckOnlyPacket() error { if packet == nil { return nil } - s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket()) + s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(s.retransmissionQueue)) s.sendQueue.Send(packet) return nil } -// maybeSendRetransmission sends retransmissions for at most one packet. -// It takes care that Initials aren't retransmitted, if a packet from the server was already received. -func (s *session) maybeSendRetransmission() (bool, error) { - retransmitPacket := s.sentPacketHandler.DequeuePacketForRetransmission() - if retransmitPacket == nil { - return false, nil - } - - s.logger.Debugf("Dequeueing retransmission for packet 0x%x (%s)", retransmitPacket.PacketNumber, retransmitPacket.EncryptionLevel) - packets, err := s.packer.PackRetransmission(retransmitPacket) - if err != nil { - return false, err - } - for _, packet := range packets { - s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket()) - s.sendPackedPacket(packet) - } - return true, nil -} - func (s *session) sendProbePacket() error { - p, err := s.sentPacketHandler.DequeueProbePacket() - if err != nil { - return err + // Queue probe packets until we actually send out a packet. + for { + if wasQueued := s.sentPacketHandler.QueueProbePacket(); !wasQueued { + break + } + sent, err := s.sendPacket() + if err != nil { + return err + } + if sent { + return nil + } } - s.logger.Debugf("Sending a retransmission for %#x as a probe packet.", p.PacketNumber) - - packets, err := s.packer.PackRetransmission(p) - if err != nil { - return err - } - for _, packet := range packets { - s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket()) - s.sendPackedPacket(packet) - } - return nil + // If there is nothing else to queue, make sure we send out something. + s.framer.QueueControlFrame(&wire.PingFrame{}) + _, err := s.sendPacket() + return err } func (s *session) sendPacket() (bool, error) { @@ -1205,7 +1181,7 @@ func (s *session) sendPacket() (bool, error) { if err != nil || packet == nil { return false, err } - s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket()) + s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(s.retransmissionQueue)) s.sendPackedPacket(packet) return true, nil } diff --git a/session_test.go b/session_test.go index f418da0d..ddd708ec 100644 --- a/session_test.go +++ b/session_test.go @@ -892,72 +892,14 @@ var _ = Describe("Session", func() { Expect(frames).To(Equal([]wire.Frame{&wire.DataBlockedFrame{DataLimit: 1337}})) }) - It("sends a retransmission and a regular packet in the same run", func() { - packetToRetransmit := &ackhandler.Packet{PacketNumber: 10} - retransmittedPacket := getPacket(123) - newPacket := getPacket(234) - sess.windowUpdateQueue.callback(&wire.MaxDataFrame{}) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().DequeuePacketForRetransmission().Return(packetToRetransmit) - sph.EXPECT().SendMode().Return(ackhandler.SendRetransmission) - sph.EXPECT().SendMode().Return(ackhandler.SendAny) - sph.EXPECT().ShouldSendNumPackets().Return(2) - sph.EXPECT().TimeUntilSend() - gomock.InOrder( - packer.EXPECT().PackRetransmission(packetToRetransmit).Return([]*packedPacket{retransmittedPacket}, nil), - sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) - }), - packer.EXPECT().PackPacket().Return(newPacket, nil), - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(234))) - }), - ) - sess.sentPacketHandler = sph - Expect(sess.sendPackets()).To(Succeed()) - Eventually(mconn.written).Should(HaveLen(2)) - }) - - It("sends multiple packets, if the retransmission is split", func() { - packet := &ackhandler.Packet{ - PacketNumber: 42, - Frames: []wire.Frame{&wire.StreamFrame{ - StreamID: 0x5, - Data: []byte("foobar"), - }}, - EncryptionLevel: protocol.Encryption1RTT, - } - retransmissions := []*packedPacket{getPacket(1337), getPacket(1338)} - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().DequeuePacketForRetransmission().Return(packet) - packer.EXPECT().PackRetransmission(packet).Return(retransmissions, nil) - gomock.InOrder( - sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(1337))) - }), - sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(1338))) - }), - ) - sess.sentPacketHandler = sph - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Eventually(mconn.written).Should(HaveLen(2)) - }) - It("sends a probe packet", func() { - packetToRetransmit := &ackhandler.Packet{PacketNumber: 0x42} - retransmittedPacket := getPacket(123) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend() sph.EXPECT().SendMode().Return(ackhandler.SendPTO) sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().DequeueProbePacket().Return(packetToRetransmit, nil) - packer.EXPECT().PackRetransmission(packetToRetransmit).Return([]*packedPacket{retransmittedPacket}, nil) + sph.EXPECT().QueueProbePacket() + packer.EXPECT().PackPacket().Return(getPacket(123), nil) sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) @@ -965,6 +907,25 @@ var _ = Describe("Session", func() { Expect(sess.sendPackets()).To(Succeed()) }) + It("sends a PING as a probe packet", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend() + sph.EXPECT().SendMode().Return(ackhandler.SendPTO) + sph.EXPECT().ShouldSendNumPackets().Return(1) + sph.EXPECT().QueueProbePacket().Return(false) + packer.EXPECT().PackPacket().Return(getPacket(123), nil) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) + }) + sess.sentPacketHandler = sph + Expect(sess.sendPackets()).To(Succeed()) + // We're using a mock packet packer in this test. + // We therefore need to test separately that the PING was actually queued. + frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) + Expect(frames).To(Equal([]wire.Frame{&wire.PingFrame{}})) + }) + It("doesn't send when the SentPacketHandler doesn't allow it", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() @@ -981,7 +942,6 @@ var _ = Describe("Session", func() { BeforeEach(func() { sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().DequeuePacketForRetransmission().AnyTimes() sess.sentPacketHandler = sph streamManager.EXPECT().CloseWithError(gomock.Any()) })