mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
move parsing of extended headers to the same file as the struct
This commit is contained in:
parent
9bcedd988d
commit
af157408de
6 changed files with 96 additions and 91 deletions
|
@ -4,8 +4,10 @@ import (
|
|||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -22,6 +24,7 @@ type ExtendedHeader struct {
|
|||
PacketNumberLen protocol.PacketNumberLen
|
||||
PacketNumber protocol.PacketNumber
|
||||
|
||||
typeByte byte
|
||||
Type protocol.PacketType
|
||||
IsLongHeader bool
|
||||
KeyPhase int
|
||||
|
@ -29,6 +32,79 @@ type ExtendedHeader struct {
|
|||
Token []byte
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
if h.IsLongHeader {
|
||||
return h.parseLongHeader(b, v)
|
||||
}
|
||||
return h.parseShortHeader(b, v)
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
h.Type = protocol.PacketType(h.typeByte & 0x7f)
|
||||
|
||||
if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
|
||||
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type))
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
odcilByte, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
odcil := decodeSingleConnIDLen(odcilByte & 0xf)
|
||||
h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.Token = make([]byte, b.Len())
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
tokenLen, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tokenLen > uint64(b.Len()) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
h.Token = make([]byte, tokenLen)
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pl, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.Length = protocol.ByteCount(pl)
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PacketNumber = pn
|
||||
h.PacketNumberLen = pnLen
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
h.KeyPhase = int(h.typeByte&0x40) >> 6
|
||||
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PacketNumber = pn
|
||||
h.PacketNumberLen = pnLen
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Write writes the Header.
|
||||
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
|
||||
if h.IsLongHeader {
|
||||
|
|
|
@ -2,7 +2,6 @@ package wire
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -95,91 +94,21 @@ func (h *Header) IsVersionNegotiation() bool {
|
|||
return h.IsLongHeader() && h.Version == 0
|
||||
}
|
||||
|
||||
// Parse parses the version dependent part of the header.
|
||||
// ParseExtended parses the version dependent part of the header.
|
||||
// The Reader has to be set such that it points to the first byte of the header.
|
||||
func (h *Header) Parse(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
if _, err := b.Seek(int64(h.len), io.SeekCurrent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if h.IsLongHeader() {
|
||||
return h.parseLongHeader(b, ver)
|
||||
}
|
||||
return h.parseShortHeader(b, ver)
|
||||
return h.toExtendedHeader().parse(b, ver)
|
||||
}
|
||||
|
||||
func (h *Header) toExtendedHeader() *ExtendedHeader {
|
||||
return &ExtendedHeader{
|
||||
IsLongHeader: h.IsLongHeader(),
|
||||
typeByte: h.typeByte,
|
||||
DestConnectionID: h.DestConnectionID,
|
||||
SrcConnectionID: h.SrcConnectionID,
|
||||
Version: h.Version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Header) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
eh := h.toExtendedHeader()
|
||||
eh.Type = protocol.PacketType(h.typeByte & 0x7f)
|
||||
|
||||
if eh.Type != protocol.PacketTypeInitial && eh.Type != protocol.PacketTypeRetry && eh.Type != protocol.PacketType0RTT && eh.Type != protocol.PacketTypeHandshake {
|
||||
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", eh.Type))
|
||||
}
|
||||
|
||||
if eh.Type == protocol.PacketTypeRetry {
|
||||
odcilByte, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
odcil := decodeSingleConnIDLen(odcilByte & 0xf)
|
||||
eh.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eh.Token = make([]byte, b.Len())
|
||||
if _, err := io.ReadFull(b, eh.Token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return eh, nil
|
||||
}
|
||||
|
||||
if eh.Type == protocol.PacketTypeInitial {
|
||||
tokenLen, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tokenLen > uint64(b.Len()) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
eh.Token = make([]byte, tokenLen)
|
||||
if _, err := io.ReadFull(b, eh.Token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pl, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eh.Length = protocol.ByteCount(pl)
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eh.PacketNumber = pn
|
||||
eh.PacketNumberLen = pnLen
|
||||
|
||||
return eh, nil
|
||||
}
|
||||
|
||||
func (h *Header) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
eh := h.toExtendedHeader()
|
||||
eh.KeyPhase = int(h.typeByte&0x40) >> 6
|
||||
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eh.PacketNumber = pn
|
||||
eh.PacketNumberLen = pnLen
|
||||
|
||||
return eh, nil
|
||||
}
|
|
@ -85,7 +85,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
Expect(hdr.DestConnectionID).To(Equal(destConnID))
|
||||
Expect(hdr.SrcConnectionID).To(Equal(srcConnID))
|
||||
b := bytes.NewReader(data)
|
||||
extHdr, err := hdr.Parse(b, versionIETFFrames)
|
||||
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(extHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(extHdr.IsLongHeader).To(BeTrue())
|
||||
|
@ -142,7 +142,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
hdr, err := ParseHeader(bytes.NewReader(data), 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := bytes.NewReader(data)
|
||||
extHdr, err := hdr.Parse(b, versionIETFFrames)
|
||||
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123)))
|
||||
Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2))
|
||||
|
@ -161,7 +161,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
hdr, err := ParseHeader(bytes.NewReader(data), 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := bytes.NewReader(data)
|
||||
extHdr, err := hdr.Parse(b, versionIETFFrames)
|
||||
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(extHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||
Expect(extHdr.OrigDestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
|
||||
|
@ -183,7 +183,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
b := bytes.NewReader(buf.Bytes())
|
||||
hdr, err := ParseHeader(b, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = hdr.Parse(b, versionIETFFrames)
|
||||
_, err = hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42"))
|
||||
})
|
||||
|
||||
|
@ -200,7 +200,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
b := bytes.NewReader(data)
|
||||
hdr, err := ParseHeader(b, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = hdr.Parse(b, versionIETFFrames)
|
||||
_, err = hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
|
||||
|
@ -231,7 +231,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
b := bytes.NewReader(data[:i])
|
||||
hdr, err := ParseHeader(b, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = hdr.Parse(b, versionIETFFrames)
|
||||
_, err = hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).To(Equal(io.EOF))
|
||||
}
|
||||
})
|
||||
|
@ -251,7 +251,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
b := bytes.NewReader(data[:i])
|
||||
hdr, err := ParseHeader(b, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = hdr.Parse(b, versionIETFFrames)
|
||||
_, err = hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).To(Equal(io.EOF))
|
||||
}
|
||||
})
|
||||
|
@ -268,7 +268,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
Expect(hdr.IsVersionNegotiation()).To(BeFalse())
|
||||
Expect(hdr.DestConnectionID).To(Equal(connID))
|
||||
b := bytes.NewReader(data)
|
||||
extHdr, err := hdr.Parse(b, versionIETFFrames)
|
||||
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(extHdr.KeyPhase).To(Equal(0))
|
||||
Expect(extHdr.DestConnectionID).To(Equal(connID))
|
||||
|
@ -286,7 +286,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
Expect(hdr.IsLongHeader()).To(BeFalse())
|
||||
Expect(hdr.DestConnectionID).To(Equal(connID))
|
||||
b := bytes.NewReader(data)
|
||||
extHdr, err := hdr.Parse(b, versionIETFFrames)
|
||||
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(extHdr.KeyPhase).To(Equal(0))
|
||||
Expect(extHdr.DestConnectionID).To(Equal(connID))
|
||||
|
@ -304,7 +304,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hdr.IsLongHeader()).To(BeFalse())
|
||||
b := bytes.NewReader(data)
|
||||
extHdr, err := hdr.Parse(b, versionIETFFrames)
|
||||
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(extHdr.KeyPhase).To(Equal(1))
|
||||
Expect(b.Len()).To(BeZero())
|
||||
|
@ -319,7 +319,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
hdr, err := ParseHeader(bytes.NewReader(data), 4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := bytes.NewReader(data)
|
||||
extHdr, err := hdr.Parse(b, versionIETFFrames)
|
||||
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(extHdr.IsLongHeader).To(BeFalse())
|
||||
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337)))
|
||||
|
@ -336,7 +336,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
hdr, err := ParseHeader(bytes.NewReader(data), 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := bytes.NewReader(data)
|
||||
extHdr, err := hdr.Parse(b, versionIETFFrames)
|
||||
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(extHdr.IsLongHeader).To(BeFalse())
|
||||
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef)))
|
||||
|
@ -366,7 +366,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
b := bytes.NewReader(data[:i])
|
||||
hdr, err := ParseHeader(b, 6)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = hdr.Parse(b, versionIETFFrames)
|
||||
_, err = hdr.ParseExtended(b, versionIETFFrames)
|
||||
Expect(err).To(Equal(io.EOF))
|
||||
}
|
||||
})
|
|
@ -210,7 +210,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
|||
if !hdr.IsVersionNegotiation() {
|
||||
r = bytes.NewReader(data)
|
||||
var err error
|
||||
extHdr, err = hdr.Parse(r, version)
|
||||
extHdr, err = hdr.ParseExtended(r, version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing extended header: %s", err)
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ var _ = Describe("Packet packer", func() {
|
|||
hdr, err := wire.ParseHeader(bytes.NewReader(data), 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
r := bytes.NewReader(data)
|
||||
extHdr, err := hdr.Parse(r, protocol.VersionWhatever)
|
||||
extHdr, err := hdr.ParseExtended(r, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(0, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen)))
|
||||
}
|
||||
|
|
|
@ -101,7 +101,7 @@ var _ = Describe("Server", func() {
|
|||
parseHeader := func(data []byte) *wire.ExtendedHeader {
|
||||
hdr, err := wire.ParseHeader(bytes.NewReader(data), 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
extHdr, err := hdr.Parse(bytes.NewReader(data), protocol.VersionTLS)
|
||||
extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return extHdr
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue