diff --git a/packet_packer.go b/packet_packer.go index 83d5fa86..26562e0b 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -234,9 +234,12 @@ func (p *packetPacker) packConnectionClose( reason string, ) (*coalescedPacket, error) { var sealers [4]sealer - var hdrs [4]*wire.ExtendedHeader + var hdrs [3]*wire.ExtendedHeader var payloads [4]*payload var size protocol.ByteCount + var connID protocol.ConnectionID + var oneRTTPacketNumber protocol.PacketNumber + var oneRTTPacketNumberLen protocol.PacketNumberLen var keyPhase protocol.KeyPhaseBit // only set for 1-RTT var numLongHdrPackets uint8 encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT} @@ -287,16 +290,16 @@ func (p *packetPacker) packConnectionClose( sealers[i] = sealer var hdr *wire.ExtendedHeader if encLevel == protocol.Encryption1RTT { - hdr = p.getShortHeader(keyPhase) + connID = p.getDestConnID() + oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, payload) } else { hdr = p.getLongHeader(encLevel) - } - hdrs[i] = hdr - payloads[i] = payload - size += p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) - if encLevel != protocol.Encryption1RTT { + hdrs[i] = hdr + size += p.longHeaderPacketLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) numLongHdrPackets++ } + payloads[i] = payload } buffer := getPacketBuffer() packet := &coalescedPacket{ @@ -312,7 +315,7 @@ func (p *packetPacker) packConnectionClose( paddingLen = p.initialPaddingLen(payloads[i].frames, size) } if encLevel == protocol.Encryption1RTT { - shortHdrPacket, err := p.appendShortHeaderPacket(buffer, hdrs[i].PacketNumber, hdrs[i].PacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false) + shortHdrPacket, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false) if err != nil { return nil, err } @@ -328,10 +331,13 @@ func (p *packetPacker) packConnectionClose( return packet, nil } -// packetLength calculates the length of the serialized packet. +// longHeaderPacketLength calculates the length of a serialized long header packet. // It takes into account that packets that have a tiny payload need to be padded, // such that len(payload) + packet number len >= 4 + AEAD overhead -func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount { +func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount { + if !hdr.IsLongHeader { + panic("wrong code path") + } var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(hdr.PacketNumberLen) if payload.length < 4-pnLen { @@ -340,6 +346,17 @@ func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) return hdr.GetLength(p.version) + payload.length + paddingLen } +// shortHeaderPacketLength calculates the length of a serialized short header packet. +// It takes into account that packets that have a tiny payload need to be padded, +// such that len(payload) + packet number len >= 4 + AEAD overhead +func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnLen protocol.PacketNumberLen, payload *payload) protocol.ByteCount { + var paddingLen protocol.ByteCount + if payload.length < 4-protocol.ByteCount(pnLen) { + paddingLen = 4 - protocol.ByteCount(pnLen) - payload.length + } + return wire.ShortHeaderLen(connID, pnLen) + payload.length + paddingLen +} + // size is the expected size of the packet, if no padding was applied. func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { // For the server, only ack-eliciting Initial packets need to be padded. @@ -360,9 +377,12 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro if p.perspective == protocol.PerspectiveClient { maxPacketSize = protocol.MinInitialPacketSize } - var initialHdr, handshakeHdr, appDataHdr *wire.ExtendedHeader - var initialPayload, handshakePayload, appDataPayload *payload - var numPackets int + var ( + initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader + initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload *payload + oneRTTPacketNumber protocol.PacketNumber + oneRTTPacketNumberLen protocol.PacketNumberLen + ) // Try packing an Initial packet. initialSealer, err := p.cryptoSetup.GetInitialSealer() if err != nil && err != handshake.ErrKeysDropped { @@ -372,8 +392,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro if initialSealer != nil { initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true) if initialPayload != nil { - size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) - numPackets++ + size += p.longHeaderPacketLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) } } @@ -388,50 +407,55 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro if handshakeSealer != nil { handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0) if handshakePayload != nil { - s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) + s := p.longHeaderPacketLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) size += s - numPackets++ } } } // Add a 0-RTT / 1-RTT packet. - var appDataSealer sealer + var zeroRTTSealer sealer + var oneRTTSealer handshake.ShortHeaderSealer + var connID protocol.ConnectionID var kp protocol.KeyPhaseBit - appDataEncLevel := protocol.Encryption1RTT if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { - var sErr error - var oneRTTSealer handshake.ShortHeaderSealer - oneRTTSealer, sErr = p.cryptoSetup.Get1RTTSealer() - appDataSealer = oneRTTSealer - if sErr != nil && p.perspective == protocol.PerspectiveClient { - appDataSealer, sErr = p.cryptoSetup.Get0RTTSealer() - appDataEncLevel = protocol.Encryption0RTT + var err error + oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err } - if appDataSealer != nil && sErr == nil { - //nolint:exhaustive // 0-RTT and 1-RTT are the only two application data encryption levels. - switch appDataEncLevel { - case protocol.Encryption0RTT: - appDataHdr, appDataPayload = p.maybeGetAppDataPacketFor0RTT(appDataSealer, maxPacketSize-size) - case protocol.Encryption1RTT: - kp = oneRTTSealer.KeyPhase() - appDataHdr, appDataPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, kp, maxPacketSize-size, onlyAck, size == 0) + if err == nil { // 1-RTT + kp = oneRTTSealer.KeyPhase() + connID = p.getDestConnID() + oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen) + oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0) + if oneRTTPayload != nil { + size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead()) } - if appDataHdr != nil && appDataPayload != nil { - size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead()) - numPackets++ + } else if p.perspective == protocol.PerspectiveClient { // 0-RTT + var err error + zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if zeroRTTSealer != nil { + zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size) + if zeroRTTPayload != nil { + size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload) + protocol.ByteCount(zeroRTTSealer.Overhead()) + } } } } - if numPackets == 0 { + if initialPayload == nil && handshakePayload == nil && zeroRTTPayload == nil && oneRTTPayload == nil { return nil, nil } buffer := getPacketBuffer() packet := &coalescedPacket{ buffer: buffer, - longHdrPackets: make([]*longHeaderPacket, 0, numPackets), + longHdrPackets: make([]*longHeaderPacket, 0, 3), } if initialPayload != nil { padding := p.initialPaddingLen(initialPayload.frames, size) @@ -448,20 +472,18 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro } packet.longHdrPackets = append(packet.longHdrPackets, cont) } - if appDataPayload != nil { - if appDataEncLevel == protocol.Encryption0RTT { - longHdrPacket, err := p.appendLongHeaderPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer) - if err != nil { - return nil, err - } - packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) - } else { - shortHdrPacket, err := p.appendShortHeaderPacket(buffer, appDataHdr.PacketNumber, appDataHdr.PacketNumberLen, kp, appDataPayload, 0, appDataSealer, false) - if err != nil { - return nil, err - } - packet.shortHdrPacket = shortHdrPacket + if zeroRTTPayload != nil { + longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer) + if err != nil { + return nil, err } + packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) + } else if oneRTTPayload != nil { + shortHdrPacket, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false) + if err != nil { + return nil, err + } + packet.shortHdrPacket = shortHdrPacket } return packet, nil } @@ -473,13 +495,16 @@ func (p *packetPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacke if err != nil { return shortHeaderPacket{}, nil, err } - kp := sealer.KeyPhase() - hdr, payload := p.maybeGetShortHeaderPacket(sealer, kp, p.maxPacketSize, onlyAck, true) + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + connID := p.getDestConnID() + hdrLen := wire.ShortHeaderLen(connID, pnLen) + payload := p.maybeGetShortHeaderPacket(sealer, hdrLen, p.maxPacketSize, onlyAck, true) if payload == nil { return shortHeaderPacket{}, nil, errNothingToPack } + kp := sealer.KeyPhase() buffer := getPacketBuffer() - packet, err := p.appendShortHeaderPacket(buffer, hdr.PacketNumber, hdr.PacketNumberLen, kp, payload, 0, sealer, false) + packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, sealer, false) if err != nil { return shortHeaderPacket{}, nil, err } @@ -564,11 +589,9 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize return hdr, payload } -func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, kp protocol.KeyPhaseBit, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { - hdr := p.getShortHeader(kp) - maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed) - return hdr, payload +func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload { + maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) + return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed) } func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload { @@ -666,10 +689,32 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc } func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*coalescedPacket, error) { + if encLevel == protocol.Encryption1RTT { + s, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, err + } + kp := s.KeyPhase() + connID := p.getDestConnID() + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + hdrLen := wire.ShortHeaderLen(connID, pnLen) + payload := p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true) + if payload == nil { + return nil, nil + } + buffer := getPacketBuffer() + packet := &coalescedPacket{buffer: buffer} + shortHdrPacket, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, s, false) + if err != nil { + return nil, err + } + packet.shortHdrPacket = shortHdrPacket + return packet, nil + } + var hdr *wire.ExtendedHeader var payload *payload - var sealer sealer - var kp protocol.KeyPhaseBit + var sealer handshake.LongHeaderSealer //nolint:exhaustive // Probe packets are never sent for 0-RTT. switch encLevel { case protocol.EncryptionInitial: @@ -686,36 +731,20 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( return nil, err } hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true) - case protocol.Encryption1RTT: - oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, err - } - kp = oneRTTSealer.KeyPhase() - sealer = oneRTTSealer - hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) - payload = p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), false, true) default: panic("unknown encryption level") } + if payload == nil { return nil, nil } - size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) + buffer := getPacketBuffer() + packet := &coalescedPacket{buffer: buffer} + size := p.longHeaderPacketLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) var padding protocol.ByteCount if encLevel == protocol.EncryptionInitial { padding = p.initialPaddingLen(payload.frames, size) } - buffer := getPacketBuffer() - packet := &coalescedPacket{buffer: buffer} - if encLevel == protocol.Encryption1RTT { - shortHdrPacket, err := p.appendShortHeaderPacket(buffer, hdr.PacketNumber, hdr.PacketNumberLen, kp, payload, padding, sealer, false) - if err != nil { - return nil, err - } - packet.shortHdrPacket = shortHdrPacket - return packet, nil - } longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, payload, padding, encLevel, sealer) if err != nil { @@ -731,29 +760,20 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B length: ping.Length(p.version), } buffer := getPacketBuffer() - sealer, err := p.cryptoSetup.Get1RTTSealer() + s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return shortHeaderPacket{}, nil, err } - hdr := p.getShortHeader(sealer.KeyPhase()) - padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead()) - packet, err := p.appendShortHeaderPacket(buffer, hdr.PacketNumber, hdr.PacketNumberLen, sealer.KeyPhase(), payload, padding, sealer, true) + connID := p.getDestConnID() + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + padding := size - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead()) + packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, s.KeyPhase(), payload, padding, s, true) if err != nil { return shortHeaderPacket{}, nil, err } return *packet, buffer, nil } -func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { - pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) - hdr := &wire.ExtendedHeader{} - hdr.PacketNumber = pn - hdr.PacketNumberLen = pnLen - hdr.DestConnectionID = p.getDestConnID() - hdr.KeyPhase = kp - return hdr -} - func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) hdr := &wire.ExtendedHeader{ @@ -823,6 +843,7 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire func (p *packetPacker) appendShortHeaderPacket( buffer *packetBuffer, + connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, @@ -840,7 +861,6 @@ func (p *packetPacker) appendShortHeaderPacket( raw := buffer.Data[len(buffer.Data):] buf := bytes.NewBuffer(buffer.Data) startLen := buf.Len() - connID := p.getDestConnID() if err := wire.WriteShortHeader(buf, connID, pn, pnLen, kp); err != nil { return nil, err } diff --git a/packet_packer_test.go b/packet_packer_test.go index b23b424f..87087a43 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -148,15 +148,6 @@ var _ = Describe("Packet packer", func() { Expect(h.SrcConnectionID).To(Equal(srcConnID)) Expect(h.DestConnectionID).To(Equal(destConnID)) }) - - It("gets a short header", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen4) - h := packer.getShortHeader(protocol.KeyPhaseOne) - Expect(h.IsLongHeader).To(BeFalse()) - Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) - Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) - Expect(h.KeyPhase).To(Equal(protocol.KeyPhaseOne)) - }) }) Context("encrypting packets", func() {