mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
parse the whole Long Header, if the version is known
This commit is contained in:
parent
b740d57c61
commit
70ce6a5814
11 changed files with 312 additions and 318 deletions
|
@ -2,6 +2,7 @@ package wire
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -15,14 +16,23 @@ type Header struct {
|
|||
SrcConnectionID protocol.ConnectionID
|
||||
DestConnectionID protocol.ConnectionID
|
||||
|
||||
SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
|
||||
|
||||
IsLongHeader bool
|
||||
typeByte byte
|
||||
len int // how many bytes were read while parsing this header
|
||||
Type protocol.PacketType
|
||||
Length protocol.ByteCount
|
||||
|
||||
Token []byte
|
||||
SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
|
||||
OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet
|
||||
|
||||
typeByte byte
|
||||
len int // how many bytes were read while parsing this header
|
||||
}
|
||||
|
||||
// ParseHeader parses the version independent part of the header
|
||||
// ParseHeader parses the header.
|
||||
// For short header packets: up to the packet number.
|
||||
// For long header packets:
|
||||
// * if we understand the version: up to the packet number
|
||||
// * if not, only the invariant part of the header
|
||||
func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
|
||||
startLen := b.Len()
|
||||
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
|
||||
|
@ -45,46 +55,107 @@ func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error)
|
|||
}
|
||||
|
||||
if !h.IsLongHeader {
|
||||
var err error
|
||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
|
||||
if err != nil {
|
||||
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
// Long Header
|
||||
if err := h.parseLongHeader(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
|
||||
var err error
|
||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
h.Version = protocol.VersionNumber(v)
|
||||
connIDLenByte, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
dcil, scil := decodeConnIDLen(connIDLenByte)
|
||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
if h.Version == 0 {
|
||||
if b.Len() == 0 {
|
||||
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
||||
return h.parseVersionNegotiationPacket(b)
|
||||
}
|
||||
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
|
||||
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.Type = protocol.PacketType(h.typeByte & 0x7f)
|
||||
if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
|
||||
return qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type))
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
odcilByte, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
|
||||
for i := 0; b.Len() > 0; i++ {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil {
|
||||
return nil, qerr.InvalidVersionNegotiationPacket
|
||||
}
|
||||
h.SupportedVersions[i] = protocol.VersionNumber(v)
|
||||
odcil := decodeSingleConnIDLen(odcilByte & 0xf)
|
||||
h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Token = make([]byte, b.Len())
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
tokenLen, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tokenLen > uint64(b.Len()) {
|
||||
return io.EOF
|
||||
}
|
||||
h.Token = make([]byte, tokenLen)
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return h, nil
|
||||
|
||||
pl, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Length = protocol.ByteCount(pl)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error {
|
||||
if b.Len() == 0 {
|
||||
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
||||
}
|
||||
h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
|
||||
for i := 0; b.Len() > 0; i++ {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil {
|
||||
return qerr.InvalidVersionNegotiationPacket
|
||||
}
|
||||
h.SupportedVersions[i] = protocol.VersionNumber(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsVersionNegotiation says if this a version negotiation packet
|
||||
|
@ -104,3 +175,14 @@ func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*Ex
|
|||
func (h *Header) toExtendedHeader() *ExtendedHeader {
|
||||
return &ExtendedHeader{Header: *h}
|
||||
}
|
||||
|
||||
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
|
||||
return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
|
||||
}
|
||||
|
||||
func decodeSingleConnIDLen(enc uint8) int {
|
||||
if enc == 0 {
|
||||
return 0
|
||||
}
|
||||
return int(enc) + 3
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue