limit available window to 3x of received bytes before address validation

This commit is contained in:
Marten Seemann 2020-04-23 13:34:31 +07:00
parent 98233f6743
commit 60a918a108
9 changed files with 186 additions and 13 deletions

View file

@ -25,12 +25,14 @@ type SentPacketHandler interface {
// SentPacket may modify the packet // SentPacket may modify the packet
SentPacket(packet *Packet) SentPacket(packet *Packet)
ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) error ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) error
ReceivedBytes(protocol.ByteCount)
DropPackets(protocol.EncryptionLevel) DropPackets(protocol.EncryptionLevel)
ResetForRetry() error ResetForRetry() error
SetHandshakeComplete() SetHandshakeComplete()
// The SendMode determines if and what kind of packets can be sent. // The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode SendMode() SendMode
AmplificationWindow() protocol.ByteCount
// TimeUntilSend is the time when the next packet should be sent. // TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets. // It is used for pacing packets.
TimeUntilSend() time.Time TimeUntilSend() time.Time
@ -56,6 +58,7 @@ type SentPacketHandler interface {
type sentPacketTracker interface { type sentPacketTracker interface {
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
ReceivedPacket(protocol.EncryptionLevel)
} }
// ReceivedPacketHandler handles ACKs needed to send for incoming packets // ReceivedPacketHandler handles ACKs needed to send for incoming packets

View file

@ -47,3 +47,15 @@ func (mr *MockSentPacketTrackerMockRecorder) GetLowestPacketNotConfirmedAcked()
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked))
} }
// ReceivedPacket mocks base method
func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedPacket", arg0)
}
// ReceivedPacket indicates an expected call of ReceivedPacket
func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0)
}

View file

@ -64,6 +64,7 @@ func (h *receivedPacketHandler) ReceivedPacket(
rcvTime time.Time, rcvTime time.Time,
shouldInstigateAck bool, shouldInstigateAck bool,
) error { ) error {
h.sentPackets.ReceivedPacket(encLevel)
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck)

View file

