diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 9ec70216..ec36226d 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -32,7 +32,7 @@ type SentPacketHandler interface { GetStopWaitingFrame(force bool) *wire.StopWaitingFrame GetLowestPacketNotConfirmedAcked() protocol.PacketNumber DequeuePacketForRetransmission() (packet *Packet) - GetLeastUnacked() protocol.PacketNumber + GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen GetAlarmTimeout() time.Time OnAlarm() diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 53399442..12ebbd4b 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -354,8 +354,8 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { return packet } -func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber { - return h.lowestUnacked() +func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen { + return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked()) } func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame { diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 98818c3f..975b2172 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -59,9 +59,10 @@ var _ = Describe("SentPacketHandler", func() { return nil } - It("gets the LeastUnacked packet number", func() { + It("determines the packet number length", func() { handler.largestAcked = 0x1337 - Expect(handler.GetLeastUnacked()).To(Equal(protocol.PacketNumber(0x1337 + 1))) + Expect(handler.GetPacketNumberLen(0x1338)).To(Equal(protocol.PacketNumberLen2)) + Expect(handler.GetPacketNumberLen(0xfffffff)).To(Equal(protocol.PacketNumberLen4)) }) Context("registering sent packets", func() { diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 80bb62c4..2c0b5132 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -61,18 +61,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetAlarmTimeout)) } -// GetLeastUnacked mocks base method -func (m *MockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { - ret := m.ctrl.Call(m, "GetLeastUnacked") - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// GetLeastUnacked indicates an expected call of GetLeastUnacked -func (mr *MockSentPacketHandlerMockRecorder) GetLeastUnacked() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeastUnacked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLeastUnacked)) -} - // GetLowestPacketNotConfirmedAcked mocks base method func (m *MockSentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { ret := m.ctrl.Call(m, "GetLowestPacketNotConfirmedAcked") @@ -85,6 +73,18 @@ func (mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked)) } +// GetPacketNumberLen mocks base method +func (m *MockSentPacketHandler) GetPacketNumberLen(arg0 protocol.PacketNumber) protocol.PacketNumberLen { + ret := m.ctrl.Call(m, "GetPacketNumberLen", arg0) + ret0, _ := ret[0].(protocol.PacketNumberLen) + return ret0 +} + +// GetPacketNumberLen indicates an expected call of GetPacketNumberLen +func (mr *MockSentPacketHandlerMockRecorder) GetPacketNumberLen(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPacketNumberLen", reflect.TypeOf((*MockSentPacketHandler)(nil).GetPacketNumberLen), arg0) +} + // GetStopWaitingFrame mocks base method func (m *MockSentPacketHandler) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame { ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0) diff --git a/packet_packer.go b/packet_packer.go index 6070dd90..535708af 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -32,6 +32,7 @@ type packetPacker struct { cryptoSetup handshake.CryptoSetup packetNumberGenerator *packetNumberGenerator + getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen streams streamFrameSource controlFrameMutex sync.Mutex @@ -39,7 +40,6 @@ type packetPacker struct { stopWaiting *wire.StopWaitingFrame ackFrame *wire.AckFrame - leastUnacked protocol.PacketNumber omitConnectionID bool hasSentPacket bool // has the packetPacker already sent a packet numNonRetransmittableAcks int @@ -47,6 +47,7 @@ type packetPacker struct { func newPacketPacker(connectionID protocol.ConnectionID, initialPacketNumber protocol.PacketNumber, + getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen, cryptoSetup handshake.CryptoSetup, streamFramer streamFrameSource, perspective protocol.Perspective, @@ -58,6 +59,7 @@ func newPacketPacker(connectionID protocol.ConnectionID, perspective: perspective, version: version, streams: streamFramer, + getPacketNumberLen: getPacketNumberLen, packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), } } @@ -402,7 +404,7 @@ func (p *packetPacker) QueueControlFrame(frame wire.Frame) { func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { pnum := p.packetNumberGenerator.Peek() - packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked) + packetNumberLen := p.getPacketNumberLen(pnum) header := &wire.Header{ ConnectionID: p.connectionID, @@ -496,10 +498,6 @@ func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool { return encLevel == protocol.EncryptionForwardSecure } -func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) { - p.leastUnacked = leastUnacked -} - func (p *packetPacker) SetOmitConnectionID() { p.omitConnectionID = true } diff --git a/packet_packer_test.go b/packet_packer_test.go index e620869c..3b2e1fa5 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -68,6 +68,7 @@ var _ = Describe("Packet packer", func() { packer = newPacketPacker( 0x1337, 1, + func(protocol.PacketNumber) protocol.PacketNumberLen { return protocol.PacketNumberLen2 }, &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, mockStreamFramer, protocol.PerspectiveServer, @@ -283,14 +284,13 @@ var _ = Describe("Packet packer", func() { It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) - packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number - packer.packetNumberGenerator.next = packetNumber - swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} + packer.packetNumberGenerator.next = 0x1337 + swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337 - 0x100} packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(swf) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) }) It("does not pack a packet containing only a STOP_WAITING frame", func() { diff --git a/session.go b/session.go index 5c966830..069d0b44 100644 --- a/session.go +++ b/session.go @@ -338,6 +338,7 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker(s.connectionID, initialPacketNumber, + s.sentPacketHandler.GetPacketNumberLen, s.cryptoSetup, s.streamFramer, s.perspective, @@ -763,7 +764,6 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete } func (s *session) sendPackets() error { - s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) s.pacingDeadline = time.Time{} if !s.sentPacketHandler.SendingAllowed() { // if congestion limited, at least try sending an ACK frame return s.maybeSendAckOnlyPacket() @@ -905,7 +905,6 @@ func (s *session) sendPackedPacket(packet *packedPacket) error { } func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { - s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked()) packet, err := s.packer.PackConnectionClose(&wire.ConnectionCloseFrame{ ErrorCode: quicErr.ErrorCode, ReasonPhrase: quicErr.ErrorMessage, diff --git a/session_test.go b/session_test.go index 9e932a64..dffdeff2 100644 --- a/session_test.go +++ b/session_test.go @@ -725,7 +725,7 @@ var _ = Describe("Session", func() { It("doesn't retransmit an Initial packet if it already received a response", func() { sess.unpacker = &mockUnpacker{} sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ PacketNumber: 10, PacketType: protocol.PacketTypeInitial, @@ -749,7 +749,7 @@ var _ = Describe("Session", func() { It("sends a retransmission and a regular packet in the same run", func() { sess.windowUpdateQueue.callback(&wire.MaxDataFrame{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLeastUnacked() + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ PacketNumber: 10, PacketType: protocol.PacketTypeHandshake, @@ -781,7 +781,7 @@ var _ = Describe("Session", func() { BeforeEach(func() { sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() sph.EXPECT().DequeuePacketForRetransmission().AnyTimes() sess.sentPacketHandler = sph sess.packer.hasSentPacket = true @@ -894,7 +894,7 @@ var _ = Describe("Session", func() { It("sends ACK only packets", func() { swf := &wire.StopWaitingFrame{LeastUnacked: 10} sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() sph.EXPECT().GetAlarmTimeout().AnyTimes() sph.EXPECT().SendingAllowed() sph.EXPECT().GetStopWaitingFrame(false).Return(swf) @@ -925,7 +925,7 @@ var _ = Describe("Session", func() { sess.version = versionIETFFrames done := make(chan struct{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() sph.EXPECT().GetAlarmTimeout().AnyTimes() sph.EXPECT().SendingAllowed() sph.EXPECT().TimeUntilSend() @@ -957,7 +957,7 @@ var _ = Describe("Session", func() { sess.packer.packetNumberGenerator.next = 0x1337 + 10 sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() sess.sentPacketHandler = sph sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} }) @@ -1114,7 +1114,7 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendingAllowed().AnyTimes().Return(true) sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(1) - sph.EXPECT().GetLeastUnacked().AnyTimes() + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{LeastUnacked: 10}) sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ PacketNumber: 0x1337, @@ -1145,7 +1145,6 @@ var _ = Describe("Session", func() { sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().GetAlarmTimeout().AnyTimes() sph.EXPECT().SendingAllowed().Return(true).AnyTimes() - sph.EXPECT().GetLeastUnacked().Times(2) sph.EXPECT().DequeuePacketForRetransmission() sph.EXPECT().GetStopWaitingFrame(gomock.Any()) sph.EXPECT().ShouldSendNumPackets().Return(1)