diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index c66e5008..d010c2b3 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -42,7 +42,7 @@ type SentPacketHandler interface { // only to be called once the handshake is complete GetLowestPacketNotConfirmedAcked() protocol.PacketNumber - QueueProbePacket() bool /* was a packet queued */ + QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber diff --git a/internal/ackhandler/send_mode.go b/internal/ackhandler/send_mode.go index 360ce2e3..3d5fe560 100644 --- a/internal/ackhandler/send_mode.go +++ b/internal/ackhandler/send_mode.go @@ -10,8 +10,12 @@ const ( SendNone SendMode = iota // SendAck means an ACK-only packet should be sent SendAck - // SendPTO means that a probe packet should be sent - SendPTO + // SendPTOInitial means that an Initial probe packet should be sent + SendPTOInitial + // SendPTOHandshake means that a Handshake probe packet should be sent + SendPTOHandshake + // SendPTOAppData means that an Application data probe packet should be sent + SendPTOAppData // SendAny means that any packet should be sent SendAny ) @@ -22,8 +26,12 @@ func (s SendMode) String() string { return "none" case SendAck: return "ack" - case SendPTO: - return "pto" + case SendPTOInitial: + return "pto (Initial)" + case SendPTOHandshake: + return "pto (Handshake)" + case SendPTOAppData: + return "pto (Application Data)" case SendAny: return "any" default: diff --git a/internal/ackhandler/send_mode_test.go b/internal/ackhandler/send_mode_test.go index 0b846b85..86515d74 100644 --- a/internal/ackhandler/send_mode_test.go +++ b/internal/ackhandler/send_mode_test.go @@ -10,7 +10,9 @@ var _ = Describe("Send Mode", func() { Expect(SendNone.String()).To(Equal("none")) Expect(SendAny.String()).To(Equal("any")) Expect(SendAck.String()).To(Equal("ack")) - Expect(SendPTO.String()).To(Equal("pto")) + Expect(SendPTOInitial.String()).To(Equal("pto (Initial)")) + Expect(SendPTOHandshake.String()).To(Equal("pto (Handshake)")) + Expect(SendPTOAppData.String()).To(Equal("pto (Application Data)")) Expect(SendMode(123).String()).To(Equal("invalid send mode: 123")) }) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 8dead2a9..ff35e330 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -25,7 +25,9 @@ type packetNumberSpace struct { history *sentPacketHistory pns *packetNumberGenerator - lossTime time.Time + lossTime time.Time + lastSentAckElicitingPacketTime time.Time + largestAcked protocol.PacketNumber largestSent protocol.PacketNumber } @@ -40,9 +42,6 @@ func newPacketNumberSpace(initialPN protocol.PacketNumber) *packetNumberSpace { } type sentPacketHandler struct { - lastSentAckElicitingPacketTime time.Time // only applies to the application-data packet number space - lastSentCryptoPacketTime time.Time - nextSendTime time.Time initialPackets *packetNumberSpace @@ -62,6 +61,7 @@ type sentPacketHandler struct { // The number of times a PTO has been sent without receiving an ack. ptoCount uint32 + ptoMode SendMode // The number of PTO probe packets that should be sent. // Only applies to the application-data packet number space. numProbesToSend int @@ -153,10 +153,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-elicit isAckEliciting := len(packet.Frames) > 0 if isAckEliciting { - if packet.EncryptionLevel != protocol.Encryption1RTT { - h.lastSentCryptoPacketTime = packet.SendTime - } - h.lastSentAckElicitingPacketTime = packet.SendTime + pnSpace.lastSentAckElicitingPacketTime = packet.SendTime packet.includedInBytesInFlight = true h.bytesInFlight += packet.Length if h.numProbesToSend > 0 { @@ -281,7 +278,7 @@ func (h *sentPacketHandler) determineNewlyAckedPackets( return ackedPackets, err } -func (h *sentPacketHandler) getEarliestLossTime() (time.Time, protocol.EncryptionLevel) { +func (h *sentPacketHandler) getEarliestLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) { var encLevel protocol.EncryptionLevel var lossTime time.Time @@ -300,6 +297,26 @@ func (h *sentPacketHandler) getEarliestLossTime() (time.Time, protocol.Encryptio return lossTime, encLevel } +// same logic as getEarliestLossTimeAndSpace, but for lastSentAckElicitingPacketTime instead of lossTime +func (h *sentPacketHandler) getEarliestSentTimeAndSpace() (time.Time, protocol.EncryptionLevel) { + var encLevel protocol.EncryptionLevel + var sentTime time.Time + + if h.initialPackets != nil { + sentTime = h.initialPackets.lastSentAckElicitingPacketTime + encLevel = protocol.EncryptionInitial + } + if h.handshakePackets != nil && (sentTime.IsZero() || (!h.handshakePackets.lastSentAckElicitingPacketTime.IsZero() && h.handshakePackets.lastSentAckElicitingPacketTime.Before(sentTime))) { + sentTime = h.handshakePackets.lastSentAckElicitingPacketTime + encLevel = protocol.EncryptionHandshake + } + if sentTime.IsZero() || (!h.oneRTTPackets.lastSentAckElicitingPacketTime.IsZero() && h.oneRTTPackets.lastSentAckElicitingPacketTime.Before(sentTime)) { + sentTime = h.oneRTTPackets.lastSentAckElicitingPacketTime + encLevel = protocol.Encryption1RTT + } + return sentTime, encLevel +} + func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { var hasInitial, hasHandshake bool if h.initialPackets != nil { @@ -316,7 +333,7 @@ func (h *sentPacketHandler) hasOutstandingPackets() bool { } func (h *sentPacketHandler) setLossDetectionTimer() { - if lossTime, _ := h.getEarliestLossTime(); !lossTime.IsZero() { + if lossTime, _ := h.getEarliestLossTimeAndSpace(); !lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = lossTime } @@ -329,7 +346,8 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } // PTO alarm - h.alarm = h.lastSentAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) + sentTime, encLevel := h.getEarliestSentTimeAndSpace() + h.alarm = sentTime.Add(h.rttStats.PTO(encLevel == protocol.Encryption1RTT) << h.ptoCount) } func (h *sentPacketHandler) detectLostPackets( @@ -415,10 +433,10 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { } func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { - lossTime, encLevel := h.getEarliestLossTime() - if !lossTime.IsZero() { + earliestLossTime, encLevel := h.getEarliestLossTimeAndSpace() + if !earliestLossTime.IsZero() { if h.logger.Debug() { - h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", lossTime) + h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) } // Early retransmit or time loss detection return h.detectLostPackets(time.Now(), encLevel, h.bytesInFlight) @@ -426,10 +444,21 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { // PTO if h.logger.Debug() { - h.logger.Debugf("Loss detection alarm fired in PTO mode. PTO count: %d", h.ptoCount) + h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) } + _, encLevel = h.getEarliestSentTimeAndSpace() h.ptoCount++ h.numProbesToSend += 2 + switch encLevel { + case protocol.EncryptionInitial: + h.ptoMode = SendPTOInitial + case protocol.EncryptionHandshake: + h.ptoMode = SendPTOHandshake + case protocol.Encryption1RTT: + h.ptoMode = SendPTOAppData + default: + return fmt.Errorf("TPO timer in unexpected encryption level: %s", encLevel) + } return nil } @@ -492,7 +521,7 @@ func (h *sentPacketHandler) SendMode() SendMode { return SendNone } if h.numProbesToSend > 0 { - return SendPTO + return h.ptoMode } // Only send ACKs if we're congestion limited. if !h.congestion.CanSend(h.bytesInFlight) { @@ -526,17 +555,9 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int { return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) } -func (h *sentPacketHandler) QueueProbePacket() bool { - var p *Packet - if h.initialPackets != nil { - p = h.initialPackets.history.FirstOutstanding() - } - if p == nil && h.handshakePackets != nil { - p = h.handshakePackets.history.FirstOutstanding() - } - if p == nil { - p = h.oneRTTPackets.history.FirstOutstanding() - } +func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool { + pnSpace := h.getPacketNumberSpace(encLevel) + p := pnSpace.history.FirstOutstanding() if p == nil { return false } @@ -546,8 +567,8 @@ func (h *sentPacketHandler) QueueProbePacket() bool { if p.includedInBytesInFlight { h.bytesInFlight -= p.Length } - if err := h.getPacketNumberSpace(p.EncryptionLevel).history.Remove(p.PacketNumber); err != nil { - // should never happen. We just got this packet from the history a lines above. + if err := pnSpace.history.Remove(p.PacketNumber); err != nil { + // should never happen. We just got this packet from the history. panic(err) } return true diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 0859a0a7..adb3ec75 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -103,20 +103,20 @@ var _ = Describe("SentPacketHandler", func() { It("stores the sent time", func() { sendTime := time.Now().Add(-time.Minute) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) - Expect(handler.lastSentAckElicitingPacketTime).To(Equal(sendTime)) + Expect(handler.oneRTTPackets.lastSentAckElicitingPacketTime).To(Equal(sendTime)) }) - It("stores the sent time of crypto packets", func() { + It("stores the sent time of Initial packets", func() { sendTime := time.Now().Add(-time.Minute) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime, EncryptionLevel: protocol.EncryptionInitial})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: sendTime.Add(time.Hour), EncryptionLevel: protocol.Encryption1RTT})) - Expect(handler.lastSentCryptoPacketTime).To(Equal(sendTime)) + Expect(handler.initialPackets.lastSentAckElicitingPacketTime).To(Equal(sendTime)) }) It("does not store non-ack-eliciting packets", func() { - handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.Encryption1RTT})) + handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: 1})) Expect(handler.oneRTTPackets.history.Len()).To(BeZero()) - Expect(handler.lastSentAckElicitingPacketTime).To(BeZero()) + Expect(handler.oneRTTPackets.lastSentAckElicitingPacketTime).To(BeZero()) Expect(handler.bytesInFlight).To(BeZero()) }) }) @@ -508,11 +508,12 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode()).To(Equal(SendAck)) }) - It("allows RTOs, even when congestion limited", func() { + It("allows PTOs, even when congestion limited", func() { // note that we don't EXPECT a call to GetCongestionWindow // that means retransmissions are sent without considering the congestion window handler.numProbesToSend = 1 - Expect(handler.SendMode()).To(Equal(SendPTO)) + handler.ptoMode = SendPTOHandshake + Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) }) It("gets the pacing delay", func() { @@ -565,13 +566,13 @@ var _ = Describe("SentPacketHandler", func() { It("queues a probe packet", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) - queued := handler.QueueProbePacket() + queued := handler.QueueProbePacket(protocol.Encryption1RTT) Expect(queued).To(BeTrue()) Expect(lostPackets).To(Equal([]protocol.PacketNumber{10})) }) It("says when it can't queue a probe packet", func() { - queued := handler.QueueProbePacket() + queued := handler.QueueProbePacket(protocol.Encryption1RTT) Expect(queued).To(BeFalse()) }) @@ -588,7 +589,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(4 * timeout)) }) - It("sets the PTO send mode until two packets is sent", func() { + It("allows two 1-RTT PTOs", func() { var lostPackets []protocol.PacketNumber handler.SentPacket(ackElicitingPacket(&Packet{ PacketNumber: 1, @@ -598,27 +599,27 @@ var _ = Describe("SentPacketHandler", func() { }, })) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) Expect(handler.ShouldSendNumPackets()).To(Equal(2)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) - Expect(handler.SendMode()).ToNot(Equal(SendPTO)) + Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) }) It("only counts ack-eliciting packets as probe packets", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) Expect(handler.ShouldSendNumPackets()).To(Equal(2)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) for p := protocol.PacketNumber(3); p < 30; p++ { handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: p})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) } handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 30})) - Expect(handler.SendMode()).ToNot(Equal(SendPTO)) + Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) }) It("gets two probe packets if RTO expires", func() { @@ -630,22 +631,22 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP Expect(handler.ptoCount).To(BeEquivalentTo(1)) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4})) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // PTO Expect(handler.ptoCount).To(BeEquivalentTo(2)) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 6})) Expect(handler.SendMode()).To(Equal(SendAny)) }) - It("gets two probe packets if PTO expires, for crypto packets", func() { + It("gets two probe packets if PTO expires, for Handshake packets", func() { handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 1})) handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 2})) @@ -653,9 +654,9 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 3})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 3})) Expect(handler.SendMode()).To(Equal(SendAny)) @@ -665,7 +666,7 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendAny)) diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 7c032987..3879ee7f 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -136,17 +136,17 @@ func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) * } // QueueProbePacket mocks base method -func (m *MockSentPacketHandler) QueueProbePacket() bool { +func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueueProbePacket") + ret := m.ctrl.Call(m, "QueueProbePacket", arg0) ret0, _ := ret[0].(bool) return ret0 } // QueueProbePacket indicates an expected call of QueueProbePacket -func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) } // ReceivedAck mocks base method diff --git a/session.go b/session.go index 1c9040f3..a0d5407d 100644 --- a/session.go +++ b/session.go @@ -1146,8 +1146,18 @@ sendLoop: // There will only be a new ACK after receiving new packets. // SendAck is only returned when we're congestion limited, so we don't need to set the pacingt timer. return s.maybeSendAckOnlyPacket() - case ackhandler.SendPTO: - if err := s.sendProbePacket(); err != nil { + case ackhandler.SendPTOInitial: + if err := s.sendProbePacket(protocol.EncryptionInitial); err != nil { + return err + } + numPacketsSent++ + case ackhandler.SendPTOHandshake: + if err := s.sendProbePacket(protocol.EncryptionHandshake); err != nil { + return err + } + numPacketsSent++ + case ackhandler.SendPTOAppData: + if err := s.sendProbePacket(protocol.Encryption1RTT); err != nil { return err } numPacketsSent++ @@ -1189,24 +1199,46 @@ func (s *session) maybeSendAckOnlyPacket() error { return nil } -func (s *session) sendProbePacket() error { - // Queue probe packets until we actually send out a packet. +func (s *session) sendProbePacket(encLevel protocol.EncryptionLevel) error { + // Queue probe packets until we actually send out a packet, + // or until there are no more packets to queue. + var packet *packedPacket for { - if wasQueued := s.sentPacketHandler.QueueProbePacket(); !wasQueued { + if wasQueued := s.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued { break } - sent, err := s.sendPacket() + var err error + packet, err = s.packer.MaybePackProbePacket(encLevel) if err != nil { return err } - if sent { - return nil + if packet != nil { + break } } - // If there is nothing else to queue, make sure we send out something. - s.framer.QueueControlFrame(&wire.PingFrame{}) - _, err := s.sendPacket() - return err + if packet == nil { + switch encLevel { + case protocol.EncryptionInitial: + s.retransmissionQueue.AddInitial(&wire.PingFrame{}) + case protocol.EncryptionHandshake: + s.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + case protocol.Encryption1RTT: + s.retransmissionQueue.AddAppData(&wire.PingFrame{}) + default: + panic("unexpected encryption level") + } + var err error + packet, err = s.packer.MaybePackProbePacket(encLevel) + if err != nil { + return err + } + } + if packet == nil { + return fmt.Errorf("session BUG: couldn't pack %s probe packet", encLevel) + } + s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(s.retransmissionQueue)) + s.sendPackedPacket(packet) + return nil } func (s *session) sendPacket() (bool, error) { diff --git a/session_test.go b/session_test.go index 8aad69fc..a06a849d 100644 --- a/session_test.go +++ b/session_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/tls" "errors" + "fmt" "net" "runtime/pprof" "strings" @@ -924,40 +925,6 @@ var _ = Describe("Session", func() { Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.DataBlockedFrame{DataLimit: 1337}}})) }) - It("sends a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend() - sph.EXPECT().SendMode().Return(ackhandler.SendPTO) - sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().QueueProbePacket() - packer.EXPECT().PackPacket().Return(getPacket(123), nil) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) - }) - sess.sentPacketHandler = sph - Expect(sess.sendPackets()).To(Succeed()) - }) - - It("sends a PING as a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend() - sph.EXPECT().SendMode().Return(ackhandler.SendPTO) - sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().QueueProbePacket().Return(false) - packer.EXPECT().PackPacket().Return(getPacket(123), nil) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) - }) - sess.sentPacketHandler = sph - Expect(sess.sendPackets()).To(Succeed()) - // We're using a mock packet packer in this test. - // We therefore need to test separately that the PING was actually queued. - frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) - Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}})) - }) - It("doesn't send when the SentPacketHandler doesn't allow it", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() @@ -966,6 +933,62 @@ var _ = Describe("Session", func() { err := sess.sendPackets() Expect(err).ToNot(HaveOccurred()) }) + + for _, enc := range []protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption1RTT} { + encLevel := enc + + Context(fmt.Sprintf("sending %s probe packets", encLevel), func() { + var sendMode ackhandler.SendMode + var getFrame func(protocol.ByteCount) wire.Frame + + BeforeEach(func() { + switch encLevel { + case protocol.EncryptionInitial: + sendMode = ackhandler.SendPTOInitial + getFrame = sess.retransmissionQueue.GetInitialFrame + case protocol.EncryptionHandshake: + sendMode = ackhandler.SendPTOHandshake + getFrame = sess.retransmissionQueue.GetHandshakeFrame + case protocol.Encryption1RTT: + sendMode = ackhandler.SendPTOAppData + getFrame = sess.retransmissionQueue.GetAppDataFrame + } + }) + + It("sends a probe packet", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend() + sph.EXPECT().SendMode().Return(sendMode) + sph.EXPECT().ShouldSendNumPackets().Return(1) + sph.EXPECT().QueueProbePacket(encLevel) + packer.EXPECT().MaybePackProbePacket(encLevel).Return(getPacket(123), nil) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) + }) + sess.sentPacketHandler = sph + Expect(sess.sendPackets()).To(Succeed()) + }) + + It("sends a PING as a probe packet", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend() + sph.EXPECT().SendMode().Return(sendMode) + sph.EXPECT().ShouldSendNumPackets().Return(1) + sph.EXPECT().QueueProbePacket(encLevel).Return(false) + packer.EXPECT().MaybePackProbePacket(encLevel).Return(getPacket(123), nil) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) + }) + sess.sentPacketHandler = sph + Expect(sess.sendPackets()).To(Succeed()) + // We're using a mock packet packer in this test. + // We therefore need to test separately that the PING was actually queued. + Expect(getFrame(1000)).To(BeAssignableToTypeOf(&wire.PingFrame{})) + }) + }) + } }) Context("packet pacing", func() {