mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 05:07:36 +03:00
remove version dependency of packet number inferring
This commit is contained in:
parent
4145bcc8a7
commit
faed2ba30a
5 changed files with 38 additions and 43 deletions
|
@ -75,12 +75,10 @@ type sentPacketHandler struct {
|
||||||
alarm time.Time
|
alarm time.Time
|
||||||
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
|
|
||||||
version protocol.VersionNumber
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSentPacketHandler creates a new sentPacketHandler
|
// NewSentPacketHandler creates a new sentPacketHandler
|
||||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
|
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
|
||||||
congestion := congestion.NewCubicSender(
|
congestion := congestion.NewCubicSender(
|
||||||
congestion.DefaultClock{},
|
congestion.DefaultClock{},
|
||||||
rttStats,
|
rttStats,
|
||||||
|
@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
congestion: congestion,
|
congestion: congestion,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
version: version,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
|
||||||
|
|
||||||
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||||
pn := h.packetNumberGenerator.Peek()
|
pn := h.packetNumberGenerator.Peek()
|
||||||
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
|
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
||||||
|
|
|
@ -49,7 +49,7 @@ var _ = Describe("SentPacketHandler", func() {
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
rttStats := &congestion.RTTStats{}
|
rttStats := &congestion.RTTStats{}
|
||||||
handler = NewSentPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*sentPacketHandler)
|
handler = NewSentPacketHandler(rttStats, utils.DefaultLogger).(*sentPacketHandler)
|
||||||
handler.SetHandshakeComplete()
|
handler.SetHandshakeComplete()
|
||||||
streamFrame = wire.StreamFrame{
|
streamFrame = wire.StreamFrame{
|
||||||
StreamID: 5,
|
StreamID: 5,
|
||||||
|
|
|
@ -5,7 +5,6 @@ func InferPacketNumber(
|
||||||
packetNumberLength PacketNumberLen,
|
packetNumberLength PacketNumberLen,
|
||||||
lastPacketNumber PacketNumber,
|
lastPacketNumber PacketNumber,
|
||||||
wirePacketNumber PacketNumber,
|
wirePacketNumber PacketNumber,
|
||||||
version VersionNumber,
|
|
||||||
) PacketNumber {
|
) PacketNumber {
|
||||||
var epochDelta PacketNumber
|
var epochDelta PacketNumber
|
||||||
switch packetNumberLength {
|
switch packetNumberLength {
|
||||||
|
@ -42,7 +41,7 @@ func delta(a, b PacketNumber) PacketNumber {
|
||||||
|
|
||||||
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
|
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
|
||||||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
|
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
|
||||||
diff := uint64(packetNumber - leastUnacked)
|
diff := uint64(packetNumber - leastUnacked)
|
||||||
if diff < (1 << (14 - 1)) {
|
if diff < (1 << (14 - 1)) {
|
||||||
return PacketNumberLen2
|
return PacketNumberLen2
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
// Tests taken and extended from chrome
|
// Tests taken and extended from chrome
|
||||||
var _ = Describe("packet number calculation", func() {
|
var _ = Describe("packet number calculation", func() {
|
||||||
Context("infering a packet number", func() {
|
Context("infering a packet number", func() {
|
||||||
getEpoch := func(len PacketNumberLen, v VersionNumber) uint64 {
|
getEpoch := func(len PacketNumberLen) uint64 {
|
||||||
switch len {
|
switch len {
|
||||||
case PacketNumberLen1:
|
case PacketNumberLen1:
|
||||||
return uint64(1) << 7
|
return uint64(1) << 7
|
||||||
|
@ -24,36 +24,36 @@ var _ = Describe("packet number calculation", func() {
|
||||||
}
|
}
|
||||||
return uint64(1) << (len * 8)
|
return uint64(1) << (len * 8)
|
||||||
}
|
}
|
||||||
check := func(length PacketNumberLen, expected, last uint64, v VersionNumber) {
|
check := func(length PacketNumberLen, expected, last uint64) {
|
||||||
epoch := getEpoch(length, v)
|
epoch := getEpoch(length)
|
||||||
epochMask := epoch - 1
|
epochMask := epoch - 1
|
||||||
wirePacketNumber := expected & epochMask
|
wirePacketNumber := expected & epochMask
|
||||||
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber), v)).To(Equal(PacketNumber(expected)))
|
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected)))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
|
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
|
||||||
length := l
|
length := l
|
||||||
|
|
||||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||||
epoch := getEpoch(length, VersionWhatever)
|
epoch := getEpoch(length)
|
||||||
epochMask := epoch - 1
|
epochMask := epoch - 1
|
||||||
|
|
||||||
It("works near epoch start", func() {
|
It("works near epoch start", func() {
|
||||||
// A few quick manual sanity check
|
// A few quick manual sanity check
|
||||||
check(length, 1, 0, VersionWhatever)
|
check(length, 1, 0)
|
||||||
check(length, epoch+1, epochMask, VersionWhatever)
|
check(length, epoch+1, epochMask)
|
||||||
check(length, epoch, epochMask, VersionWhatever)
|
check(length, epoch, epochMask)
|
||||||
|
|
||||||
// Cases where the last number was close to the start of the range.
|
// Cases where the last number was close to the start of the range.
|
||||||
for last := uint64(0); last < 10; last++ {
|
for last := uint64(0); last < 10; last++ {
|
||||||
// Small numbers should not wrap (even if they're out of order).
|
// Small numbers should not wrap (even if they're out of order).
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
check(length, j, last, VersionWhatever)
|
check(length, j, last)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Large numbers should not wrap either (because we're near 0 already).
|
// Large numbers should not wrap either (because we're near 0 already).
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
check(length, epoch-1-j, last, VersionWhatever)
|
check(length, epoch-1-j, last)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -65,12 +65,12 @@ var _ = Describe("packet number calculation", func() {
|
||||||
|
|
||||||
// Small numbers should wrap.
|
// Small numbers should wrap.
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
check(length, epoch+j, last, VersionWhatever)
|
check(length, epoch+j, last)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Large numbers should not (even if they're out of order).
|
// Large numbers should not (even if they're out of order).
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
check(length, epoch-1-j, last, VersionWhatever)
|
check(length, epoch-1-j, last)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -85,13 +85,13 @@ var _ = Describe("packet number calculation", func() {
|
||||||
last := curEpoch + i
|
last := curEpoch + i
|
||||||
// Small number should not wrap (even if they're out of order).
|
// Small number should not wrap (even if they're out of order).
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
check(length, curEpoch+j, last, VersionWhatever)
|
check(length, curEpoch+j, last)
|
||||||
}
|
}
|
||||||
|
|
||||||
// But large numbers should reverse wrap.
|
// But large numbers should reverse wrap.
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
num := epoch - 1 - j
|
num := epoch - 1 - j
|
||||||
check(length, prevEpoch+num, last, VersionWhatever)
|
check(length, prevEpoch+num, last)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -105,13 +105,13 @@ var _ = Describe("packet number calculation", func() {
|
||||||
|
|
||||||
// Small numbers should wrap.
|
// Small numbers should wrap.
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
check(length, nextEpoch+j, last, VersionWhatever)
|
check(length, nextEpoch+j, last)
|
||||||
}
|
}
|
||||||
|
|
||||||
// but large numbers should not (even if they're out of order).
|
// but large numbers should not (even if they're out of order).
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
num := epoch - 1 - j
|
num := epoch - 1 - j
|
||||||
check(length, curEpoch+num, last, VersionWhatever)
|
check(length, curEpoch+num, last)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -128,13 +128,13 @@ var _ = Describe("packet number calculation", func() {
|
||||||
|
|
||||||
// Small numbers should not wrap, because they have nowhere to go.
|
// Small numbers should not wrap, because they have nowhere to go.
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
check(length, maxEpoch+j, last, VersionWhatever)
|
check(length, maxEpoch+j, last)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Large numbers should not wrap either.
|
// Large numbers should not wrap either.
|
||||||
for j := uint64(0); j < 10; j++ {
|
for j := uint64(0); j < 10; j++ {
|
||||||
num := epoch - 1 - j
|
num := epoch - 1 - j
|
||||||
check(length, maxEpoch+num, last, VersionWhatever)
|
check(length, maxEpoch+num, last)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -142,17 +142,17 @@ var _ = Describe("packet number calculation", func() {
|
||||||
Context("shortening a packet number for the header", func() {
|
Context("shortening a packet number for the header", func() {
|
||||||
Context("shortening", func() {
|
Context("shortening", func() {
|
||||||
It("sends out low packet numbers as 2 byte", func() {
|
It("sends out low packet numbers as 2 byte", func() {
|
||||||
length := GetPacketNumberLengthForHeader(4, 2, VersionWhatever)
|
length := GetPacketNumberLengthForHeader(4, 2)
|
||||||
Expect(length).To(Equal(PacketNumberLen2))
|
Expect(length).To(Equal(PacketNumberLen2))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
|
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
|
||||||
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, VersionWhatever)
|
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1)
|
||||||
Expect(length).To(Equal(PacketNumberLen2))
|
Expect(length).To(Equal(PacketNumberLen2))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
|
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
|
||||||
length := GetPacketNumberLengthForHeader(40000, 2, VersionWhatever)
|
length := GetPacketNumberLengthForHeader(40000, 2)
|
||||||
Expect(length).To(Equal(PacketNumberLen4))
|
Expect(length).To(Equal(PacketNumberLen4))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -162,10 +162,10 @@ var _ = Describe("packet number calculation", func() {
|
||||||
for i := uint64(1); i < 10000; i++ {
|
for i := uint64(1); i < 10000; i++ {
|
||||||
packetNumber := PacketNumber(i)
|
packetNumber := PacketNumber(i)
|
||||||
leastUnacked := PacketNumber(1)
|
leastUnacked := PacketNumber(1)
|
||||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
|
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||||
|
|
||||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
|
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -174,28 +174,28 @@ var _ = Describe("packet number calculation", func() {
|
||||||
for i := uint64(1); i < 10000; i++ {
|
for i := uint64(1); i < 10000; i++ {
|
||||||
packetNumber := PacketNumber(i)
|
packetNumber := PacketNumber(i)
|
||||||
leastUnacked := PacketNumber(i / 2)
|
leastUnacked := PacketNumber(i / 2)
|
||||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
|
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||||
epochMask := getEpoch(length, VersionWhatever) - 1
|
epochMask := getEpoch(length) - 1
|
||||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
wirePacketNumber := uint64(packetNumber) & epochMask
|
||||||
|
|
||||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
|
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("also works for larger packet numbers", func() {
|
It("also works for larger packet numbers", func() {
|
||||||
var increment uint64
|
var increment uint64
|
||||||
for i := uint64(1); i < getEpoch(PacketNumberLen4, VersionWhatever); i += increment {
|
for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment {
|
||||||
packetNumber := PacketNumber(i)
|
packetNumber := PacketNumber(i)
|
||||||
leastUnacked := PacketNumber(1)
|
leastUnacked := PacketNumber(1)
|
||||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
|
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||||
epochMask := getEpoch(length, VersionWhatever) - 1
|
epochMask := getEpoch(length) - 1
|
||||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
wirePacketNumber := uint64(packetNumber) & epochMask
|
||||||
|
|
||||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
|
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||||
|
|
||||||
increment = getEpoch(length, VersionWhatever) / 8
|
increment = getEpoch(length) / 8
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -203,10 +203,10 @@ var _ = Describe("packet number calculation", func() {
|
||||||
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
|
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
|
||||||
packetNumber := PacketNumber(i)
|
packetNumber := PacketNumber(i)
|
||||||
leastUnacked := PacketNumber(i - 1000)
|
leastUnacked := PacketNumber(i - 1000)
|
||||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
|
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
|
||||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||||
|
|
||||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
|
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
|
||||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -289,7 +289,7 @@ var newClientSession = func(
|
||||||
|
|
||||||
func (s *session) preSetup() {
|
func (s *session) preSetup() {
|
||||||
s.rttStats = &congestion.RTTStats{}
|
s.rttStats = &congestion.RTTStats{}
|
||||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version)
|
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
|
||||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
|
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
|
||||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||||
protocol.InitialMaxData,
|
protocol.InitialMaxData,
|
||||||
|
@ -512,7 +512,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
|
||||||
hdr.PacketNumberLen,
|
hdr.PacketNumberLen,
|
||||||
s.largestRcvdPacketNumber,
|
s.largestRcvdPacketNumber,
|
||||||
hdr.PacketNumber,
|
hdr.PacketNumber,
|
||||||
s.version,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
|
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue