move parsing of extended headers to the same file as the struct

This commit is contained in:
Marten Seemann 2018-11-25 15:36:59 +07:00
parent 9bcedd988d
commit af157408de
6 changed files with 96 additions and 91 deletions

View file

@ -4,8 +4,10 @@ import (
"bytes"
"crypto/rand"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
)
@ -22,6 +24,7 @@ type ExtendedHeader struct {
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
typeByte byte
Type protocol.PacketType
IsLongHeader bool
KeyPhase int
@ -29,6 +32,79 @@ type ExtendedHeader struct {
Token []byte
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
if h.IsLongHeader {
return h.parseLongHeader(b, v)
}
return h.parseShortHeader(b, v)
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
h.Type = protocol.PacketType(h.typeByte & 0x7f)
if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type))
}
if h.Type == protocol.PacketTypeRetry {
odcilByte, err := b.ReadByte()
if err != nil {
return nil, err
}
odcil := decodeSingleConnIDLen(odcilByte & 0xf)
h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
if err != nil {
return nil, err
}
h.Token = make([]byte, b.Len())
if _, err := io.ReadFull(b, h.Token); err != nil {
return nil, err
}
return h, nil
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
}
if tokenLen > uint64(b.Len()) {
return nil, io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return nil, err
}
}
pl, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
}
h.Length = protocol.ByteCount(pl)
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
h.PacketNumber = pn
h.PacketNumberLen = pnLen
return h, nil
}
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
h.KeyPhase = int(h.typeByte&0x40) >> 6
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
h.PacketNumber = pn
h.PacketNumberLen = pnLen
return h, nil
}
// Write writes the Header.
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
if h.IsLongHeader {

View file

@ -2,7 +2,6 @@ package wire
import (
"bytes"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -95,91 +94,21 @@ func (h *Header) IsVersionNegotiation() bool {
return h.IsLongHeader() && h.Version == 0
}
// Parse parses the version dependent part of the header.
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) Parse(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
if _, err := b.Seek(int64(h.len), io.SeekCurrent); err != nil {
return nil, err
}
if h.IsLongHeader() {
return h.parseLongHeader(b, ver)
}
return h.parseShortHeader(b, ver)
return h.toExtendedHeader().parse(b, ver)
}
func (h *Header) toExtendedHeader() *ExtendedHeader {
return &ExtendedHeader{
IsLongHeader: h.IsLongHeader(),
typeByte: h.typeByte,
DestConnectionID: h.DestConnectionID,
SrcConnectionID: h.SrcConnectionID,
Version: h.Version,
}
}
func (h *Header) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
eh := h.toExtendedHeader()
eh.Type = protocol.PacketType(h.typeByte & 0x7f)
if eh.Type != protocol.PacketTypeInitial && eh.Type != protocol.PacketTypeRetry && eh.Type != protocol.PacketType0RTT && eh.Type != protocol.PacketTypeHandshake {
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", eh.Type))
}
if eh.Type == protocol.PacketTypeRetry {
odcilByte, err := b.ReadByte()
if err != nil {
return nil, err
}
odcil := decodeSingleConnIDLen(odcilByte & 0xf)
eh.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
if err != nil {
return nil, err
}
eh.Token = make([]byte, b.Len())
if _, err := io.ReadFull(b, eh.Token); err != nil {
return nil, err
}
return eh, nil
}
if eh.Type == protocol.PacketTypeInitial {
tokenLen, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
}
if tokenLen > uint64(b.Len()) {
return nil, io.EOF
}
eh.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, eh.Token); err != nil {
return nil, err
}
}
pl, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
}
eh.Length = protocol.ByteCount(pl)
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
eh.PacketNumber = pn
eh.PacketNumberLen = pnLen
return eh, nil
}
func (h *Header) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
eh := h.toExtendedHeader()
eh.KeyPhase = int(h.typeByte&0x40) >> 6
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
eh.PacketNumber = pn
eh.PacketNumberLen = pnLen
return eh, nil
}

View file

