add a function to drop sent packets of a certain encryption level

This commit is contained in:
Marten Seemann 2019-05-30 01:01:27 +08:00
parent 195bdc9944
commit 4d5b4fd790
5 changed files with 62 additions and 40 deletions

View file

@ -14,7 +14,7 @@ type SentPacketHandler interface {
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetMaxAckDelay(time.Duration)
SetHandshakeComplete()
DropPackets(protocol.EncryptionLevel)
ResetForRetry() error
// The SendMode determines if and what kind of packets can be sent.

View file

@ -102,27 +102,27 @@ func NewSentPacketHandler(
}
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding crypto packets.")
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
// remove outstanding packets from bytes_in_flight
pnSpace := h.getPacketNumberSpace(encLevel)
var packets []*Packet
pnSpace.history.Iterate(func(p *Packet) (bool, error) {
packets = append(packets, p)
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
return true, nil
})
for _, p := range packets {
pnSpace.history.Remove(p.PacketNumber)
}
// remove packets from the retransmission queue
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.Encryption1RTT {
if packet.EncryptionLevel != encLevel {
queue = append(queue, packet)
}
}
for _, pnSpace := range []*packetNumberSpace{h.initialPackets, h.handshakePackets} {
var cryptoPackets []*Packet
pnSpace.history.Iterate(func(p *Packet) (bool, error) {
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
cryptoPackets = append(cryptoPackets, p)
return true, nil
})
for _, p := range cryptoPackets {
pnSpace.history.Remove(p.PacketNumber)
}
}
h.retransmissionQueue = queue
}

View file

@ -49,7 +49,6 @@ var _ = Describe("SentPacketHandler", func() {
BeforeEach(func() {
rttStats := &congestion.RTTStats{}
handler = NewSentPacketHandler(42, rttStats, utils.DefaultLogger).(*sentPacketHandler)
handler.SetHandshakeComplete()
streamFrame = wire.StreamFrame{
StreamID: 5,
Data: []byte{0x13, 0x37},
@ -847,24 +846,45 @@ var _ = Describe("SentPacketHandler", func() {
Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet"))
})
It("deletes crypto packets when the handshake completes", func() {
It("deletes Initial packets", func() {
for i := protocol.PacketNumber(0); i < 6; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionInitial})
handler.SentPacket(p)
}
for i := protocol.PacketNumber(0); i <= 6; i++ {
for i := protocol.PacketNumber(0); i < 10; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake})
handler.SentPacket(p)
}
Expect(handler.bytesInFlight).ToNot(BeZero())
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16)))
handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionInitial), handler.getPacketNumberSpace(protocol.EncryptionInitial))
handler.queuePacketForRetransmission(getPacket(3, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.SetHandshakeComplete()
lostPacket := getPacket(3, protocol.EncryptionHandshake)
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.DropPackets(protocol.EncryptionInitial)
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
Expect(handler.initialPackets.history.Len()).To(BeZero())
Expect(handler.handshakePackets.history.Len()).To(BeZero())
Expect(handler.bytesInFlight).To(BeZero())
Expect(handler.handshakePackets.history.Len()).ToNot(BeZero())
packet := handler.DequeuePacketForRetransmission()
Expect(packet).To(BeNil())
Expect(packet).To(Equal(lostPacket))
})
It("deletes Handshake packets", func() {
for i := protocol.PacketNumber(0); i < 6; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake})
handler.SentPacket(p)
}
for i := protocol.PacketNumber(0); i < 10; i++ {
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.Encryption1RTT})
handler.SentPacket(p)
}
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16)))
handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionInitial))
lostPacket := getPacket(3, protocol.Encryption1RTT)
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.DropPackets(protocol.EncryptionHandshake)
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
Expect(handler.handshakePackets.history.Len()).To(BeZero())
packet := handler.DequeuePacketForRetransmission()
Expect(packet).To(Equal(lostPacket))
})
})

View file

@ -66,6 +66,18 @@ func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket))
}
// DropPackets mocks base method
func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DropPackets", arg0)
}
// DropPackets indicates an expected call of DropPackets
func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0)
}
// GetAlarmTimeout mocks base method
func (m *MockSentPacketHandler) GetAlarmTimeout() time.Time {
m.ctrl.T.Helper()
@ -203,18 +215,6 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacketsAsRetransmission(arg0, a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacketsAsRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacketsAsRetransmission), arg0, arg1)
}
// SetHandshakeComplete mocks base method
func (m *MockSentPacketHandler) SetHandshakeComplete() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetHandshakeComplete")
}
// SetHandshakeComplete indicates an expected call of SetHandshakeComplete
func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete))
}
// SetMaxAckDelay mocks base method
func (m *MockSentPacketHandler) SetMaxAckDelay(arg0 time.Duration) {
m.ctrl.T.Helper()

View file

@ -485,7 +485,8 @@ func (s *session) handleHandshakeComplete() {
// independent from the application protocol.
if s.perspective == protocol.PerspectiveServer {
s.queueControlFrame(&wire.PingFrame{})
s.sentPacketHandler.SetHandshakeComplete()
s.sentPacketHandler.DropPackets(protocol.EncryptionInitial)
s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake)
}
}
@ -646,7 +647,8 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time
if s.perspective == protocol.PerspectiveClient {
if !s.receivedFirstForwardSecurePacket && packet.encryptionLevel == protocol.Encryption1RTT {
s.receivedFirstForwardSecurePacket = true
s.sentPacketHandler.SetHandshakeComplete()
s.sentPacketHandler.DropPackets(protocol.EncryptionInitial)
s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake)
}
}