mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
247 lines
7.8 KiB
Go
247 lines
7.8 KiB
Go
package wire
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
)
|
|
|
|
// Header is the header of a QUIC packet.
|
|
// It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header.
|
|
type Header struct {
|
|
IsPublicHeader bool
|
|
|
|
Raw []byte
|
|
|
|
Version protocol.VersionNumber
|
|
|
|
DestConnectionID protocol.ConnectionID
|
|
SrcConnectionID protocol.ConnectionID
|
|
|
|
PacketNumberLen protocol.PacketNumberLen
|
|
PacketNumber protocol.PacketNumber
|
|
|
|
IsVersionNegotiation bool
|
|
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
|
|
|
|
// only needed for the gQUIC Public Header
|
|
VersionFlag bool
|
|
ResetFlag bool
|
|
DiversificationNonce []byte
|
|
|
|
// only needed for the IETF Header
|
|
Type protocol.PacketType
|
|
IsLongHeader bool
|
|
KeyPhase int
|
|
PayloadLen protocol.ByteCount
|
|
}
|
|
|
|
var errInvalidPacketNumberLen6 = errors.New("invalid packet number length: 6 bytes")
|
|
|
|
// Write writes the Header.
|
|
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error {
|
|
if !version.UsesTLS() {
|
|
h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
|
|
return h.writePublicHeader(b, pers, version)
|
|
}
|
|
// write an IETF QUIC header
|
|
if h.IsLongHeader {
|
|
return h.writeLongHeader(b)
|
|
}
|
|
return h.writeShortHeader(b)
|
|
}
|
|
|
|
// TODO: add support for the key phase
|
|
func (h *Header) writeLongHeader(b *bytes.Buffer) error {
|
|
if h.SrcConnectionID.Len() != protocol.ConnectionIDLen {
|
|
return fmt.Errorf("Header: source connection ID must be %d bytes, is %d", protocol.ConnectionIDLen, h.SrcConnectionID.Len())
|
|
}
|
|
b.WriteByte(byte(0x80 | h.Type))
|
|
utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
|
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.WriteByte(connIDLen)
|
|
b.Write(h.DestConnectionID.Bytes())
|
|
b.Write(h.SrcConnectionID.Bytes())
|
|
utils.WriteVarInt(b, uint64(h.PayloadLen))
|
|
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
|
}
|
|
|
|
func (h *Header) writeShortHeader(b *bytes.Buffer) error {
|
|
typeByte := byte(0x30)
|
|
typeByte |= byte(h.KeyPhase << 6)
|
|
b.WriteByte(typeByte)
|
|
|
|
b.Write(h.DestConnectionID.Bytes())
|
|
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
|
}
|
|
|
|
// writePublicHeader writes a Public Header.
|
|
func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
|
|
if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
|
|
return errors.New("PublicHeader: Can only write regular packets")
|
|
}
|
|
if h.SrcConnectionID.Len() != 0 {
|
|
return errors.New("PublicHeader: SrcConnectionID must not be set")
|
|
}
|
|
if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
|
|
return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
|
|
}
|
|
|
|
publicFlagByte := uint8(0x00)
|
|
if h.VersionFlag {
|
|
publicFlagByte |= 0x01
|
|
}
|
|
if h.DestConnectionID.Len() > 0 {
|
|
publicFlagByte |= 0x08
|
|
}
|
|
if len(h.DiversificationNonce) > 0 {
|
|
if len(h.DiversificationNonce) != 32 {
|
|
return errors.New("invalid diversification nonce length")
|
|
}
|
|
publicFlagByte |= 0x04
|
|
}
|
|
switch h.PacketNumberLen {
|
|
case protocol.PacketNumberLen1:
|
|
publicFlagByte |= 0x00
|
|
case protocol.PacketNumberLen2:
|
|
publicFlagByte |= 0x10
|
|
case protocol.PacketNumberLen4:
|
|
publicFlagByte |= 0x20
|
|
}
|
|
b.WriteByte(publicFlagByte)
|
|
|
|
if h.DestConnectionID.Len() > 0 {
|
|
b.Write(h.DestConnectionID)
|
|
}
|
|
if h.VersionFlag && pers == protocol.PerspectiveClient {
|
|
utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
|
}
|
|
if len(h.DiversificationNonce) > 0 {
|
|
b.Write(h.DiversificationNonce)
|
|
}
|
|
|
|
switch h.PacketNumberLen {
|
|
case protocol.PacketNumberLen1:
|
|
b.WriteByte(uint8(h.PacketNumber))
|
|
case protocol.PacketNumberLen2:
|
|
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
|
|
case protocol.PacketNumberLen4:
|
|
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
|
|
case protocol.PacketNumberLen6:
|
|
return errInvalidPacketNumberLen6
|
|
default:
|
|
return errors.New("PublicHeader: PacketNumberLen not set")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetLength determines the length of the Header.
|
|
func (h *Header) GetLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
if !version.UsesTLS() {
|
|
return h.getPublicHeaderLength()
|
|
}
|
|
return h.getHeaderLength()
|
|
}
|
|
|
|
func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
|
|
if h.IsLongHeader {
|
|
return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + protocol.ByteCount(h.PacketNumberLen), nil
|
|
}
|
|
|
|
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
|
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
|
|
return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
|
|
}
|
|
length += protocol.ByteCount(h.PacketNumberLen)
|
|
return length, nil
|
|
}
|
|
|
|
// getPublicHeaderLength gets the length of the publicHeader in bytes.
|
|
// It can only be called for regular packets.
|
|
func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
|
|
length := protocol.ByteCount(1) // 1 byte for public flags
|
|
if h.PacketNumberLen == protocol.PacketNumberLen6 {
|
|
return 0, errInvalidPacketNumberLen6
|
|
}
|
|
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
|
|
return 0, errPacketNumberLenNotSet
|
|
}
|
|
length += protocol.ByteCount(h.PacketNumberLen)
|
|
length += protocol.ByteCount(h.DestConnectionID.Len()) // if set, always 8 bytes
|
|
// Version Number in packets sent by the client
|
|
if h.VersionFlag {
|
|
length += 4
|
|
}
|
|
length += protocol.ByteCount(len(h.DiversificationNonce))
|
|
return length, nil
|
|
}
|
|
|
|
// Log logs the Header
|
|
func (h *Header) Log(logger utils.Logger) {
|
|
if h.IsPublicHeader {
|
|
h.logPublicHeader(logger)
|
|
} else {
|
|
h.logHeader(logger)
|
|
}
|
|
}
|
|
|
|
func (h *Header) logHeader(logger utils.Logger) {
|
|
if h.IsLongHeader {
|
|
if h.Version == 0 {
|
|
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
|
|
} else {
|
|
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
|
|
}
|
|
} else {
|
|
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
|
|
}
|
|
}
|
|
|
|
func (h *Header) logPublicHeader(logger utils.Logger) {
|
|
ver := "(unset)"
|
|
if h.Version != 0 {
|
|
ver = h.Version.String()
|
|
}
|
|
logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
|
|
}
|
|
|
|
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
|
|
dcil, err := encodeSingleConnIDLen(dest)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
scil, err := encodeSingleConnIDLen(src)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return scil | dcil<<4, nil
|
|
}
|
|
|
|
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
|
|
len := id.Len()
|
|
if len == 0 {
|
|
return 0, nil
|
|
}
|
|
if len < 4 || len > 18 {
|
|
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
|
|
}
|
|
return byte(len - 3), nil
|
|
}
|
|
|
|
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
|
|
return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
|
|
}
|
|
|
|
func decodeSingleConnIDLen(enc uint8) int {
|
|
if enc == 0 {
|
|
return 0
|
|
}
|
|
return int(enc) + 3
|
|
}
|