simplify packing of long header ACK-only packets

This commit is contained in:
Marten Seemann 2022-09-04 11:27:04 +03:00
parent 2873125c15
commit 7bc2ba6b81

View file

@ -322,23 +322,26 @@ func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload)
func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) {
var pay *payload var pay *payload
var encLevel protocol.EncryptionLevel var sealer sealer
var hdr *wire.ExtendedHeader
encLevel := protocol.EncryptionInitial
if !handshakeConfirmed { if !handshakeConfirmed {
var ack *wire.AckFrame hdr, pay = p.maybeGetCryptoPacket(p.maxPacketSize, protocol.EncryptionInitial, true, true)
ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true) if pay != nil {
if ack != nil { var err error
encLevel = protocol.EncryptionInitial sealer, err = p.cryptoSetup.GetInitialSealer()
} else { if err != nil {
ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true) return nil, err
if ack != nil {
encLevel = protocol.EncryptionHandshake
} }
} } else {
encLevel = protocol.EncryptionHandshake
if ack != nil { hdr, pay = p.maybeGetCryptoPacket(p.maxPacketSize, protocol.EncryptionHandshake, true, true)
pay = &payload{ if pay != nil {
ack: ack, var err error
length: ack.Length(p.version), sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, err
}
} }
} }
} }
@ -348,12 +351,14 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke
return nil, nil return nil, nil
} }
encLevel = protocol.Encryption1RTT encLevel = protocol.Encryption1RTT
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
hdr = p.getShortHeader(s.KeyPhase())
sealer = s
} }
sealer, hdr, err := p.getSealerAndHeader(encLevel)
if err != nil {
return nil, err
}
return p.writeSinglePacket(hdr, pay, encLevel, sealer) return p.writeSinglePacket(hdr, pay, encLevel, sealer)
} }
@ -387,7 +392,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) {
} }
var size protocol.ByteCount var size protocol.ByteCount
if initialSealer != nil { if initialSealer != nil {
initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial) initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, false, size == 0)
if initialPayload != nil { if initialPayload != nil {
size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead())
numPackets++ numPackets++
@ -403,7 +408,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) {
return nil, err return nil, err
} }
if handshakeSealer != nil { if handshakeSealer != nil {
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake) handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, false, size == 0)
if handshakePayload != nil { if handshakePayload != nil {
s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead())
size += s size += s
@ -495,7 +500,17 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
}, nil }, nil
} }
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) {
if onlyAck {
if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
var payload payload
payload.ack = ack
payload.length = ack.Length(p.version)
return p.getLongHeader(encLevel), &payload
}
return nil, nil
}
var s cryptoStream var s cryptoStream
var hasRetransmission bool var hasRetransmission bool
//nolint:exhaustive // Initial and Handshake are the only two encryption levels here. //nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
@ -510,7 +525,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.
hasData := s.HasData() hasData := s.HasData()
var ack *wire.AckFrame var ack *wire.AckFrame
if encLevel == protocol.EncryptionInitial || currentSize == 0 { if ackAllowed {
ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData)
} }
if !hasData && !hasRetransmission && ack == nil { if !hasData && !hasRetransmission && ack == nil {
@ -677,14 +692,14 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
if err != nil { if err != nil {
return nil, err return nil, err
} }
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial) hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true)
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
var err error var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer() sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake) hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true)
case protocol.Encryption1RTT: case protocol.Encryption1RTT:
oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil { if err != nil {
@ -738,41 +753,6 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
}, nil }, nil
} }
func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) {
switch encLevel {
case protocol.EncryptionInitial:
sealer, err := p.cryptoSetup.GetInitialSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getLongHeader(protocol.EncryptionInitial)
return sealer, hdr, nil
case protocol.Encryption0RTT:
sealer, err := p.cryptoSetup.Get0RTTSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getLongHeader(protocol.Encryption0RTT)
return sealer, hdr, nil
case protocol.EncryptionHandshake:
sealer, err := p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getLongHeader(protocol.EncryptionHandshake)
return sealer, hdr, nil
case protocol.Encryption1RTT:
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getShortHeader(sealer.KeyPhase())
return sealer, hdr, nil
default:
return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel)
}
}
func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdr := &wire.ExtendedHeader{} hdr := &wire.ExtendedHeader{}