testutils: add a perspective function parameter to ComposeInitialPacket (#4276)

Currently not used, but this is useful when crafting Initial packets
sent from the client. No functional change expected.
This commit is contained in:
Marten Seemann 2024-01-29 12:30:23 +07:00 committed by GitHub
parent 940feef063
commit 03ba124241
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 15 additions and 10 deletions

View file

@ -3167,7 +3167,7 @@ var _ = Describe("Client Connection", func() {
// the connection to immediately break down // the connection to immediately break down
It("fails on Initial-level ACK for unsent packet", func() { It("fails on Initial-level ACK for unsent packet", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{ack}) initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, destConnID, []wire.Frame{ack}, protocol.PerspectiveServer, conn.version)
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse())
}) })
@ -3179,7 +3179,7 @@ var _ = Describe("Client Connection", func() {
IsApplicationError: true, IsApplicationError: true,
ReasonPhrase: "mitm attacker", ReasonPhrase: "mitm attacker",
} }
initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{connCloseFrame}) initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, destConnID, []wire.Frame{connCloseFrame}, protocol.PerspectiveServer, conn.version)
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue())
}) })
@ -3197,7 +3197,7 @@ var _ = Describe("Client Connection", func() {
tracer.EXPECT().ReceivedRetry(gomock.Any()) tracer.EXPECT().ReceivedRetry(gomock.Any())
conn.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), conn.version))) conn.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), conn.version)))
initialPacket := testutils.ComposeInitialPacket(conn.connIDManager.Get(), srcConnID, conn.version, conn.connIDManager.Get(), nil) initialPacket := testutils.ComposeInitialPacket(conn.connIDManager.Get(), srcConnID, conn.connIDManager.Get(), nil, protocol.PerspectiveServer, conn.version)
tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), gomock.Any()) tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), gomock.Any())
Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse())
}) })

View file

@ -417,7 +417,7 @@ var _ = Describe("MITM test", func() {
} }
defer close(done) defer close(done)
injected = true injected = true
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil) initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, protocol.PerspectiveServer, hdr.Version)
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()) _, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
@ -449,7 +449,7 @@ var _ = Describe("MITM test", func() {
injected = true injected = true
// Fake Initial with ACK for packet 2 (unsent) // Fake Initial with ACK for packet 2 (unsent)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack}) initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, []wire.Frame{ack}, protocol.PerspectiveServer, hdr.Version)
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()) _, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }

View file

@ -33,10 +33,15 @@ func packRawPayload(version protocol.VersionNumber, frames []wire.Frame) []byte
return b return b
} }
// ComposeInitialPacket returns an Initial packet encrypted under key // ComposeInitialPacket returns an Initial packet encrypted under key (the original destination connection ID)
// (the original destination connection ID) containing specified frames // containing specified frames.
func ComposeInitialPacket(srcConnID protocol.ConnectionID, destConnID protocol.ConnectionID, version protocol.VersionNumber, key protocol.ConnectionID, frames []wire.Frame) []byte { func ComposeInitialPacket(
sealer, _ := handshake.NewInitialAEAD(key, protocol.PerspectiveServer, version) srcConnID, destConnID, key protocol.ConnectionID,
frames []wire.Frame,
sentBy protocol.Perspective,
version protocol.VersionNumber,
) []byte {
sealer, _ := handshake.NewInitialAEAD(key, sentBy, version)
// compose payload // compose payload
var payload []byte var payload []byte
@ -48,7 +53,7 @@ func ComposeInitialPacket(srcConnID protocol.ConnectionID, destConnID protocol.C
// compose Initial header // compose Initial header
payloadSize := len(payload) payloadSize := len(payload)
pnLength := protocol.PacketNumberLen4 const pnLength = protocol.PacketNumberLen4
length := payloadSize + int(pnLength) + sealer.Overhead() length := payloadSize + int(pnLength) + sealer.Overhead()
hdr := &wire.ExtendedHeader{ hdr := &wire.ExtendedHeader{
Header: wire.Header{ Header: wire.Header{