mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
move congestion control from session to ackhandler
this simplifies things a bit, especially in prep for #56
This commit is contained in:
parent
7a97f34fac
commit
a0228a9049
6 changed files with 202 additions and 261 deletions
|
@ -1,8 +1,6 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
@ -10,13 +8,15 @@ import (
|
|||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
SentPacket(packet *Packet) error
|
||||
ReceivedAck(ackFrame *frames.AckFrame) (time.Duration, []*Packet, []*Packet, error)
|
||||
ReceivedAck(ackFrame *frames.AckFrame) error
|
||||
|
||||
HasPacketForRetransmission() bool
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
|
||||
BytesInFlight() protocol.ByteCount
|
||||
GetLargestObserved() protocol.PacketNumber
|
||||
|
||||
AllowsSending() bool
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
|
|
|
@ -4,9 +4,11 @@ import (
|
|||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -42,13 +44,28 @@ type sentPacketHandler struct {
|
|||
stopWaitingManager StopWaitingManager
|
||||
|
||||
bytesInFlight protocol.ByteCount
|
||||
|
||||
rttStats *congestion.RTTStats
|
||||
congestion congestion.SendAlgorithm
|
||||
}
|
||||
|
||||
// NewSentPacketHandler creates a new sentPacketHandler
|
||||
func NewSentPacketHandler(stopWaitingManager StopWaitingManager) SentPacketHandler {
|
||||
rttStats := &congestion.RTTStats{}
|
||||
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
false, /* don't use reno since chromium doesn't (why?) */
|
||||
protocol.InitialCongestionWindow,
|
||||
protocol.DefaultMaxCongestionWindow,
|
||||
)
|
||||
|
||||
return &sentPacketHandler{
|
||||
packetHistory: make(map[protocol.PacketNumber]*Packet),
|
||||
stopWaitingManager: stopWaitingManager,
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,6 +128,15 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
|||
packet.Entropy = h.lastSentPacketEntropy
|
||||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
h.packetHistory[packet.PacketNumber] = packet
|
||||
|
||||
h.congestion.OnPacketSent(
|
||||
time.Now(),
|
||||
h.BytesInFlight(),
|
||||
packet.PacketNumber,
|
||||
packet.Length,
|
||||
true, /* TODO: is retransmittable */
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -144,39 +170,42 @@ func (h *sentPacketHandler) calculateExpectedEntropy(ackFrame *frames.AckFrame)
|
|||
}
|
||||
|
||||
// TODO: Simplify return types
|
||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) (time.Duration, []*Packet, []*Packet, error) {
|
||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) error {
|
||||
if ackFrame.LargestObserved > h.lastSentPacketNumber {
|
||||
return 0, nil, nil, errAckForUnsentPacket
|
||||
return errAckForUnsentPacket
|
||||
}
|
||||
|
||||
if ackFrame.LargestObserved <= h.LargestObserved { // duplicate or out-of-order AckFrame
|
||||
return 0, nil, nil, ErrDuplicateOrOutOfOrderAck
|
||||
return ErrDuplicateOrOutOfOrderAck
|
||||
}
|
||||
|
||||
expectedEntropy, err := h.calculateExpectedEntropy(ackFrame)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if byte(expectedEntropy) != ackFrame.Entropy {
|
||||
return 0, nil, nil, ErrEntropy
|
||||
return ErrEntropy
|
||||
}
|
||||
|
||||
// Entropy ok. Now actually process the ACK packet
|
||||
h.LargestObserved = ackFrame.LargestObserved
|
||||
highestInOrderAckedPacketNumber := ackFrame.GetHighestInOrderPacketNumber()
|
||||
|
||||
// Calculate the RTT
|
||||
// Update the RTT
|
||||
timeDelta := time.Now().Sub(h.packetHistory[h.LargestObserved].sendTime)
|
||||
// TODO: Don't always update RTT
|
||||
h.rttStats.UpdateRTT(timeDelta, ackFrame.DelayTime, time.Now())
|
||||
utils.Debugf("\tEstimated RTT: %dms", h.rttStats.SmoothedRTT()/time.Millisecond)
|
||||
|
||||
var ackedPackets []*Packet
|
||||
var lostPackets []*Packet
|
||||
var ackedPackets congestion.PacketVector
|
||||
var lostPackets congestion.PacketVector
|
||||
|
||||
// ACK all packets below the highestInOrderAckedPacketNumber
|
||||
for i := h.highestInOrderAckedPacketNumber; i <= highestInOrderAckedPacketNumber; i++ {
|
||||
p := h.ackPacket(i)
|
||||
if p != nil {
|
||||
ackedPackets = append(ackedPackets, p)
|
||||
ackedPackets = append(ackedPackets, congestion.PacketInfo{Number: p.PacketNumber, Length: p.Length})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -193,15 +222,15 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) (time.Duratio
|
|||
if nackRange.ContainsPacketNumber(i) {
|
||||
p, err := h.nackPacket(i)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
return err
|
||||
}
|
||||
if p != nil {
|
||||
lostPackets = append(lostPackets, p)
|
||||
lostPackets = append(lostPackets, congestion.PacketInfo{Number: p.PacketNumber, Length: p.Length})
|
||||
}
|
||||
} else {
|
||||
p := h.ackPacket(i)
|
||||
if p != nil {
|
||||
ackedPackets = append(ackedPackets, p)
|
||||
ackedPackets = append(ackedPackets, congestion.PacketInfo{Number: p.PacketNumber, Length: p.Length})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -209,7 +238,14 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) (time.Duratio
|
|||
|
||||
h.highestInOrderAckedPacketNumber = highestInOrderAckedPacketNumber
|
||||
|
||||
return timeDelta, ackedPackets, lostPackets, nil
|
||||
h.congestion.OnCongestionEvent(
|
||||
true, /* TODO: rtt updated */
|
||||
h.BytesInFlight(),
|
||||
ackedPackets,
|
||||
lostPackets,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) HasPacketForRetransmission() bool {
|
||||
|
@ -238,3 +274,7 @@ func (h *sentPacketHandler) BytesInFlight() protocol.ByteCount {
|
|||
func (h *sentPacketHandler) GetLargestObserved() protocol.PacketNumber {
|
||||
return h.LargestObserved
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) AllowsSending() bool {
|
||||
return h.BytesInFlight() <= h.congestion.GetCongestionWindow()
|
||||
}
|
||||
|
|
|
@ -3,15 +3,51 @@ package ackhandler
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockCongestion struct {
|
||||
nCalls int
|
||||
argsOnPacketSent []interface{}
|
||||
argsOnCongestionEvent []interface{}
|
||||
}
|
||||
|
||||
func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
||||
m.nCalls++
|
||||
m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount {
|
||||
m.nCalls++
|
||||
return protocol.DefaultTCPMSS
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets congestion.PacketVector, lostPackets congestion.PacketVector) {
|
||||
m.nCalls++
|
||||
m.argsOnCongestionEvent = []interface{}{rttUpdated, bytesInFlight, ackedPackets, lostPackets}
|
||||
}
|
||||
|
||||
func (m *mockCongestion) SetNumEmulatedConnections(n int) { panic("not implemented") }
|
||||
func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) { panic("not implemented") }
|
||||
func (m *mockCongestion) OnConnectionMigration() { panic("not implemented") }
|
||||
func (m *mockCongestion) RetransmissionDelay() time.Duration { panic("not implemented") }
|
||||
func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") }
|
||||
|
||||
var _ = Describe("SentPacketHandler", func() {
|
||||
var handler *sentPacketHandler
|
||||
var streamFrame frames.StreamFrame
|
||||
var (
|
||||
handler *sentPacketHandler
|
||||
streamFrame frames.StreamFrame
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
stopWaitingManager := NewStopWaitingManager()
|
||||
handler = NewSentPacketHandler(stopWaitingManager).(*sentPacketHandler)
|
||||
|
@ -88,7 +124,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
LargestObserved: 1,
|
||||
Entropy: byte(entropy),
|
||||
}
|
||||
_, _, _, err = handler.ReceivedAck(&ack)
|
||||
err = handler.ReceivedAck(&ack)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(0)))
|
||||
err = handler.SentPacket(&packet2)
|
||||
|
@ -200,7 +236,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
LargestObserved: 4,
|
||||
Entropy: 1,
|
||||
}
|
||||
_, _, _, err := handler.ReceivedAck(&ack)
|
||||
err := handler.ReceivedAck(&ack)
|
||||
Expect(err).To(MatchError(ErrEntropy))
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(6)))
|
||||
})
|
||||
|
@ -215,19 +251,13 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
LargestObserved: protocol.PacketNumber(largestObserved),
|
||||
Entropy: byte(entropy),
|
||||
}
|
||||
_, acked, lost, err := handler.ReceivedAck(&ack)
|
||||
err := handler.ReceivedAck(&ack)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(2)))
|
||||
Expect(handler.LargestObserved).To(Equal(protocol.PacketNumber(largestObserved)))
|
||||
Expect(handler.highestInOrderAckedPacketNumber).To(Equal(protocol.PacketNumber(largestObserved)))
|
||||
Expect(handler.packetHistory).ToNot(HaveKey(protocol.PacketNumber(largestObserved - 1)))
|
||||
Expect(handler.packetHistory).To(HaveKey(protocol.PacketNumber(largestObserved + 1)))
|
||||
Expect(acked).To(HaveLen(4))
|
||||
Expect(acked[0].PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(acked[1].PacketNumber).To(Equal(protocol.PacketNumber(2)))
|
||||
Expect(acked[2].PacketNumber).To(Equal(protocol.PacketNumber(3)))
|
||||
Expect(acked[3].PacketNumber).To(Equal(protocol.PacketNumber(4)))
|
||||
Expect(lost).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("completely processes an ACK with a NACK range", func() {
|
||||
|
@ -247,7 +277,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
{FirstPacketNumber: 3, LastPacketNumber: 3},
|
||||
},
|
||||
}
|
||||
_, acked, lost, err := handler.ReceivedAck(&ack)
|
||||
err := handler.ReceivedAck(&ack)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(2)))
|
||||
Expect(handler.LargestObserved).To(Equal(protocol.PacketNumber(largestObserved)))
|
||||
|
@ -257,12 +287,6 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Expect(handler.packetHistory).ToNot(HaveKey(protocol.PacketNumber(4)))
|
||||
Expect(handler.packetHistory).To(HaveKey(protocol.PacketNumber(5)))
|
||||
Expect(handler.packetHistory).ToNot(HaveKey(protocol.PacketNumber(6)))
|
||||
Expect(acked).To(HaveLen(4))
|
||||
Expect(acked[0].PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(acked[1].PacketNumber).To(Equal(protocol.PacketNumber(2)))
|
||||
Expect(acked[2].PacketNumber).To(Equal(protocol.PacketNumber(6)))
|
||||
Expect(acked[3].PacketNumber).To(Equal(protocol.PacketNumber(4)))
|
||||
Expect(lost).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -290,10 +314,10 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
ack := frames.AckFrame{
|
||||
LargestObserved: protocol.PacketNumber(largestObserved),
|
||||
}
|
||||
_, _, _, err := handler.ReceivedAck(&ack)
|
||||
err := handler.ReceivedAck(&ack)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3)))
|
||||
_, _, _, err = handler.ReceivedAck(&ack)
|
||||
err = handler.ReceivedAck(&ack)
|
||||
Expect(err).To(MatchError(ErrDuplicateOrOutOfOrderAck))
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3)))
|
||||
})
|
||||
|
@ -303,11 +327,11 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
ack := frames.AckFrame{
|
||||
LargestObserved: protocol.PacketNumber(largestObserved),
|
||||
}
|
||||
_, _, _, err := handler.ReceivedAck(&ack)
|
||||
err := handler.ReceivedAck(&ack)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3)))
|
||||
ack.LargestObserved--
|
||||
_, _, _, err = handler.ReceivedAck(&ack)
|
||||
err = handler.ReceivedAck(&ack)
|
||||
Expect(err).To(MatchError(ErrDuplicateOrOutOfOrderAck))
|
||||
Expect(handler.LargestObserved).To(Equal(protocol.PacketNumber(largestObserved)))
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(3)))
|
||||
|
@ -317,29 +341,39 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
ack := frames.AckFrame{
|
||||
LargestObserved: packets[len(packets)-1].PacketNumber + 1337,
|
||||
}
|
||||
_, _, _, err := handler.ReceivedAck(&ack)
|
||||
err := handler.ReceivedAck(&ack)
|
||||
Expect(err).To(MatchError(errAckForUnsentPacket))
|
||||
Expect(handler.highestInOrderAckedPacketNumber).To(Equal(protocol.PacketNumber(0)))
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(6)))
|
||||
})
|
||||
})
|
||||
|
||||
It("calculates the time delta", func() {
|
||||
now := time.Now()
|
||||
// First, fake the sent times of the first, second and last packet
|
||||
handler.packetHistory[1].sendTime = now.Add(-10 * time.Minute)
|
||||
handler.packetHistory[2].sendTime = now.Add(-5 * time.Minute)
|
||||
handler.packetHistory[6].sendTime = now.Add(-1 * time.Minute)
|
||||
// Now, check that the proper times are used when calculating the deltas
|
||||
d, _, _, err := handler.ReceivedAck(&frames.AckFrame{LargestObserved: 1})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(d).To(BeNumerically("~", 10*time.Minute, 1*time.Second))
|
||||
d, _, _, err = handler.ReceivedAck(&frames.AckFrame{LargestObserved: 2})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(d).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
|
||||
d, _, _, err = handler.ReceivedAck(&frames.AckFrame{LargestObserved: 6})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(d).To(BeNumerically("~", 1*time.Minute, 1*time.Second))
|
||||
Context("calculating RTT", func() {
|
||||
It("calculates the RTT", func() {
|
||||
now := time.Now()
|
||||
// First, fake the sent times of the first, second and last packet
|
||||
handler.packetHistory[1].sendTime = now.Add(-10 * time.Minute)
|
||||
handler.packetHistory[2].sendTime = now.Add(-5 * time.Minute)
|
||||
handler.packetHistory[6].sendTime = now.Add(-1 * time.Minute)
|
||||
// Now, check that the proper times are used when calculating the deltas
|
||||
err := handler.ReceivedAck(&frames.AckFrame{LargestObserved: 1})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second))
|
||||
err = handler.ReceivedAck(&frames.AckFrame{LargestObserved: 2})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
|
||||
err = handler.ReceivedAck(&frames.AckFrame{LargestObserved: 6})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 1*time.Minute, 1*time.Second))
|
||||
})
|
||||
|
||||
It("uses the DelayTime in the ack frame", func() {
|
||||
now := time.Now()
|
||||
handler.packetHistory[1].sendTime = now.Add(-10 * time.Minute)
|
||||
err := handler.ReceivedAck(&frames.AckFrame{LargestObserved: 1, DelayTime: 5 * time.Minute})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -411,7 +445,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
LargestObserved: 4,
|
||||
NackRanges: []frames.NackRange{{FirstPacketNumber: 3, LastPacketNumber: 3}},
|
||||
}
|
||||
_, _, _, err := handler.ReceivedAck(&ack)
|
||||
err := handler.ReceivedAck(&ack)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.highestInOrderAckedPacketNumber).To(Equal(protocol.PacketNumber(2)))
|
||||
handler.nackPacket(3) // this is the second NACK for this packet
|
||||
|
@ -434,7 +468,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
LargestObserved: 2,
|
||||
NackRanges: []frames.NackRange{{FirstPacketNumber: 1, LastPacketNumber: 1}},
|
||||
}
|
||||
_, _, _, err = handler.ReceivedAck(&ack)
|
||||
err = handler.ReceivedAck(&ack)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(1)))
|
||||
|
||||
|
@ -453,32 +487,70 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
ack = frames.AckFrame{
|
||||
LargestObserved: 3,
|
||||
}
|
||||
_, _, _, err = handler.ReceivedAck(&ack)
|
||||
err = handler.ReceivedAck(&ack)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.BytesInFlight()).To(Equal(protocol.ByteCount(0)))
|
||||
})
|
||||
})
|
||||
|
||||
It("returns lost packets in ReceivedAck()", func() {
|
||||
packet1 := Packet{PacketNumber: 1, Frames: []frames.Frame{&streamFrame}, EntropyBit: false, Length: 1}
|
||||
packet2 := Packet{PacketNumber: 2, Frames: []frames.Frame{&streamFrame}, EntropyBit: false, Length: 2}
|
||||
err := handler.SentPacket(&packet1)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = handler.SentPacket(&packet2)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Context("congestion", func() {
|
||||
var (
|
||||
cong *mockCongestion
|
||||
)
|
||||
|
||||
// First, simulate a NACK for packet number 1
|
||||
handler.nackPacket(1)
|
||||
// Now, simulate an ack frame
|
||||
ack := &frames.AckFrame{
|
||||
LargestObserved: 2,
|
||||
NackRanges: []frames.NackRange{{FirstPacketNumber: 1, LastPacketNumber: 1}},
|
||||
}
|
||||
_, acked, lost, err := handler.ReceivedAck(ack)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(acked).To(HaveLen(1))
|
||||
Expect(acked[0].PacketNumber).To(Equal(protocol.PacketNumber(2)))
|
||||
Expect(lost).To(HaveLen(1))
|
||||
Expect(lost[0].PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
||||
BeforeEach(func() {
|
||||
cong = &mockCongestion{}
|
||||
handler.congestion = cong
|
||||
})
|
||||
|
||||
It("should call OnSent", func() {
|
||||
p := &Packet{
|
||||
PacketNumber: 1,
|
||||
Frames: []frames.Frame{&frames.StreamFrame{StreamID: 5}},
|
||||
Length: 42,
|
||||
}
|
||||
err := handler.SentPacket(p)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cong.nCalls).To(Equal(1))
|
||||
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(42)))
|
||||
Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(42)))
|
||||
Expect(cong.argsOnPacketSent[4]).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should call OnCongestionEvent", func() {
|
||||
handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1})
|
||||
handler.SentPacket(&Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 2})
|
||||
handler.SentPacket(&Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 3})
|
||||
err := handler.ReceivedAck(&frames.AckFrame{
|
||||
LargestObserved: 3,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cong.nCalls).To(Equal(4)) // 3 * SentPacket + 1 * ReceivedAck
|
||||
// rttUpdated, bytesInFlight, ackedPackets, lostPackets
|
||||
Expect(cong.argsOnCongestionEvent[0]).To(BeTrue())
|
||||
Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(2)))
|
||||
Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{1, 1}, {3, 3}}))
|
||||
Expect(cong.argsOnCongestionEvent[3]).To(BeEmpty())
|
||||
|
||||
// Loose the packet
|
||||
handler.SentPacket(&Packet{PacketNumber: 4, Frames: []frames.Frame{}, Length: 4})
|
||||
err = handler.ReceivedAck(&frames.AckFrame{
|
||||
LargestObserved: 4,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{4, 4}}))
|
||||
Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{2, 2}}))
|
||||
})
|
||||
|
||||
It("allows or denies sending", func() {
|
||||
Expect(handler.AllowsSending()).To(BeTrue())
|
||||
err := handler.SentPacket(&Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: protocol.DefaultTCPMSS + 1})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.AllowsSending()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -2,7 +2,6 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
|
@ -15,29 +14,13 @@ import (
|
|||
|
||||
type mockSentPacketHandler struct{}
|
||||
|
||||
func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) (time.Duration, []*ackhandler.Packet, []*ackhandler.Packet, error) {
|
||||
return 0, nil, nil, nil
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) DequeuePacketForRetransmission() (packet *ackhandler.Packet) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) HasPacketForRetransmission() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) BytesInFlight() protocol.ByteCount {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) GetLargestObserved() protocol.PacketNumber {
|
||||
return 1
|
||||
}
|
||||
func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error { return nil }
|
||||
func (h *mockSentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) error { return nil }
|
||||
func (h *mockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet { return nil }
|
||||
func (h *mockSentPacketHandler) HasPacketForRetransmission() bool { return false }
|
||||
func (h *mockSentPacketHandler) BytesInFlight() protocol.ByteCount { return 0 }
|
||||
func (h *mockSentPacketHandler) GetLargestObserved() protocol.PacketNumber { return 1 }
|
||||
func (h *mockSentPacketHandler) AllowsSending() bool { panic("not implemented") }
|
||||
|
||||
func newMockSentPacketHandler() ackhandler.SentPacketHandler {
|
||||
return &mockSentPacketHandler{}
|
||||
|
|
54
session.go
54
session.go
|
@ -9,7 +9,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
|
@ -76,9 +75,6 @@ type Session struct {
|
|||
// Used to calculate the next packet number from the truncated wire
|
||||
// representation, and sent back in public reset packets
|
||||
lastRcvdPacketNumber protocol.PacketNumber
|
||||
|
||||
rttStats congestion.RTTStats
|
||||
congestion congestion.SendAlgorithm
|
||||
}
|
||||
|
||||
// newSession makes a new session
|
||||
|
@ -100,7 +96,6 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
|||
receivedPackets: make(chan receivedPacket, protocol.MaxSessionUnprocessedPackets),
|
||||
closeChan: make(chan struct{}, 1),
|
||||
sendingScheduled: make(chan struct{}, 1),
|
||||
rttStats: congestion.RTTStats{},
|
||||
connectionParametersManager: connectionParametersManager,
|
||||
undecryptablePackets: make([]receivedPacket, 0, protocol.MaxUndecryptablePackets),
|
||||
aeadChanged: make(chan struct{}, 1),
|
||||
|
@ -122,14 +117,6 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
|||
}
|
||||
session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v}
|
||||
|
||||
session.congestion = congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
&session.rttStats,
|
||||
false, /* don't use reno since chromium doesn't (why?) */
|
||||
protocol.InitialCongestionWindow,
|
||||
protocol.DefaultMaxCongestionWindow,
|
||||
)
|
||||
|
||||
return session, err
|
||||
}
|
||||
|
||||
|
@ -336,33 +323,12 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error {
|
|||
}
|
||||
|
||||
func (s *Session) handleAckFrame(frame *frames.AckFrame) error {
|
||||
duration, acked, lost, err := s.sentPacketHandler.ReceivedAck(frame)
|
||||
if err != nil {
|
||||
|
||||
if err := s.sentPacketHandler.ReceivedAck(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Don't always update RTT
|
||||
s.rttStats.UpdateRTT(duration, frame.DelayTime, time.Now())
|
||||
|
||||
cAcked := make(congestion.PacketVector, len(acked))
|
||||
for i, v := range acked {
|
||||
cAcked[i].Number = v.PacketNumber
|
||||
cAcked[i].Length = v.Length
|
||||
}
|
||||
cLost := make(congestion.PacketVector, len(lost))
|
||||
for i, v := range lost {
|
||||
cLost[i].Number = v.PacketNumber
|
||||
cLost[i].Length = v.Length
|
||||
}
|
||||
s.congestion.OnCongestionEvent(
|
||||
true, /* rtt updated */
|
||||
s.sentPacketHandler.BytesInFlight(),
|
||||
cAcked,
|
||||
cLost,
|
||||
)
|
||||
|
||||
utils.Debugf("\t<- %#v", frame)
|
||||
utils.Debugf("\tEstimated RTT: %dms", s.rttStats.SmoothedRTT()/time.Millisecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -409,7 +375,7 @@ func (s *Session) closeStreamsWithError(err error) {
|
|||
}
|
||||
|
||||
func (s *Session) maybeSendPacket() error {
|
||||
if !s.congestionAllowsSending() {
|
||||
if !s.sentPacketHandler.AllowsSending() {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -452,7 +418,7 @@ func (s *Session) maybeSendPacket() error {
|
|||
}
|
||||
|
||||
func (s *Session) sendPacket() error {
|
||||
if !s.congestionAllowsSending() {
|
||||
if !s.sentPacketHandler.AllowsSending() {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -514,14 +480,6 @@ func (s *Session) sendPacket() error {
|
|||
return err
|
||||
}
|
||||
|
||||
s.congestion.OnPacketSent(
|
||||
time.Now(),
|
||||
s.sentPacketHandler.BytesInFlight(),
|
||||
packet.number,
|
||||
protocol.ByteCount(len(packet.raw)),
|
||||
true, /* TODO: is retransmittable */
|
||||
)
|
||||
|
||||
s.stopWaitingManager.SentStopWaitingWithPacket(packet.number)
|
||||
|
||||
utils.Debugf("-> Sending packet 0x%x (%d bytes)", packet.number, len(packet.raw))
|
||||
|
@ -635,10 +593,6 @@ func (s *Session) scheduleSending() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Session) congestionAllowsSending() bool {
|
||||
return s.sentPacketHandler.BytesInFlight() <= s.congestion.GetCongestionWindow()
|
||||
}
|
||||
|
||||
func (s *Session) tryQueueingUndecryptablePacket(p receivedPacket) {
|
||||
utils.Debugf("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber)
|
||||
if len(s.undecryptablePackets)+1 >= protocol.MaxUndecryptablePackets {
|
||||
|
|
110
session_test.go
110
session_test.go
|
@ -12,7 +12,6 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
|
@ -33,52 +32,6 @@ func (m *mockConnection) write(p []byte) error {
|
|||
|
||||
func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {}
|
||||
|
||||
type mockCongestion struct {
|
||||
nCalls int
|
||||
argsOnPacketSent []interface{}
|
||||
argsOnCongestionEvent []interface{}
|
||||
}
|
||||
|
||||
func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
||||
m.nCalls++
|
||||
m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount {
|
||||
m.nCalls++
|
||||
return protocol.DefaultTCPMSS
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets congestion.PacketVector, lostPackets congestion.PacketVector) {
|
||||
m.nCalls++
|
||||
m.argsOnCongestionEvent = []interface{}{rttUpdated, bytesInFlight, ackedPackets, lostPackets}
|
||||
}
|
||||
|
||||
func (m *mockCongestion) SetNumEmulatedConnections(n int) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnConnectionMigration() {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) RetransmissionDelay() time.Duration {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
var _ = Describe("Session", func() {
|
||||
var (
|
||||
session *Session
|
||||
|
@ -604,68 +557,7 @@ var _ = Describe("Session", func() {
|
|||
close(done)
|
||||
}, 0.5)
|
||||
|
||||
Context("congestion", func() {
|
||||
var (
|
||||
cong *mockCongestion
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
cong = &mockCongestion{}
|
||||
session.congestion = cong
|
||||
})
|
||||
|
||||
It("should call OnSent", func() {
|
||||
session.queueStreamFrame(&frames.StreamFrame{StreamID: 5})
|
||||
session.sendPacket()
|
||||
Expect(cong.nCalls).To(Equal(2)) // OnPacketSent + GetCongestionWindow
|
||||
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(25)))
|
||||
Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(25)))
|
||||
Expect(cong.argsOnPacketSent[4]).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should call OnCongestionEvent", func() {
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1})
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 2})
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 3})
|
||||
err := session.handleAckFrame(&frames.AckFrame{
|
||||
LargestObserved: 3,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cong.nCalls).To(Equal(1))
|
||||
//(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets cong.PacketVector, lostPackets cong.PacketVector)
|
||||
Expect(cong.argsOnCongestionEvent[0]).To(BeTrue())
|
||||
Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(2)))
|
||||
Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{1, 1}, {3, 3}}))
|
||||
Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{}))
|
||||
|
||||
// Loose the packet
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 4, Frames: []frames.Frame{}, Length: 4})
|
||||
err = session.handleAckFrame(&frames.AckFrame{
|
||||
LargestObserved: 4,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 5, Frames: []frames.Frame{}, Length: 5})
|
||||
err = session.handleAckFrame(&frames.AckFrame{
|
||||
LargestObserved: 5,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 6, Frames: []frames.Frame{}, Length: 6})
|
||||
err = session.handleAckFrame(&frames.AckFrame{
|
||||
LargestObserved: 6,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{6, 6}}))
|
||||
Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{2, 2}}))
|
||||
})
|
||||
})
|
||||
|
||||
It("stored up to MaxSessionUnprocessedPackets packets", func(done Done) {
|
||||
It("stores up to MaxSessionUnprocessedPackets packets", func(done Done) {
|
||||
// Nothing here should block
|
||||
for i := 0; i < protocol.MaxSessionUnprocessedPackets+10; i++ {
|
||||
session.handlePacket(nil, nil, nil)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue