add a function to drop received packets of a certain encryption level

This commit is contained in:
Marten Seemann 2019-05-30 13:46:45 +08:00
parent 4d5b4fd790
commit 4834962cbd
4 changed files with 57 additions and 4 deletions

View file

@ -45,6 +45,7 @@ type SentPacketHandler interface {
type ReceivedPacketHandler interface {
ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
DropPackets(protocol.EncryptionLevel)
GetAlarmTimeout() time.Time
GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame

View file

@ -75,9 +75,25 @@ func (h *receivedPacketHandler) IgnoreBelow(pn protocol.PacketNumber) {
h.oneRTTPackets.IgnoreBelow(pn)
}
func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
h.handshakePackets = nil
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
initialAlarm := h.initialPackets.GetAlarmTimeout()
handshakeAlarm := h.handshakePackets.GetAlarmTimeout()
var initialAlarm, handshakeAlarm time.Time
if h.initialPackets != nil {
initialAlarm = h.initialPackets.GetAlarmTimeout()
}
if h.handshakePackets != nil {
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
}
oneRTTAlarm := h.oneRTTPackets.GetAlarmTimeout()
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
}
@ -86,9 +102,13 @@ func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel) *
var ack *wire.AckFrame
switch encLevel {
case protocol.EncryptionInitial:
ack = h.initialPackets.GetAckFrame()
if h.initialPackets != nil {
ack = h.initialPackets.GetAckFrame()
}
case protocol.EncryptionHandshake:
ack = h.handshakePackets.GetAckFrame()
if h.handshakePackets != nil {
ack = h.handshakePackets.GetAckFrame()
}
case protocol.Encryption1RTT:
return h.oneRTTPackets.GetAckFrame()
default:

View file

@ -47,4 +47,24 @@ var _ = Describe("Received Packet Handler", func() {
Expect(oneRTTAck.AckRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
Expect(oneRTTAck.DelayTime).To(BeNumerically("~", time.Second, 50*time.Millisecond))
})
It("drops Initial packets", func() {
sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
Expect(handler.GetAckFrame(protocol.EncryptionInitial)).ToNot(BeNil())
handler.DropPackets(protocol.EncryptionInitial)
Expect(handler.GetAckFrame(protocol.EncryptionInitial)).To(BeNil())
Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil())
})
It("drops Handshake packets", func() {
sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(2, protocol.Encryption1RTT, sendTime, true)).To(Succeed())
Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).ToNot(BeNil())
handler.DropPackets(protocol.EncryptionInitial)
Expect(handler.GetAckFrame(protocol.EncryptionHandshake)).To(BeNil())
Expect(handler.GetAckFrame(protocol.Encryption1RTT)).ToNot(BeNil())
})
})

View file

@ -36,6 +36,18 @@ func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecor
return m.recorder
}
// DropPackets mocks base method
func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DropPackets", arg0)
}
// DropPackets indicates an expected call of DropPackets
func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0)
}
// GetAckFrame mocks base method
func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame {
m.ctrl.T.Helper()