diff --git a/packet_packer.go b/packet_packer.go index 87bf838a..fe5efe37 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -298,7 +298,13 @@ func (p *packetPacker) packConnectionClose( if encLevel == protocol.EncryptionInitial { paddingLen = p.initialPaddingLen(payloads[i].frames, size) } - c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false) + var c *packetContents + var err error + if encLevel == protocol.Encryption1RTT { + c, err = p.appendShortHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, sealers[i], false) + } else { + c, err = p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i]) + } if err != nil { return nil, err } @@ -412,21 +418,27 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro } if initialPayload != nil { padding := p.initialPaddingLen(initialPayload.frames, size) - cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false) + cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer) if err != nil { return nil, err } packet.packets = append(packet.packets, cont) } if handshakePayload != nil { - cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false) + cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer) if err != nil { return nil, err } packet.packets = append(packet.packets, cont) } if appDataPayload != nil { - cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false) + var cont *packetContents + var err error + if appDataEncLevel == protocol.Encryption0RTT { + cont, err = p.appendLongHeaderPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer) + } else { + cont, err = p.appendShortHeaderPacket(buffer, appDataHdr, appDataPayload, 0, appDataSealer, false) + } if err != nil { return nil, err } @@ -447,7 +459,7 @@ func (p *packetPacker) PackPacket(onlyAck bool) (*packedPacket, error) { return nil, nil } buffer := getPacketBuffer() - cont, err := p.appendPacket(buffer, hdr, payload, 0, protocol.Encryption1RTT, sealer, false) + cont, err := p.appendShortHeaderPacket(buffer, hdr, payload, 0, sealer, false) if err != nil { return nil, err } @@ -676,7 +688,13 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( padding = p.initialPaddingLen(payload.frames, size) } buffer := getPacketBuffer() - cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false) + var cont *packetContents + var err error + if encLevel == protocol.Encryption1RTT { + cont, err = p.appendShortHeaderPacket(buffer, hdr, payload, padding, sealer, false) + } else { + cont, err = p.appendLongHeaderPacket(buffer, hdr, payload, padding, encLevel, sealer) + } if err != nil { return nil, err } @@ -698,7 +716,7 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B } hdr := p.getShortHeader(sealer.KeyPhase()) padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead()) - contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true) + contents, err := p.appendShortHeaderPacket(buffer, hdr, payload, padding, sealer, true) if err != nil { return nil, err } @@ -743,7 +761,10 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex return hdr } -func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) { +func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer) (*packetContents, error) { + if !header.IsLongHeader { + panic("shouldn't have called appendLongHeaderPacket") + } var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) if payload.length < 4-pnLen { @@ -768,6 +789,46 @@ func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedH return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } + raw, err := p.appendPacketPayload(raw, payload, paddingLen) + if err != nil { + return nil, err + } + raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen) + buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] + + return &packetContents{ + header: header, + ack: payload.ack, + frames: payload.frames, + length: protocol.ByteCount(len(raw)), + }, nil +} + +func (p *packetPacker) appendShortHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, sealer sealer, isMTUProbePacket bool) (*packetContents, error) { + if header.IsLongHeader { + panic("shouldn't have called appendShortHeaderPacket") + } + var paddingLen protocol.ByteCount + pnLen := protocol.ByteCount(header.PacketNumberLen) + if payload.length < 4-pnLen { + paddingLen = 4 - pnLen - payload.length + } + paddingLen += padding + + raw := buffer.Data[len(buffer.Data):] + buf := bytes.NewBuffer(buffer.Data) + startLen := buf.Len() + if err := header.Write(buf, p.version); err != nil { + return nil, err + } + raw = raw[:buf.Len()-startLen] + payloadOffset := protocol.ByteCount(len(raw)) + + pn := p.pnManager.PopPacketNumber(protocol.Encryption1RTT) + if pn != header.PacketNumber { + return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + } + raw, err := p.appendPacketPayload(raw, payload, paddingLen) if err != nil { return nil, err