refactor sealing of packets

This commit is contained in:
Marten Seemann 2020-02-10 19:32:52 +07:00
parent a4b4d52063
commit 077504f557

View file

@ -215,7 +215,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
hdr = p.getShortHeader(s.KeyPhase()) hdr = p.getShortHeader(s.KeyPhase())
} }
return p.writeAndSealPacket(hdr, payload, encLevel, sealer) return p.writeSinglePacket(hdr, payload, encLevel, sealer)
} }
func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) {
@ -251,7 +251,7 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke
if err != nil { if err != nil {
return nil, err return nil, err
} }
return p.writeAndSealPacket(hdr, payload, encLevel, sealer) return p.writeSinglePacket(hdr, payload, encLevel, sealer)
} }
// PackPacket packs a new packet. // PackPacket packs a new packet.
@ -357,7 +357,7 @@ func (p *packetPacker) packCryptoPacket(
payload.frames = []ackhandler.Frame{{Frame: cf}} payload.frames = []ackhandler.Frame{{Frame: cf}}
payload.length += cf.Length(p.version) payload.length += cf.Length(p.version)
} }
return p.writeAndSealPacket(hdr, payload, encLevel, sealer) return p.writeSinglePacket(hdr, payload, encLevel, sealer)
} }
func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) {
@ -403,7 +403,7 @@ func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) {
p.numNonAckElicitingAcks = 0 p.numNonAckElicitingAcks = 0
} }
return p.writeAndSealPacket(header, payload, encLevel, sealer) return p.writeSinglePacket(header, payload, encLevel, sealer)
} }
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payload { func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payload {
@ -529,15 +529,37 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex
return hdr return hdr
} }
func (p *packetPacker) writeAndSealPacket( // writeSinglePacket packs a single packet.
func (p *packetPacker) writeSinglePacket(
header *wire.ExtendedHeader, header *wire.ExtendedHeader,
payload payload, payload payload,
encLevel protocol.EncryptionLevel, encLevel protocol.EncryptionLevel,
sealer sealer, sealer sealer,
) (*packedPacket, error) { ) (*packedPacket, error) {
packetBuffer := getPacketBuffer()
n, err := p.appendPacket(packetBuffer.Slice[:0], header, payload, encLevel, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
header: header,
raw: packetBuffer.Slice[:n],
ack: payload.ack,
frames: payload.frames,
buffer: packetBuffer,
}, nil
}
func (p *packetPacker) appendPacket(
raw []byte,
header *wire.ExtendedHeader,
payload payload,
encLevel protocol.EncryptionLevel,
sealer sealer,
) (int, error) {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen) pnLen := protocol.ByteCount(header.PacketNumberLen)
if encLevel != protocol.Encryption1RTT { if encLevel != protocol.Encryption1RTT {
if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
headerLen := header.GetLength(p.version) headerLen := header.GetLength(p.version)
@ -549,27 +571,17 @@ func (p *packetPacker) writeAndSealPacket(
} else if payload.length < 4-pnLen { } else if payload.length < 4-pnLen {
paddingLen = 4 - pnLen - payload.length paddingLen = 4 - pnLen - payload.length
} }
return p.writeAndSealPacketWithPadding(header, payload, paddingLen, encLevel, sealer)
}
func (p *packetPacker) writeAndSealPacketWithPadding(
header *wire.ExtendedHeader,
payload payload,
paddingLen protocol.ByteCount,
encLevel protocol.EncryptionLevel,
sealer sealer,
) (*packedPacket, error) {
packetBuffer := getPacketBuffer()
buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
hdrOffset := len(raw)
buffer := bytes.NewBuffer(raw)
if err := header.Write(buffer, p.version); err != nil { if err := header.Write(buffer, p.version); err != nil {
return nil, err return 0, err
} }
payloadOffset := buffer.Len() payloadOffset := buffer.Len()
if payload.ack != nil { if payload.ack != nil {
if err := payload.ack.Write(buffer, p.version); err != nil { if err := payload.ack.Write(buffer, p.version); err != nil {
return nil, err return 0, err
} }
} }
if paddingLen > 0 { if paddingLen > 0 {
@ -577,40 +589,29 @@ func (p *packetPacker) writeAndSealPacketWithPadding(
} }
for _, frame := range payload.frames { for _, frame := range payload.frames {
if err := frame.Write(buffer, p.version); err != nil { if err := frame.Write(buffer, p.version); err != nil {
return nil, err return 0, err
} }
} }
if payloadSize := protocol.ByteCount(buffer.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { if payloadSize := protocol.ByteCount(buffer.Len()-payloadOffset) - paddingLen; payloadSize != payload.length {
fmt.Printf("%#v\n", payload) return 0, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize)
return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize)
} }
if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize { if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) return 0, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
} }
raw := buffer.Bytes() raw = raw[:buffer.Len()]
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset]) _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset])
raw = raw[0 : buffer.Len()+sealer.Overhead()] raw = raw[0 : buffer.Len()+sealer.Overhead()]
pnOffset := payloadOffset - int(header.PacketNumberLen) pnOffset := payloadOffset - int(header.PacketNumberLen)
sealer.EncryptHeader( sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset])
raw[pnOffset+4:pnOffset+4+16],
&raw[0],
raw[pnOffset:payloadOffset],
)
num := p.pnManager.PopPacketNumber(encLevel) num := p.pnManager.PopPacketNumber(encLevel)
if num != header.PacketNumber { if num != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") return 0, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
} }
return &packedPacket{ return len(raw) - hdrOffset, nil
header: header,
raw: raw,
ack: payload.ack,
frames: payload.frames,
buffer: packetBuffer,
}, nil
} }
func (p *packetPacker) SetToken(token []byte) { func (p *packetPacker) SetToken(token []byte) {