pack ACK-only packets for all encryption levels

This commit is contained in:
Marten Seemann 2019-06-30 19:46:37 +07:00
parent 5929a83210
commit 0ce749b5f1
2 changed files with 94 additions and 53 deletions

View file

@ -199,7 +199,26 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
}
func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
ack := p.acks.GetAckFrame(protocol.Encryption1RTT)
var encLevel protocol.EncryptionLevel
var ack *wire.AckFrame
if !p.handshakeConfirmed {
ack = p.acks.GetAckFrame(protocol.EncryptionInitial)
if ack != nil {
encLevel = protocol.EncryptionInitial
} else {
ack = p.acks.GetAckFrame(protocol.EncryptionHandshake)
if ack != nil {
encLevel = protocol.EncryptionHandshake
}
}
}
if ack == nil {
ack = p.acks.GetAckFrame(protocol.Encryption1RTT)
if ack == nil {
return nil, nil
}
encLevel = protocol.Encryption1RTT
}
if ack == nil {
return nil, nil
}
@ -207,13 +226,12 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
ack: ack,
length: ack.Length(p.version),
}
// TODO(#1534): only pack ACKs with the right encryption level
sealer, err := p.cryptoSetup.Get1RTTSealer()
sealer, hdr, err := p.getSealerAndHeader(encLevel)
if err != nil {
return nil, err
}
header := p.getShortHeader(sealer.KeyPhase())
return p.writeAndSealPacket(header, payload, protocol.Encryption1RTT, sealer)
return p.writeAndSealPacket(hdr, payload, encLevel, sealer)
}
// PackRetransmission packs a retransmission
@ -239,34 +257,9 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
var frames []wire.Frame
var length protocol.ByteCount
var sealer sealer
var hdr *wire.ExtendedHeader
switch packet.EncryptionLevel {
case protocol.EncryptionInitial:
var err error
sealer, err = p.cryptoSetup.GetInitialSealer()
if err != nil {
return nil, err
}
hdr = p.getLongHeader(protocol.EncryptionInitial)
case protocol.EncryptionHandshake:
var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, err
}
hdr = p.getLongHeader(protocol.EncryptionHandshake)
case protocol.Encryption1RTT:
var s handshake.ShortHeaderSealer
var err error
s, err = p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
sealer = s
hdr = p.getShortHeader(s.KeyPhase())
default:
return nil, fmt.Errorf("unexpected encryption level: %s", packet.EncryptionLevel)
sealer, hdr, err := p.getSealerAndHeader(packet.EncryptionLevel)
if err != nil {
return nil, err
}
hdrLen := hdr.GetLength(p.version)
@ -432,6 +425,34 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (paylo
return payload, nil
}
func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) {
switch encLevel {
case protocol.EncryptionInitial:
sealer, err := p.cryptoSetup.GetInitialSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getLongHeader(protocol.EncryptionInitial)
return sealer, hdr, nil
case protocol.EncryptionHandshake:
sealer, err := p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getLongHeader(protocol.EncryptionHandshake)
return sealer, hdr, nil
case protocol.Encryption1RTT:
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getShortHeader(sealer.KeyPhase())
return sealer, hdr, nil
default:
return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel)
}
}
func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdr := &wire.ExtendedHeader{}

View file

@ -212,6 +212,46 @@ var _ = Describe("Packet packer", func() {
}).AnyTimes()
})
Context("packing ACK packets", func() {
It("doesn't pack a packet if there's no ACK to send", func() {
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake)
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
p, err := packer.MaybePackAckPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
It("packs Handshake ACK-only packets", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack)
p, err := packer.MaybePackAckPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake))
Expect(p.ack).To(Equal(ack))
})
It("packs 1-RTT ACK-only packets", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake)
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack)
p, err := packer.MaybePackAckPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
Expect(p.ack).To(Equal(ack))
})
})
Context("packing normal packets", func() {
BeforeEach(func() {
sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes()
@ -367,26 +407,6 @@ var _ = Describe("Packet packer", func() {
Expect(r.Len()).To(BeZero())
})
Context("packing ACK packets", func() {
It("doesn't pack a packet if there's no ACK to send", func() {
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT)
p, err := packer.MaybePackAckPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
It("packs ACK packets", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}
ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack)
p, err := packer.MaybePackAckPacket()
Expect(err).NotTo(HaveOccurred())
Expect(p.ack).To(Equal(ack))
})
})
Context("making ACK packets ack-eliciting", func() {
sendMaxNumNonAckElicitingAcks := func() {
for i := 0; i < protocol.MaxNonAckElicitingAcks; i++ {