@ -3,6 +3,8 @@ package ackhandler
import ( import (
"time" "time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
@ -29,6 +31,9 @@ var _ = Describe("Received Packet Handler", func() {
It("generates ACKs for different packet number spaces", func() { It("generates ACKs for different packet number spaces", func() {
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sendTime := time.Now().Add(-time.Second) sendTime := time.Now().Add(-time.Second)
sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionInitial).Times(2)
sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionHandshake).Times(2)
sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT).Times(2)
Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(5, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(5, protocol.Encryption1RTT, sendTime, true)).To(Succeed())
@ -54,6 +59,8 @@ var _ = Describe("Received Packet Handler", func() {
It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { It("uses the same packet number space for 0-RTT and 1-RTT packets", func() {
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sentPackets.EXPECT().ReceivedPacket(protocol.Encryption0RTT)
sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT)
sendTime := time.Now().Add(-time.Second) sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed())
@ -64,6 +71,7 @@ var _ = Describe("Received Packet Handler", func() {
}) })
It("rejects 0-RTT packets with higher packet numbers than 1-RTT packets", func() { It("rejects 0-RTT packets with higher packet numbers than 1-RTT packets", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3)
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sendTime := time.Now() sendTime := time.Now()
Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed())
@ -72,6 +80,7 @@ var _ = Describe("Received Packet Handler", func() {
}) })
It("allows reordered 0-RTT packets", func() { It("allows reordered 0-RTT packets", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3)
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sendTime := time.Now() sendTime := time.Now()
Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed())
@ -80,6 +89,7 @@ var _ = Describe("Received Packet Handler", func() {
}) })
It("drops Initial packets", func() { It("drops Initial packets", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2)
sendTime := time.Now().Add(-time.Second) sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
@ -90,6 +100,7 @@ var _ = Describe("Received Packet Handler", func() {
}) })
It("drops Handshake packets", func() { It("drops Handshake packets", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2)
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sendTime := time.Now().Add(-time.Second) sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
@ -105,6 +116,7 @@ var _ = Describe("Received Packet Handler", func() {
}) })
It("drops old ACK ranges", func() { It("drops old ACK ranges", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes()
sendTime := time.Now() sendTime := time.Now()
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2)
Expect(handler.ReceivedPacket(1, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.Encryption1RTT, sendTime, true)).To(Succeed())

View file

@ -20,6 +20,8 @@ const (
timeThreshold = 9.0 / 8 timeThreshold = 9.0 / 8
// Maximum reordering in packets before packet threshold loss detection considers a packet lost. // Maximum reordering in packets before packet threshold loss detection considers a packet lost.
packetThreshold = 3 packetThreshold = 3
// Before validating the client's address, the server won't send more than 3x bytes than it received.
amplificationFactor = 3
) )
type packetNumberSpace struct { type packetNumberSpace struct {
@ -49,8 +51,16 @@ type sentPacketHandler struct {
handshakePackets *packetNumberSpace handshakePackets *packetNumberSpace
appDataPackets *packetNumberSpace appDataPackets *packetNumberSpace
// Do we know that the peer completed address validation yet?
// Always true for the server.
peerCompletedAddressValidation bool peerCompletedAddressValidation bool
handshakeComplete bool bytesReceived protocol.ByteCount
bytesSent protocol.ByteCount
// Have we validated the peer's address yet?
// Always true for the client.
peerAddressValidated bool
handshakeComplete bool
// lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived // lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20 // example: we send an ACK for packets 90-100 with packet number 20
@ -99,6 +109,7 @@ func newSentPacketHandler(
return &sentPacketHandler{ return &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient,
initialPackets: newPacketNumberSpace(initialPacketNumber), initialPackets: newPacketNumberSpace(initialPacketNumber),
handshakePackets: newPacketNumberSpace(0), handshakePackets: newPacketNumberSpace(0),
appDataPackets: newPacketNumberSpace(0), appDataPackets: newPacketNumberSpace(0),
@ -168,6 +179,16 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
h.ptoMode = SendNone h.ptoMode = SendNone
} }
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) {
h.bytesReceived += n
}
func (h *sentPacketHandler) ReceivedPacket(encLevel protocol.EncryptionLevel) {
if h.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionHandshake {
h.peerAddressValidated = true
}
}
func (h *sentPacketHandler) packetsInFlight() int { func (h *sentPacketHandler) packetsInFlight() int {
packetsInFlight := h.appDataPackets.history.Len() packetsInFlight := h.appDataPackets.history.Len()
if h.handshakePackets != nil { if h.handshakePackets != nil {
@ -180,6 +201,7 @@ func (h *sentPacketHandler) packetsInFlight() int {
} }
func (h *sentPacketHandler) SentPacket(packet *Packet) { func (h *sentPacketHandler) SentPacket(packet *Packet) {
h.bytesSent += packet.Length
// For the client, drop the Initial packet number space when the first Handshake packet is sent. // For the client, drop the Initial packet number space when the first Handshake packet is sent.
if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil {
h.dropPackets(protocol.EncryptionInitial) h.dropPackets(protocol.EncryptionInitial)
@ -638,6 +660,10 @@ func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets += h.handshakePackets.history.Len() numTrackedPackets += h.handshakePackets.history.Len()
} }
if h.AmplificationWindow() == 0 {
h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent)
return SendNone
}
// Don't send any packets if we're keeping track of the maximum number of packets. // Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets, // we will stop sending out new data when reaching MaxOutstandingSentPackets,
@ -683,6 +709,16 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
} }
func (h *sentPacketHandler) AmplificationWindow() protocol.ByteCount {
if h.peerAddressValidated {
return protocol.MaxByteCount
}
if h.bytesSent >= amplificationFactor*h.bytesReceived {
return 0
}
return amplificationFactor*h.bytesReceived - h.bytesSent
}
func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool { func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
p := pnSpace.history.FirstOutstanding() p := pnSpace.history.FirstOutstanding()

View file

@ -495,13 +495,51 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed())
}) })
It("passes the bytes in flight to CanSend", func() { It("passes the bytes in flight to the congestion controller", func() {
handler.bytesInFlight = 42 handler.ReceivedPacket(protocol.EncryptionHandshake)
cong.EXPECT().CanSend(protocol.ByteCount(42)) cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(42), gomock.Any(), protocol.ByteCount(42), true)
cong.EXPECT().TimeUntilSend(gomock.Any())
handler.SentPacket(&Packet{
Length: 42,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
cong.EXPECT().CanSend(protocol.ByteCount(42)).Return(true)
handler.SendMode() handler.SendMode()
}) })
It("returns SendNone if limited by the 3x limit", func() {
handler.ReceivedBytes(100)
cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(300), gomock.Any(), protocol.ByteCount(300), true)
cong.EXPECT().TimeUntilSend(gomock.Any())
handler.SentPacket(&Packet{
Length: 300,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
cong.EXPECT().CanSend(protocol.ByteCount(300)).Return(true).AnyTimes()
Expect(handler.AmplificationWindow()).To(BeZero())
Expect(handler.SendMode()).To(Equal(SendNone))
})
It("limits the window to 3x the bytes received, to avoid amplification attacks", func() {
handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address
cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(50), gomock.Any(), protocol.ByteCount(50), true)
cong.EXPECT().TimeUntilSend(gomock.Any())
handler.SentPacket(&Packet{
Length: 50,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
handler.ReceivedBytes(100)
Expect(handler.AmplificationWindow()).To(Equal(protocol.ByteCount(3*100 - 50)))
})
It("allows sending of ACKs when congestion limited", func() { It("allows sending of ACKs when congestion limited", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
cong.EXPECT().CanSend(gomock.Any()).Return(true) cong.EXPECT().CanSend(gomock.Any()).Return(true)
Expect(handler.SendMode()).To(Equal(SendAny)) Expect(handler.SendMode()).To(Equal(SendAny))
cong.EXPECT().CanSend(gomock.Any()).Return(false) cong.EXPECT().CanSend(gomock.Any()).Return(false)
@ -509,6 +547,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("allows sending of ACKs when we're keeping track of MaxOutstandingSentPackets packets", func() { It("allows sending of ACKs when we're keeping track of MaxOutstandingSentPackets packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes() cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes()
cong.EXPECT().TimeUntilSend(gomock.Any()).AnyTimes() cong.EXPECT().TimeUntilSend(gomock.Any()).AnyTimes()
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
@ -521,6 +560,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("allows PTOs, even when congestion limited", func() { It("allows PTOs, even when congestion limited", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
// note that we don't EXPECT a call to GetCongestionWindow // note that we don't EXPECT a call to GetCongestionWindow
// that means retransmissions are sent without considering the congestion window // that means retransmissions are sent without considering the congestion window
handler.numProbesToSend = 1 handler.numProbesToSend = 1
@ -561,6 +601,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("doesn't set an alarm if there are no outstanding packets", func() { It("doesn't set an alarm if there are no outstanding packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}} ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}}
@ -569,6 +610,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("does nothing on OnAlarm if there are no outstanding packets", func() { It("does nothing on OnAlarm if there are no outstanding packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode()).To(Equal(SendAny)) Expect(handler.SendMode()).To(Equal(SendAny))
}) })
@ -602,6 +644,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("reset the PTO count when receiving an ACK", func() { It("reset the PTO count when receiving an ACK", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
now := time.Now() now := time.Now()
handler.SetHandshakeComplete() handler.SetHandshakeComplete()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)}))
@ -615,6 +658,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("resets the PTO mode and PTO count when a packet number space is dropped", func() { It("resets the PTO mode and PTO count when a packet number space is dropped", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
now := time.Now() now := time.Now()
handler.SentPacket(ackElicitingPacket(&Packet{ handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: 1, PacketNumber: 1,
@ -638,6 +682,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("allows two 1-RTT PTOs", func() { It("allows two 1-RTT PTOs", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete() handler.SetHandshakeComplete()
var lostPackets []protocol.PacketNumber var lostPackets []protocol.PacketNumber
handler.SentPacket(ackElicitingPacket(&Packet{ handler.SentPacket(ackElicitingPacket(&Packet{
@ -657,6 +702,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("only counts ack-eliciting packets as probe packets", func() { It("only counts ack-eliciting packets as probe packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete() handler.SetHandshakeComplete()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.OnLossDetectionTimeout()).To(Succeed())
@ -672,7 +718,8 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData))
}) })
It("gets two probe packets if RTO expires", func() { It("gets two probe packets if PTO expires", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete() handler.SetHandshakeComplete()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2}))
@ -698,6 +745,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("gets two probe packets if PTO expires, for Handshake packets", func() { It("gets two probe packets if PTO expires, for Handshake packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) handler.SentPacket(initialPacket(&Packet{PacketNumber: 1}))
handler.SentPacket(initialPacket(&Packet{PacketNumber: 2})) handler.SentPacket(initialPacket(&Packet{PacketNumber: 2}))
@ -714,6 +762,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("doesn't send 1-RTT probe packets before the handshake completes", func() { It("doesn't send 1-RTT probe packets before the handshake completes", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1}))
updateRTT(time.Hour) updateRTT(time.Hour)
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP
@ -726,6 +775,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete() handler.SetHandshakeComplete()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) handler.rttStats.UpdateRTT(time.Second, 0, time.Now())
@ -737,6 +787,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("handles ACKs for the original packet", func() { It("handles ACKs for the original packet", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)}))
handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) handler.rttStats.UpdateRTT(time.Second, 0, time.Now())
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.OnLossDetectionTimeout()).To(Succeed())
@ -993,6 +1044,7 @@ var _ = Describe("SentPacketHandler", func() {
}) })
It("cancels the PTO when dropping a packet number space", func() { It("cancels the PTO when dropping a packet number space", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
now := time.Now() now := time.Now()
handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)}))
handler.SentPacket(handshakePacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) handler.SentPacket(handshakePacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)}))
@ -1028,12 +1080,15 @@ var _ = Describe("SentPacketHandler", func() {
}) })
}) })
Context("resetting for retry", func() { Context("for the client", func() {
BeforeEach(func() { BeforeEach(func() {
perspective = protocol.PerspectiveClient perspective = protocol.PerspectiveClient
}) })
It("queues outstanding packets for retransmission, cancels alarms and resets PTO count", func() { It("considers the server's address validated right away", func() {
})
It("queues outstanding packets for retransmission, cancels alarms and resets PTO count when receiving a Retry", func() {
handler.SentPacket(initialPacket(&Packet{PacketNumber: 42})) handler.SentPacket(initialPacket(&Packet{PacketNumber: 42}))
Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero())
Expect(handler.bytesInFlight).ToNot(BeZero()) Expect(handler.bytesInFlight).ToNot(BeZero())
@ -1047,7 +1102,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.ptoCount).To(BeZero()) Expect(handler.ptoCount).To(BeZero())
}) })
It("queues outstanding frames for retransmission and cancels alarms", func() { It("queues outstanding frames for retransmission and cancels alarms when receiving a Retry", func() {
var lostInitial, lost0RTT bool var lostInitial, lost0RTT bool
handler.SentPacket(&Packet{ handler.SentPacket(&Packet{
PacketNumber: 13, PacketNumber: 13,

View file

@ -38,6 +38,20 @@ func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder {
return m.recorder return m.recorder
} }
// AmplificationWindow mocks base method
func (m *MockSentPacketHandler) AmplificationWindow() protocol.ByteCount {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AmplificationWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// AmplificationWindow indicates an expected call of AmplificationWindow
func (mr *MockSentPacketHandlerMockRecorder) AmplificationWindow() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AmplificationWindow", reflect.TypeOf((*MockSentPacketHandler)(nil).AmplificationWindow))
}
// DropPackets mocks base method // DropPackets mocks base method
func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -149,6 +163,18 @@ func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2)
} }
// ReceivedBytes mocks base method
func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedBytes", arg0)
}
// ReceivedBytes indicates an expected call of ReceivedBytes
func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0)
}
// ResetForRetry mocks base method // ResetForRetry mocks base method
func (m *MockSentPacketHandler) ResetForRetry() error { func (m *MockSentPacketHandler) ResetForRetry() error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -700,6 +700,7 @@ func (s *session) handlePacketImpl(rp *receivedPacket) bool {
var processed bool var processed bool
data := rp.data data := rp.data
p := rp p := rp
s.sentPacketHandler.ReceivedBytes(protocol.ByteCount(len(data)))
for len(data) > 0 { for len(data) > 0 {
if counter > 0 { if counter > 0 {
p = p.Clone() p = p.Clone()
@ -1427,7 +1428,7 @@ func (s *session) sendPacket() (bool, error) {
if !s.handshakeConfirmed { if !s.handshakeConfirmed {
now := time.Now() now := time.Now()
packet, err := s.packer.PackCoalescedPacket(protocol.MaxByteCount) packet, err := s.packer.PackCoalescedPacket(s.sentPacketHandler.AmplificationWindow())
if err != nil || packet == nil { if err != nil || packet == nil {
return false, err return false, err
} }

View file

@ -1033,6 +1033,13 @@ var _ = Describe("Session", func() {
It("sends packets", func() { It("sends packets", func() {
sess.handshakeConfirmed = true sess.handshakeConfirmed = true
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ShouldSendNumPackets().Return(1000)
sph.EXPECT().SentPacket(gomock.Any())
sess.sentPacketHandler = sph
runSession() runSession()
p := getPacket(1) p := getPacket(1)
packer.EXPECT().PackPacket().Return(p, nil) packer.EXPECT().PackPacket().Return(p, nil)
@ -1069,6 +1076,13 @@ var _ = Describe("Session", func() {
It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { It("adds a BLOCKED frame when it is connection-level flow control blocked", func() {
sess.handshakeConfirmed = true sess.handshakeConfirmed = true
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ShouldSendNumPackets().Return(1000)
sph.EXPECT().SentPacket(gomock.Any())
sess.sentPacketHandler = sph
fc := mocks.NewMockConnectionFlowController(mockCtrl) fc := mocks.NewMockConnectionFlowController(mockCtrl)
fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337))
fc.EXPECT().IsNewlyBlocked() fc.EXPECT().IsNewlyBlocked()
@ -1366,10 +1380,12 @@ var _ = Describe("Session", func() {
It("sends coalesced packets before the handshake is confirmed", func() { It("sends coalesced packets before the handshake is confirmed", func() {
sess.handshakeConfirmed = false sess.handshakeConfirmed = false
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
const window protocol.ByteCount = 321
sph.EXPECT().AmplificationWindow().Return(window).AnyTimes()
sess.sentPacketHandler = sph sess.sentPacketHandler = sph
buffer := getPacketBuffer() buffer := getPacketBuffer()
buffer.Data = append(buffer.Data, []byte("foobar")...) buffer.Data = append(buffer.Data, []byte("foobar")...)
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).Return(&coalescedPacket{ packer.EXPECT().PackCoalescedPacket(window).Return(&coalescedPacket{
buffer: buffer, buffer: buffer,
packets: []*packetContents{ packets: []*packetContents{
{ {
@ -1394,7 +1410,7 @@ var _ = Describe("Session", func() {
}, },
}, },
}, nil) }, nil)
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() packer.EXPECT().PackCoalescedPacket(window).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
@ -1545,9 +1561,17 @@ var _ = Describe("Session", func() {
}) })
It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount)
sph.EXPECT().SetHandshakeComplete()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(10)
sess.sentPacketHandler = sph
done := make(chan struct{}) done := make(chan struct{})
sessionRunner.EXPECT().Retire(clientDestConnID) sessionRunner.EXPECT().Retire(clientDestConnID)
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) { packer.EXPECT().PackCoalescedPacket(gomock.Any()).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) {
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
Expect(frames).ToNot(BeEmpty()) Expect(frames).ToNot(BeEmpty())
Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{}))
@ -1559,7 +1583,7 @@ var _ = Describe("Session", func() {
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
}, nil }, nil
}) })
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() packer.EXPECT().PackCoalescedPacket(gomock.Any()).AnyTimes()
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().RunHandshake()
@ -1659,6 +1683,7 @@ var _ = Describe("Session", func() {
BeforeEach(func() { BeforeEach(func() {
sess.config.MaxIdleTimeout = 30 * time.Second sess.config.MaxIdleTimeout = 30 * time.Second
sess.config.KeepAlive = true sess.config.KeepAlive = true
sess.receivedPacketHandler.ReceivedPacket(0, protocol.EncryptionHandshake, time.Now(), true)
}) })
AfterEach(func() { AfterEach(func() {
@ -2098,6 +2123,7 @@ var _ = Describe("Client Session", func() {
It("handles Retry packets", func() { It("handles Retry packets", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph sess.sentPacketHandler = sph
sph.EXPECT().ReceivedBytes(gomock.Any())
sph.EXPECT().ResetForRetry() sph.EXPECT().ResetForRetry()
cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})
packer.EXPECT().SetToken([]byte("foobar")) packer.EXPECT().SetToken([]byte("foobar"))
@ -2333,6 +2359,7 @@ var _ = Describe("Client Session", func() {
It("ignores Initial packets which use original source id, after accepting a Retry", func() { It("ignores Initial packets which use original source id, after accepting a Retry", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph sess.sentPacketHandler = sph
sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2)
sph.EXPECT().ResetForRetry() sph.EXPECT().ResetForRetry()
newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID)