diff --git a/internal/ackhandler/ackhandler.go b/internal/ackhandler/ackhandler.go index 989d5ed0..8a58e5db 100644 --- a/internal/ackhandler/ackhandler.go +++ b/internal/ackhandler/ackhandler.go @@ -17,6 +17,6 @@ func NewAckHandler( logger utils.Logger, version protocol.VersionNumber, ) (SentPacketHandler, ReceivedPacketHandler) { - return newSentPacketHandler(initialPacketNumber, rttStats, pers, traceCallback, qlogger, logger), - newReceivedPacketHandler(rttStats, logger, version) + sph := newSentPacketHandler(initialPacketNumber, rttStats, pers, traceCallback, qlogger, logger) + return sph, newReceivedPacketHandler(sph, rttStats, logger, version) } diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 4b302239..9fcead66 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -42,7 +42,6 @@ type SentPacketHandler interface { ShouldSendNumPackets() int // only to be called once the handshake is complete - GetLowestPacketNotConfirmedAcked() protocol.PacketNumber QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) @@ -55,10 +54,13 @@ type SentPacketHandler interface { GetStats() *quictrace.TransportState } +type sentPacketTracker interface { + GetLowestPacketNotConfirmedAcked() protocol.PacketNumber +} + // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error - IgnoreBelow(protocol.PacketNumber) DropPackets(protocol.EncryptionLevel) GetAlarmTimeout() time.Time diff --git a/internal/ackhandler/mock_sent_packet_tracker_test.go b/internal/ackhandler/mock_sent_packet_tracker_test.go new file mode 100644 index 00000000..9ababef4 --- /dev/null +++ b/internal/ackhandler/mock_sent_packet_tracker_test.go @@ -0,0 +1,49 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/ackhandler (interfaces: SentPacketTracker) + +// Package ackhandler is a generated GoMock package. +package ackhandler + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockSentPacketTracker is a mock of SentPacketTracker interface +type MockSentPacketTracker struct { + ctrl *gomock.Controller + recorder *MockSentPacketTrackerMockRecorder +} + +// MockSentPacketTrackerMockRecorder is the mock recorder for MockSentPacketTracker +type MockSentPacketTrackerMockRecorder struct { + mock *MockSentPacketTracker +} + +// NewMockSentPacketTracker creates a new mock instance +func NewMockSentPacketTracker(ctrl *gomock.Controller) *MockSentPacketTracker { + mock := &MockSentPacketTracker{ctrl: ctrl} + mock.recorder = &MockSentPacketTrackerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSentPacketTracker) EXPECT() *MockSentPacketTrackerMockRecorder { + return m.recorder +} + +// GetLowestPacketNotConfirmedAcked mocks base method +func (m *MockSentPacketTracker) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLowestPacketNotConfirmedAcked") + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 +} + +// GetLowestPacketNotConfirmedAcked indicates an expected call of GetLowestPacketNotConfirmedAcked +func (mr *MockSentPacketTrackerMockRecorder) GetLowestPacketNotConfirmedAcked() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked)) +} diff --git a/internal/ackhandler/mockgen.go b/internal/ackhandler/mockgen.go new file mode 100644 index 00000000..e957d253 --- /dev/null +++ b/internal/ackhandler/mockgen.go @@ -0,0 +1,3 @@ +package ackhandler + +//go:generate sh -c "../../mockgen_private.sh ackhandler mock_sent_packet_tracker_test.go github.com/lucas-clemente/quic-go/internal/ackhandler sentPacketTracker" diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index ffda04f3..367d5355 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -32,6 +32,8 @@ const ( ) type receivedPacketHandler struct { + sentPackets sentPacketTracker + initialPackets *receivedPacketTracker handshakePackets *receivedPacketTracker appDataPackets *receivedPacketTracker @@ -42,11 +44,13 @@ type receivedPacketHandler struct { var _ ReceivedPacketHandler = &receivedPacketHandler{} func newReceivedPacketHandler( + sentPackets sentPacketTracker, rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber, ) ReceivedPacketHandler { return &receivedPacketHandler{ + sentPackets: sentPackets, initialPackets: newReceivedPacketTracker(rttStats, logger, version), handshakePackets: newReceivedPacketTracker(rttStats, logger, version), appDataPackets: newReceivedPacketTracker(rttStats, logger, version), @@ -74,6 +78,7 @@ func (h *receivedPacketHandler) ReceivedPacket( if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { h.lowest1RTTPacket = pn } + h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) h.appDataPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) default: panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel)) @@ -81,11 +86,6 @@ func (h *receivedPacketHandler) ReceivedPacket( return nil } -// only to be used with 1-RTT packets -func (h *receivedPacketHandler) IgnoreBelow(pn protocol.PacketNumber) { - h.appDataPackets.IgnoreBelow(pn) -} - func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { switch encLevel { case protocol.EncryptionInitial: diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index 8927fd90..985c6ce6 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -14,9 +14,12 @@ import ( var _ = Describe("Received Packet Handler", func() { var handler ReceivedPacketHandler + var sentPackets *MockSentPacketTracker BeforeEach(func() { + sentPackets = NewMockSentPacketTracker(mockCtrl) handler = newReceivedPacketHandler( + sentPackets, &congestion.RTTStats{}, utils.DefaultLogger, protocol.VersionWhatever, @@ -24,6 +27,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("generates ACKs for different packet number spaces", func() { + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) @@ -49,6 +53,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) @@ -59,6 +64,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("rejects 0-RTT packets with higher packet numbers than 1-RTT packets", func() { + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now() Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(11, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) @@ -66,6 +72,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("allows reordered 0-RTT packets", func() { + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now() Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(12, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) @@ -83,6 +90,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("drops Handshake packets", func() { + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) @@ -92,7 +100,26 @@ var _ = Describe("Received Packet Handler", func() { Expect(handler.GetAckFrame(protocol.Encryption1RTT)).ToNot(BeNil()) }) - It("does nothing when droping 0-RTT packets", func() { + It("does nothing when dropping 0-RTT packets", func() { handler.DropPackets(protocol.Encryption0RTT) }) + + It("drops old ACK ranges", func() { + sendTime := time.Now() + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2) + Expect(handler.ReceivedPacket(1, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + ack := handler.GetAckFrame(protocol.Encryption1RTT) + Expect(ack).ToNot(BeNil()) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1))) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2))) + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked() + Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Return(protocol.PacketNumber(2)) + Expect(handler.ReceivedPacket(4, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) + ack = handler.GetAckFrame(protocol.Encryption1RTT) + Expect(ack).ToNot(BeNil()) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(2))) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4))) + }) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 5a9653d0..0f151110 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -80,6 +80,9 @@ type sentPacketHandler struct { logger utils.Logger } +var _ SentPacketHandler = &sentPacketHandler{} +var _ sentPacketTracker = &sentPacketHandler{} + func newSentPacketHandler( initialPacketNumber protocol.PacketNumber, rttStats *congestion.RTTStats, diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index 640b725a..45026e3d 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -76,18 +76,6 @@ func (mr *MockReceivedPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAlarmTimeout)) } -// IgnoreBelow mocks base method -func (m *MockReceivedPacketHandler) IgnoreBelow(arg0 protocol.PacketNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "IgnoreBelow", arg0) -} - -// IgnoreBelow indicates an expected call of IgnoreBelow -func (mr *MockReceivedPacketHandlerMockRecorder) IgnoreBelow(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IgnoreBelow", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IgnoreBelow), arg0) -} - // ReceivedPacket mocks base method func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel, arg2 time.Time, arg3 bool) error { m.ctrl.T.Helper() diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 0f56d230..07e57747 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -64,20 +64,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetLossDetectionTimeout() *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLossDetectionTimeout)) } -// GetLowestPacketNotConfirmedAcked mocks base method -func (m *MockSentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLowestPacketNotConfirmedAcked") - ret0, _ := ret[0].(protocol.PacketNumber) - return ret0 -} - -// GetLowestPacketNotConfirmedAcked indicates an expected call of GetLowestPacketNotConfirmedAcked -func (mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked)) -} - // GetStats mocks base method func (m *MockSentPacketHandler) GetStats() *quictrace.TransportState { m.ctrl.T.Helper() diff --git a/session.go b/session.go index 5ab6e7fa..b811ffe9 100644 --- a/session.go +++ b/session.go @@ -1060,7 +1060,6 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt return err } if encLevel == protocol.Encryption1RTT { - s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } return nil diff --git a/session_test.go b/session_test.go index d6ca84dd..463e12b9 100644 --- a/session_test.go +++ b/session_test.go @@ -154,19 +154,6 @@ var _ = Describe("Session", func() { err := sess.handleAckFrame(f, protocol.EncryptionHandshake) Expect(err).ToNot(HaveOccurred()) }) - - It("tells the ReceivedPacketHandler to ignore low ranges", func() { - cryptoSetup.EXPECT().SetLargest1RTTAcked(protocol.PacketNumber(3)) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 3}}} - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().ReceivedAck(gomock.Any(), gomock.Any(), gomock.Any()) - sph.EXPECT().GetLowestPacketNotConfirmedAcked().Return(protocol.PacketNumber(0x42)) - sess.sentPacketHandler = sph - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().IgnoreBelow(protocol.PacketNumber(0x42)) - sess.receivedPacketHandler = rph - Expect(sess.handleAckFrame(ack, protocol.Encryption1RTT)).To(Succeed()) - }) }) Context("handling RESET_STREAM frames", func() {