use separate append functions for long and short header packets

This commit is contained in:
Marten Seemann 2022-09-04 15:43:40 +03:00
parent 108f152181
commit 3e7bad5efc

View file

@ -298,7 +298,13 @@ func (p *packetPacker) packConnectionClose(
if encLevel == protocol.EncryptionInitial {
paddingLen = p.initialPaddingLen(payloads[i].frames, size)
}
c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false)
var c *packetContents
var err error
if encLevel == protocol.Encryption1RTT {
c, err = p.appendShortHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, sealers[i], false)
} else {
c, err = p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i])
}
if err != nil {
return nil, err
}
@ -412,21 +418,27 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
}
if initialPayload != nil {
padding := p.initialPaddingLen(initialPayload.frames, size)
cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
}
if handshakePayload != nil {
cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false)
cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
}
if appDataPayload != nil {
cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false)
var cont *packetContents
var err error
if appDataEncLevel == protocol.Encryption0RTT {
cont, err = p.appendLongHeaderPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer)
} else {
cont, err = p.appendShortHeaderPacket(buffer, appDataHdr, appDataPayload, 0, appDataSealer, false)
}
if err != nil {
return nil, err
}
@ -447,7 +459,7 @@ func (p *packetPacker) PackPacket(onlyAck bool) (*packedPacket, error) {
return nil, nil
}
buffer := getPacketBuffer()
cont, err := p.appendPacket(buffer, hdr, payload, 0, protocol.Encryption1RTT, sealer, false)
cont, err := p.appendShortHeaderPacket(buffer, hdr, payload, 0, sealer, false)
if err != nil {
return nil, err
}
@ -676,7 +688,13 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
padding = p.initialPaddingLen(payload.frames, size)
}
buffer := getPacketBuffer()
cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false)
var cont *packetContents
var err error
if encLevel == protocol.Encryption1RTT {
cont, err = p.appendShortHeaderPacket(buffer, hdr, payload, padding, sealer, false)
} else {
cont, err = p.appendLongHeaderPacket(buffer, hdr, payload, padding, encLevel, sealer)
}
if err != nil {
return nil, err
}
@ -698,7 +716,7 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
}
hdr := p.getShortHeader(sealer.KeyPhase())
padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead())
contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true)
contents, err := p.appendShortHeaderPacket(buffer, hdr, payload, padding, sealer, true)
if err != nil {
return nil, err
}
@ -743,7 +761,10 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex
return hdr
}
func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) {
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer) (*packetContents, error) {
if !header.IsLongHeader {
panic("shouldn't have called appendLongHeaderPacket")
}
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen)
if payload.length < 4-pnLen {
@ -768,6 +789,46 @@ func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedH
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err := p.appendPacketPayload(raw, payload, paddingLen)
if err != nil {
return nil, err
}
raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen)
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
return &packetContents{
header: header,
ack: payload.ack,
frames: payload.frames,
length: protocol.ByteCount(len(raw)),
}, nil
}
func (p *packetPacker) appendShortHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, sealer sealer, isMTUProbePacket bool) (*packetContents, error) {
if header.IsLongHeader {
panic("shouldn't have called appendShortHeaderPacket")
}
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen)
if payload.length < 4-pnLen {
paddingLen = 4 - pnLen - payload.length
}
paddingLen += padding
raw := buffer.Data[len(buffer.Data):]
buf := bytes.NewBuffer(buffer.Data)
startLen := buf.Len()
if err := header.Write(buf, p.version); err != nil {
return nil, err
}
raw = raw[:buf.Len()-startLen]
payloadOffset := protocol.ByteCount(len(raw))
pn := p.pnManager.PopPacketNumber(protocol.Encryption1RTT)
if pn != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err := p.appendPacketPayload(raw, payload, paddingLen)
if err != nil {
return nil, err