mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
210 lines
6.3 KiB
Go
210 lines
6.3 KiB
Go
package wire
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/quic-go/quic-go/internal/protocol"
|
|
"github.com/quic-go/quic-go/internal/utils"
|
|
"github.com/quic-go/quic-go/quicvarint"
|
|
)
|
|
|
|
// ErrInvalidReservedBits is returned when the reserved bits are incorrect.
|
|
// When this error is returned, parsing continues, and an ExtendedHeader is returned.
|
|
// This is necessary because we need to decrypt the packet in that case,
|
|
// in order to avoid a timing side-channel.
|
|
var ErrInvalidReservedBits = errors.New("invalid reserved bits")
|
|
|
|
// ExtendedHeader is the header of a QUIC packet.
|
|
type ExtendedHeader struct {
|
|
Header
|
|
|
|
typeByte byte
|
|
|
|
KeyPhase protocol.KeyPhaseBit
|
|
|
|
PacketNumberLen protocol.PacketNumberLen
|
|
PacketNumber protocol.PacketNumber
|
|
|
|
parsedLen protocol.ByteCount
|
|
}
|
|
|
|
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) {
|
|
startLen := b.Len()
|
|
// read the (now unencrypted) first byte
|
|
var err error
|
|
h.typeByte, err = b.ReadByte()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
|
|
return false, err
|
|
}
|
|
reservedBitsValid, err := h.parseLongHeader(b, v)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
h.parsedLen = protocol.ByteCount(startLen - b.Len())
|
|
return reservedBitsValid, err
|
|
}
|
|
|
|
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) {
|
|
if err := h.readPacketNumber(b); err != nil {
|
|
return false, err
|
|
}
|
|
if h.typeByte&0xc != 0 {
|
|
return false, nil
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
|
|
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
|
|
switch h.PacketNumberLen {
|
|
case protocol.PacketNumberLen1:
|
|
n, err := b.ReadByte()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.PacketNumber = protocol.PacketNumber(n)
|
|
case protocol.PacketNumberLen2:
|
|
n, err := utils.BigEndian.ReadUint16(b)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.PacketNumber = protocol.PacketNumber(n)
|
|
case protocol.PacketNumberLen3:
|
|
n, err := utils.BigEndian.ReadUint24(b)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.PacketNumber = protocol.PacketNumber(n)
|
|
case protocol.PacketNumberLen4:
|
|
n, err := utils.BigEndian.ReadUint32(b)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.PacketNumber = protocol.PacketNumber(n)
|
|
default:
|
|
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Append appends the Header.
|
|
func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) {
|
|
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
|
|
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
|
|
}
|
|
if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
|
|
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
|
|
}
|
|
|
|
var packetType uint8
|
|
if v == protocol.Version2 {
|
|
//nolint:exhaustive
|
|
switch h.Type {
|
|
case protocol.PacketTypeInitial:
|
|
packetType = 0b01
|
|
case protocol.PacketType0RTT:
|
|
packetType = 0b10
|
|
case protocol.PacketTypeHandshake:
|
|
packetType = 0b11
|
|
case protocol.PacketTypeRetry:
|
|
packetType = 0b00
|
|
}
|
|
} else {
|
|
//nolint:exhaustive
|
|
switch h.Type {
|
|
case protocol.PacketTypeInitial:
|
|
packetType = 0b00
|
|
case protocol.PacketType0RTT:
|
|
packetType = 0b01
|
|
case protocol.PacketTypeHandshake:
|
|
packetType = 0b10
|
|
case protocol.PacketTypeRetry:
|
|
packetType = 0b11
|
|
}
|
|
}
|
|
firstByte := 0xc0 | packetType<<4
|
|
if h.Type != protocol.PacketTypeRetry {
|
|
// Retry packets don't have a packet number
|
|
firstByte |= uint8(h.PacketNumberLen - 1)
|
|
}
|
|
|
|
b = append(b, firstByte)
|
|
b = append(b, make([]byte, 4)...)
|
|
binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version))
|
|
b = append(b, uint8(h.DestConnectionID.Len()))
|
|
b = append(b, h.DestConnectionID.Bytes()...)
|
|
b = append(b, uint8(h.SrcConnectionID.Len()))
|
|
b = append(b, h.SrcConnectionID.Bytes()...)
|
|
|
|
//nolint:exhaustive
|
|
switch h.Type {
|
|
case protocol.PacketTypeRetry:
|
|
b = append(b, h.Token...)
|
|
return b, nil
|
|
case protocol.PacketTypeInitial:
|
|
b = quicvarint.Append(b, uint64(len(h.Token)))
|
|
b = append(b, h.Token...)
|
|
}
|
|
b = quicvarint.AppendWithLen(b, uint64(h.Length), 2)
|
|
return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
|
}
|
|
|
|
// ParsedLen returns the number of bytes that were consumed when parsing the header
|
|
func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
|
|
return h.parsedLen
|
|
}
|
|
|
|
// GetLength determines the length of the Header.
|
|
func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount {
|
|
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
|
|
if h.Type == protocol.PacketTypeInitial {
|
|
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
|
}
|
|
return length
|
|
}
|
|
|
|
// Log logs the Header
|
|
func (h *ExtendedHeader) Log(logger utils.Logger) {
|
|
var token string
|
|
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
|
|
if len(h.Token) == 0 {
|
|
token = "Token: (empty), "
|
|
} else {
|
|
token = fmt.Sprintf("Token: %#x, ", h.Token)
|
|
}
|
|
if h.Type == protocol.PacketTypeRetry {
|
|
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
|
|
return
|
|
}
|
|
}
|
|
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
|
|
}
|
|
|
|
func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) {
|
|
switch pnLen {
|
|
case protocol.PacketNumberLen1:
|
|
b = append(b, uint8(pn))
|
|
case protocol.PacketNumberLen2:
|
|
buf := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(buf, uint16(pn))
|
|
b = append(b, buf...)
|
|
case protocol.PacketNumberLen3:
|
|
buf := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(buf, uint32(pn))
|
|
b = append(b, buf[1:]...)
|
|
case protocol.PacketNumberLen4:
|
|
buf := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(buf, uint32(pn))
|
|
b = append(b, buf...)
|
|
default:
|
|
return nil, fmt.Errorf("invalid packet number length: %d", pnLen)
|
|
}
|
|
return b, nil
|
|
}
|