refactor header parsing in a version independent and dependent part

This commit is contained in:
Marten Seemann 2018-06-30 23:32:26 +07:00
parent 74ed3f7037
commit 4109c85c8a
20 changed files with 1587 additions and 1804 deletions

View file

@ -2,6 +2,8 @@ package wire
import (
"bytes"
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
@ -37,54 +39,7 @@ type Header struct {
PayloadLen protocol.ByteCount
}
// ParseHeaderSentByServer parses the header for a packet that was sent by the server.
func ParseHeaderSentByServer(b *bytes.Reader) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
_ = b.UnreadByte() // unread the type byte
var isPublicHeader bool
if typeByte&0x80 > 0 { // gQUIC always has 0x80 unset. IETF Long Header or Version Negotiation
isPublicHeader = false
} else {
// gQUIC never uses 6 byte packet numbers, so the third and fourth bit will never be 11
isPublicHeader = typeByte&0x30 != 0x30
}
return parsePacketHeader(b, protocol.PerspectiveServer, isPublicHeader)
}
// ParseHeaderSentByClient parses the header for a packet that was sent by the client.
func ParseHeaderSentByClient(b *bytes.Reader) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
_ = b.UnreadByte() // unread the type byte
// In an IETF QUIC packet header
// * either 0x80 is set (for the Long Header)
// * or 0x8 is unset (for the Short Header)
// In a gQUIC Public Header
// * 0x80 is always unset and
// * and 0x8 is always set (this is the Connection ID flag, which the client always sets)
isPublicHeader := typeByte&0x88 == 0x8
return parsePacketHeader(b, protocol.PerspectiveClient, isPublicHeader)
}
func parsePacketHeader(b *bytes.Reader, sentBy protocol.Perspective, isPublicHeader bool) (*Header, error) {
// This is a gQUIC Public Header.
if isPublicHeader {
hdr, err := parsePublicHeader(b, sentBy)
if err != nil {
return nil, err
}
hdr.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
return hdr, nil
}
return parseHeader(b)
}
var errInvalidPacketNumberLen6 = errors.New("invalid packet number length: 6 bytes")
// Write writes the Header.
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error {
@ -92,17 +47,142 @@ func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version proto
h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
return h.writePublicHeader(b, pers, version)
}
return h.writeHeader(b)
// write an IETF QUIC header
if h.IsLongHeader {
return h.writeLongHeader(b)
}
return h.writeShortHeader(b)
}
// TODO: add support for the key phase
func (h *Header) writeLongHeader(b *bytes.Buffer) error {
if h.SrcConnectionID.Len() != protocol.ConnectionIDLen {
return fmt.Errorf("Header: source connection ID must be %d bytes, is %d", protocol.ConnectionIDLen, h.SrcConnectionID.Len())
}
b.WriteByte(byte(0x80 | h.Type))
utils.BigEndian.WriteUint32(b, uint32(h.Version))
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
if err != nil {
return err
}
b.WriteByte(connIDLen)
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
utils.WriteVarInt(b, uint64(h.PayloadLen))
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
func (h *Header) writeShortHeader(b *bytes.Buffer) error {
typeByte := byte(0x30)
typeByte |= byte(h.KeyPhase << 6)
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
// writePublicHeader writes a Public Header.
func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
return errors.New("PublicHeader: Can only write regular packets")
}
if h.SrcConnectionID.Len() != 0 {
return errors.New("PublicHeader: SrcConnectionID must not be set")
}
if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
}
publicFlagByte := uint8(0x00)
if h.VersionFlag {
publicFlagByte |= 0x01
}
if h.DestConnectionID.Len() > 0 {
publicFlagByte |= 0x08
}
if len(h.DiversificationNonce) > 0 {
if len(h.DiversificationNonce) != 32 {
return errors.New("invalid diversification nonce length")
}
publicFlagByte |= 0x04
}
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
publicFlagByte |= 0x00
case protocol.PacketNumberLen2:
publicFlagByte |= 0x10
case protocol.PacketNumberLen4:
publicFlagByte |= 0x20
}
b.WriteByte(publicFlagByte)
if h.DestConnectionID.Len() > 0 {
b.Write(h.DestConnectionID)
}
if h.VersionFlag && pers == protocol.PerspectiveClient {
utils.BigEndian.WriteUint32(b, uint32(h.Version))
}
if len(h.DiversificationNonce) > 0 {
b.Write(h.DiversificationNonce)
}
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen6:
return errInvalidPacketNumberLen6
default:
return errors.New("PublicHeader: PacketNumberLen not set")
}
return nil
}
// GetLength determines the length of the Header.
func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNumber) (protocol.ByteCount, error) {
func (h *Header) GetLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
if !version.UsesTLS() {
return h.getPublicHeaderLength(pers)
return h.getPublicHeaderLength()
}
return h.getHeaderLength()
}
func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
if h.IsLongHeader {
return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + protocol.ByteCount(h.PacketNumberLen), nil
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
length += protocol.ByteCount(h.PacketNumberLen)
return length, nil
}
// getPublicHeaderLength gets the length of the publicHeader in bytes.
// It can only be called for regular packets.
func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
length := protocol.ByteCount(1) // 1 byte for public flags
if h.PacketNumberLen == protocol.PacketNumberLen6 {
return 0, errInvalidPacketNumberLen6
}
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
return 0, errPacketNumberLenNotSet
}
length += protocol.ByteCount(h.PacketNumberLen)
length += protocol.ByteCount(h.DestConnectionID.Len()) // if set, always 8 bytes
// Version Number in packets sent by the client
if h.VersionFlag {
length += 4
}
length += protocol.ByteCount(len(h.DiversificationNonce))
return length, nil
}
// Log logs the Header
func (h *Header) Log(logger utils.Logger) {
if h.IsPublicHeader {
@ -111,3 +191,57 @@ func (h *Header) Log(logger utils.Logger) {
h.logHeader(logger)
}
}
func (h *Header) logHeader(logger utils.Logger) {
if h.IsLongHeader {
if h.Version == 0 {
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
} else {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
}
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
}
}
func (h *Header) logPublicHeader(logger utils.Logger) {
ver := "(unset)"
if h.Version != 0 {
ver = h.Version.String()
}
logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
}
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
dcil, err := encodeSingleConnIDLen(dest)
if err != nil {
return 0, err
}
scil, err := encodeSingleConnIDLen(src)
if err != nil {
return 0, err
}
return scil | dcil<<4, nil
}
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
len := id.Len()
if len == 0 {
return 0, nil
}
if len < 4 || len > 18 {
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
}
return byte(len - 3), nil
}
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
}
func decodeSingleConnIDLen(enc uint8) int {
if enc == 0 {
return 0
}
return int(enc) + 3
}