use PADDING frames to pad Initial packets

This commit is contained in:
Marten Seemann 2020-11-14 18:10:20 +07:00
parent 2c975bca54
commit d1a784d092
2 changed files with 390 additions and 168 deletions

View file

@ -7,11 +7,10 @@ import (
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/ackhandler"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
@ -203,9 +202,13 @@ func (p *packetPacker) PackConnectionClose(quicErr *qerr.QuicError) (*coalescedP
reason = quicErr.ErrorMessage
}
buffer := getPacketBuffer()
contents := make([]*packetContents, 0, 1)
for _, encLevel := range []protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT} {
var sealers [4]sealer
var hdrs [4]*wire.ExtendedHeader
var payloads [4]*payload
var size protocol.ByteCount
var numPackets uint8
encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT}
for i, encLevel := range encLevels {
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT {
continue
}
@ -224,7 +227,7 @@ func (p *packetPacker) PackConnectionClose(quicErr *qerr.QuicError) (*coalescedP
FrameType: quicErrToSend.FrameType,
ReasonPhrase: reasonPhrase,
}
payload := payload{
payload := &payload{
frames: []ackhandler.Frame{{Frame: ccf}},
length: ccf.Length(p.version),
}
@ -253,23 +256,49 @@ func (p *packetPacker) PackConnectionClose(quicErr *qerr.QuicError) (*coalescedP
if err != nil {
return nil, err
}
sealers[i] = sealer
var hdr *wire.ExtendedHeader
if encLevel == protocol.Encryption1RTT {
hdr = p.getShortHeader(keyPhase)
} else {
hdr = p.getLongHeader(encLevel)
}
c, err := p.appendPacket(buffer, hdr, payload, encLevel, sealer)
hdrs[i] = hdr
payloads[i] = payload
size += p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead())
numPackets++
}
contents := make([]*packetContents, 0, numPackets)
buffer := getPacketBuffer()
for i, encLevel := range encLevels {
if sealers[i] == nil {
continue
}
var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
paddingLen = p.paddingLen(payloads[i].frames, size)
}
c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i])
if err != nil {
return nil, err
}
contents = append(contents, c)
}
p.maybePadPacket(contents[0], buffer)
return &coalescedPacket{buffer: buffer, packets: contents}, nil
}
// packetLength calculates the length of the serialized packet.
// It takes into account that packets that have a tiny payload need to be padded,
// such that len(payload) + packet number len >= 4 + AEAD overhead
func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount {
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(hdr.PacketNumberLen)
if payload.length < 4-pnLen {
paddingLen = 4 - pnLen - payload.length
}
return hdr.GetLength(p.version) + payload.length + paddingLen
}
func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) {
var encLevel protocol.EncryptionLevel
var ack *wire.AckFrame
@ -294,7 +323,7 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke
if ack == nil {
return nil, nil
}
payload := payload{
payload := &payload{
ack: ack,
length: ack.Length(p.version),
}
@ -303,93 +332,115 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke
if err != nil {
return nil, err
}
packet, err := p.writeSinglePacket(hdr, payload, encLevel, sealer)
if err != nil {
return nil, err
}
p.maybePadPacket(packet.packetContents, packet.buffer)
return packet, nil
return p.writeSinglePacket(hdr, payload, encLevel, sealer)
}
func (p *packetPacker) maybePadPacket(firstPacket *packetContents, buffer *packetBuffer) {
// Only Initial packets need to be padded.
if firstPacket.header.Type != protocol.PacketTypeInitial {
return
}
// only works for Initial packets
// The size is the expected size of the packet, if no padding was applied.
func (p *packetPacker) paddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount {
// For the server, only ack-eliciting Initial packets need to be padded.
if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(firstPacket.frames) {
return
if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) {
return 0
}
if dataLen := protocol.ByteCount(len(buffer.Data)); dataLen < p.maxPacketSize {
buffer.Data = buffer.Data[:p.maxPacketSize]
for n := dataLen; n < p.maxPacketSize; n++ {
buffer.Data[n] = 0
}
if size >= p.maxPacketSize {
return 0
}
return p.maxPacketSize - size
}
// PackCoalescedPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed.
func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) {
buffer := getPacketBuffer()
packet, err := p.packCoalescedPacket(buffer)
if err != nil {
return nil, err
}
if packet == nil || len(packet.packets) == 0 { // nothing to send
buffer.Release()
return nil, nil
}
p.maybePadPacket(packet.packets[0], buffer)
return packet, nil
}
func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPacket, error) {
maxPacketSize := p.maxPacketSize
if p.perspective == protocol.PerspectiveClient {
maxPacketSize = protocol.MinInitialPacketSize
}
packet := &coalescedPacket{
buffer: buffer,
packets: make([]*packetContents, 0, 3),
}
var initialHdr, handshakeHdr, appDataHdr *wire.ExtendedHeader
var initialPayload, handshakePayload, appDataPayload *payload
var numPackets int
// Try packing an Initial packet.
contents, err := p.maybeAppendCryptoPacket(buffer, maxPacketSize, protocol.EncryptionInitial)
initialSealer, err := p.cryptoSetup.GetInitialSealer()
if err != nil && err != handshake.ErrKeysDropped {
return nil, err
}
if contents != nil {
packet.packets = append(packet.packets, contents)
}
if buffer.Len() >= maxPacketSize-protocol.MinCoalescedPacketSize {
return packet, nil
var size protocol.ByteCount
if initialSealer != nil {
initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial)
if initialPayload != nil {
size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead())
numPackets++
}
}
// Add a Handshake packet.
contents, err = p.maybeAppendCryptoPacket(buffer, maxPacketSize, protocol.EncryptionHandshake)
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if contents != nil {
packet.packets = append(packet.packets, contents)
}
if buffer.Len() >= maxPacketSize-protocol.MinCoalescedPacketSize {
return packet, nil
var handshakeSealer sealer
if size < maxPacketSize-protocol.MinCoalescedPacketSize {
var err error
handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if handshakeSealer != nil {
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake)
if handshakePayload != nil {
s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead())
size += s
numPackets++
}
}
}
// Add a 0-RTT / 1-RTT packet.
contents, err = p.maybeAppendAppDataPacket(buffer, maxPacketSize)
if err == handshake.ErrKeysNotYetAvailable {
return packet, nil
var appDataSealer sealer
appDataEncLevel := protocol.Encryption1RTT
if size < maxPacketSize-protocol.MinCoalescedPacketSize {
var err error
appDataSealer, appDataHdr, appDataPayload = p.maybeGetAppDataPacket(maxPacketSize-size, size)
if err != nil {
return nil, err
}
if appDataHdr != nil {
if appDataHdr.IsLongHeader {
appDataEncLevel = protocol.Encryption0RTT
}
if appDataPayload != nil {
size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead())
numPackets++
}
}
}
if err != nil {
return nil, err
if numPackets == 0 {
return nil, nil
}
if contents != nil {
packet.packets = append(packet.packets, contents)
buffer := getPacketBuffer()
packet := &coalescedPacket{
buffer: buffer,
packets: make([]*packetContents, 0, numPackets),
}
if initialPayload != nil {
padding := p.paddingLen(initialPayload.frames, size)
cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
}
if handshakePayload != nil {
cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
}
if appDataPayload != nil {
cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
}
return packet, nil
}
@ -397,20 +448,26 @@ func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPack
// PackPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) PackPacket() (*packedPacket, error) {
sealer, hdr, payload := p.maybeGetAppDataPacket(p.maxPacketSize, 0)
if payload == nil {
return nil, nil
}
buffer := getPacketBuffer()
contents, err := p.maybeAppendAppDataPacket(buffer, p.maxPacketSize)
if err != nil || contents == nil {
buffer.Release()
encLevel := protocol.Encryption1RTT
if hdr.IsLongHeader {
encLevel = protocol.Encryption0RTT
}
cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
buffer: buffer,
packetContents: contents,
packetContents: cont,
}, nil
}
func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*packetContents, error) {
var sealer sealer
func (p *packetPacker) maybeGetCryptoPacket(maxSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) {
var s cryptoStream
var hasRetransmission bool
//nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
@ -418,24 +475,14 @@ func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, maxPacketSi
case protocol.EncryptionInitial:
s = p.initialStream
hasRetransmission = p.retransmissionQueue.HasInitialData()
var err error
sealer, err = p.cryptoSetup.GetInitialSealer()
if err != nil {
return nil, err
}
case protocol.EncryptionHandshake:
s = p.handshakeStream
hasRetransmission = p.retransmissionQueue.HasHandshakeData()
var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, err
}
}
hasData := s.HasData()
var ack *wire.AckFrame
if encLevel != protocol.EncryptionHandshake || buffer.Len() == 0 {
if encLevel == protocol.EncryptionInitial || currentSize == 0 {
ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData)
}
if !hasData && !hasRetransmission && ack == nil {
@ -443,25 +490,23 @@ func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, maxPacketSi
return nil, nil
}
remainingLen := maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead())
var payload payload
if ack != nil {
payload.ack = ack
payload.length = ack.Length(p.version)
remainingLen -= payload.length
maxSize -= payload.length
}
hdr := p.getLongHeader(encLevel)
remainingLen -= hdr.GetLength(p.version)
maxSize -= hdr.GetLength(p.version)
if hasRetransmission {
for {
var f wire.Frame
//nolint:exhaustive // 0-RTT packets can't contain any retransmission.s
switch encLevel {
case protocol.EncryptionInitial:
f = p.retransmissionQueue.GetInitialFrame(remainingLen)
f = p.retransmissionQueue.GetInitialFrame(maxSize)
case protocol.EncryptionHandshake:
f = p.retransmissionQueue.GetHandshakeFrame(remainingLen)
f = p.retransmissionQueue.GetHandshakeFrame(maxSize)
}
if f == nil {
break
@ -469,45 +514,49 @@ func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, maxPacketSi
payload.frames = append(payload.frames, ackhandler.Frame{Frame: f})
frameLen := f.Length(p.version)
payload.length += frameLen
remainingLen -= frameLen
maxSize -= frameLen
}
} else if s.HasData() {
cf := s.PopCryptoFrame(remainingLen)
cf := s.PopCryptoFrame(maxSize)
payload.frames = []ackhandler.Frame{{Frame: cf}}
payload.length += cf.Length(p.version)
}
return p.appendPacket(buffer, hdr, payload, encLevel, sealer)
return hdr, &payload
}
func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer, maxPacketSize protocol.ByteCount) (*packetContents, error) {
func (p *packetPacker) maybeGetAppDataPacket(maxPacketSize, currentSize protocol.ByteCount) (sealer, *wire.ExtendedHeader, *payload) {
var sealer sealer
var header *wire.ExtendedHeader
var encLevel protocol.EncryptionLevel
var hdr *wire.ExtendedHeader
oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer()
if err == nil {
encLevel = protocol.Encryption1RTT
sealer = oneRTTSealer
header = p.getShortHeader(oneRTTSealer.KeyPhase())
hdr = p.getShortHeader(oneRTTSealer.KeyPhase())
} else {
// 1-RTT sealer not yet available
if p.perspective != protocol.PerspectiveClient {
return nil, nil
return nil, nil, nil
}
sealer, err = p.cryptoSetup.Get0RTTSealer()
if sealer == nil || err != nil {
return nil, nil
return nil, nil, nil
}
encLevel = protocol.Encryption0RTT
header = p.getLongHeader(protocol.Encryption0RTT)
hdr = p.getLongHeader(protocol.Encryption0RTT)
}
headerLen := header.GetLength(p.version)
maxSize := maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) - headerLen
payload := p.composeNextPacket(maxSize, encLevel == protocol.Encryption1RTT && buffer.Len() == 0)
maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead())
payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, currentSize, encLevel)
return sealer, hdr, payload
}
func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) *payload {
payload := p.composeNextPacket(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0)
// check if we have anything to send
if len(payload.frames) == 0 && payload.ack == nil {
return nil, nil
return nil
}
if len(payload.frames) == 0 { // the packet only contains an ACK
if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks {
@ -521,12 +570,11 @@ func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer, maxPacketS
} else {
p.numNonAckElicitingAcks = 0
}
return p.appendPacket(buffer, header, payload, encLevel, sealer)
return payload
}
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) payload {
var payload payload
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload {
payload := &payload{}
var ack *wire.AckFrame
hasData := p.framer.HasData()
hasRetransmission := p.retransmissionQueue.HasAppData()
@ -569,27 +617,54 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAll
}
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*packedPacket, error) {
var contents *packetContents
var err error
buffer := getPacketBuffer()
var hdr *wire.ExtendedHeader
var payload *payload
var sealer sealer
//nolint:exhaustive Probe packets are never sent for 0-RTT.
switch encLevel {
case protocol.EncryptionInitial:
contents, err = p.maybeAppendCryptoPacket(buffer, p.maxPacketSize, protocol.EncryptionInitial)
var err error
sealer, err = p.cryptoSetup.GetInitialSealer()
if err != nil {
return nil, err
}
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial)
case protocol.EncryptionHandshake:
contents, err = p.maybeAppendCryptoPacket(buffer, p.maxPacketSize, protocol.EncryptionHandshake)
var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, err
}
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake)
case protocol.Encryption1RTT:
contents, err = p.maybeAppendAppDataPacket(buffer, p.maxPacketSize)
oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
sealer = oneRTTSealer
payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.Encryption1RTT)
if payload != nil {
hdr = p.getShortHeader(oneRTTSealer.KeyPhase())
}
default:
panic("unknown encryption level")
}
if err != nil || contents == nil {
if hdr == nil {
return nil, nil
}
size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead())
var padding protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
padding = p.paddingLen(payload.frames, size)
}
buffer := getPacketBuffer()
cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer)
if err != nil {
return nil, err
}
p.maybePadPacket(contents, buffer)
return &packedPacket{
buffer: buffer,
packetContents: contents,
packetContents: cont,
}, nil
}
@ -640,19 +715,15 @@ func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHea
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
hdr := &wire.ExtendedHeader{}
hdr := &wire.ExtendedHeader{
PacketNumber: pn,
PacketNumberLen: pnLen,
}
hdr.IsLongHeader = true
hdr.Version = p.version
hdr.SrcConnectionID = p.srcConnID
hdr.DestConnectionID = p.getDestConnID()
// Set the length to the maximum packet size.
// Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
hdr.Length = p.maxPacketSize
hdr.PacketNumber = pn
hdr.PacketNumberLen = pnLen
//nolint:exhaustive // 1-RTT packets are not long header packets.
switch encLevel {
case protocol.EncryptionInitial:
@ -663,19 +734,22 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex
case protocol.Encryption0RTT:
hdr.Type = protocol.PacketType0RTT
}
return hdr
}
// writeSinglePacket packs a single packet.
func (p *packetPacker) writeSinglePacket(
header *wire.ExtendedHeader,
payload payload,
hdr *wire.ExtendedHeader,
payload *payload,
encLevel protocol.EncryptionLevel,
sealer sealer,
) (*packedPacket, error) {
buffer := getPacketBuffer()
contents, err := p.appendPacket(buffer, header, payload, encLevel, sealer)
var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
paddingLen = p.paddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead()))
}
contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer)
if err != nil {
return nil, err
}
@ -688,7 +762,8 @@ func (p *packetPacker) writeSinglePacket(
func (p *packetPacker) appendPacket(
buffer *packetBuffer,
header *wire.ExtendedHeader,
payload payload,
payload *payload,
padding protocol.ByteCount, // add padding such that the packet has this length. 0 for no padding.
encLevel protocol.EncryptionLevel,
sealer sealer,
) (*packetContents, error) {
@ -697,6 +772,7 @@ func (p *packetPacker) appendPacket(
if payload.length < 4-pnLen {
paddingLen = 4 - pnLen - payload.length
}
paddingLen += padding
if header.IsLongHeader {
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen
}
@ -714,7 +790,7 @@ func (p *packetPacker) appendPacket(
}
}
if paddingLen > 0 {
buf.Write(bytes.Repeat([]byte{0}, int(paddingLen)))
buf.Write(make([]byte, paddingLen))
}
for _, frame := range payload.frames {
if err := frame.Write(buf, p.version); err != nil {