@ -85,7 +85,7 @@ var _ = Describe("Header Parsing", func() {
Expect(hdr.DestConnectionID).To(Equal(destConnID))
Expect(hdr.SrcConnectionID).To(Equal(srcConnID))
b := bytes.NewReader(data)
extHdr, err := hdr.Parse(b, versionIETFFrames)
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(extHdr.IsLongHeader).To(BeTrue())
@ -142,7 +142,7 @@ var _ = Describe("Header Parsing", func() {
hdr, err := ParseHeader(bytes.NewReader(data), 0)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
extHdr, err := hdr.Parse(b, versionIETFFrames)
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123)))
Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2))
@ -161,7 +161,7 @@ var _ = Describe("Header Parsing", func() {
hdr, err := ParseHeader(bytes.NewReader(data), 0)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
extHdr, err := hdr.Parse(b, versionIETFFrames)
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(extHdr.OrigDestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
@ -183,7 +183,7 @@ var _ = Describe("Header Parsing", func() {
b := bytes.NewReader(buf.Bytes())
hdr, err := ParseHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
_, err = hdr.Parse(b, versionIETFFrames)
_, err = hdr.ParseExtended(b, versionIETFFrames)
Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42"))
})
@ -200,7 +200,7 @@ var _ = Describe("Header Parsing", func() {
b := bytes.NewReader(data)
hdr, err := ParseHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
_, err = hdr.Parse(b, versionIETFFrames)
_, err = hdr.ParseExtended(b, versionIETFFrames)
Expect(err).To(MatchError(io.EOF))
})
@ -231,7 +231,7 @@ var _ = Describe("Header Parsing", func() {
b := bytes.NewReader(data[:i])
hdr, err := ParseHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
_, err = hdr.Parse(b, versionIETFFrames)
_, err = hdr.ParseExtended(b, versionIETFFrames)
Expect(err).To(Equal(io.EOF))
}
})
@ -251,7 +251,7 @@ var _ = Describe("Header Parsing", func() {
b := bytes.NewReader(data[:i])
hdr, err := ParseHeader(b, 0)
Expect(err).ToNot(HaveOccurred())
_, err = hdr.Parse(b, versionIETFFrames)
_, err = hdr.ParseExtended(b, versionIETFFrames)
Expect(err).To(Equal(io.EOF))
}
})
@ -268,7 +268,7 @@ var _ = Describe("Header Parsing", func() {
Expect(hdr.IsVersionNegotiation()).To(BeFalse())
Expect(hdr.DestConnectionID).To(Equal(connID))
b := bytes.NewReader(data)
extHdr, err := hdr.Parse(b, versionIETFFrames)
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.KeyPhase).To(Equal(0))
Expect(extHdr.DestConnectionID).To(Equal(connID))
@ -286,7 +286,7 @@ var _ = Describe("Header Parsing", func() {
Expect(hdr.IsLongHeader()).To(BeFalse())
Expect(hdr.DestConnectionID).To(Equal(connID))
b := bytes.NewReader(data)
extHdr, err := hdr.Parse(b, versionIETFFrames)
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.KeyPhase).To(Equal(0))
Expect(extHdr.DestConnectionID).To(Equal(connID))
@ -304,7 +304,7 @@ var _ = Describe("Header Parsing", func() {
Expect(err).ToNot(HaveOccurred())
Expect(hdr.IsLongHeader()).To(BeFalse())
b := bytes.NewReader(data)
extHdr, err := hdr.Parse(b, versionIETFFrames)
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.KeyPhase).To(Equal(1))
Expect(b.Len()).To(BeZero())
@ -319,7 +319,7 @@ var _ = Describe("Header Parsing", func() {
hdr, err := ParseHeader(bytes.NewReader(data), 4)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
extHdr, err := hdr.Parse(b, versionIETFFrames)
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.IsLongHeader).To(BeFalse())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337)))
@ -336,7 +336,7 @@ var _ = Describe("Header Parsing", func() {
hdr, err := ParseHeader(bytes.NewReader(data), 10)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
extHdr, err := hdr.Parse(b, versionIETFFrames)
extHdr, err := hdr.ParseExtended(b, versionIETFFrames)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.IsLongHeader).To(BeFalse())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef)))
@ -366,7 +366,7 @@ var _ = Describe("Header Parsing", func() {
b := bytes.NewReader(data[:i])
hdr, err := ParseHeader(b, 6)
Expect(err).ToNot(HaveOccurred())
_, err = hdr.Parse(b, versionIETFFrames)
_, err = hdr.ParseExtended(b, versionIETFFrames)
Expect(err).To(Equal(io.EOF))
}
})

View file

@ -210,7 +210,7 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
if !hdr.IsVersionNegotiation() {
r = bytes.NewReader(data)
var err error
extHdr, err = hdr.Parse(r, version)
extHdr, err = hdr.ParseExtended(r, version)
if err != nil {
return fmt.Errorf("error parsing extended header: %s", err)
}

View file

@ -34,7 +34,7 @@ var _ = Describe("Packet packer", func() {
hdr, err := wire.ParseHeader(bytes.NewReader(data), 0)
Expect(err).ToNot(HaveOccurred())
r := bytes.NewReader(data)
extHdr, err := hdr.Parse(r, protocol.VersionWhatever)
extHdr, err := hdr.ParseExtended(r, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
ExpectWithOffset(0, extHdr.Length).To(BeEquivalentTo(r.Len() + int(extHdr.PacketNumberLen)))
}

View file

@ -101,7 +101,7 @@ var _ = Describe("Server", func() {
parseHeader := func(data []byte) *wire.ExtendedHeader {
hdr, err := wire.ParseHeader(bytes.NewReader(data), 0)
Expect(err).ToNot(HaveOccurred())
extHdr, err := hdr.Parse(bytes.NewReader(data), protocol.VersionTLS)
extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
return extHdr
}