From 29f98a296c61f3a6e71ec4ccc69a9fdd7bd9c0a7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 17 Jan 2025 23:40:30 -0800 Subject: [PATCH] ackhandler: avoid calling time.Now() when setting loss detection timer (#4885) --- connection.go | 48 +++---- internal/ackhandler/interfaces.go | 10 +- .../mock_sent_packet_tracker_test.go | 13 +- .../ackhandler/received_packet_handler.go | 2 +- .../received_packet_handler_test.go | 23 ++- internal/ackhandler/sent_packet_handler.go | 46 +++--- .../ackhandler/sent_packet_handler_test.go | 133 +++++++++--------- .../mocks/ackhandler/sent_packet_handler.go | 48 +++---- 8 files changed, 162 insertions(+), 161 deletions(-) diff --git a/connection.go b/connection.go index 4c25fc88..6bc9169b 100644 --- a/connection.go +++ b/connection.go @@ -518,7 +518,7 @@ func (s *connection) run() error { if err := s.cryptoStreamHandler.StartHandshake(s.ctx); err != nil { return err } - if err := s.handleHandshakeEvents(); err != nil { + if err := s.handleHandshakeEvents(time.Now()); err != nil { return err } go func() { @@ -619,7 +619,7 @@ runLoop: if timeout := s.sentPacketHandler.GetLossDetectionTimeout(); !timeout.IsZero() && timeout.Before(now) { // This could cause packets to be retransmitted. // Check it before trying to send packets. - if err := s.sentPacketHandler.OnLossDetectionTimeout(); err != nil { + if err := s.sentPacketHandler.OnLossDetectionTimeout(now); err != nil { s.closeLocal(err) } } @@ -744,7 +744,7 @@ func (s *connection) idleTimeoutStartTime() time.Time { return startTime } -func (s *connection) handleHandshakeComplete() error { +func (s *connection) handleHandshakeComplete(now time.Time) error { defer close(s.handshakeCompleteChan) // Once the handshake completes, we have derived 1-RTT keys. // There's no point in queueing undecryptable packets for later decryption anymore. @@ -765,7 +765,7 @@ func (s *connection) handleHandshakeComplete() error { } // All these only apply to the server side. - if err := s.handleHandshakeConfirmed(); err != nil { + if err := s.handleHandshakeConfirmed(now); err != nil { return err } @@ -788,13 +788,13 @@ func (s *connection) handleHandshakeComplete() error { return nil } -func (s *connection) handleHandshakeConfirmed() error { - if err := s.dropEncryptionLevel(protocol.EncryptionHandshake); err != nil { +func (s *connection) handleHandshakeConfirmed(now time.Time) error { + if err := s.dropEncryptionLevel(protocol.EncryptionHandshake, now); err != nil { return err } s.handshakeConfirmed = true - s.sentPacketHandler.SetHandshakeConfirmed() + s.sentPacketHandler.SetHandshakeConfirmed(now) s.cryptoStreamHandler.SetHandshakeConfirmed() if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF { @@ -804,7 +804,7 @@ func (s *connection) handleHandshakeConfirmed() error { } func (s *connection) handlePacketImpl(rp receivedPacket) bool { - s.sentPacketHandler.ReceivedBytes(rp.Size()) + s.sentPacketHandler.ReceivedBytes(rp.Size(), rp.rcvTime) if wire.IsVersionNegotiationPacket(rp.data) { s.handleVersionNegotiationPacket(rp) @@ -1211,7 +1211,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( !s.droppedInitialKeys { // On the server side, Initial keys are dropped as soon as the first Handshake packet is received. // See Section 4.9.1 of RFC 9001. - if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { + if err := s.dropEncryptionLevel(protocol.EncryptionInitial, rcvTime); err != nil { return err } } @@ -1308,7 +1308,7 @@ func (s *connection) handleFrames( // We receive a Handshake packet that contains the CRYPTO frame that allows us to complete the handshake, // and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame. if !handshakeWasComplete && s.handshakeComplete { - if err := s.handleHandshakeComplete(); err != nil { + if err := s.handleHandshakeComplete(rcvTime); err != nil { return false, err } } @@ -1326,11 +1326,11 @@ func (s *connection) handleFrame( wire.LogFrame(s.logger, f, false) switch frame := f.(type) { case *wire.CryptoFrame: - err = s.handleCryptoFrame(frame, encLevel) + err = s.handleCryptoFrame(frame, encLevel, rcvTime) case *wire.StreamFrame: err = s.handleStreamFrame(frame, rcvTime) case *wire.AckFrame: - err = s.handleAckFrame(frame, encLevel) + err = s.handleAckFrame(frame, encLevel, rcvTime) case *wire.ConnectionCloseFrame: s.handleConnectionCloseFrame(frame) case *wire.ResetStreamFrame: @@ -1363,7 +1363,7 @@ func (s *connection) handleFrame( case *wire.RetireConnectionIDFrame: err = s.handleRetireConnectionIDFrame(frame, destConnID) case *wire.HandshakeDoneFrame: - err = s.handleHandshakeDoneFrame() + err = s.handleHandshakeDoneFrame(rcvTime) case *wire.DatagramFrame: err = s.handleDatagramFrame(frame) default: @@ -1402,7 +1402,7 @@ func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame }) } -func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { +func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { if err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil { return err } @@ -1415,10 +1415,10 @@ func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protoco return err } } - return s.handleHandshakeEvents() + return s.handleHandshakeEvents(rcvTime) } -func (s *connection) handleHandshakeEvents() error { +func (s *connection) handleHandshakeEvents(now time.Time) error { for { ev := s.cryptoStreamHandler.NextEvent() var err error @@ -1439,7 +1439,7 @@ func (s *connection) handleHandshakeEvents() error { s.undecryptablePacketsToProcess = s.undecryptablePackets s.undecryptablePackets = nil case handshake.EventDiscard0RTTKeys: - err = s.dropEncryptionLevel(protocol.Encryption0RTT) + err = s.dropEncryptionLevel(protocol.Encryption0RTT, now) case handshake.EventWriteInitialData: _, err = s.initialStream.Write(ev.Data) case handshake.EventWriteHandshakeData: @@ -1540,7 +1540,7 @@ func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFra return s.connIDGenerator.Retire(f.SequenceNumber, destConnID) } -func (s *connection) handleHandshakeDoneFrame() error { +func (s *connection) handleHandshakeDoneFrame(rcvTime time.Time) error { if s.perspective == protocol.PerspectiveServer { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, @@ -1548,12 +1548,12 @@ func (s *connection) handleHandshakeDoneFrame() error { } } if !s.handshakeConfirmed { - return s.handleHandshakeConfirmed() + return s.handleHandshakeConfirmed(rcvTime) } return nil } -func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { +func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { acked1RTTPacket, err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime) if err != nil { return err @@ -1565,7 +1565,7 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr // This is only possible if the ACK was sent in a 1-RTT packet. // This is an optimization over simply waiting for a HANDSHAKE_DONE frame, see section 4.1.2 of RFC 9001. if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed { - if err := s.handleHandshakeConfirmed(); err != nil { + if err := s.handleHandshakeConfirmed(rcvTime); err != nil { return err } } @@ -1697,11 +1697,11 @@ func (s *connection) handleCloseError(closeErr *closeError) { s.connIDGenerator.ReplaceWithClosed(connClosePacket) } -func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { +func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel, now time.Time) error { if s.tracer != nil && s.tracer.DroppedEncryptionLevel != nil { s.tracer.DroppedEncryptionLevel(encLevel) } - s.sentPacketHandler.DropPackets(encLevel) + s.sentPacketHandler.DropPackets(encLevel, now) s.receivedPacketHandler.DropPackets(encLevel) //nolint:exhaustive // only Initial and 0-RTT need special treatment switch encLevel { @@ -2140,7 +2140,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot !s.droppedInitialKeys { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // See Section 4.9.1 of RFC 9001. - if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { + if err := s.dropEncryptionLevel(protocol.EncryptionInitial, now); err != nil { return err } } diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 24d69603..a99c00ea 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -14,10 +14,10 @@ type SentPacketHandler interface { // ReceivedAck processes an ACK frame. // It does not store a copy of the frame. ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error) - ReceivedBytes(protocol.ByteCount) - DropPackets(protocol.EncryptionLevel) + ReceivedBytes(_ protocol.ByteCount, rcvTime time.Time) + DropPackets(_ protocol.EncryptionLevel, rcvTime time.Time) ResetForRetry(rcvTime time.Time) - SetHandshakeConfirmed() + SetHandshakeConfirmed(now time.Time) // The SendMode determines if and what kind of packets can be sent. SendMode(now time.Time) SendMode @@ -34,12 +34,12 @@ type SentPacketHandler interface { PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber GetLossDetectionTimeout() time.Time - OnLossDetectionTimeout() error + OnLossDetectionTimeout(now time.Time) error } type sentPacketTracker interface { GetLowestPacketNotConfirmedAcked() protocol.PacketNumber - ReceivedPacket(protocol.EncryptionLevel) + ReceivedPacket(_ protocol.EncryptionLevel, rcvTime time.Time) } // ReceivedPacketHandler handles ACKs needed to send for incoming packets diff --git a/internal/ackhandler/mock_sent_packet_tracker_test.go b/internal/ackhandler/mock_sent_packet_tracker_test.go index dd4a91c3..644046c5 100644 --- a/internal/ackhandler/mock_sent_packet_tracker_test.go +++ b/internal/ackhandler/mock_sent_packet_tracker_test.go @@ -11,6 +11,7 @@ package ackhandler import ( reflect "reflect" + time "time" protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" @@ -79,15 +80,15 @@ func (c *MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall) DoAndReturn( } // ReceivedPacket mocks base method. -func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) { +func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel, rcvTime time.Time) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedPacket", arg0) + m.ctrl.Call(m, "ReceivedPacket", arg0, rcvTime) } // ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 any) *MockSentPacketTrackerReceivedPacketCall { +func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0, rcvTime any) *MockSentPacketTrackerReceivedPacketCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0, rcvTime) return &MockSentPacketTrackerReceivedPacketCall{Call: call} } @@ -103,13 +104,13 @@ func (c *MockSentPacketTrackerReceivedPacketCall) Return() *MockSentPacketTracke } // Do rewrite *gomock.Call.Do -func (c *MockSentPacketTrackerReceivedPacketCall) Do(f func(protocol.EncryptionLevel)) *MockSentPacketTrackerReceivedPacketCall { +func (c *MockSentPacketTrackerReceivedPacketCall) Do(f func(protocol.EncryptionLevel, time.Time)) *MockSentPacketTrackerReceivedPacketCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSentPacketTrackerReceivedPacketCall) DoAndReturn(f func(protocol.EncryptionLevel)) *MockSentPacketTrackerReceivedPacketCall { +func (c *MockSentPacketTrackerReceivedPacketCall) DoAndReturn(f func(protocol.EncryptionLevel, time.Time)) *MockSentPacketTrackerReceivedPacketCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index 1175c790..e487f12a 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -38,7 +38,7 @@ func (h *receivedPacketHandler) ReceivedPacket( rcvTime time.Time, ackEliciting bool, ) error { - h.sentPackets.ReceivedPacket(encLevel) + h.sentPackets.ReceivedPacket(encLevel, rcvTime) switch encLevel { case protocol.EncryptionInitial: return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index 9f35b28b..e66f72bd 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -17,12 +17,11 @@ func TestGenerateACKsForPacketNumberSpaces(t *testing.T) { sentPackets := NewMockSentPacketTracker(ctrl) handler := newReceivedPacketHandler(sentPackets, utils.DefaultLogger) - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() - sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionInitial).Times(2) - sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionHandshake).Times(2) - sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT).Times(2) - sendTime := time.Now().Add(-time.Second) + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionInitial, sendTime).Times(2) + sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionHandshake, sendTime).Times(2) + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT, sendTime).Times(2) require.NoError(t, handler.ReceivedPacket(2, protocol.ECT0, protocol.EncryptionInitial, sendTime, true)) require.NoError(t, handler.ReceivedPacket(1, protocol.ECT1, protocol.EncryptionHandshake, sendTime, true)) @@ -64,11 +63,11 @@ func TestReceive0RTTAnd1RTT(t *testing.T) { sentPackets := NewMockSentPacketTracker(mockCtrl) handler := newReceivedPacketHandler(sentPackets, utils.DefaultLogger) - sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() - sentPackets.EXPECT().ReceivedPacket(protocol.Encryption0RTT).AnyTimes() - sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT) - sendTime := time.Now().Add(-time.Second) + sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption0RTT, sendTime).AnyTimes() + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT, sendTime) + require.NoError(t, handler.ReceivedPacket(2, protocol.ECNNon, protocol.Encryption0RTT, sendTime, true)) require.NoError(t, handler.ReceivedPacket(3, protocol.ECNNon, protocol.Encryption1RTT, sendTime, true)) @@ -85,7 +84,7 @@ func TestReceive0RTTAnd1RTT(t *testing.T) { func TestDropPackets(t *testing.T) { mockCtrl := gomock.NewController(t) sentPackets := NewMockSentPacketTracker(mockCtrl) - sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() + sentPackets.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any()).AnyTimes() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() handler := newReceivedPacketHandler(sentPackets, utils.DefaultLogger) @@ -115,7 +114,7 @@ func TestDropPackets(t *testing.T) { func TestAckRangePruning(t *testing.T) { mockCtrl := gomock.NewController(t) sentPackets := NewMockSentPacketTracker(mockCtrl) - sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() + sentPackets.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any()).AnyTimes() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(3) handler := newReceivedPacketHandler(sentPackets, utils.DefaultLogger) @@ -139,7 +138,7 @@ func TestAckRangePruning(t *testing.T) { func TestPacketDuplicateDetection(t *testing.T) { mockCtrl := gomock.NewController(t) sentPackets := NewMockSentPacketTracker(mockCtrl) - sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() + sentPackets.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any()).AnyTimes() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() handler := newReceivedPacketHandler(sentPackets, utils.DefaultLogger) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index d528e6c0..7f64cf8f 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -155,7 +155,7 @@ func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { } } -func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { +func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now time.Time) { // The server won't await address validation after the handshake is confirmed. // This applies even if we didn't receive an ACK for a Handshake packet. if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake { @@ -202,21 +202,21 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { h.ptoCount = 0 h.numProbesToSend = 0 h.ptoMode = SendNone - h.setLossDetectionTimer() + h.setLossDetectionTimer(now) } -func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) { +func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount, t time.Time) { wasAmplificationLimit := h.isAmplificationLimited() h.bytesReceived += n if wasAmplificationLimit && !h.isAmplificationLimited() { - h.setLossDetectionTimer() + h.setLossDetectionTimer(t) } } -func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) { +func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel, t time.Time) { if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated { h.peerAddressValidated = true - h.setLossDetectionTimer() + h.setLossDetectionTimer(t) } } @@ -269,7 +269,7 @@ func (h *sentPacketHandler) SentPacket( if !isAckEliciting { pnSpace.history.SentNonAckElicitingPacket(pn) if !h.peerCompletedAddressValidation { - h.setLossDetectionTimer() + h.setLossDetectionTimer(t) } return } @@ -289,7 +289,7 @@ func (h *sentPacketHandler) SentPacket( if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } - h.setLossDetectionTimer() + h.setLossDetectionTimer(t) } func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace { @@ -322,7 +322,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.peerCompletedAddressValidation = true h.logger.Debugf("Peer doesn't await address validation any longer.") // Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets. - h.setLossDetectionTimer() + h.setLossDetectionTimer(rcvTime) } priorInFlight := h.bytesInFlight @@ -387,7 +387,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } - h.setLossDetectionTimer() + h.setLossDetectionTimer(rcvTime) return acked1RTTPacket, nil } @@ -498,14 +498,14 @@ func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration } // same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime -func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { +func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { // We only send application data probe packets once the handshake is confirmed, // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() { if h.peerCompletedAddressValidation { return } - t := time.Now().Add(h.getScaledPTO(false)) + t := now.Add(h.getScaledPTO(false)) if h.initialPackets != nil { return t, protocol.EncryptionInitial, true } @@ -549,7 +549,7 @@ func (h *sentPacketHandler) hasOutstandingPackets() bool { return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets() } -func (h *sentPacketHandler) setLossDetectionTimer() { +func (h *sentPacketHandler) setLossDetectionTimer(now time.Time) { oldAlarm := h.alarm // only needed in case tracing is enabled lossTime, encLevel := h.getLossTimeAndSpace() if !lossTime.IsZero() { @@ -586,7 +586,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } // PTO alarm - ptoTime, encLevel, ok := h.getPTOTimeAndSpace() + ptoTime, encLevel, ok := h.getPTOTimeAndSpace(now) if !ok { if !oldAlarm.IsZero() { h.alarm = time.Time{} @@ -669,8 +669,8 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E }) } -func (h *sentPacketHandler) OnLossDetectionTimeout() error { - defer h.setLossDetectionTimer() +func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error { + defer h.setLossDetectionTimer(now) earliestLossTime, encLevel := h.getLossTimeAndSpace() if !earliestLossTime.IsZero() { if h.logger.Debug() { @@ -680,13 +680,13 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection - return h.detectLostPackets(time.Now(), encLevel) + return h.detectLostPackets(now, encLevel) } // PTO - // When all outstanding are acknowledged, the alarm is canceled in - // setLossDetectionTimer. This doesn't reset the timer in the session though. - // When OnAlarm is called, we therefore need to make sure that there are + // When all outstanding are acknowledged, the alarm is canceled in setLossDetectionTimer. + // However, there's no way to reset the timer in the connection. + // When OnLossDetectionTimeout is called, we therefore need to make sure that there are // actually packets outstanding. if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation { h.ptoCount++ @@ -701,7 +701,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { return nil } - _, encLevel, ok := h.getPTOTimeAndSpace() + _, encLevel, ok := h.getPTOTimeAndSpace(now) if !ok { return nil } @@ -913,7 +913,7 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) { h.ptoCount = 0 } -func (h *sentPacketHandler) SetHandshakeConfirmed() { +func (h *sentPacketHandler) SetHandshakeConfirmed(now time.Time) { if h.initialPackets != nil { panic("didn't drop initial correctly") } @@ -923,5 +923,5 @@ func (h *sentPacketHandler) SetHandshakeConfirmed() { h.handshakeConfirmed = true // We don't send PTOs for application data packets before the handshake completes. // Make sure the timer is armed now, if necessary. - h.setLossDetectionTimer() + h.setLossDetectionTimer(now) } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index a215c039..ce08c994 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -131,9 +131,9 @@ var _ = Describe("SentPacketHandler", func() { // setHandshakeConfirmed drops both Initial and Handshake packets and then confirms the handshake setHandshakeConfirmed := func() { - handler.DropPackets(protocol.EncryptionInitial) - handler.DropPackets(protocol.EncryptionHandshake) - handler.SetHandshakeConfirmed() + handler.DropPackets(protocol.EncryptionInitial, time.Now()) + handler.DropPackets(protocol.EncryptionHandshake, time.Now()) + handler.SetHandshakeConfirmed(time.Now()) } Context("registering sent packets", func() { @@ -629,7 +629,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("passes the bytes in flight to the congestion controller", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(42), gomock.Any(), protocol.ByteCount(42), true) sentPacket(&packet{ Length: 42, @@ -643,7 +643,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows sending of ACKs when congestion limited", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) cong.EXPECT().CanSend(gomock.Any()).Return(true) cong.EXPECT().HasPacingBudget(gomock.Any()).Return(true) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) @@ -652,7 +652,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows sending of ACKs when we're keeping track of MaxOutstandingSentPackets packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes() cong.EXPECT().HasPacingBudget(gomock.Any()).Return(true).AnyTimes() cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() @@ -664,7 +664,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows PTOs, even when congestion limited", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) // note that we don't EXPECT a call to GetCongestionWindow // that means retransmissions are sent without considering the congestion window handler.numProbesToSend = 1 @@ -680,7 +680,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("doesn't set an alarm if there are no outstanding packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) sentPacket(ackElicitingPacket(&packet{PacketNumber: 10})) sentPacket(ackElicitingPacket(&packet{PacketNumber: 11})) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}} @@ -690,8 +690,8 @@ var _ = Describe("SentPacketHandler", func() { }) It("does nothing on OnAlarm if there are no outstanding packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) }) @@ -712,35 +712,36 @@ var _ = Describe("SentPacketHandler", func() { It("implements exponential backoff", func() { handler.peerAddressValidated = true setHandshakeConfirmed() - sendTime := time.Now().Add(-time.Hour) + now := time.Now() + sendTime := now.Add(-time.Hour) sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: sendTime})) timeout := handler.GetLossDetectionTimeout().Sub(sendTime) Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(timeout)) handler.ptoCount = 1 - handler.setLossDetectionTimer() + handler.setLossDetectionTimer(now) Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(2 * timeout)) handler.ptoCount = 2 - handler.setLossDetectionTimer() + handler.setLossDetectionTimer(now) Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(4 * timeout)) // truncated when the exponential gets too large handler.ptoCount = 20 - handler.setLossDetectionTimer() + handler.setLossDetectionTimer(now) Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(maxPTODuration)) // protected from rollover handler.ptoCount = 100 - handler.setLossDetectionTimer() + handler.setLossDetectionTimer(now) Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(maxPTODuration)) }) It("reset the PTO count when receiving an ACK", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) now := time.Now() setHandshakeConfirmed() sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) sentPacket(ackElicitingPacket(&packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) handler.appDataPackets.pns.(*skippingPacketNumberGenerator).next = 3 Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) Expect(handler.ptoCount).To(BeEquivalentTo(1)) _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.Encryption1RTT, time.Now()) @@ -749,7 +750,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("resets the PTO mode and PTO count when a packet number space is dropped", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) now := time.Now() handler.rttStats.UpdateRTT(time.Second/2, 0) @@ -772,12 +773,12 @@ var _ = Describe("SentPacketHandler", func() { })) // PTO timer based on the Handshake packet - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.ptoCount).To(BeEquivalentTo(1)) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOHandshake)) Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeHandshake.Add(handler.rttStats.PTO(false) << 1))) setHandshakeConfirmed() - handler.DropPackets(protocol.EncryptionHandshake) + handler.DropPackets(protocol.EncryptionHandshake, time.Now()) // PTO timer based on the 1-RTT packet Expect(handler.GetLossDetectionTimeout()).To(Equal(sendTimeAppData.Add(handler.rttStats.PTO(true)))) // no backoff. PTO count = 0 Expect(handler.SendMode(time.Now())).ToNot(Equal(SendPTOHandshake)) @@ -785,7 +786,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows two 1-RTT PTOs", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) setHandshakeConfirmed() var lostPackets []protocol.PacketNumber sentPacket(ackElicitingPacket(&packet{ @@ -798,7 +799,7 @@ var _ = Describe("SentPacketHandler", func() { }, }, })) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) @@ -807,13 +808,13 @@ var _ = Describe("SentPacketHandler", func() { }) It("only counts ack-eliciting packets as probe packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) setHandshakeConfirmed() sentPacket(ackElicitingPacket(&packet{ PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT), SendTime: time.Now().Add(-time.Hour), })) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) @@ -826,7 +827,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("gets two probe packets if PTO expires", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) setHandshakeConfirmed() sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) @@ -834,14 +835,14 @@ var _ = Describe("SentPacketHandler", func() { updateRTT(time.Hour) Expect(handler.appDataPackets.lossTime.IsZero()).To(BeTrue()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) // TLP Expect(handler.ptoCount).To(BeEquivalentTo(1)) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // PTO + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) // PTO Expect(handler.ptoCount).To(BeEquivalentTo(2)) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) @@ -852,14 +853,14 @@ var _ = Describe("SentPacketHandler", func() { }) It("gets two probe packets if PTO expires, for Handshake packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) sentPacket(initialPacket(&packet{PacketNumber: 1})) sentPacket(initialPacket(&packet{PacketNumber: 2})) updateRTT(time.Hour) Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOInitial)) sentPacket(initialPacket(&packet{PacketNumber: 3})) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOInitial)) @@ -869,25 +870,25 @@ var _ = Describe("SentPacketHandler", func() { }) It("doesn't send 1-RTT probe packets before the handshake completes", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) sentPacket(ackElicitingPacket(&packet{PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT)})) updateRTT(time.Hour) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) setHandshakeConfirmed() Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) }) It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) setHandshakeConfirmed() pn := handler.PopPacketNumber(protocol.Encryption1RTT) sentPacket(ackElicitingPacket(&packet{PacketNumber: pn, SendTime: time.Now().Add(-time.Hour)})) updateRTT(time.Second) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOAppData)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: pn, Largest: pn}}} _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) @@ -896,18 +897,18 @@ var _ = Describe("SentPacketHandler", func() { }) It("handles ACKs for the original packet", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) sentPacket(ackElicitingPacket(&packet{ PacketNumber: handler.PopPacketNumber(protocol.Encryption1RTT), SendTime: time.Now().Add(-time.Hour), })) updateRTT(time.Second) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) }) It("doesn't set the PTO timer for Path MTU probe packets", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) setHandshakeConfirmed() updateRTT(time.Second) sentPacket(ackElicitingPacket(&packet{PacketNumber: 5, SendTime: time.Now(), IsPathMTUProbePacket: true})) @@ -918,8 +919,8 @@ var _ = Describe("SentPacketHandler", func() { Context("amplification limit, for the server", func() { It("limits the window to 3x the bytes received, to avoid amplification attacks", func() { now := time.Now() - handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address - handler.ReceivedBytes(200) + handler.ReceivedPacket(protocol.EncryptionInitial, now) // receiving an Initial packet doesn't validate the client's address + handler.ReceivedBytes(200, now) sentPacket(&packet{ PacketNumber: 1, Length: 599, @@ -939,7 +940,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("cancels the loss detection timer when it is amplification limited, and resets it when becoming unblocked", func() { - handler.ReceivedBytes(300) + handler.ReceivedBytes(300, time.Now()) sentPacket(&packet{ PacketNumber: 1, Length: 900, @@ -950,12 +951,12 @@ var _ = Describe("SentPacketHandler", func() { // Amplification limited. We don't need to set a timer now. Expect(handler.GetLossDetectionTimeout()).To(BeZero()) // Unblock the server. Now we should fire up the timer. - handler.ReceivedBytes(1) + handler.ReceivedBytes(1, time.Now()) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) }) It("resets the loss detection timer when the client's address is validated", func() { - handler.ReceivedBytes(300) + handler.ReceivedBytes(300, time.Now()) sentPacket(&packet{ PacketNumber: 1, Length: 900, @@ -965,13 +966,13 @@ var _ = Describe("SentPacketHandler", func() { }) // Amplification limited. We don't need to set a timer now. Expect(handler.GetLossDetectionTimeout()).To(BeZero()) - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) }) It("cancels the loss detection alarm when all Handshake packets are acknowledged", func() { t := time.Now().Add(-time.Second) - handler.ReceivedBytes(99999) + handler.ReceivedBytes(99999, time.Now()) sentPacket(ackElicitingPacket(&packet{PacketNumber: 2, SendTime: t})) sentPacket(handshakePacket(&packet{PacketNumber: 3, SendTime: t})) sentPacket(handshakePacket(&packet{PacketNumber: 4, SendTime: t})) @@ -988,7 +989,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("do not limits the window", func() { - handler.ReceivedBytes(0) + handler.ReceivedBytes(0, time.Now()) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) sentPacket(&packet{ PacketNumber: 1, @@ -1017,7 +1018,7 @@ var _ = Describe("SentPacketHandler", func() { // No packets are outstanding at this point. // Make sure that a probe packet is sent. Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOInitial)) // send a single packet to unblock the server @@ -1027,7 +1028,7 @@ var _ = Describe("SentPacketHandler", func() { // Now receive an ACK for a Handshake packet. // This tells the client that the server completed address validation. sentPacket(handshakePacket(&packet{PacketNumber: 1})) - handler.DropPackets(protocol.EncryptionInitial) // sending a Handshake packet drops the Initial packet number space + handler.DropPackets(protocol.EncryptionInitial, time.Now()) // sending a Handshake packet drops the Initial packet number space _, err = handler.ReceivedAck( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionHandshake, @@ -1048,9 +1049,9 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) sentPacket(handshakePacketNonAckEliciting(&packet{PacketNumber: 1})) - handler.DropPackets(protocol.EncryptionInitial) // sending a Handshake packet drops the Initial packet number space + handler.DropPackets(protocol.EncryptionInitial, time.Now()) // sending a Handshake packet drops the Initial packet number space Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOHandshake)) // Now receive an ACK for this packet, and send another one. @@ -1067,10 +1068,10 @@ var _ = Describe("SentPacketHandler", func() { It("doesn't send a packet to unblock the server after handshake confirmation, even if no Handshake ACK was received", func() { sentPacket(handshakePacket(&packet{PacketNumber: 1})) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOHandshake)) // confirm the handshake - handler.DropPackets(protocol.EncryptionHandshake) + handler.DropPackets(protocol.EncryptionHandshake, time.Now()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) }) @@ -1083,7 +1084,7 @@ var _ = Describe("SentPacketHandler", func() { ) Expect(err).ToNot(HaveOccurred()) sentPacket(handshakePacketNonAckEliciting(&packet{PacketNumber: 1, SendTime: time.Now()})) - handler.DropPackets(protocol.EncryptionInitial) // sending a Handshake packet drops the Initial packet number space + handler.DropPackets(protocol.EncryptionInitial, time.Now()) // sending a Handshake packet drops the Initial packet number space pto := handler.rttStats.PTO(false) Expect(pto).ToNot(BeZero()) @@ -1097,7 +1098,7 @@ var _ = Describe("SentPacketHandler", func() { sentPacket(initialPacket(&packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) sentPacket(initialPacket(&packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOInitial)) Expect(handler.ptoCount).To(BeEquivalentTo(1)) _, err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, time.Now()) @@ -1136,7 +1137,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("sets the early retransmit alarm", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) handler.handshakeConfirmed = true now := time.Now() sentPacket(ackElicitingPacket(&packet{PacketNumber: 1, SendTime: now.Add(-2 * time.Second)})) @@ -1154,13 +1155,13 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) expectInPacketHistory([]protocol.PacketNumber{1, 3}, protocol.Encryption1RTT) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) expectInPacketHistory([]protocol.PacketNumber{3}, protocol.Encryption1RTT) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) }) It("sets the early retransmit alarm for crypto packets", func() { - handler.ReceivedBytes(1000) + handler.ReceivedBytes(1000, time.Now()) now := time.Now() sentPacket(initialPacket(&packet{PacketNumber: 1, SendTime: now.Add(-2 * time.Second)})) sentPacket(initialPacket(&packet{PacketNumber: 2, SendTime: now.Add(-2 * time.Second)})) @@ -1177,7 +1178,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) expectInPacketHistory([]protocol.PacketNumber{1, 3}, protocol.EncryptionInitial) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) expectInPacketHistory([]protocol.PacketNumber{3}, protocol.EncryptionInitial) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) }) @@ -1202,7 +1203,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(err).ToNot(HaveOccurred()) Expect(mtuPacketDeclaredLost).To(BeFalse()) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(mtuPacketDeclaredLost).To(BeTrue()) Expect(handler.GetLossDetectionTimeout()).To(BeZero()) }) @@ -1236,7 +1237,7 @@ var _ = Describe("SentPacketHandler", func() { })) } Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) - handler.DropPackets(protocol.EncryptionInitial) + handler.DropPackets(protocol.EncryptionInitial, time.Now()) Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) Expect(handler.initialPackets).To(BeNil()) @@ -1257,7 +1258,7 @@ var _ = Describe("SentPacketHandler", func() { })) } Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16))) - handler.DropPackets(protocol.EncryptionHandshake) + handler.DropPackets(protocol.EncryptionHandshake, time.Now()) Expect(lostPackets).To(BeEmpty()) // frames must not be queued for retransmission Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) Expect(handler.handshakePackets).To(BeNil()) @@ -1278,21 +1279,21 @@ var _ = Describe("SentPacketHandler", func() { sentPacket(ackElicitingPacket(&packet{PacketNumber: i})) } Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11))) - handler.DropPackets(protocol.Encryption0RTT) + handler.DropPackets(protocol.Encryption0RTT, time.Now()) Expect(lostPackets).To(BeEmpty()) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(6))) }) It("cancels the PTO when dropping a packet number space", func() { - handler.ReceivedPacket(protocol.EncryptionHandshake) + handler.ReceivedPacket(protocol.EncryptionHandshake, time.Now()) now := time.Now() sentPacket(handshakePacket(&packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) sentPacket(handshakePacket(&packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", now.Add(-time.Minute), time.Second)) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOHandshake)) Expect(handler.ptoCount).To(BeEquivalentTo(1)) - handler.DropPackets(protocol.EncryptionHandshake) + handler.DropPackets(protocol.EncryptionHandshake, time.Now()) Expect(handler.ptoCount).To(BeZero()) Expect(handler.SendMode(time.Now())).To(Equal(SendAny)) }) @@ -1415,7 +1416,7 @@ var _ = Describe("SentPacketHandler", func() { EncryptionLevel: protocol.EncryptionInitial, SendTime: now, })) - Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.OnLossDetectionTimeout(time.Now())).To(Succeed()) Expect(handler.SendMode(time.Now())).To(Equal(SendPTOInitial)) sentPacket(ackElicitingPacket(&packet{ PacketNumber: 43, diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index b6118712..b2134b16 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -44,15 +44,15 @@ func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { } // DropPackets mocks base method. -func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { +func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel, rcvTime time.Time) { m.ctrl.T.Helper() - m.ctrl.Call(m, "DropPackets", arg0) + m.ctrl.Call(m, "DropPackets", arg0, rcvTime) } // DropPackets indicates an expected call of DropPackets. -func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 any) *MockSentPacketHandlerDropPacketsCall { +func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0, rcvTime any) *MockSentPacketHandlerDropPacketsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0, rcvTime) return &MockSentPacketHandlerDropPacketsCall{Call: call} } @@ -68,13 +68,13 @@ func (c *MockSentPacketHandlerDropPacketsCall) Return() *MockSentPacketHandlerDr } // Do rewrite *gomock.Call.Do -func (c *MockSentPacketHandlerDropPacketsCall) Do(f func(protocol.EncryptionLevel)) *MockSentPacketHandlerDropPacketsCall { +func (c *MockSentPacketHandlerDropPacketsCall) Do(f func(protocol.EncryptionLevel, time.Time)) *MockSentPacketHandlerDropPacketsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSentPacketHandlerDropPacketsCall) DoAndReturn(f func(protocol.EncryptionLevel)) *MockSentPacketHandlerDropPacketsCall { +func (c *MockSentPacketHandlerDropPacketsCall) DoAndReturn(f func(protocol.EncryptionLevel, time.Time)) *MockSentPacketHandlerDropPacketsCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -156,17 +156,17 @@ func (c *MockSentPacketHandlerGetLossDetectionTimeoutCall) DoAndReturn(f func() } // OnLossDetectionTimeout mocks base method. -func (m *MockSentPacketHandler) OnLossDetectionTimeout() error { +func (m *MockSentPacketHandler) OnLossDetectionTimeout(now time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnLossDetectionTimeout") + ret := m.ctrl.Call(m, "OnLossDetectionTimeout", now) ret0, _ := ret[0].(error) return ret0 } // OnLossDetectionTimeout indicates an expected call of OnLossDetectionTimeout. -func (mr *MockSentPacketHandlerMockRecorder) OnLossDetectionTimeout() *MockSentPacketHandlerOnLossDetectionTimeoutCall { +func (mr *MockSentPacketHandlerMockRecorder) OnLossDetectionTimeout(now any) *MockSentPacketHandlerOnLossDetectionTimeoutCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).OnLossDetectionTimeout)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).OnLossDetectionTimeout), now) return &MockSentPacketHandlerOnLossDetectionTimeoutCall{Call: call} } @@ -182,13 +182,13 @@ func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) Return(arg0 error) *Mo } // Do rewrite *gomock.Call.Do -func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) Do(f func() error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { +func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) Do(f func(time.Time) error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) DoAndReturn(f func() error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { +func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) DoAndReturn(f func(time.Time) error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -348,15 +348,15 @@ func (c *MockSentPacketHandlerReceivedAckCall) DoAndReturn(f func(*wire.AckFrame } // ReceivedBytes mocks base method. -func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { +func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount, rcvTime time.Time) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedBytes", arg0) + m.ctrl.Call(m, "ReceivedBytes", arg0, rcvTime) } // ReceivedBytes indicates an expected call of ReceivedBytes. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 any) *MockSentPacketHandlerReceivedBytesCall { +func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0, rcvTime any) *MockSentPacketHandlerReceivedBytesCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0, rcvTime) return &MockSentPacketHandlerReceivedBytesCall{Call: call} } @@ -372,13 +372,13 @@ func (c *MockSentPacketHandlerReceivedBytesCall) Return() *MockSentPacketHandler } // Do rewrite *gomock.Call.Do -func (c *MockSentPacketHandlerReceivedBytesCall) Do(f func(protocol.ByteCount)) *MockSentPacketHandlerReceivedBytesCall { +func (c *MockSentPacketHandlerReceivedBytesCall) Do(f func(protocol.ByteCount, time.Time)) *MockSentPacketHandlerReceivedBytesCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSentPacketHandlerReceivedBytesCall) DoAndReturn(f func(protocol.ByteCount)) *MockSentPacketHandlerReceivedBytesCall { +func (c *MockSentPacketHandlerReceivedBytesCall) DoAndReturn(f func(protocol.ByteCount, time.Time)) *MockSentPacketHandlerReceivedBytesCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -494,15 +494,15 @@ func (c *MockSentPacketHandlerSentPacketCall) DoAndReturn(f func(time.Time, prot } // SetHandshakeConfirmed mocks base method. -func (m *MockSentPacketHandler) SetHandshakeConfirmed() { +func (m *MockSentPacketHandler) SetHandshakeConfirmed(now time.Time) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHandshakeConfirmed") + m.ctrl.Call(m, "SetHandshakeConfirmed", now) } // SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. -func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *MockSentPacketHandlerSetHandshakeConfirmedCall { +func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed(now any) *MockSentPacketHandlerSetHandshakeConfirmedCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed), now) return &MockSentPacketHandlerSetHandshakeConfirmedCall{Call: call} } @@ -518,13 +518,13 @@ func (c *MockSentPacketHandlerSetHandshakeConfirmedCall) Return() *MockSentPacke } // Do rewrite *gomock.Call.Do -func (c *MockSentPacketHandlerSetHandshakeConfirmedCall) Do(f func()) *MockSentPacketHandlerSetHandshakeConfirmedCall { +func (c *MockSentPacketHandlerSetHandshakeConfirmedCall) Do(f func(time.Time)) *MockSentPacketHandlerSetHandshakeConfirmedCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSentPacketHandlerSetHandshakeConfirmedCall) DoAndReturn(f func()) *MockSentPacketHandlerSetHandshakeConfirmedCall { +func (c *MockSentPacketHandlerSetHandshakeConfirmedCall) DoAndReturn(f func(time.Time)) *MockSentPacketHandlerSetHandshakeConfirmedCall { c.Call = c.Call.DoAndReturn(f) return c }