use separate functions per encryption level to get sealers

This commit is contained in:
Marten Seemann 2019-06-10 15:16:25 +08:00
parent d4d3f09ee3
commit c503769bcd
6 changed files with 180 additions and 124 deletions

View file

@ -93,8 +93,9 @@ type packetNumberManager interface {
}
type sealingManager interface {
GetSealer() (protocol.EncryptionLevel, handshake.Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error)
GetInitialSealer() (handshake.Sealer, error)
GetHandshakeSealer() (handshake.Sealer, error)
Get1RTTSealer() (handshake.Sealer, error)
}
type frameSource interface {
@ -163,7 +164,23 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
frames: []wire.Frame{ccf},
length: ccf.Length(p.version),
}
encLevel, sealer := p.cryptoSetup.GetSealer()
// send the CONNECTION_CLOSE frame with the highest available encryption level
var sealer handshake.Sealer
var err error
encLevel := protocol.Encryption1RTT
sealer, err = p.cryptoSetup.Get1RTTSealer()
if err != nil {
encLevel = protocol.EncryptionHandshake
sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil {
encLevel = protocol.EncryptionInitial
sealer, err = p.cryptoSetup.GetInitialSealer()
if err != nil {
return nil, err
}
}
}
header := p.getHeader(encLevel)
return p.writeAndSealPacket(header, payload, encLevel, sealer)
}
@ -178,9 +195,12 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
length: ack.Length(p.version),
}
// TODO(#1534): only pack ACKs with the right encryption level
encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel)
return p.writeAndSealPacket(header, payload, encLevel, sealer)
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
header := p.getHeader(protocol.Encryption1RTT)
return p.writeAndSealPacket(header, payload, protocol.Encryption1RTT, sealer)
}
// PackRetransmission packs a retransmission
@ -202,8 +222,18 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
}
var packets []*packedPacket
encLevel := packet.EncryptionLevel
sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
var err error
var sealer handshake.Sealer
switch packet.EncryptionLevel {
case protocol.EncryptionInitial:
sealer, err = p.cryptoSetup.GetInitialSealer()
case protocol.EncryptionHandshake:
sealer, err = p.cryptoSetup.GetHandshakeSealer()
case protocol.Encryption1RTT:
sealer, err = p.cryptoSetup.Get1RTTSealer()
default:
return nil, fmt.Errorf("unexpected encryption level: %s", packet.EncryptionLevel)
}
if err != nil {
return nil, err
}
@ -211,7 +241,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
var frames []wire.Frame
var length protocol.ByteCount
header := p.getHeader(encLevel)
header := p.getHeader(packet.EncryptionLevel)
headerLen := header.GetLength(p.version)
maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen
@ -247,7 +277,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
sf.DataLenPresent = false
}
p, err := p.writeAndSealPacket(header, payload{frames: frames, length: length}, encLevel, sealer)
p, err := p.writeAndSealPacket(header, payload{frames: frames, length: length}, packet.EncryptionLevel, sealer)
if err != nil {
return nil, err
}
@ -267,8 +297,12 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
return packet, nil
}
encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel)
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
// sealer not yet available
return nil, nil
}
header := p.getHeader(protocol.Encryption1RTT)
headerLen := header.GetLength(p.version)
if err != nil {
return nil, err
@ -297,7 +331,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
p.numNonAckElicitingAcks = 0
}
return p.writeAndSealPacket(header, payload, encLevel, sealer)
return p.writeAndSealPacket(header, payload, protocol.Encryption1RTT, sealer)
}
func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
@ -306,25 +340,27 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
hasData := p.initialStream.HasData()
ack := p.acks.GetAckFrame(protocol.EncryptionInitial)
var sealer handshake.Sealer
var err error
if hasData || ack != nil {
s = p.initialStream
encLevel = protocol.EncryptionInitial
sealer, err = p.cryptoSetup.GetInitialSealer()
} else {
hasData = p.handshakeStream.HasData()
ack = p.acks.GetAckFrame(protocol.EncryptionHandshake)
if hasData || ack != nil {
s = p.handshakeStream
encLevel = protocol.EncryptionHandshake
sealer, err = p.cryptoSetup.GetHandshakeSealer()
}
}
if err != nil {
return nil, err
}
if s == nil {
return nil, nil
}
sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
if err != nil {
// The sealer
return nil, err
}
var payload payload
if ack != nil {