diff --git a/packet_packer.go b/packet_packer.go index f616db17..2f58c967 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -564,6 +564,10 @@ func (p *packetPacker) SetOmitConnectionID() { p.omitConnectionID = true } +func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) { + p.destConnID = connID +} + func (p *packetPacker) SetMaxPacketSize(size protocol.ByteCount) { p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, size) } diff --git a/packet_packer_test.go b/packet_packer_test.go index d682cc95..e7fdec1f 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -230,6 +230,21 @@ var _ = Describe("Packet packer", func() { Expect(h.DestConnectionID).To(Equal(destConnID)) }) + It("changes the destination connection ID", func() { + srcConnID := protocol.ConnectionID{1, 1, 1, 1, 1, 1, 1, 1} + packer.srcConnID = srcConnID + dest1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + dest2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + packer.ChangeDestConnectionID(dest1) + h := packer.getHeader(protocol.EncryptionUnencrypted) + Expect(h.SrcConnectionID).To(Equal(srcConnID)) + Expect(h.DestConnectionID).To(Equal(dest1)) + packer.ChangeDestConnectionID(dest2) + h = packer.getHeader(protocol.EncryptionUnencrypted) + Expect(h.SrcConnectionID).To(Equal(srcConnID)) + Expect(h.DestConnectionID).To(Equal(dest2)) + }) + It("uses the Short Header format for forward-secure packets", func() { h := packer.getHeader(protocol.EncryptionForwardSecure) Expect(h.IsLongHeader).To(BeFalse()) diff --git a/session.go b/session.go index d542672f..c209504a 100644 --- a/session.go +++ b/session.go @@ -637,6 +637,12 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { return err } + if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && !hdr.SrcConnectionID.Equal(s.destConnID) { + s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID) + s.destConnID = hdr.SrcConnectionID + s.packer.ChangeDestConnectionID(s.destConnID) + } + s.receivedFirstPacket = true s.lastNetworkActivityTime = p.rcvTime s.keepAlivePingSent = false diff --git a/session_test.go b/session_test.go index b4d9b5d8..38aec526 100644 --- a/session_test.go +++ b/session_test.go @@ -1782,6 +1782,39 @@ var _ = Describe("Client Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) + It("changes the connection ID when receiving the first packet from the server", func() { + sess.version = protocol.VersionTLS + sess.packer.version = protocol.VersionTLS + unpacker := NewMockUnpacker(mockCtrl) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) + sess.unpacker = unpacker + go func() { + defer GinkgoRecover() + sess.run() + }() + err := sess.handlePacketImpl(&receivedPacket{ + header: &wire.Header{ + Type: protocol.PacketTypeHandshake, + SrcConnectionID: protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}, + DestConnectionID: sess.srcConnID, + }, + data: []byte{0}, + }) + Expect(err).ToNot(HaveOccurred()) + // the session should have changed the dest connection ID now + sess.packer.hasSentPacket = true + sess.queueControlFrame(&wire.PingFrame{}) + var packet []byte + Eventually(mconn.written).Should(Receive(&packet)) + hdr, err := wire.ParseHeaderSentByClient(bytes.NewReader(packet)) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7})) + // make sure the go routine returns + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + Expect(sess.Close(nil)).To(Succeed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + Context("receiving packets", func() { var hdr *wire.Header