uquic/internal/wire/ietf_header.go
Marten Seemann 8f2fed1b10 parse IETF headers independent of the sender's perspective
The IETF header format allows parsing of the header without knowing
which peer sent the packet.
2018-05-09 13:28:30 +09:00

234 lines
6.7 KiB
Go

package wire
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// parseHeader parses the header.
func parseHeader(b *bytes.Reader) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
if typeByte&0x80 > 0 {
return parseLongHeader(b, typeByte)
}
return parseShortHeader(b, typeByte)
}
// parse long header and version negotiation packets
func parseLongHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return nil, err
}
connIDLenByte, err := b.ReadByte()
if err != nil {
return nil, err
}
dcil, scil := decodeConnIDLen(connIDLenByte)
destConnID, err := protocol.ReadConnectionID(b, dcil)
if err != nil {
return nil, err
}
srcConnID, err := protocol.ReadConnectionID(b, scil)
if err != nil {
return nil, err
}
h := &Header{
IsLongHeader: true,
Version: protocol.VersionNumber(v),
DestConnectionID: destConnID,
SrcConnectionID: srcConnID,
}
if v == 0 { // version negotiation packet
if b.Len() == 0 {
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
}
h.IsVersionNegotiation = true
h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
for i := 0; b.Len() > 0; i++ {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return nil, qerr.InvalidVersionNegotiationPacket
}
h.SupportedVersions[i] = protocol.VersionNumber(v)
}
return h, nil
}
pl, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
}
h.PayloadLen = protocol.ByteCount(pl)
pn, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return nil, err
}
h.PacketNumber = protocol.PacketNumber(pn)
h.PacketNumberLen = protocol.PacketNumberLen4
h.Type = protocol.PacketType(typeByte & 0x7f)
if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type))
}
return h, nil
}
func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
connID := make(protocol.ConnectionID, 8)
if _, err := io.ReadFull(b, connID); err != nil {
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return nil, err
}
// bits 2 and 3 must be set, bit 4 must be unset
if typeByte&0x38 != 0x30 {
return nil, errors.New("invalid bits 3, 4 and 5")
}
var pnLen protocol.PacketNumberLen
switch typeByte & 0x3 {
case 0x0:
pnLen = protocol.PacketNumberLen1
case 0x1:
pnLen = protocol.PacketNumberLen2
case 0x2:
pnLen = protocol.PacketNumberLen4
default:
return nil, errors.New("invalid short header type")
}
pn, err := utils.BigEndian.ReadUintN(b, uint8(pnLen))
if err != nil {
return nil, err
}
return &Header{
KeyPhase: int(typeByte&0x40) >> 6,
DestConnectionID: connID,
PacketNumber: protocol.PacketNumber(pn),
PacketNumberLen: pnLen,
}, nil
}
// writeHeader writes the Header.
func (h *Header) writeHeader(b *bytes.Buffer) error {
if h.IsLongHeader {
return h.writeLongHeader(b)
}
return h.writeShortHeader(b)
}
// TODO: add support for the key phase
func (h *Header) writeLongHeader(b *bytes.Buffer) error {
if h.SrcConnectionID.Len() != protocol.ConnectionIDLen {
return fmt.Errorf("Header: source connection ID must be %d bytes, is %d", protocol.ConnectionIDLen, h.SrcConnectionID.Len())
}
b.WriteByte(byte(0x80 | h.Type))
utils.BigEndian.WriteUint32(b, uint32(h.Version))
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
if err != nil {
return err
}
b.WriteByte(connIDLen)
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
utils.WriteVarInt(b, uint64(h.PayloadLen))
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
return nil
}
func (h *Header) writeShortHeader(b *bytes.Buffer) error {
typeByte := byte(0x30)
typeByte |= byte(h.KeyPhase << 6)
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
case protocol.PacketNumberLen2:
typeByte |= 0x1
case protocol.PacketNumberLen4:
typeByte |= 0x2
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
}
return nil
}
func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
if h.IsLongHeader {
return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + 4 /* packet number */, nil
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
length += protocol.ByteCount(h.PacketNumberLen)
return length, nil
}
func (h *Header) logHeader(logger utils.Logger) {
if h.IsLongHeader {
if h.Version == 0 {
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
} else {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PayloadLen, h.Version)
}
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
}
}
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
dcil, err := encodeSingleConnIDLen(dest)
if err != nil {
return 0, err
}
scil, err := encodeSingleConnIDLen(src)
if err != nil {
return 0, err
}
return scil | dcil<<4, nil
}
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
len := id.Len()
if len == 0 {
return 0, nil
}
if len < 4 || len > 18 {
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
}
return byte(len - 3), nil
}
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
}
func decodeSingleConnIDLen(enc uint8) int {
if enc == 0 {
return 0
}
return int(enc) + 3
}