diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index c66e5008..02979faa 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -27,6 +27,7 @@ type SentPacketHandler interface { ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error DropPackets(protocol.EncryptionLevel) ResetForRetry() error + SetHandshakeComplete() // The SendMode determines if and what kind of packets can be sent. SendMode() SendMode @@ -42,7 +43,7 @@ type SentPacketHandler interface { // only to be called once the handshake is complete GetLowestPacketNotConfirmedAcked() protocol.PacketNumber - QueueProbePacket() bool /* was a packet queued */ + QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber diff --git a/internal/ackhandler/send_mode.go b/internal/ackhandler/send_mode.go index 360ce2e3..3d5fe560 100644 --- a/internal/ackhandler/send_mode.go +++ b/internal/ackhandler/send_mode.go @@ -10,8 +10,12 @@ const ( SendNone SendMode = iota // SendAck means an ACK-only packet should be sent SendAck - // SendPTO means that a probe packet should be sent - SendPTO + // SendPTOInitial means that an Initial probe packet should be sent + SendPTOInitial + // SendPTOHandshake means that a Handshake probe packet should be sent + SendPTOHandshake + // SendPTOAppData means that an Application data probe packet should be sent + SendPTOAppData // SendAny means that any packet should be sent SendAny ) @@ -22,8 +26,12 @@ func (s SendMode) String() string { return "none" case SendAck: return "ack" - case SendPTO: - return "pto" + case SendPTOInitial: + return "pto (Initial)" + case SendPTOHandshake: + return "pto (Handshake)" + case SendPTOAppData: + return "pto (Application Data)" case SendAny: return "any" default: diff --git a/internal/ackhandler/send_mode_test.go b/internal/ackhandler/send_mode_test.go index 0b846b85..86515d74 100644 --- a/internal/ackhandler/send_mode_test.go +++ b/internal/ackhandler/send_mode_test.go @@ -10,7 +10,9 @@ var _ = Describe("Send Mode", func() { Expect(SendNone.String()).To(Equal("none")) Expect(SendAny.String()).To(Equal("any")) Expect(SendAck.String()).To(Equal("ack")) - Expect(SendPTO.String()).To(Equal("pto")) + Expect(SendPTOInitial.String()).To(Equal("pto (Initial)")) + Expect(SendPTOHandshake.String()).To(Equal("pto (Handshake)")) + Expect(SendPTOAppData.String()).To(Equal("pto (Application Data)")) Expect(SendMode(123).String()).To(Equal("invalid send mode: 123")) }) }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index d52328fe..3bbee31f 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -25,7 +25,9 @@ type packetNumberSpace struct { history *sentPacketHistory pns *packetNumberGenerator - lossTime time.Time + lossTime time.Time + lastSentAckElicitingPacketTime time.Time + largestAcked protocol.PacketNumber largestSent protocol.PacketNumber } @@ -40,15 +42,14 @@ func newPacketNumberSpace(initialPN protocol.PacketNumber) *packetNumberSpace { } type sentPacketHandler struct { - lastSentAckElicitingPacketTime time.Time // only applies to the application-data packet number space - lastSentCryptoPacketTime time.Time - nextSendTime time.Time initialPackets *packetNumberSpace handshakePackets *packetNumberSpace oneRTTPackets *packetNumberSpace + handshakeComplete bool + // lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived // example: we send an ACK for packets 90-100 with packet number 20 // once we receive an ACK from the peer for packet 20, the lowestNotConfirmedAcked is 101 @@ -62,6 +63,7 @@ type sentPacketHandler struct { // The number of times a PTO has been sent without receiving an ack. ptoCount uint32 + ptoMode SendMode // The number of PTO probe packets that should be sent. // Only applies to the application-data packet number space. numProbesToSend int @@ -153,10 +155,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-elicit isAckEliciting := len(packet.Frames) > 0 if isAckEliciting { - if packet.EncryptionLevel != protocol.Encryption1RTT { - h.lastSentCryptoPacketTime = packet.SendTime - } - h.lastSentAckElicitingPacketTime = packet.SendTime + pnSpace.lastSentAckElicitingPacketTime = packet.SendTime packet.includedInBytesInFlight = true h.bytesInFlight += packet.Length if h.numProbesToSend > 0 { @@ -281,7 +280,7 @@ func (h *sentPacketHandler) determineNewlyAckedPackets( return ackedPackets, err } -func (h *sentPacketHandler) getEarliestLossTime() (time.Time, protocol.EncryptionLevel) { +func (h *sentPacketHandler) getEarliestLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) { var encLevel protocol.EncryptionLevel var lossTime time.Time @@ -293,13 +292,35 @@ func (h *sentPacketHandler) getEarliestLossTime() (time.Time, protocol.Encryptio lossTime = h.handshakePackets.lossTime encLevel = protocol.EncryptionHandshake } - if lossTime.IsZero() || (!h.oneRTTPackets.lossTime.IsZero() && h.oneRTTPackets.lossTime.Before(lossTime)) { + if h.handshakeComplete && + (lossTime.IsZero() || (!h.oneRTTPackets.lossTime.IsZero() && h.oneRTTPackets.lossTime.Before(lossTime))) { lossTime = h.oneRTTPackets.lossTime encLevel = protocol.Encryption1RTT } return lossTime, encLevel } +// same logic as getEarliestLossTimeAndSpace, but for lastSentAckElicitingPacketTime instead of lossTime +func (h *sentPacketHandler) getEarliestSentTimeAndSpace() (time.Time, protocol.EncryptionLevel) { + var encLevel protocol.EncryptionLevel + var sentTime time.Time + + if h.initialPackets != nil { + sentTime = h.initialPackets.lastSentAckElicitingPacketTime + encLevel = protocol.EncryptionInitial + } + if h.handshakePackets != nil && (sentTime.IsZero() || (!h.handshakePackets.lastSentAckElicitingPacketTime.IsZero() && h.handshakePackets.lastSentAckElicitingPacketTime.Before(sentTime))) { + sentTime = h.handshakePackets.lastSentAckElicitingPacketTime + encLevel = protocol.EncryptionHandshake + } + if h.handshakeComplete && + (sentTime.IsZero() || (!h.oneRTTPackets.lastSentAckElicitingPacketTime.IsZero() && h.oneRTTPackets.lastSentAckElicitingPacketTime.Before(sentTime))) { + sentTime = h.oneRTTPackets.lastSentAckElicitingPacketTime + encLevel = protocol.Encryption1RTT + } + return sentTime, encLevel +} + func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { var hasInitial, hasHandshake bool if h.initialPackets != nil { @@ -312,11 +333,14 @@ func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { } func (h *sentPacketHandler) hasOutstandingPackets() bool { - return h.oneRTTPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets() + // We only send application data probe packets once the handshake completes, + // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. + return (h.handshakeComplete && h.oneRTTPackets.history.HasOutstandingPackets()) || + h.hasOutstandingCryptoPackets() } func (h *sentPacketHandler) setLossDetectionTimer() { - if lossTime, _ := h.getEarliestLossTime(); !lossTime.IsZero() { + if lossTime, _ := h.getEarliestLossTimeAndSpace(); !lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = lossTime } @@ -329,7 +353,8 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } // PTO alarm - h.alarm = h.lastSentAckElicitingPacketTime.Add(h.rttStats.PTO() << h.ptoCount) + sentTime, encLevel := h.getEarliestSentTimeAndSpace() + h.alarm = sentTime.Add(h.rttStats.PTO(encLevel == protocol.Encryption1RTT) << h.ptoCount) } func (h *sentPacketHandler) detectLostPackets( @@ -405,7 +430,7 @@ func (h *sentPacketHandler) detectLostPackets( func (h *sentPacketHandler) OnLossDetectionTimeout() error { // When all outstanding are acknowledged, the alarm is canceled in - // updateLossDetectionAlarm. This doesn't reset the timer in the session though. + // setLossDetectionTimer. This doesn't reset the timer in the session though. // When OnAlarm is called, we therefore need to make sure that there are // actually packets outstanding. if h.hasOutstandingPackets() { @@ -418,10 +443,10 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { } func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { - lossTime, encLevel := h.getEarliestLossTime() - if !lossTime.IsZero() { + earliestLossTime, encLevel := h.getEarliestLossTimeAndSpace() + if !earliestLossTime.IsZero() { if h.logger.Debug() { - h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", lossTime) + h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) } // Early retransmit or time loss detection return h.detectLostPackets(time.Now(), encLevel, h.bytesInFlight) @@ -429,10 +454,21 @@ func (h *sentPacketHandler) onVerifiedLossDetectionTimeout() error { // PTO if h.logger.Debug() { - h.logger.Debugf("Loss detection alarm fired in PTO mode. PTO count: %d", h.ptoCount) + h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) } + _, encLevel = h.getEarliestSentTimeAndSpace() h.ptoCount++ h.numProbesToSend += 2 + switch encLevel { + case protocol.EncryptionInitial: + h.ptoMode = SendPTOInitial + case protocol.EncryptionHandshake: + h.ptoMode = SendPTOHandshake + case protocol.Encryption1RTT: + h.ptoMode = SendPTOAppData + default: + return fmt.Errorf("TPO timer in unexpected encryption level: %s", encLevel) + } return nil } @@ -495,7 +531,7 @@ func (h *sentPacketHandler) SendMode() SendMode { return SendNone } if h.numProbesToSend > 0 { - return SendPTO + return h.ptoMode } // Only send ACKs if we're congestion limited. if !h.congestion.CanSend(h.bytesInFlight) { @@ -529,17 +565,9 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int { return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) } -func (h *sentPacketHandler) QueueProbePacket() bool { - var p *Packet - if h.initialPackets != nil { - p = h.initialPackets.history.FirstOutstanding() - } - if p == nil && h.handshakePackets != nil { - p = h.handshakePackets.history.FirstOutstanding() - } - if p == nil { - p = h.oneRTTPackets.history.FirstOutstanding() - } +func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool { + pnSpace := h.getPacketNumberSpace(encLevel) + p := pnSpace.history.FirstOutstanding() if p == nil { return false } @@ -549,8 +577,8 @@ func (h *sentPacketHandler) QueueProbePacket() bool { if p.includedInBytesInFlight { h.bytesInFlight -= p.Length } - if err := h.getPacketNumberSpace(p.EncryptionLevel).history.Remove(p.PacketNumber); err != nil { - // should never happen. We just got this packet from the history a lines above. + if err := pnSpace.history.Remove(p.PacketNumber); err != nil { + // should never happen. We just got this packet from the history. panic(err) } return true @@ -573,6 +601,13 @@ func (h *sentPacketHandler) ResetForRetry() error { return nil } +func (h *sentPacketHandler) SetHandshakeComplete() { + h.handshakeComplete = true + // We don't send PTOs for application data packets before the handshake completes. + // Make sure the timer is armed now, if necessary. + h.setLossDetectionTimer() +} + func (h *sentPacketHandler) GetStats() *quictrace.TransportState { return &quictrace.TransportState{ MinRTT: h.rttStats.MinRTT(), diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 0859a0a7..772f0c14 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -103,20 +103,20 @@ var _ = Describe("SentPacketHandler", func() { It("stores the sent time", func() { sendTime := time.Now().Add(-time.Minute) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) - Expect(handler.lastSentAckElicitingPacketTime).To(Equal(sendTime)) + Expect(handler.oneRTTPackets.lastSentAckElicitingPacketTime).To(Equal(sendTime)) }) - It("stores the sent time of crypto packets", func() { + It("stores the sent time of Initial packets", func() { sendTime := time.Now().Add(-time.Minute) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime, EncryptionLevel: protocol.EncryptionInitial})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2, SendTime: sendTime.Add(time.Hour), EncryptionLevel: protocol.Encryption1RTT})) - Expect(handler.lastSentCryptoPacketTime).To(Equal(sendTime)) + Expect(handler.initialPackets.lastSentAckElicitingPacketTime).To(Equal(sendTime)) }) It("does not store non-ack-eliciting packets", func() { - handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: 1, EncryptionLevel: protocol.Encryption1RTT})) + handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: 1})) Expect(handler.oneRTTPackets.history.Len()).To(BeZero()) - Expect(handler.lastSentAckElicitingPacketTime).To(BeZero()) + Expect(handler.oneRTTPackets.lastSentAckElicitingPacketTime).To(BeZero()) Expect(handler.bytesInFlight).To(BeZero()) }) }) @@ -508,11 +508,12 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode()).To(Equal(SendAck)) }) - It("allows RTOs, even when congestion limited", func() { + It("allows PTOs, even when congestion limited", func() { // note that we don't EXPECT a call to GetCongestionWindow // that means retransmissions are sent without considering the congestion window handler.numProbesToSend = 1 - Expect(handler.SendMode()).To(Equal(SendPTO)) + handler.ptoMode = SendPTOHandshake + Expect(handler.SendMode()).To(Equal(SendPTOHandshake)) }) It("gets the pacing delay", func() { @@ -565,17 +566,18 @@ var _ = Describe("SentPacketHandler", func() { It("queues a probe packet", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) - queued := handler.QueueProbePacket() + queued := handler.QueueProbePacket(protocol.Encryption1RTT) Expect(queued).To(BeTrue()) Expect(lostPackets).To(Equal([]protocol.PacketNumber{10})) }) It("says when it can't queue a probe packet", func() { - queued := handler.QueueProbePacket() + queued := handler.QueueProbePacket(protocol.Encryption1RTT) Expect(queued).To(BeFalse()) }) It("implements exponential backoff", func() { + handler.SetHandshakeComplete() sendTime := time.Now().Add(-time.Hour) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: sendTime})) timeout := handler.GetLossDetectionTimeout().Sub(sendTime) @@ -588,7 +590,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(4 * timeout)) }) - It("sets the PTO send mode until two packets is sent", func() { + It("allows two 1-RTT PTOs", func() { + handler.SetHandshakeComplete() var lostPackets []protocol.PacketNumber handler.SentPacket(ackElicitingPacket(&Packet{ PacketNumber: 1, @@ -598,30 +601,32 @@ var _ = Describe("SentPacketHandler", func() { }, })) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) Expect(handler.ShouldSendNumPackets()).To(Equal(2)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) - Expect(handler.SendMode()).ToNot(Equal(SendPTO)) + Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) }) It("only counts ack-eliciting packets as probe packets", func() { + handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) Expect(handler.ShouldSendNumPackets()).To(Equal(2)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) for p := protocol.PacketNumber(3); p < 30; p++ { handler.SentPacket(nonAckElicitingPacket(&Packet{PacketNumber: p})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) } handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 30})) - Expect(handler.SendMode()).ToNot(Equal(SendPTO)) + Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) }) It("gets two probe packets if RTO expires", func() { + handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) @@ -630,22 +635,22 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP Expect(handler.ptoCount).To(BeEquivalentTo(1)) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 3})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 4})) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // PTO Expect(handler.ptoCount).To(BeEquivalentTo(2)) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 6})) Expect(handler.SendMode()).To(Equal(SendAny)) }) - It("gets two probe packets if PTO expires, for crypto packets", func() { + It("gets two probe packets if PTO expires, for Handshake packets", func() { handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 1})) handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 2})) @@ -653,19 +658,32 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.initialPackets.lossTime.IsZero()).To(BeTrue()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 3})) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOInitial)) handler.SentPacket(cryptoPacket(&Packet{PacketNumber: 3})) Expect(handler.SendMode()).To(Equal(SendAny)) }) + It("doesn't send 1-RTT probe packets before the handshake completes", func() { + handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) + updateRTT(time.Hour) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP + Expect(handler.GetLossDetectionTimeout()).To(BeZero()) + Expect(handler.SendMode()).To(Equal(SendAny)) + handler.SetHandshakeComplete() + Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) + Expect(handler.OnLossDetectionTimeout()).To(Succeed()) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) + }) + It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { + handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) - Expect(handler.SendMode()).To(Equal(SendPTO)) + Expect(handler.SendMode()).To(Equal(SendPTOAppData)) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} Expect(handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendAny)) diff --git a/internal/congestion/rtt_stats.go b/internal/congestion/rtt_stats.go index 0b17fc10..9ae42706 100644 --- a/internal/congestion/rtt_stats.go +++ b/internal/congestion/rtt_stats.go @@ -46,13 +46,19 @@ func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } // MeanDeviation gets the mean deviation func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } +// MaxAckDelay gets the max_ack_delay advertized by the peer func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay } -func (r *RTTStats) PTO() time.Duration { +// PTO gets the probe timeout duration. +func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { if r.SmoothedRTT() == 0 { return 2 * defaultInitialRTT } - return r.SmoothedRTT() + utils.MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) + r.MaxAckDelay() + pto := r.SmoothedRTT() + utils.MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) + if includeMaxAckDelay { + pto += r.MaxAckDelay() + } + return pto } // UpdateRTT updates the RTT based on a new sample. diff --git a/internal/congestion/rtt_stats_test.go b/internal/congestion/rtt_stats_test.go index c899dde0..fe722281 100644 --- a/internal/congestion/rtt_stats_test.go +++ b/internal/congestion/rtt_stats_test.go @@ -66,13 +66,14 @@ var _ = Describe("RTT stats", func() { rttStats.UpdateRTT(rtt, 0, time.Time{}) Expect(rttStats.SmoothedRTT()).To(Equal(rtt)) Expect(rttStats.MeanDeviation()).To(Equal(rtt / 2)) - Expect(rttStats.PTO()).To(Equal(rtt + 4*(rtt/2) + maxAckDelay)) + Expect(rttStats.PTO(false)).To(Equal(rtt + 4*(rtt/2))) + Expect(rttStats.PTO(true)).To(Equal(rtt + 4*(rtt/2) + maxAckDelay)) }) It("uses the granularity for computing the PTO for short RTTs", func() { rtt := time.Microsecond rttStats.UpdateRTT(rtt, 0, time.Time{}) - Expect(rttStats.PTO()).To(Equal(rtt + protocol.TimerGranularity)) + Expect(rttStats.PTO(true)).To(Equal(rtt + protocol.TimerGranularity)) }) It("ExpireSmoothedMetrics", func() { diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index c7f0a887..a075e609 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -576,7 +576,7 @@ func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { if h.initialSealer == nil { return nil, ErrKeysDropped } - return nil, errors.New("CryptoSetup: no sealer with encryption level Handshake") + return nil, ErrKeysNotYetAvailable } return h.handshakeSealer, nil } @@ -586,7 +586,7 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { defer h.mutex.Unlock() if !h.has1RTTSealer { - return nil, errors.New("CryptoSetup: no sealer with encryption level 1-RTT") + return nil, ErrKeysNotYetAvailable } return h.aead, nil } @@ -607,7 +607,7 @@ func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { if h.handshakeOpener == nil { if h.initialOpener != nil { - return nil, ErrOpenerNotYetAvailable + return nil, ErrKeysNotYetAvailable } // if the initial opener is also not available, the keys were already dropped return nil, ErrKeysDropped @@ -620,7 +620,7 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { defer h.mutex.Unlock() if !h.has1RTTOpener { - return nil, ErrOpenerNotYetAvailable + return nil, ErrKeysNotYetAvailable } return h.aead, nil } diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 25b22b2c..1baee25a 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -11,10 +11,10 @@ import ( ) var ( - // ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level, + // ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level, // but the corresponding opener has not yet been initialized // This can happen when packets arrive out of order. - ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available") + ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available") // ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level, // but the corresponding keys have already been dropped. ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped") diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index e7cdcaeb..63979acf 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -99,7 +99,7 @@ func (a *updatableAEAD) rollKeys(now time.Time) { a.numRcvdWithCurrentKey = 0 a.numSentWithCurrentKey = 0 a.prevRcvAEAD = a.rcvAEAD - a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO()) + a.prevRcvAEADExpiry = now.Add(3 * a.rttStats.PTO(true)) a.rcvAEAD = a.nextRcvAEAD a.sendAEAD = a.nextSendAEAD diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 51000a44..54d071de 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -151,7 +151,7 @@ var _ = Describe("Updatable AEAD", func() { It("drops keys 3 PTOs after a key update", func() { now := time.Now() rttStats.UpdateRTT(10*time.Millisecond, 0, now) - pto := rttStats.PTO() + pto := rttStats.PTO(true) encrypted01 := client.Seal(nil, msg, 0x42, ad) encrypted02 := client.Seal(nil, msg, 0x43, ad) // receive the first packet with key phase 0 diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 7c032987..abde89aa 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -136,17 +136,17 @@ func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 interface{}) * } // QueueProbePacket mocks base method -func (m *MockSentPacketHandler) QueueProbePacket() bool { +func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueueProbePacket") + ret := m.ctrl.Call(m, "QueueProbePacket", arg0) ret0, _ := ret[0].(bool) return ret0 } // QueueProbePacket indicates an expected call of QueueProbePacket -func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) } // ReceivedAck mocks base method @@ -203,6 +203,18 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0) } +// SetHandshakeComplete mocks base method +func (m *MockSentPacketHandler) SetHandshakeComplete() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetHandshakeComplete") +} + +// SetHandshakeComplete indicates an expected call of SetHandshakeComplete +func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete)) +} + // ShouldSendNumPackets mocks base method func (m *MockSentPacketHandler) ShouldSendNumPackets() int { m.ctrl.T.Helper() diff --git a/mock_packer_test.go b/mock_packer_test.go index e2bf3507..6425c1ea 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -9,6 +9,7 @@ import ( gomock "github.com/golang/mock/gomock" handshake "github.com/lucas-clemente/quic-go/internal/handshake" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" wire "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -62,6 +63,21 @@ func (mr *MockPackerMockRecorder) MaybePackAckPacket() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket)) } +// MaybePackProbePacket mocks base method +func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel) (*packedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MaybePackProbePacket", arg0) + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MaybePackProbePacket indicates an expected call of MaybePackProbePacket +func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0) +} + // PackConnectionClose mocks base method func (m *MockPacker) PackConnectionClose(arg0 *wire.ConnectionCloseFrame) (*packedPacket, error) { m.ctrl.T.Helper() diff --git a/packet_packer.go b/packet_packer.go index 357e77dc..d31253ef 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -16,6 +16,7 @@ import ( type packer interface { PackPacket() (*packedPacket, error) + MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) MaybePackAckPacket() (*packedPacket, error) PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error) @@ -134,8 +135,9 @@ type packetPacker struct { version protocol.VersionNumber cryptoSetup sealingManager - // Once the handshake is confirmed, we only need to send 1-RTT packets. - handshakeConfirmed bool + // Once both Initial and Handshake keys are dropped, we only send 1-RTT packets. + droppedInitial bool + droppedHandshake bool initialStream cryptoStream handshakeStream cryptoStream @@ -183,6 +185,10 @@ func newPacketPacker( } } +func (p *packetPacker) handshakeConfirmed() bool { + return p.droppedInitial && p.droppedHandshake +} + // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { payload := payload{ @@ -219,7 +225,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { var encLevel protocol.EncryptionLevel var ack *wire.AckFrame - if !p.handshakeConfirmed { + if !p.handshakeConfirmed() { ack = p.acks.GetAckFrame(protocol.EncryptionInitial) if ack != nil { encLevel = protocol.EncryptionInitial @@ -255,7 +261,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { // PackPacket packs a new packet // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise func (p *packetPacker) PackPacket() (*packedPacket, error) { - if !p.handshakeConfirmed { + if !p.handshakeConfirmed() { packet, err := p.maybePackCryptoPacket() if err != nil { return nil, err @@ -265,6 +271,105 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { } } + return p.maybePackAppDataPacket() +} + +func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { + // Try packing an Initial packet. + packet, err := p.maybePackInitialPacket() + if err == handshake.ErrKeysDropped { + p.droppedInitial = true + } else if err != nil || packet != nil { + return packet, err + } + + // No Initial was packed. Try packing a Handshake packet. + packet, err = p.maybePackHandshakePacket() + if err == handshake.ErrKeysDropped { + p.droppedHandshake = true + return nil, nil + } + if err == handshake.ErrKeysNotYetAvailable { + return nil, nil + } + return packet, err +} + +func (p *packetPacker) maybePackInitialPacket() (*packedPacket, error) { + sealer, err := p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, err + } + + hasRetransmission := p.retransmissionQueue.HasInitialData() + ack := p.acks.GetAckFrame(protocol.EncryptionInitial) + if !p.initialStream.HasData() && !hasRetransmission && ack == nil { + // nothing to send + return nil, nil + } + return p.packCryptoPacket(protocol.EncryptionInitial, sealer, ack, hasRetransmission) +} + +func (p *packetPacker) maybePackHandshakePacket() (*packedPacket, error) { + sealer, err := p.cryptoSetup.GetHandshakeSealer() + + if err != nil { + return nil, err + } + + hasRetransmission := p.retransmissionQueue.HasHandshakeData() + ack := p.acks.GetAckFrame(protocol.EncryptionHandshake) + if !p.handshakeStream.HasData() && !hasRetransmission && ack == nil { + // nothing to send + return nil, nil + } + return p.packCryptoPacket(protocol.EncryptionHandshake, sealer, ack, hasRetransmission) +} + +func (p *packetPacker) packCryptoPacket( + encLevel protocol.EncryptionLevel, + sealer handshake.LongHeaderSealer, + ack *wire.AckFrame, + hasRetransmission bool, +) (*packedPacket, error) { + s := p.initialStream + if encLevel == protocol.EncryptionHandshake { + s = p.handshakeStream + } + + var payload payload + if ack != nil { + payload.ack = ack + payload.length = ack.Length(p.version) + } + hdr := p.getLongHeader(encLevel) + hdrLen := hdr.GetLength(p.version) + if hasRetransmission { + for { + var f wire.Frame + switch encLevel { + case protocol.EncryptionInitial: + remainingLen := protocol.MinInitialPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length + f = p.retransmissionQueue.GetInitialFrame(remainingLen) + case protocol.EncryptionHandshake: + remainingLen := p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length + f = p.retransmissionQueue.GetHandshakeFrame(remainingLen) + } + if f == nil { + break + } + payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) + payload.length += f.Length(p.version) + } + } else if s.HasData() { + cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length) + payload.frames = []ackhandler.Frame{{Frame: cf}} + payload.length += cf.Length(p.version) + } + return p.writeAndSealPacket(hdr, payload, encLevel, sealer) +} + +func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { // sealer not yet available @@ -296,78 +401,6 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { return p.writeAndSealPacket(header, payload, protocol.Encryption1RTT, sealer) } -func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { - var s cryptoStream - var encLevel protocol.EncryptionLevel - - initialSealer, errInitialSealer := p.cryptoSetup.GetInitialSealer() - handshakeSealer, errHandshakeSealer := p.cryptoSetup.GetHandshakeSealer() - - if errInitialSealer == handshake.ErrKeysDropped && - errHandshakeSealer == handshake.ErrKeysDropped { - p.handshakeConfirmed = true - } - - hasData := p.initialStream.HasData() - hasRetransmission := p.retransmissionQueue.HasInitialData() - ack := p.acks.GetAckFrame(protocol.EncryptionInitial) - var sealer handshake.LongHeaderSealer - if hasData || hasRetransmission || ack != nil { - s = p.initialStream - encLevel = protocol.EncryptionInitial - sealer = initialSealer - if errInitialSealer != nil { - return nil, fmt.Errorf("PacketPacker BUG: no Initial sealer: %s", errInitialSealer) - } - } else { - hasData = p.handshakeStream.HasData() - hasRetransmission = p.retransmissionQueue.HasHandshakeData() - ack = p.acks.GetAckFrame(protocol.EncryptionHandshake) - if hasData || hasRetransmission || ack != nil { - s = p.handshakeStream - encLevel = protocol.EncryptionHandshake - sealer = handshakeSealer - if errHandshakeSealer != nil { - return nil, fmt.Errorf("PacketPacker BUG: no Handshake sealer: %s", errHandshakeSealer) - } - } - } - if s == nil { - return nil, nil - } - - var payload payload - if ack != nil { - payload.ack = ack - payload.length = ack.Length(p.version) - } - hdr := p.getLongHeader(encLevel) - hdrLen := hdr.GetLength(p.version) - if hasRetransmission { - for { - var f wire.Frame - switch encLevel { - case protocol.EncryptionInitial: - remainingLen := protocol.MinInitialPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length - f = p.retransmissionQueue.GetInitialFrame(remainingLen) - case protocol.EncryptionHandshake: - remainingLen := p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length - f = p.retransmissionQueue.GetHandshakeFrame(remainingLen) - } - if f == nil { - break - } - payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) - payload.length += f.Length(p.version) - } - } else if hasData { - cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length) - payload.frames = []ackhandler.Frame{{Frame: cf}} - payload.length += cf.Length(p.version) - } - return p.writeAndSealPacket(hdr, payload, encLevel, sealer) -} - func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payload { var payload payload @@ -398,6 +431,19 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payloa return payload } +func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*packedPacket, error) { + switch encLevel { + case protocol.EncryptionInitial: + return p.maybePackInitialPacket() + case protocol.EncryptionHandshake: + return p.maybePackHandshakePacket() + case protocol.Encryption1RTT: + return p.maybePackAppDataPacket() + default: + panic("unknown encryption level") + } +} + func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { switch encLevel { case protocol.EncryptionInitial: diff --git a/packet_packer_test.go b/packet_packer_test.go index d8a95002..683badff 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -2,7 +2,6 @@ package quic import ( "bytes" - "errors" "math/rand" "net" "time" @@ -580,25 +579,24 @@ var _ = Describe("Packet packer", func() { Data: []byte("foobar"), } ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - initialStream.EXPECT().HasData().Return(true) + initialStream.EXPECT().HasData().Return(true).AnyTimes() initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) checkLength(p.raw) }) - It("packs a maximum size crypto packet", func() { + It("packs a maximum size Handshake packet", func() { var f *wire.CryptoFrame pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(nil, errors.New("no sealer")) + sealingManager.EXPECT().GetInitialSealer().Return(mocks.NewMockShortHeaderSealer(mockCtrl), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData().Return(true) + handshakeStream.EXPECT().HasData().Return(true).Times(2) handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { f = &wire.CryptoFrame{Offset: 0x1337} f.Data = bytes.Repeat([]byte{'f'}, int(size-f.Length(packer.version)-1)) @@ -620,7 +618,6 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData() p, err := packer.PackPacket() @@ -634,9 +631,8 @@ var _ = Describe("Packet packer", func() { It("sends an Initial packet containing only an ACK", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) - initialStream.EXPECT().HasData() + initialStream.EXPECT().HasData().Times(2) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) p, err := packer.PackPacket() @@ -644,13 +640,24 @@ var _ = Describe("Packet packer", func() { Expect(p.ack).To(Equal(ack)) }) + It("doesn't pack anything if there's nothing to send at Initial and Handshake keys are not yet available", func() { + sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) + initialStream.EXPECT().HasData() + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + }) + It("sends a Handshake packet containing only an ACK", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 20}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData() - sealingManager.EXPECT().GetInitialSealer().Return(nil, errors.New("no sealer")) + handshakeStream.EXPECT().HasData().Times(2) + sealingManager.EXPECT().GetInitialSealer().Return(mocks.NewMockShortHeaderSealer(mockCtrl), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) @@ -666,9 +673,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - initialStream.EXPECT().HasData().Return(true) + initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient packet, err := packer.PackPacket() @@ -678,21 +684,6 @@ var _ = Describe("Packet packer", func() { Expect(packet.frames).To(HaveLen(1)) cf := packet.frames[0].Frame.(*wire.CryptoFrame) Expect(cf.Data).To(Equal([]byte("foobar"))) - }) - - It("sets the correct length for an Initial packet", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - initialStream.EXPECT().HasData().Return(true) - initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(&wire.CryptoFrame{ - Data: []byte("foobar"), - }) - packer.perspective = protocol.PerspectiveClient - packet, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) checkLength(packet.raw) }) @@ -702,9 +693,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) - sealingManager.EXPECT().GetHandshakeSealer().Return(nil, errors.New("no sealer")) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial).Return(ack) - initialStream.EXPECT().HasData().Return(true) + initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.perspective = protocol.PerspectiveClient @@ -718,11 +708,7 @@ var _ = Describe("Packet packer", func() { It("stops packing crypto packets when the keys are dropped", func() { sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped) - initialStream.EXPECT().HasData() - handshakeStream.EXPECT().HasData() sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) @@ -744,6 +730,62 @@ var _ = Describe("Packet packer", func() { Expect(packet).ToNot(BeNil()) }) }) + + Context("packing probe packets", func() { + It("packs an Initial probe packet", func() { + f := &wire.CryptoFrame{Data: []byte("Initial")} + retransmissionQueue.AddInitial(f) + sealingManager.EXPECT().GetInitialSealer().Return(sealer, nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + initialStream.EXPECT().HasData() + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) + + packet, err := packer.MaybePackProbePacket(protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(Equal(f)) + checkLength(packet.raw) + }) + + It("packs a Handshake probe packet", func() { + f := &wire.CryptoFrame{Data: []byte("Handshake")} + retransmissionQueue.AddHandshake(f) + sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + handshakeStream.EXPECT().HasData() + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + + packet, err := packer.MaybePackProbePacket(protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(Equal(f)) + checkLength(packet.raw) + }) + + It("packs a 1-RTT probe packet", func() { + f := &wire.StreamFrame{Data: []byte("1-RTT")} + retransmissionQueue.AddInitial(f) + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + expectAppendControlFrames() + expectAppendStreamFrames(ackhandler.Frame{Frame: f}) + + packet, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(packet).ToNot(BeNil()) + Expect(packet.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(packet.frames).To(HaveLen(1)) + Expect(packet.frames[0].Frame).To(Equal(f)) + }) + }) }) }) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 0f5623e7..9b1eb260 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -85,9 +85,9 @@ var _ = Describe("Packet Unpacker", func() { PacketNumberLen: 2, } hdr, hdrRaw := getHeader(extHdr) - cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrOpenerNotYetAvailable) + cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).To(MatchError(handshake.ErrOpenerNotYetAvailable)) + Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable)) }) It("returns the error when unpacking fails", func() { diff --git a/session.go b/session.go index e7be6c23..f15fd193 100644 --- a/session.go +++ b/session.go @@ -590,6 +590,7 @@ func (s *session) handleHandshakeComplete() { s.handshakeCompleteChan = nil // prevent this case from ever being selected again s.handshakeCtxCancel() + s.sentPacketHandler.SetHandshakeComplete() // The client completes the handshake first (after sending the CFIN). // We need to make sure it learns about the server completing the handshake, // in order to stop retransmitting handshake packets. @@ -677,7 +678,7 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool / switch err { case handshake.ErrKeysDropped: s.logger.Debugf("Dropping packet because we already dropped the keys.") - case handshake.ErrOpenerNotYetAvailable: + case handshake.ErrKeysNotYetAvailable: // Sealer for this encryption level not yet available. // Try again later. wasQueued = true @@ -1146,8 +1147,18 @@ sendLoop: // There will only be a new ACK after receiving new packets. // SendAck is only returned when we're congestion limited, so we don't need to set the pacingt timer. return s.maybeSendAckOnlyPacket() - case ackhandler.SendPTO: - if err := s.sendProbePacket(); err != nil { + case ackhandler.SendPTOInitial: + if err := s.sendProbePacket(protocol.EncryptionInitial); err != nil { + return err + } + numPacketsSent++ + case ackhandler.SendPTOHandshake: + if err := s.sendProbePacket(protocol.EncryptionHandshake); err != nil { + return err + } + numPacketsSent++ + case ackhandler.SendPTOAppData: + if err := s.sendProbePacket(protocol.Encryption1RTT); err != nil { return err } numPacketsSent++ @@ -1189,24 +1200,46 @@ func (s *session) maybeSendAckOnlyPacket() error { return nil } -func (s *session) sendProbePacket() error { - // Queue probe packets until we actually send out a packet. +func (s *session) sendProbePacket(encLevel protocol.EncryptionLevel) error { + // Queue probe packets until we actually send out a packet, + // or until there are no more packets to queue. + var packet *packedPacket for { - if wasQueued := s.sentPacketHandler.QueueProbePacket(); !wasQueued { + if wasQueued := s.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued { break } - sent, err := s.sendPacket() + var err error + packet, err = s.packer.MaybePackProbePacket(encLevel) if err != nil { return err } - if sent { - return nil + if packet != nil { + break } } - // If there is nothing else to queue, make sure we send out something. - s.framer.QueueControlFrame(&wire.PingFrame{}) - _, err := s.sendPacket() - return err + if packet == nil { + switch encLevel { + case protocol.EncryptionInitial: + s.retransmissionQueue.AddInitial(&wire.PingFrame{}) + case protocol.EncryptionHandshake: + s.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + case protocol.Encryption1RTT: + s.retransmissionQueue.AddAppData(&wire.PingFrame{}) + default: + panic("unexpected encryption level") + } + var err error + packet, err = s.packer.MaybePackProbePacket(encLevel) + if err != nil { + return err + } + } + if packet == nil { + return fmt.Errorf("session BUG: couldn't pack %s probe packet", encLevel) + } + s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(s.retransmissionQueue)) + s.sendPackedPacket(packet) + return nil } func (s *session) sendPacket() (bool, error) { diff --git a/session_test.go b/session_test.go index fbe5e3c4..80ff56de 100644 --- a/session_test.go +++ b/session_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/tls" "errors" + "fmt" "net" "runtime/pprof" "strings" @@ -749,7 +750,7 @@ var _ = Describe("Session", func() { PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 1, } - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable) packet := getPacket(hdr, nil) Expect(sess.handlePacketImpl(packet)).To(BeFalse()) Expect(sess.undecryptablePackets).To(Equal([]*receivedPacket{packet})) @@ -832,7 +833,7 @@ var _ = Describe("Session", func() { hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) gomock.InOrder( - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable), + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) return &unpackedPacket{ @@ -924,40 +925,6 @@ var _ = Describe("Session", func() { Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.DataBlockedFrame{DataLimit: 1337}}})) }) - It("sends a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend() - sph.EXPECT().SendMode().Return(ackhandler.SendPTO) - sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().QueueProbePacket() - packer.EXPECT().PackPacket().Return(getPacket(123), nil) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) - }) - sess.sentPacketHandler = sph - Expect(sess.sendPackets()).To(Succeed()) - }) - - It("sends a PING as a probe packet", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend() - sph.EXPECT().SendMode().Return(ackhandler.SendPTO) - sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().QueueProbePacket().Return(false) - packer.EXPECT().PackPacket().Return(getPacket(123), nil) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { - Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) - }) - sess.sentPacketHandler = sph - Expect(sess.sendPackets()).To(Succeed()) - // We're using a mock packet packer in this test. - // We therefore need to test separately that the PING was actually queued. - frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) - Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PingFrame{}}})) - }) - It("doesn't send when the SentPacketHandler doesn't allow it", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() @@ -966,6 +933,62 @@ var _ = Describe("Session", func() { err := sess.sendPackets() Expect(err).ToNot(HaveOccurred()) }) + + for _, enc := range []protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption1RTT} { + encLevel := enc + + Context(fmt.Sprintf("sending %s probe packets", encLevel), func() { + var sendMode ackhandler.SendMode + var getFrame func(protocol.ByteCount) wire.Frame + + BeforeEach(func() { + switch encLevel { + case protocol.EncryptionInitial: + sendMode = ackhandler.SendPTOInitial + getFrame = sess.retransmissionQueue.GetInitialFrame + case protocol.EncryptionHandshake: + sendMode = ackhandler.SendPTOHandshake + getFrame = sess.retransmissionQueue.GetHandshakeFrame + case protocol.Encryption1RTT: + sendMode = ackhandler.SendPTOAppData + getFrame = sess.retransmissionQueue.GetAppDataFrame + } + }) + + It("sends a probe packet", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend() + sph.EXPECT().SendMode().Return(sendMode) + sph.EXPECT().ShouldSendNumPackets().Return(1) + sph.EXPECT().QueueProbePacket(encLevel) + packer.EXPECT().MaybePackProbePacket(encLevel).Return(getPacket(123), nil) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) + }) + sess.sentPacketHandler = sph + Expect(sess.sendPackets()).To(Succeed()) + }) + + It("sends a PING as a probe packet", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend() + sph.EXPECT().SendMode().Return(sendMode) + sph.EXPECT().ShouldSendNumPackets().Return(1) + sph.EXPECT().QueueProbePacket(encLevel).Return(false) + packer.EXPECT().MaybePackProbePacket(encLevel).Return(getPacket(123), nil) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) + }) + sess.sentPacketHandler = sph + Expect(sess.sendPackets()).To(Succeed()) + // We're using a mock packet packer in this test. + // We therefore need to test separately that the PING was actually queued. + Expect(getFrame(1000)).To(BeAssignableToTypeOf(&wire.PingFrame{})) + }) + }) + } }) Context("packet pacing", func() { @@ -1140,9 +1163,16 @@ var _ = Describe("Session", func() { }) }) - It("cancels the HandshakeComplete context when the handshake completes", func() { + It("cancels the HandshakeComplete context and informs the SentPacketHandler when the handshake completes", func() { packer.EXPECT().PackPacket().AnyTimes() finishHandshake := make(chan struct{}) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sess.sentPacketHandler = sph + sphNotified := make(chan struct{}) + sph.EXPECT().SetHandshakeComplete().Do(func() { close(sphNotified) }) + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SendMode().AnyTimes() go func() { defer GinkgoRecover() <-finishHandshake @@ -1154,6 +1184,7 @@ var _ = Describe("Session", func() { Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) close(finishHandshake) Eventually(handshakeCtx.Done()).Should(BeClosed()) + Eventually(sphNotified).Should(BeClosed()) // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed()