use separate opener interfaces for long and short header packets

This commit is contained in:
Marten Seemann 2019-06-10 17:11:47 +08:00
parent b2d3ef691e
commit 598628d05b
12 changed files with 320 additions and 149 deletions

View file

@ -10,6 +10,10 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
type unpackedPacket struct {
packetNumber protocol.PacketNumber // the decoded packet number
hdr *wire.ExtendedHeader
@ -36,28 +40,85 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
}
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
r := bytes.NewReader(data)
var opener handshake.Opener
var encLevel protocol.EncryptionLevel
var err error
var extHdr *wire.ExtendedHeader
var decrypted []byte
switch hdr.Type {
case protocol.PacketTypeInitial:
encLevel = protocol.EncryptionInitial
opener, err = u.cs.GetInitialOpener()
opener, err := u.cs.GetInitialOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
case protocol.PacketTypeHandshake:
encLevel = protocol.EncryptionHandshake
opener, err = u.cs.GetHandshakeOpener()
opener, err := u.cs.GetHandshakeOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
default:
if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
}
encLevel = protocol.Encryption1RTT
opener, err = u.cs.Get1RTTOpener()
opener, err := u.cs.Get1RTTOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
}
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber)
return &unpackedPacket{
hdr: extHdr,
packetNumber: extHdr.PacketNumber,
encryptionLevel: encLevel,
data: decrypted,
}, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, err := u.unpack(opener, hdr, data)
if err != nil {
return nil, err
return nil, nil, err
}
extHdrLen := extHdr.GetLength(u.version)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, err := u.unpack(opener, hdr, data)
if err != nil {
return nil, nil, err
}
extHdrLen := extHdr.GetLength(u.version)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpack(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data)
hdrLen := int(hdr.ParsedLen())
if len(data) < hdrLen+4+16 {
return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
@ -67,7 +128,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number
opener.DecryptHeader(
hd.DecryptHeader(
data[hdrLen+4:hdrLen+4+16],
&data[0],
data[hdrLen:hdrLen+4],
@ -77,30 +138,15 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
if err != nil {
return nil, fmt.Errorf("error parsing extended header: %s", err)
}
extHdrLen := hdrLen + int(extHdr.PacketNumberLen)
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
copy(data[extHdrLen:hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
copy(data[extHdr.GetLength(u.version):hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
}
pn := protocol.DecodePacketNumber(
extHdr.PacketNumber = protocol.DecodePacketNumber(
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
extHdr.PacketNumber,
)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], pn, data[:extHdrLen])
if err != nil {
return nil, err
}
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, pn)
return &unpackedPacket{
hdr: extHdr,
packetNumber: pn,
encryptionLevel: encLevel,
data: decrypted,
}, nil
return extHdr, nil
}