pass payload around by value in the packet packer (#3648)

This commit is contained in:
Marten Seemann 2023-01-17 22:53:05 -08:00 committed by GitHub
parent 4d9ab7b604
commit b77d8570df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -231,7 +231,7 @@ func (p *packetPacker) packConnectionClose(
) (*coalescedPacket, error) { ) (*coalescedPacket, error) {
var sealers [4]sealer var sealers [4]sealer
var hdrs [3]*wire.ExtendedHeader var hdrs [3]*wire.ExtendedHeader
var payloads [4]*payload var payloads [4]payload
var size protocol.ByteCount var size protocol.ByteCount
var connID protocol.ConnectionID var connID protocol.ConnectionID
var oneRTTPacketNumber protocol.PacketNumber var oneRTTPacketNumber protocol.PacketNumber
@ -255,7 +255,7 @@ func (p *packetPacker) packConnectionClose(
ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode) ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode)
ccf.ReasonPhrase = "" ccf.ReasonPhrase = ""
} }
payload := &payload{ pl := payload{
frames: []ackhandler.Frame{{Frame: ccf}}, frames: []ackhandler.Frame{{Frame: ccf}},
length: ccf.Length(p.version), length: ccf.Length(p.version),
} }
@ -288,14 +288,14 @@ func (p *packetPacker) packConnectionClose(
if encLevel == protocol.Encryption1RTT { if encLevel == protocol.Encryption1RTT {
connID = p.getDestConnID() connID = p.getDestConnID()
oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, payload) size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, pl)
} else { } else {
hdr = p.getLongHeader(encLevel) hdr = p.getLongHeader(encLevel)
hdrs[i] = hdr hdrs[i] = hdr
size += p.longHeaderPacketLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) size += p.longHeaderPacketLength(hdr, pl) + protocol.ByteCount(sealer.Overhead())
numLongHdrPackets++ numLongHdrPackets++
} }
payloads[i] = payload payloads[i] = pl
} }
buffer := getPacketBuffer() buffer := getPacketBuffer()
packet := &coalescedPacket{ packet := &coalescedPacket{
@ -336,24 +336,24 @@ func (p *packetPacker) packConnectionClose(
// longHeaderPacketLength calculates the length of a serialized long header packet. // longHeaderPacketLength calculates the length of a serialized long header packet.
// It takes into account that packets that have a tiny payload need to be padded, // It takes into account that packets that have a tiny payload need to be padded,
// such that len(payload) + packet number len >= 4 + AEAD overhead // such that len(payload) + packet number len >= 4 + AEAD overhead
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount { func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload) protocol.ByteCount {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(hdr.PacketNumberLen) pnLen := protocol.ByteCount(hdr.PacketNumberLen)
if payload.length < 4-pnLen { if pl.length < 4-pnLen {
paddingLen = 4 - pnLen - payload.length paddingLen = 4 - pnLen - pl.length
} }
return hdr.GetLength(p.version) + payload.length + paddingLen return hdr.GetLength(p.version) + pl.length + paddingLen
} }
// shortHeaderPacketLength calculates the length of a serialized short header packet. // shortHeaderPacketLength calculates the length of a serialized short header packet.
// It takes into account that packets that have a tiny payload need to be padded, // It takes into account that packets that have a tiny payload need to be padded,
// such that len(payload) + packet number len >= 4 + AEAD overhead // such that len(payload) + packet number len >= 4 + AEAD overhead
func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnLen protocol.PacketNumberLen, payload *payload) protocol.ByteCount { func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnLen protocol.PacketNumberLen, pl payload) protocol.ByteCount {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
if payload.length < 4-protocol.ByteCount(pnLen) { if pl.length < 4-protocol.ByteCount(pnLen) {
paddingLen = 4 - protocol.ByteCount(pnLen) - payload.length paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length
} }
return wire.ShortHeaderLen(connID, pnLen) + payload.length + paddingLen return wire.ShortHeaderLen(connID, pnLen) + pl.length + paddingLen
} }
// size is the expected size of the packet, if no padding was applied. // size is the expected size of the packet, if no padding was applied.
@ -378,7 +378,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
} }
var ( var (
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload *payload initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
oneRTTPacketNumber protocol.PacketNumber oneRTTPacketNumber protocol.PacketNumber
oneRTTPacketNumberLen protocol.PacketNumberLen oneRTTPacketNumberLen protocol.PacketNumberLen
) )
@ -390,7 +390,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
var size protocol.ByteCount var size protocol.ByteCount
if initialSealer != nil { if initialSealer != nil {
initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true) initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true)
if initialPayload != nil { if initialPayload.length > 0 {
size += p.longHeaderPacketLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) size += p.longHeaderPacketLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead())
} }
} }
@ -405,7 +405,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
} }
if handshakeSealer != nil { if handshakeSealer != nil {
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0) handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0)
if handshakePayload != nil { if handshakePayload.length > 0 {
s := p.longHeaderPacketLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) s := p.longHeaderPacketLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead())
size += s size += s
} }
@ -429,7 +429,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen) hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen)
oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0) oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0)
if oneRTTPayload != nil { if oneRTTPayload.length > 0 {
size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead()) size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead())
} }
} else if p.perspective == protocol.PerspectiveClient { // 0-RTT } else if p.perspective == protocol.PerspectiveClient { // 0-RTT
@ -440,14 +440,14 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
} }
if zeroRTTSealer != nil { if zeroRTTSealer != nil {
zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size) zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size)
if zeroRTTPayload != nil { if zeroRTTPayload.length > 0 {
size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload) + protocol.ByteCount(zeroRTTSealer.Overhead()) size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload) + protocol.ByteCount(zeroRTTSealer.Overhead())
} }
} }
} }
} }
if initialPayload == nil && handshakePayload == nil && zeroRTTPayload == nil && oneRTTPayload == nil { if initialPayload.length == 0 && handshakePayload.length == 0 && zeroRTTPayload.length == 0 && oneRTTPayload.length == 0 {
return nil, nil return nil, nil
} }
@ -456,7 +456,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
buffer: buffer, buffer: buffer,
longHdrPackets: make([]*longHeaderPacket, 0, 3), longHdrPackets: make([]*longHeaderPacket, 0, 3),
} }
if initialPayload != nil { if initialPayload.length > 0 {
padding := p.initialPaddingLen(initialPayload.frames, size) padding := p.initialPaddingLen(initialPayload.frames, size)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer) cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer)
if err != nil { if err != nil {
@ -464,20 +464,20 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
} }
packet.longHdrPackets = append(packet.longHdrPackets, cont) packet.longHdrPackets = append(packet.longHdrPackets, cont)
} }
if handshakePayload != nil { if handshakePayload.length > 0 {
cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer) cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packet.longHdrPackets = append(packet.longHdrPackets, cont) packet.longHdrPackets = append(packet.longHdrPackets, cont)
} }
if zeroRTTPayload != nil { if zeroRTTPayload.length > 0 {
longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer) longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
} else if oneRTTPayload != nil { } else if oneRTTPayload.length > 0 {
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false) ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -503,13 +503,13 @@ func (p *packetPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacke
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
connID := p.getDestConnID() connID := p.getDestConnID()
hdrLen := wire.ShortHeaderLen(connID, pnLen) hdrLen := wire.ShortHeaderLen(connID, pnLen)
payload := p.maybeGetShortHeaderPacket(sealer, hdrLen, p.maxPacketSize, onlyAck, true) pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, p.maxPacketSize, onlyAck, true)
if payload == nil { if pl.length == 0 {
return shortHeaderPacket{}, nil, errNothingToPack return shortHeaderPacket{}, nil, errNothingToPack
} }
kp := sealer.KeyPhase() kp := sealer.KeyPhase()
buffer := getPacketBuffer() buffer := getPacketBuffer()
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, sealer, false) ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, sealer, false)
if err != nil { if err != nil {
return shortHeaderPacket{}, nil, err return shortHeaderPacket{}, nil, err
} }
@ -522,15 +522,15 @@ func (p *packetPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacke
}, buffer, nil }, buffer, nil
} }
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, payload) {
if onlyAck { if onlyAck {
if ack := p.acks.GetAckFrame(encLevel, true); ack != nil { if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
var payload payload return p.getLongHeader(encLevel), payload{
payload.ack = ack ack: ack,
payload.length = ack.Length(p.version) length: ack.Length(p.version),
return p.getLongHeader(encLevel), &payload }
} }
return nil, nil return nil, payload{}
} }
var s cryptoStream var s cryptoStream
@ -552,14 +552,14 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
} }
if !hasData && !hasRetransmission && ack == nil { if !hasData && !hasRetransmission && ack == nil {
// nothing to send // nothing to send
return nil, nil return nil, payload{}
} }
var payload payload var pl payload
if ack != nil { if ack != nil {
payload.ack = ack pl.ack = ack
payload.length = ack.Length(p.version) pl.length = ack.Length(p.version)
maxPacketSize -= payload.length maxPacketSize -= pl.length
} }
hdr := p.getLongHeader(encLevel) hdr := p.getLongHeader(encLevel)
maxPacketSize -= hdr.GetLength(p.version) maxPacketSize -= hdr.GetLength(p.version)
@ -576,49 +576,48 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
if f == nil { if f == nil {
break break
} }
payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) pl.frames = append(pl.frames, ackhandler.Frame{Frame: f})
frameLen := f.Length(p.version) frameLen := f.Length(p.version)
payload.length += frameLen pl.length += frameLen
maxPacketSize -= frameLen maxPacketSize -= frameLen
} }
} else if s.HasData() { } else if s.HasData() {
cf := s.PopCryptoFrame(maxPacketSize) cf := s.PopCryptoFrame(maxPacketSize)
payload.frames = []ackhandler.Frame{{Frame: cf}} pl.frames = []ackhandler.Frame{{Frame: cf}}
payload.length += cf.Length(p.version) pl.length += cf.Length(p.version)
} }
return hdr, &payload return hdr, pl
} }
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount) (*wire.ExtendedHeader, *payload) { func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount) (*wire.ExtendedHeader, payload) {
if p.perspective != protocol.PerspectiveClient { if p.perspective != protocol.PerspectiveClient {
return nil, nil return nil, payload{}
} }
hdr := p.getLongHeader(protocol.Encryption0RTT) hdr := p.getLongHeader(protocol.Encryption0RTT)
maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead())
payload := p.maybeGetAppDataPacket(maxPayloadSize, false, false) return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false)
return hdr, payload
} }
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload { func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool) payload {
maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed) return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed)
} }
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload { func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool) payload {
payload := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed) pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed)
// check if we have anything to send // check if we have anything to send
if len(payload.frames) == 0 { if len(pl.frames) == 0 {
if payload.ack == nil { if pl.ack == nil {
return nil return payload{}
} }
// the packet only contains an ACK // the packet only contains an ACK
if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks {
ping := &wire.PingFrame{} ping := &wire.PingFrame{}
// don't retransmit the PING frame when it is lost // don't retransmit the PING frame when it is lost
payload.frames = append(payload.frames, ackhandler.Frame{Frame: ping, OnLost: func(wire.Frame) {}}) pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping, OnLost: func(wire.Frame) {}})
payload.length += ping.Length(p.version) pl.length += ping.Length(p.version)
p.numNonAckElicitingAcks = 0 p.numNonAckElicitingAcks = 0
} else { } else {
p.numNonAckElicitingAcks++ p.numNonAckElicitingAcks++
@ -626,21 +625,21 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
} else { } else {
p.numNonAckElicitingAcks = 0 p.numNonAckElicitingAcks = 0
} }
return payload return pl
} }
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload { func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool) payload {
if onlyAck { if onlyAck {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
payload := &payload{} return payload{
payload.ack = ack ack: ack,
payload.length += ack.Length(p.version) length: ack.Length(p.version),
return payload }
} }
return &payload{} return payload{}
} }
payload := &payload{frames: make([]ackhandler.Frame, 0, 1)} pl := payload{frames: make([]ackhandler.Frame, 0, 1)}
hasData := p.framer.HasData() hasData := p.framer.HasData()
hasRetransmission := p.retransmissionQueue.HasAppData() hasRetransmission := p.retransmissionQueue.HasAppData()
@ -648,8 +647,8 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
var hasAck bool var hasAck bool
if ackAllowed { if ackAllowed {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil {
payload.ack = ack pl.ack = ack
payload.length += ack.Length(p.version) pl.length += ack.Length(p.version)
hasAck = true hasAck = true
} }
} }
@ -657,25 +656,25 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if p.datagramQueue != nil { if p.datagramQueue != nil {
if f := p.datagramQueue.Peek(); f != nil { if f := p.datagramQueue.Peek(); f != nil {
size := f.Length(p.version) size := f.Length(p.version)
if size <= maxFrameSize-payload.length { if size <= maxFrameSize-pl.length {
payload.frames = append(payload.frames, ackhandler.Frame{ pl.frames = append(pl.frames, ackhandler.Frame{
Frame: f, Frame: f,
// set it to a no-op. Then we won't set the default callback, which would retransmit the frame. // set it to a no-op. Then we won't set the default callback, which would retransmit the frame.
OnLost: func(wire.Frame) {}, OnLost: func(wire.Frame) {},
}) })
payload.length += size pl.length += size
p.datagramQueue.Pop() p.datagramQueue.Pop()
} }
} }
} }
if hasAck && !hasData && !hasRetransmission { if hasAck && !hasData && !hasRetransmission {
return payload return pl
} }
if hasRetransmission { if hasRetransmission {
for { for {
remainingLen := maxFrameSize - payload.length remainingLen := maxFrameSize - pl.length
if remainingLen < protocol.MinStreamFrameSize { if remainingLen < protocol.MinStreamFrameSize {
break break
} }
@ -683,20 +682,20 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if f == nil { if f == nil {
break break
} }
payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) pl.frames = append(pl.frames, ackhandler.Frame{Frame: f})
payload.length += f.Length(p.version) pl.length += f.Length(p.version)
} }
} }
if hasData { if hasData {
var lengthAdded protocol.ByteCount var lengthAdded protocol.ByteCount
payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length)
payload.length += lengthAdded pl.length += lengthAdded
payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length) pl.frames, lengthAdded = p.framer.AppendStreamFrames(pl.frames, maxFrameSize-pl.length)
payload.length += lengthAdded pl.length += lengthAdded
} }
return payload return pl
} }
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*coalescedPacket, error) { func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*coalescedPacket, error) {
@ -709,13 +708,13 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
connID := p.getDestConnID() connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, pnLen) hdrLen := wire.ShortHeaderLen(connID, pnLen)
payload := p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true) pl := p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true)
if payload == nil { if pl.length == 0 {
return nil, nil return nil, nil
} }
buffer := getPacketBuffer() buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer} packet := &coalescedPacket{buffer: buffer}
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, s, false) ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, s, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -730,7 +729,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
} }
var hdr *wire.ExtendedHeader var hdr *wire.ExtendedHeader
var payload *payload var pl payload
var sealer handshake.LongHeaderSealer var sealer handshake.LongHeaderSealer
//nolint:exhaustive // Probe packets are never sent for 0-RTT. //nolint:exhaustive // Probe packets are never sent for 0-RTT.
switch encLevel { switch encLevel {
@ -740,30 +739,30 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
if err != nil { if err != nil {
return nil, err return nil, err
} }
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true) hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true)
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
var err error var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer() sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true) hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true)
default: default:
panic("unknown encryption level") panic("unknown encryption level")
} }
if payload == nil { if pl.length == 0 {
return nil, nil return nil, nil
} }
buffer := getPacketBuffer() buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer} packet := &coalescedPacket{buffer: buffer}
size := p.longHeaderPacketLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) size := p.longHeaderPacketLength(hdr, pl) + protocol.ByteCount(sealer.Overhead())
var padding protocol.ByteCount var padding protocol.ByteCount
if encLevel == protocol.EncryptionInitial { if encLevel == protocol.EncryptionInitial {
padding = p.initialPaddingLen(payload.frames, size) padding = p.initialPaddingLen(pl.frames, size)
} }
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, payload, padding, encLevel, sealer) longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -772,7 +771,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
} }
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, *packetBuffer, error) { func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, *packetBuffer, error) {
payload := &payload{ pl := payload{
frames: []ackhandler.Frame{ping}, frames: []ackhandler.Frame{ping},
length: ping.Length(p.version), length: ping.Length(p.version),
} }
@ -783,9 +782,9 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
} }
connID := p.getDestConnID() connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
padding := size - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead()) padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead())
kp := s.KeyPhase() kp := s.KeyPhase()
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, padding, s, true) ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, s, true)
if err != nil { if err != nil {
return shortHeaderPacket{}, nil, err return shortHeaderPacket{}, nil, err
} }
@ -821,14 +820,14 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex
return hdr return hdr
} }
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer) (*longHeaderPacket, error) { func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer) (*longHeaderPacket, error) {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen) pnLen := protocol.ByteCount(header.PacketNumberLen)
if payload.length < 4-pnLen { if pl.length < 4-pnLen {
paddingLen = 4 - pnLen - payload.length paddingLen = 4 - pnLen - pl.length
} }
paddingLen += padding paddingLen += padding
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen
startLen := len(buffer.Data) startLen := len(buffer.Data)
raw := buffer.Data[startLen:] raw := buffer.Data[startLen:]
@ -843,7 +842,7 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
} }
raw, err = p.appendPacketPayload(raw, payload, paddingLen) raw, err = p.appendPacketPayload(raw, pl, paddingLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -852,8 +851,8 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
return &longHeaderPacket{ return &longHeaderPacket{
header: header, header: header,
ack: payload.ack, ack: pl.ack,
frames: payload.frames, frames: pl.frames,
length: protocol.ByteCount(len(raw)), length: protocol.ByteCount(len(raw)),
}, nil }, nil
} }
@ -864,14 +863,14 @@ func (p *packetPacker) appendShortHeaderPacket(
pn protocol.PacketNumber, pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen, pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit, kp protocol.KeyPhaseBit,
payload *payload, pl payload,
padding protocol.ByteCount, padding protocol.ByteCount,
sealer sealer, sealer sealer,
isMTUProbePacket bool, isMTUProbePacket bool,
) (*ackhandler.Packet, *wire.AckFrame, error) { ) (*ackhandler.Packet, *wire.AckFrame, error) {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
if payload.length < 4-protocol.ByteCount(pnLen) { if pl.length < 4-protocol.ByteCount(pnLen) {
paddingLen = 4 - protocol.ByteCount(pnLen) - payload.length paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length
} }
paddingLen += padding paddingLen += padding
@ -887,7 +886,7 @@ func (p *packetPacker) appendShortHeaderPacket(
return nil, nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") return nil, nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
} }
raw, err = p.appendPacketPayload(raw, payload, paddingLen) raw, err = p.appendPacketPayload(raw, pl, paddingLen)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -901,33 +900,33 @@ func (p *packetPacker) appendShortHeaderPacket(
// create the ackhandler.Packet // create the ackhandler.Packet
largestAcked := protocol.InvalidPacketNumber largestAcked := protocol.InvalidPacketNumber
if payload.ack != nil { if pl.ack != nil {
largestAcked = payload.ack.LargestAcked() largestAcked = pl.ack.LargestAcked()
} }
for i := range payload.frames { for i := range pl.frames {
if payload.frames[i].OnLost != nil { if pl.frames[i].OnLost != nil {
continue continue
} }
payload.frames[i].OnLost = p.retransmissionQueue.AddAppData pl.frames[i].OnLost = p.retransmissionQueue.AddAppData
} }
ap := ackhandler.GetPacket() ap := ackhandler.GetPacket()
ap.PacketNumber = pn ap.PacketNumber = pn
ap.LargestAcked = largestAcked ap.LargestAcked = largestAcked
ap.Frames = payload.frames ap.Frames = pl.frames
ap.Length = protocol.ByteCount(len(raw)) ap.Length = protocol.ByteCount(len(raw))
ap.EncryptionLevel = protocol.Encryption1RTT ap.EncryptionLevel = protocol.Encryption1RTT
ap.SendTime = time.Now() ap.SendTime = time.Now()
ap.IsPathMTUProbePacket = isMTUProbePacket ap.IsPathMTUProbePacket = isMTUProbePacket
return ap, payload.ack, nil return ap, pl.ack, nil
} }
func (p *packetPacker) appendPacketPayload(raw []byte, payload *payload, paddingLen protocol.ByteCount) ([]byte, error) { func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount) ([]byte, error) {
payloadOffset := len(raw) payloadOffset := len(raw)
if payload.ack != nil { if pl.ack != nil {
var err error var err error
raw, err = payload.ack.Append(raw, p.version) raw, err = pl.ack.Append(raw, p.version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -935,7 +934,7 @@ func (p *packetPacker) appendPacketPayload(raw []byte, payload *payload, padding
if paddingLen > 0 { if paddingLen > 0 {
raw = append(raw, make([]byte, paddingLen)...) raw = append(raw, make([]byte, paddingLen)...)
} }
for _, frame := range payload.frames { for _, frame := range pl.frames {
var err error var err error
raw, err = frame.Append(raw, p.version) raw, err = frame.Append(raw, p.version)
if err != nil { if err != nil {
@ -943,8 +942,8 @@ func (p *packetPacker) appendPacketPayload(raw []byte, payload *payload, padding
} }
} }
if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != payload.length { if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != pl.length {
return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", pl.length, payloadSize)
} }
return raw, nil return raw, nil
} }