diff --git a/internal/protocol/version.go b/internal/protocol/version.go index e6c12dc3..25f851db 100644 --- a/internal/protocol/version.go +++ b/internal/protocol/version.go @@ -80,11 +80,27 @@ func (vn VersionNumber) UsesIETFFrameFormat() bool { return !vn.isGQUIC() } +// UsesIETFHeaderFormat tells if this version uses the IETF header format +func (vn VersionNumber) UsesIETFHeaderFormat() bool { + return !vn.isGQUIC() || vn >= Version44 +} + +// UsesLengthInHeader tells if this version uses the Length field in the IETF header +func (vn VersionNumber) UsesLengthInHeader() bool { + return !vn.isGQUIC() +} + +// UsesTokenInHeader tells if this version uses the Token field in the IETF header +func (vn VersionNumber) UsesTokenInHeader() bool { + return !vn.isGQUIC() +} + // UsesStopWaitingFrames tells if this version uses STOP_WAITING frames func (vn VersionNumber) UsesStopWaitingFrames() bool { return vn.isGQUIC() && vn <= Version43 } +// UsesVarintPacketNumbers tells if this version uses 7/14/30 bit packet numbers func (vn VersionNumber) UsesVarintPacketNumbers() bool { return !vn.isGQUIC() } diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index c67eecbe..cdfaa26c 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -75,20 +75,32 @@ var _ = Describe("Version", func() { It("tells if a version uses the IETF frame types", func() { Expect(Version39.UsesIETFFrameFormat()).To(BeFalse()) Expect(Version43.UsesIETFFrameFormat()).To(BeFalse()) + Expect(Version44.UsesIETFFrameFormat()).To(BeFalse()) Expect(VersionTLS.UsesIETFFrameFormat()).To(BeTrue()) }) + It("tells if a version uses the IETF header format", func() { + Expect(Version39.UsesIETFHeaderFormat()).To(BeFalse()) + Expect(Version43.UsesIETFHeaderFormat()).To(BeFalse()) + Expect(Version44.UsesIETFHeaderFormat()).To(BeTrue()) + Expect(VersionTLS.UsesIETFHeaderFormat()).To(BeTrue()) + }) + It("tells if a version uses varint packet numbers", func() { Expect(Version39.UsesVarintPacketNumbers()).To(BeFalse()) Expect(Version43.UsesVarintPacketNumbers()).To(BeFalse()) + Expect(Version44.UsesVarintPacketNumbers()).To(BeFalse()) Expect(VersionTLS.UsesVarintPacketNumbers()).To(BeTrue()) }) - It("tells if a version uses the IETF frame types", func() { - Expect(Version39.UsesIETFFrameFormat()).To(BeFalse()) - Expect(Version43.UsesIETFFrameFormat()).To(BeFalse()) - Expect(Version44.UsesIETFFrameFormat()).To(BeFalse()) - Expect(VersionTLS.UsesIETFFrameFormat()).To(BeTrue()) + It("tells if a version uses the Length field in the IETF header", func() { + Expect(Version44.UsesLengthInHeader()).To(BeFalse()) + Expect(VersionTLS.UsesLengthInHeader()).To(BeTrue()) + }) + + It("tells if a version uses the Token field in the IETF header", func() { + Expect(Version44.UsesTokenInHeader()).To(BeFalse()) + Expect(VersionTLS.UsesTokenInHeader()).To(BeTrue()) }) It("tells if a version uses STOP_WAITING frames", func() { diff --git a/internal/wire/header.go b/internal/wire/header.go index 466ab289..99e5b930 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -42,7 +42,7 @@ type Header struct { Token []byte } -var errInvalidPacketNumberLen6 = errors.New("invalid packet number length: 6 bytes") +var errInvalidPacketNumberLen = errors.New("invalid packet number length") // Write writes the Header. func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error { @@ -155,7 +155,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ case protocol.PacketNumberLen4: utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) case protocol.PacketNumberLen6: - return errInvalidPacketNumberLen6 + return errInvalidPacketNumberLen default: return errors.New("PublicHeader: PacketNumberLen not set") } @@ -193,7 +193,7 @@ func (h *Header) getHeaderLength() (protocol.ByteCount, error) { func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) { length := protocol.ByteCount(1) // 1 byte for public flags if h.PacketNumberLen == protocol.PacketNumberLen6 { - return 0, errInvalidPacketNumberLen6 + return 0, errInvalidPacketNumberLen } if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 { return 0, errPacketNumberLenNotSet diff --git a/internal/wire/header_parser.go b/internal/wire/header_parser.go index fdf93a38..08f5b406 100644 --- a/internal/wire/header_parser.go +++ b/internal/wire/header_parser.go @@ -79,7 +79,7 @@ func (iv *InvariantHeader) Parse(b *bytes.Reader, sentBy protocol.Perspective, v if iv.Version == 0 { // Version Negotiation Packet return iv.parseVersionNegotiationPacket(b) } - return iv.parseLongHeader(b) + return iv.parseLongHeader(b, sentBy, ver) } // The Public Header never uses 6 byte packet numbers. // Therefore, the third and fourth bit will never be 11. @@ -90,8 +90,7 @@ func (iv *InvariantHeader) Parse(b *bytes.Reader, sentBy protocol.Perspective, v } return iv.parsePublicHeader(b, sentBy, ver) } - return iv.parseShortHeader(b) - + return iv.parseShortHeader(b, ver) } func (iv *InvariantHeader) toHeader() *Header { @@ -121,7 +120,7 @@ func (iv *InvariantHeader) parseVersionNegotiationPacket(b *bytes.Reader) (*Head return h, nil } -func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader) (*Header, error) { +func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, v protocol.VersionNumber) (*Header, error) { h := iv.toHeader() h.Type = protocol.PacketType(iv.typeByte & 0x7f) @@ -146,7 +145,7 @@ func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader) (*Header, error) { return h, nil } - if h.Type == protocol.PacketTypeInitial { + if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() { tokenLen, err := utils.ReadVarInt(b) if err != nil { return nil, err @@ -160,30 +159,69 @@ func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader) (*Header, error) { } } - pl, err := utils.ReadVarInt(b) - if err != nil { - return nil, err + if v.UsesLengthInHeader() { + pl, err := utils.ReadVarInt(b) + if err != nil { + return nil, err + } + h.PayloadLen = protocol.ByteCount(pl) } - h.PayloadLen = protocol.ByteCount(pl) - pn, pnLen, err := utils.ReadVarIntPacketNumber(b) - if err != nil { - return nil, err + if v.UsesVarintPacketNumbers() { + pn, pnLen, err := utils.ReadVarIntPacketNumber(b) + if err != nil { + return nil, err + } + h.PacketNumber = pn + h.PacketNumberLen = pnLen + } else { + pn, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, err + } + h.PacketNumber = protocol.PacketNumber(pn) + h.PacketNumberLen = protocol.PacketNumberLen4 + } + if h.Type == protocol.PacketType0RTT && v == protocol.Version44 && sentBy == protocol.PerspectiveServer { + h.DiversificationNonce = make([]byte, 32) + if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } } - h.PacketNumber = pn - h.PacketNumberLen = pnLen return h, nil } -func (iv *InvariantHeader) parseShortHeader(b *bytes.Reader) (*Header, error) { +func (iv *InvariantHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*Header, error) { h := iv.toHeader() h.KeyPhase = int(iv.typeByte&0x40) >> 6 - pn, pnLen, err := utils.ReadVarIntPacketNumber(b) - if err != nil { - return nil, err + + if v.UsesVarintPacketNumbers() { + pn, pnLen, err := utils.ReadVarIntPacketNumber(b) + if err != nil { + return nil, err + } + h.PacketNumber = pn + h.PacketNumberLen = pnLen + } else { + switch iv.typeByte & 0x3 { + case 0x0: + h.PacketNumberLen = protocol.PacketNumberLen1 + case 0x1: + h.PacketNumberLen = protocol.PacketNumberLen2 + case 0x2: + h.PacketNumberLen = protocol.PacketNumberLen4 + default: + return nil, errInvalidPacketNumberLen + } + p, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen)) + if err != nil { + return nil, err + } + h.PacketNumber = protocol.PacketNumber(p) } - h.PacketNumber = pn - h.PacketNumberLen = pnLen return h, nil } diff --git a/internal/wire/header_parser_test.go b/internal/wire/header_parser_test.go index 94d7986a..cf126674 100644 --- a/internal/wire/header_parser_test.go +++ b/internal/wire/header_parser_test.go @@ -383,6 +383,153 @@ var _ = Describe("Header Parsing", func() { }) }) + Context("gQUIC 44", func() { + Context("Long Headers", func() { + It("parses a Long Header", func() { + destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + srcConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + data := []byte{ + 0x80 ^ uint8(protocol.PacketTypeInitial), + 0x1, 0x2, 0x3, 0x4, // version + 0x55, // connection ID lengths + } + data = append(data, destConnID...) + data = append(data, srcConnID...) + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) + b := bytes.NewReader(data) + iHdr, err := ParseInvariantHeader(b, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(iHdr.IsLongHeader).To(BeTrue()) + Expect(iHdr.Version).To(Equal(protocol.VersionNumber(0x1020304))) + Expect(iHdr.DestConnectionID).To(Equal(destConnID)) + Expect(iHdr.SrcConnectionID).To(Equal(srcConnID)) + hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, protocol.Version44) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsPublicHeader).To(BeFalse()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdeadbeef))) + }) + + It("parses a Long Header containing a Diversification Nonce", func() { + srcConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + divNonce := bytes.Repeat([]byte{'f'}, 32) + data := []byte{ + 0x80 ^ uint8(protocol.PacketType0RTT), + 0x1, 0x2, 0x3, 0x4, // version + 0x5, // connection ID lengths + } + data = append(data, srcConnID...) + data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) + data = append(data, divNonce...) + b := bytes.NewReader(data) + iHdr, err := ParseInvariantHeader(b, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(iHdr.IsLongHeader).To(BeTrue()) + Expect(iHdr.Version).To(Equal(protocol.VersionNumber(0x1020304))) + Expect(iHdr.SrcConnectionID).To(Equal(srcConnID)) + hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, protocol.Version44) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsPublicHeader).To(BeFalse()) + Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdeadbeef))) + Expect(hdr.DiversificationNonce).To(Equal(divNonce)) + }) + + It("errors on EOF, for Long Headers containing a Diversification Nonce", func() { + data := []byte{ + 0x80 ^ uint8(protocol.PacketType0RTT), + 0x1, 0x2, 0x3, 0x4, // version + 0x5, + 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID + } + iHdrLen := len(data) + data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) // packet number + data = append(data, bytes.Repeat([]byte{'d'}, 32)...) + for i := iHdrLen; i < len(data); i++ { + b := bytes.NewReader(data[:i]) + iHdr, err := ParseInvariantHeader(b, 8) + Expect(err).ToNot(HaveOccurred()) + Expect(iHdr.IsLongHeader).To(BeTrue()) + _, err = iHdr.Parse(b, protocol.PerspectiveServer, protocol.Version44) + Expect(err).To(Equal(io.EOF)) + } + }) + }) + + Context("Short Headers", func() { + It("parses a Short Header with a 1 byte packet number", func() { + destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + data := []byte{0x30} + data = append(data, destConnID...) + data = append(data, 0x42) // packet number + b := bytes.NewReader(data) + iHdr, err := ParseInvariantHeader(b, 8) + Expect(err).ToNot(HaveOccurred()) + Expect(iHdr.IsLongHeader).To(BeFalse()) + Expect(iHdr.DestConnectionID).To(Equal(destConnID)) + hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, protocol.Version44) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsPublicHeader).To(BeFalse()) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + }) + + It("parses a Short Header with a 2 byte packet number", func() { + data := []byte{0x30 ^ 0x1, 0xca, 0xfe} + b := bytes.NewReader(data) + iHdr, err := ParseInvariantHeader(b, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(iHdr.IsLongHeader).To(BeFalse()) + Expect(iHdr.DestConnectionID.Len()).To(BeZero()) + hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, protocol.Version44) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsPublicHeader).To(BeFalse()) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xcafe))) + }) + + It("parses a Short Header with a 4 byte packet number", func() { + data := []byte{0x30 ^ 0x2, 0xde, 0xad, 0xbe, 0xef} + b := bytes.NewReader(data) + iHdr, err := ParseInvariantHeader(b, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(iHdr.IsLongHeader).To(BeFalse()) + Expect(iHdr.DestConnectionID.Len()).To(BeZero()) + hdr, err := iHdr.Parse(b, protocol.PerspectiveServer, protocol.Version44) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.IsPublicHeader).To(BeFalse()) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdeadbeef))) + }) + + It("errors on an invalid packet number length flag", func() { + data := []byte{0x30 ^ 0x3, 0xde, 0xad, 0xbe, 0xef} + b := bytes.NewReader(data) + iHdr, err := ParseInvariantHeader(b, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(iHdr.IsLongHeader).To(BeFalse()) + Expect(iHdr.DestConnectionID.Len()).To(BeZero()) + _, err = iHdr.Parse(b, protocol.PerspectiveServer, protocol.Version44) + Expect(err).To(MatchError(errInvalidPacketNumberLen)) + }) + + It("errors on EOF", func() { + data := []byte{0x30 ^ 0x2, 0xde, 0xad, 0xbe, 0xef} + iHdrLen := 1 + for i := iHdrLen; i < len(data); i++ { + b := bytes.NewReader(data[:i]) + iHdr, err := ParseInvariantHeader(b, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(iHdr.IsLongHeader).To(BeFalse()) + _, err = iHdr.Parse(b, protocol.PerspectiveServer, protocol.Version44) + Expect(err).To(Equal(io.EOF)) + } + }) + }) + }) + Context("Public Header", func() { It("accepts a sample client header", func() { data := []byte{ diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index a9a4fcf1..2f53e189 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -390,7 +390,7 @@ var _ = Describe("Header", func() { PacketNumberLen: protocol.PacketNumberLen6, } err := hdr.writePublicHeader(buf, protocol.PerspectiveServer, versionPublicHeader) - Expect(err).To(MatchError(errInvalidPacketNumberLen6)) + Expect(err).To(MatchError(errInvalidPacketNumberLen)) }) }) })