diff --git a/packet_packer.go b/packet_packer.go index 419d90c1..e86d6b8a 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -199,7 +199,26 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac } func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { - ack := p.acks.GetAckFrame(protocol.Encryption1RTT) + var encLevel protocol.EncryptionLevel + var ack *wire.AckFrame + if !p.handshakeConfirmed { + ack = p.acks.GetAckFrame(protocol.EncryptionInitial) + if ack != nil { + encLevel = protocol.EncryptionInitial + } else { + ack = p.acks.GetAckFrame(protocol.EncryptionHandshake) + if ack != nil { + encLevel = protocol.EncryptionHandshake + } + } + } + if ack == nil { + ack = p.acks.GetAckFrame(protocol.Encryption1RTT) + if ack == nil { + return nil, nil + } + encLevel = protocol.Encryption1RTT + } if ack == nil { return nil, nil } @@ -207,13 +226,12 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { ack: ack, length: ack.Length(p.version), } - // TODO(#1534): only pack ACKs with the right encryption level - sealer, err := p.cryptoSetup.Get1RTTSealer() + + sealer, hdr, err := p.getSealerAndHeader(encLevel) if err != nil { return nil, err } - header := p.getShortHeader(sealer.KeyPhase()) - return p.writeAndSealPacket(header, payload, protocol.Encryption1RTT, sealer) + return p.writeAndSealPacket(hdr, payload, encLevel, sealer) } // PackRetransmission packs a retransmission @@ -239,34 +257,9 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP var frames []wire.Frame var length protocol.ByteCount - var sealer sealer - var hdr *wire.ExtendedHeader - switch packet.EncryptionLevel { - case protocol.EncryptionInitial: - var err error - sealer, err = p.cryptoSetup.GetInitialSealer() - if err != nil { - return nil, err - } - hdr = p.getLongHeader(protocol.EncryptionInitial) - case protocol.EncryptionHandshake: - var err error - sealer, err = p.cryptoSetup.GetHandshakeSealer() - if err != nil { - return nil, err - } - hdr = p.getLongHeader(protocol.EncryptionHandshake) - case protocol.Encryption1RTT: - var s handshake.ShortHeaderSealer - var err error - s, err = p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, err - } - sealer = s - hdr = p.getShortHeader(s.KeyPhase()) - default: - return nil, fmt.Errorf("unexpected encryption level: %s", packet.EncryptionLevel) + sealer, hdr, err := p.getSealerAndHeader(packet.EncryptionLevel) + if err != nil { + return nil, err } hdrLen := hdr.GetLength(p.version) @@ -432,6 +425,34 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (paylo return payload, nil } +func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { + switch encLevel { + case protocol.EncryptionInitial: + sealer, err := p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getLongHeader(protocol.EncryptionInitial) + return sealer, hdr, nil + case protocol.EncryptionHandshake: + sealer, err := p.cryptoSetup.GetHandshakeSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getLongHeader(protocol.EncryptionHandshake) + return sealer, hdr, nil + case protocol.Encryption1RTT: + sealer, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getShortHeader(sealer.KeyPhase()) + return sealer, hdr, nil + default: + return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel) + } +} + func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdr := &wire.ExtendedHeader{} diff --git a/packet_packer_test.go b/packet_packer_test.go index 44605369..e016870a 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -212,6 +212,46 @@ var _ = Describe("Packet packer", func() { }).AnyTimes() }) + Context("packing ACK packets", func() { + It("doesn't pack a packet if there's no ACK to send", func() { + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) + p, err := packer.MaybePackAckPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + }) + + It("packs Handshake ACK-only packets", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake).Return(ack) + p, err := packer.MaybePackAckPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.ack).To(Equal(ack)) + }) + + It("packs 1-RTT ACK-only packets", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) + p, err := packer.MaybePackAckPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(p.ack).To(Equal(ack)) + }) + }) + Context("packing normal packets", func() { BeforeEach(func() { sealingManager.EXPECT().GetInitialSealer().Return(nil, nil).AnyTimes() @@ -367,26 +407,6 @@ var _ = Describe("Packet packer", func() { Expect(r.Len()).To(BeZero()) }) - Context("packing ACK packets", func() { - It("doesn't pack a packet if there's no ACK to send", func() { - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT) - p, err := packer.MaybePackAckPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) - }) - - It("packs ACK packets", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT).Return(ack) - p, err := packer.MaybePackAckPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p.ack).To(Equal(ack)) - }) - }) - Context("making ACK packets ack-eliciting", func() { sendMaxNumNonAckElicitingAcks := func() { for i := 0; i < protocol.MaxNonAckElicitingAcks; i++ {