From e89fc1152bc2f704445f24ec7fc7fda0ce75f2f0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 27 Dec 2022 10:02:45 +1300 Subject: [PATCH] stop using the ExtendedHeader for writing short header packets in tests --- connection_test.go | 86 +++++++------------- fuzzing/header/cmd/corpus.go | 14 +++- fuzzing/header/fuzz.go | 3 + integrationtests/self/mitm_test.go | 73 +++++++++++------ internal/wire/extended_header.go | 15 +--- internal/wire/extended_header_test.go | 113 -------------------------- internal/wire/header_test.go | 33 -------- internal/wire/short_header.go | 3 + packet_unpacker_test.go | 70 ++++++---------- 9 files changed, 121 insertions(+), 289 deletions(-) diff --git a/connection_test.go b/connection_test.go index ef433641..c73b73cf 100644 --- a/connection_test.go +++ b/connection_test.go @@ -572,11 +572,7 @@ var _ = Describe("Connection", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() buf := &bytes.Buffer{} - hdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen2, - } - Expect(hdr.Write(buf, conn.version)).To(Succeed()) + Expect(wire.WriteShortHeader(buf, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed()) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version) @@ -652,7 +648,18 @@ var _ = Describe("Connection", func() { conn.unpacker = unpacker }) - getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { + getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) *receivedPacket { + buf := &bytes.Buffer{} + Expect(wire.WriteShortHeader(buf, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed()) + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + rcvTime: time.Now(), + } + } + + getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { + ExpectWithOffset(1, extHdr.IsLongHeader).To(BeTrue()) buf := &bytes.Buffer{} Expect(extHdr.Write(buf, conn.version)).To(Succeed()) return &receivedPacket{ @@ -663,7 +670,7 @@ var _ = Describe("Connection", func() { } It("drops Retry packets", func() { - p := getPacket(&wire.ExtendedHeader{Header: wire.Header{ + p := getLongHeaderPacket(&wire.ExtendedHeader{Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, DestConnectionID: destConnID, @@ -689,7 +696,7 @@ var _ = Describe("Connection", func() { }) It("drops packets for which header decryption fails", func() { - p := getPacket(&wire.ExtendedHeader{ + p := getLongHeaderPacket(&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -703,7 +710,7 @@ var _ = Describe("Connection", func() { }) It("drops packets for which the version is unsupported", func() { - p := getPacket(&wire.ExtendedHeader{ + p := getLongHeaderPacket(&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -723,7 +730,7 @@ var _ = Describe("Connection", func() { }() protocol.SupportedVersions = append(protocol.SupportedVersions, conn.version+1) - p := getPacket(&wire.ExtendedHeader{ + p := getLongHeaderPacket(&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -751,7 +758,7 @@ var _ = Describe("Connection", func() { } unpackedHdr := *hdr unpackedHdr.PacketNumber = 0x1337 - packet := getPacket(hdr, nil) + packet := getLongHeaderPacket(hdr, nil) packet.ecn = protocol.ECNCE rcvTime := time.Now().Add(-10 * time.Second) unpacker.EXPECT().UnpackLongHeader(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ @@ -772,15 +779,10 @@ var _ = Describe("Connection", func() { }) It("informs the ReceivedPacketHandler about ack-eliciting packets", func() { - hdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumber: 0x37, - PacketNumberLen: protocol.PacketNumberLen1, - } rcvTime := time.Now().Add(-10 * time.Second) b, err := (&wire.PingFrame{}).Append(nil, conn.version) Expect(err).ToNot(HaveOccurred()) - packet := getPacket(hdr, nil) + packet := getShortHeaderPacket(srcConnID, 0x37, nil) packet.ecn = protocol.ECT1 unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, b, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) @@ -795,12 +797,7 @@ var _ = Describe("Connection", func() { }) It("drops duplicate packets", func() { - hdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumber: 0x37, - PacketNumberLen: protocol.PacketNumberLen1, - } - packet := getPacket(hdr, nil) + packet := getShortHeaderPacket(srcConnID, 0x37, nil) unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseOne, []byte("foobar"), nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) @@ -820,7 +817,7 @@ var _ = Describe("Connection", func() { conn.run() }() expectReplaceWithClosed() - p := getPacket(&wire.ExtendedHeader{ + p := getLongHeaderPacket(&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -854,11 +851,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackCoalescedPacket(false) // only expect a single call for i := 0; i < 3; i++ { - conn.handlePacket(getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumber: 0x1337, - PacketNumberLen: protocol.PacketNumberLen2, - }, []byte("foobar"))) + conn.handlePacket(getShortHeaderPacket(srcConnID, 0x1337+protocol.PacketNumber(i), []byte("foobar"))) } go func() { @@ -893,11 +886,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackCoalescedPacket(false).Times(3) // only expect a single call for i := 0; i < 3; i++ { - conn.handlePacket(getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumber: 0x1337, - PacketNumberLen: protocol.PacketNumberLen2, - }, []byte("foobar"))) + conn.handlePacket(getShortHeaderPacket(srcConnID, 0x1337+protocol.PacketNumber(i), []byte("foobar"))) } go func() { @@ -936,10 +925,7 @@ var _ = Describe("Connection", func() { }() expectReplaceWithClosed() mconn.EXPECT().Write(gomock.Any()) - packet := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen1, - }, nil) + packet := getShortHeaderPacket(srcConnID, 0x42, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() conn.handlePacket(packet) @@ -959,10 +945,7 @@ var _ = Describe("Connection", func() { }() expectReplaceWithClosed() tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, gomock.Any(), logging.PacketDropHeaderParseError) - conn.handlePacket(getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen1, - }, nil)) + conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) Consistently(runErr).ShouldNot(Receive()) // make the go routine return packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) @@ -990,13 +973,9 @@ var _ = Describe("Connection", func() { }() expectReplaceWithClosed() mconn.EXPECT().Write(gomock.Any()) - packet := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen1, - }, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.handlePacket(packet) + conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -1033,12 +1012,12 @@ var _ = Describe("Connection", func() { hdr: hdr1, data: []byte{0}, // one PADDING frame }, nil) - p1 := getPacket(hdr1, nil) + p1 := getLongHeaderPacket(hdr1, nil) tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any()) Expect(conn.handlePacketImpl(p1)).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. - p2 := getPacket(hdr2, nil) + p2 := getLongHeaderPacket(hdr2, nil) tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.ByteCount(len(p2.data)), logging.PacketDropUnknownConnectionID) Expect(conn.handlePacketImpl(p2)).To(BeFalse()) }) @@ -1058,7 +1037,7 @@ var _ = Describe("Connection", func() { PacketNumber: 1, } unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable) - packet := getPacket(hdr, nil) + packet := getLongHeaderPacket(hdr, nil) tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size()) Expect(conn.handlePacketImpl(packet)).To(BeFalse()) Expect(conn.undecryptablePackets).To(Equal([]*receivedPacket{packet})) @@ -1067,10 +1046,7 @@ var _ = Describe("Connection", func() { Context("updating the remote address", func() { It("doesn't support connection migration", func() { unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* one PADDING frame */, nil) - packet := getPacket(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: srcConnID}, - PacketNumberLen: protocol.PacketNumberLen1, - }, nil) + packet := getShortHeaderPacket(srcConnID, 0x42, nil) packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) @@ -1096,7 +1072,7 @@ var _ = Describe("Connection", func() { hdrLen := hdr.GetLength(conn.version) b := make([]byte, 1) rand.Read(b) - packet := getPacket(hdr, bytes.Repeat(b, int(length)-3)) + packet := getLongHeaderPacket(hdr, bytes.Repeat(b, int(length)-3)) return int(hdrLen), packet } diff --git a/fuzzing/header/cmd/corpus.go b/fuzzing/header/cmd/corpus.go index 0c02b699..e5622213 100644 --- a/fuzzing/header/cmd/corpus.go +++ b/fuzzing/header/cmd/corpus.go @@ -11,7 +11,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -const version = protocol.VersionTLS +const version = protocol.Version1 func getRandomData(l int) []byte { b := make([]byte, l) @@ -84,9 +84,6 @@ func main() { Token: getRandomData(1000), Version: version, }, - { // Short-Header - DestConnectionID: protocol.ParseConnectionID(getRandomData(8)), - }, } for _, h := range headers { @@ -111,6 +108,15 @@ func main() { } } + // short header + b := &bytes.Buffer{} + if err := wire.WriteShortHeader(b, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne); err != nil { + log.Fatal(err) + } + if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil { + log.Fatal(err) + } + vnps := [][]byte{ getVNP( protocol.ArbitraryLenConnectionID(getRandomData(8)), diff --git a/fuzzing/header/fuzz.go b/fuzzing/header/fuzz.go index ea0054d8..7c561ea1 100644 --- a/fuzzing/header/fuzz.go +++ b/fuzzing/header/fuzz.go @@ -58,6 +58,9 @@ func Fuzz(data []byte) int { if hdr.IsLongHeader && hdr.Length > 16383 { return 1 } + if !hdr.IsLongHeader { + return 1 + } b := &bytes.Buffer{} if err := extHdr.Write(b, version); err != nil { // We are able to parse packets with connection IDs longer than 20 bytes, diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index f965acb5..790efdd8 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -95,34 +95,57 @@ var _ = Describe("MITM test", func() { sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) { defer GinkgoRecover() - hdr, _, _, err := wire.ParsePacket(raw, connIDLen) - Expect(err).ToNot(HaveOccurred()) - replyHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: hdr.IsLongHeader, - DestConnectionID: hdr.DestConnectionID, - SrcConnectionID: hdr.SrcConnectionID, - Type: hdr.Type, - Version: hdr.Version, - }, - PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)), - PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1), - } - const numPackets = 10 ticker := time.NewTicker(rtt / numPackets) - for i := 0; i < numPackets; i++ { - payloadLen := mrand.Int31n(100) - replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1)) - buf := &bytes.Buffer{} - Expect(replyHdr.Write(buf, version)).To(Succeed()) - b := make([]byte, payloadLen) - mrand.Read(b) - buf.Write(b) - if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { - return + defer ticker.Stop() + + if wire.IsLongHeaderPacket(raw[0]) { + hdr, _, _, err := wire.ParsePacket(raw, connIDLen) + Expect(err).ToNot(HaveOccurred()) + replyHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: hdr.IsLongHeader, + DestConnectionID: hdr.DestConnectionID, + SrcConnectionID: hdr.SrcConnectionID, + Type: hdr.Type, + Version: hdr.Version, + }, + PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)), + PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1), + } + + for i := 0; i < numPackets; i++ { + payloadLen := mrand.Int31n(100) + replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1)) + buf := &bytes.Buffer{} + Expect(replyHdr.Write(buf, version)).To(Succeed()) + b := make([]byte, payloadLen) + mrand.Read(b) + buf.Write(b) + if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { + return + } + <-ticker.C + } + } else { + connID, err := wire.ParseConnectionID(raw, connIDLen) + Expect(err).ToNot(HaveOccurred()) + _, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen) + if err != nil { // normally, ParseShortHeader is called after decrypting the header + Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) + } + for i := 0; i < numPackets; i++ { + buf := &bytes.Buffer{} + Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))).To(Succeed()) + payloadLen := mrand.Int31n(100) + b := make([]byte, payloadLen) + mrand.Read(b) + buf.Write(b) + if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { + return + } + <-ticker.C } - <-ticker.C } } diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 74f2fc5f..310546f2 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -124,7 +124,7 @@ func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) erro if h.IsLongHeader { return h.writeLongHeader(b, ver) } - return h.writeShortHeader(b, ver) + panic("tried to write short extended header") } func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error { @@ -180,17 +180,6 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.Versi return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen) } -func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { - typeByte := 0x40 | uint8(h.PacketNumberLen-1) - if h.KeyPhase == protocol.KeyPhaseOne { - typeByte |= byte(1 << 2) - } - - b.WriteByte(typeByte) - b.Write(h.DestConnectionID.Bytes()) - return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen) -} - // ParsedLen returns the number of bytes that were consumed when parsing the header func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { return h.parsedLen @@ -228,7 +217,7 @@ func (h *ExtendedHeader) Log(logger utils.Logger) { } logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) } else { - logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + panic("logged short ExtendedHeader") } } diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go index 0b1e5824..0c5a2404 100644 --- a/internal/wire/extended_header_test.go +++ b/internal/wire/extended_header_test.go @@ -174,74 +174,6 @@ var _ = Describe("Header", func() { Expect(buf.Bytes()[0]>>4&0b11 == 0b10) }) }) - - Context("short header", func() { - It("writes a header with connection ID", func() { - Expect((&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), - }, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()).To(Equal([]byte{ - 0x40, - 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID - 0x42, // packet number - })) - }) - - It("writes a header without connection ID", func() { - Expect((&ExtendedHeader{ - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()).To(Equal([]byte{ - 0x40, - 0x42, // packet number - })) - }) - - It("writes a header with a 2 byte packet number", func() { - Expect((&ExtendedHeader{ - PacketNumberLen: protocol.PacketNumberLen2, - PacketNumber: 0x765, - }).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0x40 | 0x1} - expected = append(expected, []byte{0x7, 0x65}...) // packet number - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("writes a header with a 4 byte packet number", func() { - Expect((&ExtendedHeader{ - PacketNumberLen: protocol.PacketNumberLen4, - PacketNumber: 0x12345678, - }).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{0x40 | 0x3} - expected = append(expected, []byte{0x12, 0x34, 0x56, 0x78}...) - Expect(buf.Bytes()).To(Equal(expected)) - }) - - It("errors when given an invalid packet number length", func() { - err := (&ExtendedHeader{ - PacketNumberLen: 5, - PacketNumber: 0xdecafbad, - }).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid packet number length: 5")) - }) - - It("writes the Key Phase Bit", func() { - Expect((&ExtendedHeader{ - KeyPhase: protocol.KeyPhaseOne, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, - }).Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Bytes()).To(Equal([]byte{ - 0x40 | 0x4, - 0x42, // packet number - })) - }) - }) }) Context("getting the length", func() { @@ -336,39 +268,6 @@ var _ = Describe("Header", func() { Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) Expect(buf.Len()).To(Equal(expectedLen)) }) - - It("has the right length for a Short Header containing a connection ID", func() { - h := &ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), - }, - PacketNumberLen: protocol.PacketNumberLen1, - } - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 8 + 1))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(10)) - }) - - It("has the right length for a short header without a connection ID", func() { - h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1} - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 1))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(2)) - }) - - It("has the right length for a short header with a 2 byte packet number", func() { - h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen2} - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 2))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(3)) - }) - - It("has the right length for a short header with a 5 byte packet number", func() { - h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen4} - Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 4))) - Expect(h.Write(buf, versionIETFHeader)).To(Succeed()) - Expect(buf.Len()).To(Equal(5)) - }) }) Context("Logging", func() { @@ -450,17 +349,5 @@ var _ = Describe("Header", func() { }).Log(logger) Expect(buf.String()).To(ContainSubstring("Long Header{Type: Retry, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0x123456, Version: 0xfeed}")) }) - - It("logs Short Headers containing a connection ID", func() { - (&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), - }, - KeyPhase: protocol.KeyPhaseOne, - PacketNumber: 1337, - PacketNumberLen: 4, - }).Log(logger) - Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}")) - }) }) }) diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 6f96d7c8..87d854b7 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -31,39 +31,6 @@ var _ = Describe("Header Parsing", func() { Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) }) - It("parses the connection ID of a short header packet", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - buf.Write([]byte("foobar")) - connID, err := ParseConnectionID(buf.Bytes(), 4) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) - }) - - It("errors on EOF, for short header packets", func() { - buf := &bytes.Buffer{} - Expect((&ExtendedHeader{ - Header: Header{ - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), - }, - PacketNumberLen: 2, - }).Write(buf, protocol.Version1)).To(Succeed()) - data := buf.Bytes()[:buf.Len()-2] // cut the packet number - _, err := ParseConnectionID(data, 8) - Expect(err).ToNot(HaveOccurred()) - for i := 0; i < len(data); i++ { - b := make([]byte, i) - copy(b, data[:i]) - _, err := ParseConnectionID(b, 8) - Expect(err).To(MatchError(io.EOF)) - } - }) - It("errors on EOF, for long header packets", func() { buf := &bytes.Buffer{} Expect((&ExtendedHeader{ diff --git a/internal/wire/short_header.go b/internal/wire/short_header.go index 57913aaf..77308ad2 100644 --- a/internal/wire/short_header.go +++ b/internal/wire/short_header.go @@ -10,6 +10,9 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) +// ParseShortHeader parses a short header packet. +// It must be called after header protection was removed. +// Otherwise, the check for the reserved bits will (most likely) fail. func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.PacketNumber, _ protocol.PacketNumberLen, _ protocol.KeyPhaseBit, _ error) { if len(data) == 0 { return 0, 0, 0, 0, io.EOF diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 4d43999b..b4c157e3 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -27,7 +27,8 @@ var _ = Describe("Packet Unpacker", func() { payload = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ) - getHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { + getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { + ExpectWithOffset(1, extHdr.IsLongHeader).To(BeTrue()) buf := &bytes.Buffer{} ExpectWithOffset(1, extHdr.Write(buf, version)).To(Succeed()) hdrLen := buf.Len() @@ -39,6 +40,12 @@ var _ = Describe("Packet Unpacker", func() { return hdr, buf.Bytes()[:hdrLen] } + getShortHeader := func(connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) []byte { + buf := &bytes.Buffer{} + Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, kp)).To(Succeed()) + return buf.Bytes() + } + BeforeEach(func() { cs = mocks.NewMockCryptoSetup(mockCtrl) unpacker = newPacketUnpacker(cs, 4, version).(*packetUnpacker) @@ -55,7 +62,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 1337, PacketNumberLen: protocol.PacketNumberLen2, } - hdr, hdrRaw := getHeader(extHdr) + hdr, hdrRaw := getLongHeader(extHdr) data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) @@ -67,12 +74,9 @@ var _ = Describe("Packet Unpacker", func() { }) It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() { - _, hdrRaw := getHeader(&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 1337, - PacketNumberLen: protocol.PacketNumberLen2, - }) - data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) + buf := &bytes.Buffer{} + Expect(wire.WriteShortHeader(buf, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed()) + data := append(buf.Bytes(), make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) opener := mocks.NewMockShortHeaderOpener(mockCtrl) cs.EXPECT().Get1RTTOpener().Return(opener, nil) _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), data) @@ -92,7 +96,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 2, PacketNumberLen: 3, } - hdr, hdrRaw := getHeader(extHdr) + hdr, hdrRaw := getLongHeader(extHdr) opener := mocks.NewMockLongHeaderOpener(mockCtrl) gomock.InOrder( cs.EXPECT().GetInitialOpener().Return(opener, nil), @@ -118,7 +122,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 20, PacketNumberLen: 2, } - hdr, hdrRaw := getHeader(extHdr) + hdr, hdrRaw := getLongHeader(extHdr) opener := mocks.NewMockLongHeaderOpener(mockCtrl) gomock.InOrder( cs.EXPECT().Get0RTTOpener().Return(opener, nil), @@ -133,13 +137,7 @@ var _ = Describe("Packet Unpacker", func() { }) It("opens short header packets", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - KeyPhase: protocol.KeyPhaseOne, - PacketNumber: 99, - PacketNumberLen: protocol.PacketNumberLen4, - } - _, hdrRaw := getHeader(extHdr) + hdrRaw := getShortHeader(connID, 99, protocol.PacketNumberLen4, protocol.KeyPhaseOne) opener := mocks.NewMockShortHeaderOpener(mockCtrl) now := time.Now() gomock.InOrder( @@ -157,12 +155,7 @@ var _ = Describe("Packet Unpacker", func() { }) It("returns the error when getting the opener fails", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 0x1337, - PacketNumberLen: 2, - } - _, hdrRaw := getHeader(extHdr) + hdrRaw := getShortHeader(connID, 0x1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne) cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable) _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable)) @@ -179,7 +172,7 @@ var _ = Describe("Packet Unpacker", func() { KeyPhase: protocol.KeyPhaseOne, PacketNumberLen: protocol.PacketNumberLen4, } - hdr, hdrRaw := getHeader(extHdr) + hdr, hdrRaw := getLongHeader(extHdr) opener := mocks.NewMockLongHeaderOpener(mockCtrl) gomock.InOrder( cs.EXPECT().GetHandshakeOpener().Return(opener, nil), @@ -195,12 +188,7 @@ var _ = Describe("Packet Unpacker", func() { }) It("errors on empty packets, for short header packets", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - KeyPhase: protocol.KeyPhaseOne, - PacketNumberLen: protocol.PacketNumberLen4, - } - _, hdrRaw := getHeader(extHdr) + hdrRaw := getShortHeader(connID, 0x42, protocol.PacketNumberLen4, protocol.KeyPhaseOne) opener := mocks.NewMockShortHeaderOpener(mockCtrl) now := time.Now() gomock.InOrder( @@ -228,7 +216,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 2, PacketNumberLen: 3, } - hdr, hdrRaw := getHeader(extHdr) + hdr, hdrRaw := getLongHeader(extHdr) opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) @@ -250,7 +238,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 0x1337, PacketNumberLen: 2, } - hdr, hdrRaw := getHeader(extHdr) + hdr, hdrRaw := getLongHeader(extHdr) hdrRaw[0] |= 0xc opener := mocks.NewMockLongHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) @@ -262,12 +250,7 @@ var _ = Describe("Packet Unpacker", func() { }) It("defends against the timing side-channel when the reserved bits are wrong, for short header packets", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 0x1337, - PacketNumberLen: 2, - } - _, hdrRaw := getHeader(extHdr) + hdrRaw := getShortHeader(connID, 0x1337, protocol.PacketNumberLen2, protocol.KeyPhaseZero) hdrRaw[0] |= 0x18 opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) @@ -289,7 +272,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 0x1337, PacketNumberLen: 2, } - hdr, hdrRaw := getHeader(extHdr) + hdr, hdrRaw := getLongHeader(extHdr) hdrRaw[0] |= 0x18 opener := mocks.NewMockLongHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) @@ -301,12 +284,7 @@ var _ = Describe("Packet Unpacker", func() { }) It("returns the decryption error, when unpacking a packet with wrong reserved bits fails, for short headers", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 0x1337, - PacketNumberLen: 2, - } - _, hdrRaw := getHeader(extHdr) + hdrRaw := getShortHeader(connID, 0x1337, protocol.PacketNumberLen2, protocol.KeyPhaseZero) hdrRaw[0] |= 0x18 opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) @@ -329,7 +307,7 @@ var _ = Describe("Packet Unpacker", func() { PacketNumber: 0x1337, PacketNumberLen: 2, } - hdr, hdrRaw := getHeader(extHdr) + hdr, hdrRaw := getLongHeader(extHdr) origHdrRaw := append([]byte{}, hdrRaw...) // save a copy of the header firstHdrByte := hdrRaw[0] hdrRaw[0] ^= 0xff // invert the first byte