remove version dependency of packet number inferring

This commit is contained in:
Marten Seemann 2018-11-26 16:04:29 +07:00
parent 4145bcc8a7
commit faed2ba30a
5 changed files with 38 additions and 43 deletions

View file

@ -75,12 +75,10 @@ type sentPacketHandler struct {
alarm time.Time alarm time.Time
logger utils.Logger logger utils.Logger
version protocol.VersionNumber
} }
// NewSentPacketHandler creates a new sentPacketHandler // NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler { func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
congestion := congestion.NewCubicSender( congestion := congestion.NewCubicSender(
congestion.DefaultClock{}, congestion.DefaultClock{},
rttStats, rttStats,
@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
rttStats: rttStats, rttStats: rttStats,
congestion: congestion, congestion: congestion,
logger: logger, logger: logger,
version: version,
} }
} }
@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) { func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
pn := h.packetNumberGenerator.Peek() pn := h.packetNumberGenerator.Peek()
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version) return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked())
} }
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber { func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {

View file

@ -49,7 +49,7 @@ var _ = Describe("SentPacketHandler", func() {
BeforeEach(func() { BeforeEach(func() {
rttStats := &congestion.RTTStats{} rttStats := &congestion.RTTStats{}
handler = NewSentPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*sentPacketHandler) handler = NewSentPacketHandler(rttStats, utils.DefaultLogger).(*sentPacketHandler)
handler.SetHandshakeComplete() handler.SetHandshakeComplete()
streamFrame = wire.StreamFrame{ streamFrame = wire.StreamFrame{
StreamID: 5, StreamID: 5,

View file

@ -5,7 +5,6 @@ func InferPacketNumber(
packetNumberLength PacketNumberLen, packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber, lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber, wirePacketNumber PacketNumber,
version VersionNumber,
) PacketNumber { ) PacketNumber {
var epochDelta PacketNumber var epochDelta PacketNumber
switch packetNumberLength { switch packetNumberLength {
@ -42,7 +41,7 @@ func delta(a, b PacketNumber) PacketNumber {
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header // GetPacketNumberLengthForHeader gets the length of the packet number for the public header
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen { func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked) diff := uint64(packetNumber - leastUnacked)
if diff < (1 << (14 - 1)) { if diff < (1 << (14 - 1)) {
return PacketNumberLen2 return PacketNumberLen2

View file

@ -11,7 +11,7 @@ import (
// Tests taken and extended from chrome // Tests taken and extended from chrome
var _ = Describe("packet number calculation", func() { var _ = Describe("packet number calculation", func() {
Context("infering a packet number", func() { Context("infering a packet number", func() {
getEpoch := func(len PacketNumberLen, v VersionNumber) uint64 { getEpoch := func(len PacketNumberLen) uint64 {
switch len { switch len {
case PacketNumberLen1: case PacketNumberLen1:
return uint64(1) << 7 return uint64(1) << 7
@ -24,36 +24,36 @@ var _ = Describe("packet number calculation", func() {
} }
return uint64(1) << (len * 8) return uint64(1) << (len * 8)
} }
check := func(length PacketNumberLen, expected, last uint64, v VersionNumber) { check := func(length PacketNumberLen, expected, last uint64) {
epoch := getEpoch(length, v) epoch := getEpoch(length)
epochMask := epoch - 1 epochMask := epoch - 1
wirePacketNumber := expected & epochMask wirePacketNumber := expected & epochMask
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber), v)).To(Equal(PacketNumber(expected))) Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected)))
} }
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} { for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
length := l length := l
Context(fmt.Sprintf("with %d bytes", length), func() { Context(fmt.Sprintf("with %d bytes", length), func() {
epoch := getEpoch(length, VersionWhatever) epoch := getEpoch(length)
epochMask := epoch - 1 epochMask := epoch - 1
It("works near epoch start", func() { It("works near epoch start", func() {
// A few quick manual sanity check // A few quick manual sanity check
check(length, 1, 0, VersionWhatever) check(length, 1, 0)
check(length, epoch+1, epochMask, VersionWhatever) check(length, epoch+1, epochMask)
check(length, epoch, epochMask, VersionWhatever) check(length, epoch, epochMask)
// Cases where the last number was close to the start of the range. // Cases where the last number was close to the start of the range.
for last := uint64(0); last < 10; last++ { for last := uint64(0); last < 10; last++ {
// Small numbers should not wrap (even if they're out of order). // Small numbers should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
check(length, j, last, VersionWhatever) check(length, j, last)
} }
// Large numbers should not wrap either (because we're near 0 already). // Large numbers should not wrap either (because we're near 0 already).
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, VersionWhatever) check(length, epoch-1-j, last)
} }
} }
}) })
@ -65,12 +65,12 @@ var _ = Describe("packet number calculation", func() {
// Small numbers should wrap. // Small numbers should wrap.
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
check(length, epoch+j, last, VersionWhatever) check(length, epoch+j, last)
} }
// Large numbers should not (even if they're out of order). // Large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, VersionWhatever) check(length, epoch-1-j, last)
} }
} }
}) })
@ -85,13 +85,13 @@ var _ = Describe("packet number calculation", func() {
last := curEpoch + i last := curEpoch + i
// Small number should not wrap (even if they're out of order). // Small number should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
check(length, curEpoch+j, last, VersionWhatever) check(length, curEpoch+j, last)
} }
// But large numbers should reverse wrap. // But large numbers should reverse wrap.
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j num := epoch - 1 - j
check(length, prevEpoch+num, last, VersionWhatever) check(length, prevEpoch+num, last)
} }
} }
}) })
@ -105,13 +105,13 @@ var _ = Describe("packet number calculation", func() {
// Small numbers should wrap. // Small numbers should wrap.
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
check(length, nextEpoch+j, last, VersionWhatever) check(length, nextEpoch+j, last)
} }
// but large numbers should not (even if they're out of order). // but large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j num := epoch - 1 - j
check(length, curEpoch+num, last, VersionWhatever) check(length, curEpoch+num, last)
} }
} }
}) })
@ -128,13 +128,13 @@ var _ = Describe("packet number calculation", func() {
// Small numbers should not wrap, because they have nowhere to go. // Small numbers should not wrap, because they have nowhere to go.
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
check(length, maxEpoch+j, last, VersionWhatever) check(length, maxEpoch+j, last)
} }
// Large numbers should not wrap either. // Large numbers should not wrap either.
for j := uint64(0); j < 10; j++ { for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j num := epoch - 1 - j
check(length, maxEpoch+num, last, VersionWhatever) check(length, maxEpoch+num, last)
} }
} }
}) })
@ -142,17 +142,17 @@ var _ = Describe("packet number calculation", func() {
Context("shortening a packet number for the header", func() { Context("shortening a packet number for the header", func() {
Context("shortening", func() { Context("shortening", func() {
It("sends out low packet numbers as 2 byte", func() { It("sends out low packet numbers as 2 byte", func() {
length := GetPacketNumberLengthForHeader(4, 2, VersionWhatever) length := GetPacketNumberLengthForHeader(4, 2)
Expect(length).To(Equal(PacketNumberLen2)) Expect(length).To(Equal(PacketNumberLen2))
}) })
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() { It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, VersionWhatever) length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1)
Expect(length).To(Equal(PacketNumberLen2)) Expect(length).To(Equal(PacketNumberLen2))
}) })
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() { It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
length := GetPacketNumberLengthForHeader(40000, 2, VersionWhatever) length := GetPacketNumberLengthForHeader(40000, 2)
Expect(length).To(Equal(PacketNumberLen4)) Expect(length).To(Equal(PacketNumberLen4))
}) })
}) })
@ -162,10 +162,10 @@ var _ = Describe("packet number calculation", func() {
for i := uint64(1); i < 10000; i++ { for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i) packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1) leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever) length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever) inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
Expect(inferedPacketNumber).To(Equal(packetNumber)) Expect(inferedPacketNumber).To(Equal(packetNumber))
} }
}) })
@ -174,28 +174,28 @@ var _ = Describe("packet number calculation", func() {
for i := uint64(1); i < 10000; i++ { for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i) packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i / 2) leastUnacked := PacketNumber(i / 2)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever) length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
epochMask := getEpoch(length, VersionWhatever) - 1 epochMask := getEpoch(length) - 1
wirePacketNumber := uint64(packetNumber) & epochMask wirePacketNumber := uint64(packetNumber) & epochMask
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever) inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
Expect(inferedPacketNumber).To(Equal(packetNumber)) Expect(inferedPacketNumber).To(Equal(packetNumber))
} }
}) })
It("also works for larger packet numbers", func() { It("also works for larger packet numbers", func() {
var increment uint64 var increment uint64
for i := uint64(1); i < getEpoch(PacketNumberLen4, VersionWhatever); i += increment { for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment {
packetNumber := PacketNumber(i) packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1) leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever) length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
epochMask := getEpoch(length, VersionWhatever) - 1 epochMask := getEpoch(length) - 1
wirePacketNumber := uint64(packetNumber) & epochMask wirePacketNumber := uint64(packetNumber) & epochMask
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever) inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
Expect(inferedPacketNumber).To(Equal(packetNumber)) Expect(inferedPacketNumber).To(Equal(packetNumber))
increment = getEpoch(length, VersionWhatever) / 8 increment = getEpoch(length) / 8
} }
}) })
@ -203,10 +203,10 @@ var _ = Describe("packet number calculation", func() {
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) { for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
packetNumber := PacketNumber(i) packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i - 1000) leastUnacked := PacketNumber(i - 1000)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever) length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever) inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber))
Expect(inferedPacketNumber).To(Equal(packetNumber)) Expect(inferedPacketNumber).To(Equal(packetNumber))
} }
}) })

View file

@ -289,7 +289,7 @@ var newClientSession = func(
func (s *session) preSetup() { func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{} s.rttStats = &congestion.RTTStats{}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version) s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController( s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.InitialMaxData, protocol.InitialMaxData,
@ -512,7 +512,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
hdr.PacketNumberLen, hdr.PacketNumberLen,
s.largestRcvdPacketNumber, s.largestRcvdPacketNumber,
hdr.PacketNumber, hdr.PacketNumber,
s.version,
) )
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)