mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
add a function to drop sent packets of a certain encryption level
This commit is contained in:
parent
195bdc9944
commit
4d5b4fd790
5 changed files with 62 additions and 40 deletions
|
@ -14,7 +14,7 @@ type SentPacketHandler interface {
|
|||
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
|
||||
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||
SetMaxAckDelay(time.Duration)
|
||||
SetHandshakeComplete()
|
||||
DropPackets(protocol.EncryptionLevel)
|
||||
ResetForRetry() error
|
||||
|
||||
// The SendMode determines if and what kind of packets can be sent.
|
||||
|
|
|
@ -102,27 +102,27 @@ func NewSentPacketHandler(
|
|||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||
h.logger.Debugf("Handshake complete. Discarding all outstanding crypto packets.")
|
||||
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
|
||||
// remove outstanding packets from bytes_in_flight
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
var packets []*Packet
|
||||
pnSpace.history.Iterate(func(p *Packet) (bool, error) {
|
||||
packets = append(packets, p)
|
||||
if p.includedInBytesInFlight {
|
||||
h.bytesInFlight -= p.Length
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
for _, p := range packets {
|
||||
pnSpace.history.Remove(p.PacketNumber)
|
||||
}
|
||||
// remove packets from the retransmission queue
|
||||
var queue []*Packet
|
||||
for _, packet := range h.retransmissionQueue {
|
||||
if packet.EncryptionLevel == protocol.Encryption1RTT {
|
||||
if packet.EncryptionLevel != encLevel {
|
||||
queue = append(queue, packet)
|
||||
}
|
||||
}
|
||||
for _, pnSpace := range []*packetNumberSpace{h.initialPackets, h.handshakePackets} {
|
||||
var cryptoPackets []*Packet
|
||||
pnSpace.history.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.includedInBytesInFlight {
|
||||
h.bytesInFlight -= p.Length
|
||||
}
|
||||
cryptoPackets = append(cryptoPackets, p)
|
||||
return true, nil
|
||||
})
|
||||
for _, p := range cryptoPackets {
|
||||
pnSpace.history.Remove(p.PacketNumber)
|
||||
}
|
||||
}
|
||||
h.retransmissionQueue = queue
|
||||
}
|
||||
|
||||
|
|
|
@ -49,7 +49,6 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
BeforeEach(func() {
|
||||
rttStats := &congestion.RTTStats{}
|
||||
handler = NewSentPacketHandler(42, rttStats, utils.DefaultLogger).(*sentPacketHandler)
|
||||
handler.SetHandshakeComplete()
|
||||
streamFrame = wire.StreamFrame{
|
||||
StreamID: 5,
|
||||
Data: []byte{0x13, 0x37},
|
||||
|
@ -847,24 +846,45 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet"))
|
||||
})
|
||||
|
||||
It("deletes crypto packets when the handshake completes", func() {
|
||||
It("deletes Initial packets", func() {
|
||||
for i := protocol.PacketNumber(0); i < 6; i++ {
|
||||
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionInitial})
|
||||
handler.SentPacket(p)
|
||||
}
|
||||
for i := protocol.PacketNumber(0); i <= 6; i++ {
|
||||
for i := protocol.PacketNumber(0); i < 10; i++ {
|
||||
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake})
|
||||
handler.SentPacket(p)
|
||||
}
|
||||
Expect(handler.bytesInFlight).ToNot(BeZero())
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16)))
|
||||
handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionInitial), handler.getPacketNumberSpace(protocol.EncryptionInitial))
|
||||
handler.queuePacketForRetransmission(getPacket(3, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionHandshake))
|
||||
handler.SetHandshakeComplete()
|
||||
lostPacket := getPacket(3, protocol.EncryptionHandshake)
|
||||
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
|
||||
handler.DropPackets(protocol.EncryptionInitial)
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
|
||||
Expect(handler.initialPackets.history.Len()).To(BeZero())
|
||||
Expect(handler.handshakePackets.history.Len()).To(BeZero())
|
||||
Expect(handler.bytesInFlight).To(BeZero())
|
||||
Expect(handler.handshakePackets.history.Len()).ToNot(BeZero())
|
||||
packet := handler.DequeuePacketForRetransmission()
|
||||
Expect(packet).To(BeNil())
|
||||
Expect(packet).To(Equal(lostPacket))
|
||||
})
|
||||
|
||||
It("deletes Handshake packets", func() {
|
||||
for i := protocol.PacketNumber(0); i < 6; i++ {
|
||||
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.EncryptionHandshake})
|
||||
handler.SentPacket(p)
|
||||
}
|
||||
for i := protocol.PacketNumber(0); i < 10; i++ {
|
||||
p := ackElicitingPacket(&Packet{PacketNumber: i, EncryptionLevel: protocol.Encryption1RTT})
|
||||
handler.SentPacket(p)
|
||||
}
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(16)))
|
||||
handler.queuePacketForRetransmission(getPacket(1, protocol.EncryptionHandshake), handler.getPacketNumberSpace(protocol.EncryptionInitial))
|
||||
lostPacket := getPacket(3, protocol.Encryption1RTT)
|
||||
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
|
||||
handler.DropPackets(protocol.EncryptionHandshake)
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
|
||||
Expect(handler.handshakePackets.history.Len()).To(BeZero())
|
||||
packet := handler.DequeuePacketForRetransmission()
|
||||
Expect(packet).To(Equal(lostPacket))
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -66,6 +66,18 @@ func (mr *MockSentPacketHandlerMockRecorder) DequeueProbePacket() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DequeueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeueProbePacket))
|
||||
}
|
||||
|
||||
// DropPackets mocks base method
|
||||
func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DropPackets", arg0)
|
||||
}
|
||||
|
||||
// DropPackets indicates an expected call of DropPackets
|
||||
func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0)
|
||||
}
|
||||
|
||||
// GetAlarmTimeout mocks base method
|
||||
func (m *MockSentPacketHandler) GetAlarmTimeout() time.Time {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -203,18 +215,6 @@ func (mr *MockSentPacketHandlerMockRecorder) SentPacketsAsRetransmission(arg0, a
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacketsAsRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacketsAsRetransmission), arg0, arg1)
|
||||
}
|
||||
|
||||
// SetHandshakeComplete mocks base method
|
||||
func (m *MockSentPacketHandler) SetHandshakeComplete() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetHandshakeComplete")
|
||||
}
|
||||
|
||||
// SetHandshakeComplete indicates an expected call of SetHandshakeComplete
|
||||
func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete))
|
||||
}
|
||||
|
||||
// SetMaxAckDelay mocks base method
|
||||
func (m *MockSentPacketHandler) SetMaxAckDelay(arg0 time.Duration) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -485,7 +485,8 @@ func (s *session) handleHandshakeComplete() {
|
|||
// independent from the application protocol.
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
s.queueControlFrame(&wire.PingFrame{})
|
||||
s.sentPacketHandler.SetHandshakeComplete()
|
||||
s.sentPacketHandler.DropPackets(protocol.EncryptionInitial)
|
||||
s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -646,7 +647,8 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time
|
|||
if s.perspective == protocol.PerspectiveClient {
|
||||
if !s.receivedFirstForwardSecurePacket && packet.encryptionLevel == protocol.Encryption1RTT {
|
||||
s.receivedFirstForwardSecurePacket = true
|
||||
s.sentPacketHandler.SetHandshakeComplete()
|
||||
s.sentPacketHandler.DropPackets(protocol.EncryptionInitial)
|
||||
s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue