From 03489f56a7a2639cd2a733e6b6035f71530ed960 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 12 Dec 2018 13:06:16 +0630 Subject: [PATCH] handle the packet length before parsing the extended header --- internal/wire/extended_header.go | 2 +- internal/wire/header.go | 11 ++++++++--- internal/wire/header_test.go | 2 ++ packet_unpacker.go | 20 +++++++++----------- packet_unpacker_test.go | 16 ---------------- 5 files changed, 20 insertions(+), 31 deletions(-) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index a08dd487..b95c9cc6 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -30,7 +30,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*Exte if err != nil { return nil, err } - if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil { + if _, err := b.Seek(int64(h.ParsedLen())-1, io.SeekCurrent); err != nil { return nil, err } if h.IsLongHeader { diff --git a/internal/wire/header.go b/internal/wire/header.go index c40d40b2..5b0d6eff 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -24,8 +24,8 @@ type Header struct { SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet - typeByte byte - len int // how many bytes were read while parsing this header + typeByte byte + parsedLen protocol.ByteCount // how many bytes were read while parsing this header } // ParseHeader parses the header. @@ -39,7 +39,7 @@ func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { if err != nil { return nil, err } - h.len = startLen - b.Len() + h.parsedLen = protocol.ByteCount(startLen - b.Len()) return h, nil } @@ -171,6 +171,11 @@ func (h *Header) IsVersionNegotiation() bool { return h.IsLongHeader && h.Version == 0 } +// ParsedLen returns the number of bytes that were consumed when parsing the header +func (h *Header) ParsedLen() protocol.ByteCount { + return h.parsedLen +} + // ParseExtended parses the version dependent part of the header. // The Reader has to be set such that it points to the first byte of the header. func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index ff56b74a..6d138322 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -74,6 +74,7 @@ var _ = Describe("Header Parsing", func() { data = append(data, encodeVarInt(6)...) // token length data = append(data, []byte("foobar")...) // token data = append(data, encodeVarInt(0x1337)...) // length + hdrLen := len(data) data = append(data, []byte{0, 0, 0xbe, 0xef}...) hdr, err := ParseHeader(bytes.NewReader(data), 0) @@ -92,6 +93,7 @@ var _ = Describe("Header Parsing", func() { Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef))) Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) Expect(b.Len()).To(BeZero()) + Expect(hdr.ParsedLen()).To(BeEquivalentTo(hdrLen)) }) It("errors if 0x40 is not set", func() { diff --git a/packet_unpacker.go b/packet_unpacker.go index c8887cdd..cb7f2c0f 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -43,6 +43,15 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker { func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { r := bytes.NewReader(data) + + if hdr.IsLongHeader { + if protocol.ByteCount(r.Len()) < hdr.Length { + return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + data = data[:int(hdr.ParsedLen()+hdr.Length)] + // TODO(#1312): implement parsing of compound packets + } + extHdr, err := hdr.ParseExtended(r, u.version) if err != nil { return nil, fmt.Errorf("error parsing extended header: %s", err) @@ -50,17 +59,6 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, extHdr.Raw = data[:len(data)-r.Len()] data = data[len(data)-r.Len():] - if hdr.IsLongHeader { - if hdr.Length < protocol.ByteCount(extHdr.PacketNumberLen) { - return nil, fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", extHdr.Length, extHdr.PacketNumberLen) - } - if protocol.ByteCount(len(data))+protocol.ByteCount(extHdr.PacketNumberLen) < extHdr.Length { - return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(extHdr.PacketNumberLen), extHdr.Length) - } - data = data[:int(extHdr.Length)-int(extHdr.PacketNumberLen)] - // TODO(#1312): implement parsing of compound packets - } - pn := protocol.DecodePacketNumber( extHdr.PacketNumberLen, u.largestRcvdPacketNumber, diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index ec67a5ce..331289ba 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -103,22 +103,6 @@ var _ = Describe("Packet Unpacker", func() { Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) }) - It("errors when receiving a packet that has a length smaller than the packet number length", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: connID, - Type: protocol.PacketTypeHandshake, - Length: 3, - Version: protocol.VersionTLS, - }, - PacketNumberLen: protocol.PacketNumberLen4, - } - hdr, hdrRaw := getHeader(extHdr) - _, err := unpacker.Unpack(hdr, hdrRaw) - Expect(err).To(MatchError("packet length (3 bytes) shorter than packet number (4 bytes)")) - }) - It("cuts packets to the right length", func() { pnLen := protocol.PacketNumberLen2 extHdr := &wire.ExtendedHeader{