diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 0e253354..f155bdd9 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -75,12 +75,10 @@ type sentPacketHandler struct { alarm time.Time logger utils.Logger - - version protocol.VersionNumber } // NewSentPacketHandler creates a new sentPacketHandler -func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler { +func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler { congestion := congestion.NewCubicSender( congestion.DefaultClock{}, rttStats, @@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve rttStats: rttStats, congestion: congestion, logger: logger, - version: version, } } @@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) { func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) { pn := h.packetNumberGenerator.Peek() - return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version) + return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked()) } func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 54a4857c..fabdd133 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -49,7 +49,7 @@ var _ = Describe("SentPacketHandler", func() { BeforeEach(func() { rttStats := &congestion.RTTStats{} - handler = NewSentPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*sentPacketHandler) + handler = NewSentPacketHandler(rttStats, utils.DefaultLogger).(*sentPacketHandler) handler.SetHandshakeComplete() streamFrame = wire.StreamFrame{ StreamID: 5, diff --git a/internal/protocol/packet_number.go b/internal/protocol/packet_number.go index e32d6baa..e7b3b384 100644 --- a/internal/protocol/packet_number.go +++ b/internal/protocol/packet_number.go @@ -5,7 +5,6 @@ func InferPacketNumber( packetNumberLength PacketNumberLen, lastPacketNumber PacketNumber, wirePacketNumber PacketNumber, - version VersionNumber, ) PacketNumber { var epochDelta PacketNumber switch packetNumberLength { @@ -42,7 +41,7 @@ func delta(a, b PacketNumber) PacketNumber { // GetPacketNumberLengthForHeader gets the length of the packet number for the public header // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances -func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen { +func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen { diff := uint64(packetNumber - leastUnacked) if diff < (1 << (14 - 1)) { return PacketNumberLen2 diff --git a/internal/protocol/packet_number_test.go b/internal/protocol/packet_number_test.go index 6d23d7a3..49f3b142 100644 --- a/internal/protocol/packet_number_test.go +++ b/internal/protocol/packet_number_test.go @@ -11,7 +11,7 @@ import ( // Tests taken and extended from chrome var _ = Describe("packet number calculation", func() { Context("infering a packet number", func() { - getEpoch := func(len PacketNumberLen, v VersionNumber) uint64 { + getEpoch := func(len PacketNumberLen) uint64 { switch len { case PacketNumberLen1: return uint64(1) << 7 @@ -24,36 +24,36 @@ var _ = Describe("packet number calculation", func() { } return uint64(1) << (len * 8) } - check := func(length PacketNumberLen, expected, last uint64, v VersionNumber) { - epoch := getEpoch(length, v) + check := func(length PacketNumberLen, expected, last uint64) { + epoch := getEpoch(length) epochMask := epoch - 1 wirePacketNumber := expected & epochMask - Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber), v)).To(Equal(PacketNumber(expected))) + Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected))) } for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} { length := l Context(fmt.Sprintf("with %d bytes", length), func() { - epoch := getEpoch(length, VersionWhatever) + epoch := getEpoch(length) epochMask := epoch - 1 It("works near epoch start", func() { // A few quick manual sanity check - check(length, 1, 0, VersionWhatever) - check(length, epoch+1, epochMask, VersionWhatever) - check(length, epoch, epochMask, VersionWhatever) + check(length, 1, 0) + check(length, epoch+1, epochMask) + check(length, epoch, epochMask) // Cases where the last number was close to the start of the range. for last := uint64(0); last < 10; last++ { // Small numbers should not wrap (even if they're out of order). for j := uint64(0); j < 10; j++ { - check(length, j, last, VersionWhatever) + check(length, j, last) } // Large numbers should not wrap either (because we're near 0 already). for j := uint64(0); j < 10; j++ { - check(length, epoch-1-j, last, VersionWhatever) + check(length, epoch-1-j, last) } } }) @@ -65,12 +65,12 @@ var _ = Describe("packet number calculation", func() { // Small numbers should wrap. for j := uint64(0); j < 10; j++ { - check(length, epoch+j, last, VersionWhatever) + check(length, epoch+j, last) } // Large numbers should not (even if they're out of order). for j := uint64(0); j < 10; j++ { - check(length, epoch-1-j, last, VersionWhatever) + check(length, epoch-1-j, last) } } }) @@ -85,13 +85,13 @@ var _ = Describe("packet number calculation", func() { last := curEpoch + i // Small number should not wrap (even if they're out of order). for j := uint64(0); j < 10; j++ { - check(length, curEpoch+j, last, VersionWhatever) + check(length, curEpoch+j, last) } // But large numbers should reverse wrap. for j := uint64(0); j < 10; j++ { num := epoch - 1 - j - check(length, prevEpoch+num, last, VersionWhatever) + check(length, prevEpoch+num, last) } } }) @@ -105,13 +105,13 @@ var _ = Describe("packet number calculation", func() { // Small numbers should wrap. for j := uint64(0); j < 10; j++ { - check(length, nextEpoch+j, last, VersionWhatever) + check(length, nextEpoch+j, last) } // but large numbers should not (even if they're out of order). for j := uint64(0); j < 10; j++ { num := epoch - 1 - j - check(length, curEpoch+num, last, VersionWhatever) + check(length, curEpoch+num, last) } } }) @@ -128,13 +128,13 @@ var _ = Describe("packet number calculation", func() { // Small numbers should not wrap, because they have nowhere to go. for j := uint64(0); j < 10; j++ { - check(length, maxEpoch+j, last, VersionWhatever) + check(length, maxEpoch+j, last) } // Large numbers should not wrap either. for j := uint64(0); j < 10; j++ { num := epoch - 1 - j - check(length, maxEpoch+num, last, VersionWhatever) + check(length, maxEpoch+num, last) } } }) @@ -142,17 +142,17 @@ var _ = Describe("packet number calculation", func() { Context("shortening a packet number for the header", func() { Context("shortening", func() { It("sends out low packet numbers as 2 byte", func() { - length := GetPacketNumberLengthForHeader(4, 2, VersionWhatever) + length := GetPacketNumberLengthForHeader(4, 2) Expect(length).To(Equal(PacketNumberLen2)) }) It("sends out high packet numbers as 2 byte, if all ACKs are received", func() { - length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, VersionWhatever) + length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1) Expect(length).To(Equal(PacketNumberLen2)) }) It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() { - length := GetPacketNumberLengthForHeader(40000, 2, VersionWhatever) + length := GetPacketNumberLengthForHeader(40000, 2) Expect(length).To(Equal(PacketNumberLen4)) }) }) @@ -162,10 +162,10 @@ var _ = Describe("packet number calculation", func() { for i := uint64(1); i < 10000; i++ { packetNumber := PacketNumber(i) leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever) + inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) Expect(inferedPacketNumber).To(Equal(packetNumber)) } }) @@ -174,28 +174,28 @@ var _ = Describe("packet number calculation", func() { for i := uint64(1); i < 10000; i++ { packetNumber := PacketNumber(i) leastUnacked := PacketNumber(i / 2) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever) - epochMask := getEpoch(length, VersionWhatever) - 1 + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + epochMask := getEpoch(length) - 1 wirePacketNumber := uint64(packetNumber) & epochMask - inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever) + inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) Expect(inferedPacketNumber).To(Equal(packetNumber)) } }) It("also works for larger packet numbers", func() { var increment uint64 - for i := uint64(1); i < getEpoch(PacketNumberLen4, VersionWhatever); i += increment { + for i := uint64(1); i < getEpoch(PacketNumberLen4); i += increment { packetNumber := PacketNumber(i) leastUnacked := PacketNumber(1) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever) - epochMask := getEpoch(length, VersionWhatever) - 1 + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) + epochMask := getEpoch(length) - 1 wirePacketNumber := uint64(packetNumber) & epochMask - inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever) + inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) Expect(inferedPacketNumber).To(Equal(packetNumber)) - increment = getEpoch(length, VersionWhatever) / 8 + increment = getEpoch(length) / 8 } }) @@ -203,10 +203,10 @@ var _ = Describe("packet number calculation", func() { for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) { packetNumber := PacketNumber(i) leastUnacked := PacketNumber(i - 1000) - length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever) + length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked) wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8) - inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever) + inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber)) Expect(inferedPacketNumber).To(Equal(packetNumber)) } }) diff --git a/session.go b/session.go index c3b1892d..6eb682ba 100644 --- a/session.go +++ b/session.go @@ -289,7 +289,7 @@ var newClientSession = func( func (s *session) preSetup() { s.rttStats = &congestion.RTTStats{} - s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version) + s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.InitialMaxData, @@ -512,7 +512,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { hdr.PacketNumberLen, s.largestRcvdPacketNumber, hdr.PacketNumber, - s.version, ) packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)