use the connection ID provided by the server in first Handshake packet

This commit is contained in:
Marten Seemann 2018-05-16 21:33:17 +09:00
parent d1316f2566
commit ec6118f7a0
4 changed files with 58 additions and 0 deletions

View file

@ -564,6 +564,10 @@ func (p *packetPacker) SetOmitConnectionID() {
p.omitConnectionID = true
}
func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {
p.destConnID = connID
}
func (p *packetPacker) SetMaxPacketSize(size protocol.ByteCount) {
p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, size)
}

View file

@ -230,6 +230,21 @@ var _ = Describe("Packet packer", func() {
Expect(h.DestConnectionID).To(Equal(destConnID))
})
It("changes the destination connection ID", func() {
srcConnID := protocol.ConnectionID{1, 1, 1, 1, 1, 1, 1, 1}
packer.srcConnID = srcConnID
dest1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
dest2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
packer.ChangeDestConnectionID(dest1)
h := packer.getHeader(protocol.EncryptionUnencrypted)
Expect(h.SrcConnectionID).To(Equal(srcConnID))
Expect(h.DestConnectionID).To(Equal(dest1))
packer.ChangeDestConnectionID(dest2)
h = packer.getHeader(protocol.EncryptionUnencrypted)
Expect(h.SrcConnectionID).To(Equal(srcConnID))
Expect(h.DestConnectionID).To(Equal(dest2))
})
It("uses the Short Header format for forward-secure packets", func() {
h := packer.getHeader(protocol.EncryptionForwardSecure)
Expect(h.IsLongHeader).To(BeFalse())

View file

@ -637,6 +637,12 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
return err
}
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && !hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID)
s.destConnID = hdr.SrcConnectionID
s.packer.ChangeDestConnectionID(s.destConnID)
}
s.receivedFirstPacket = true
s.lastNetworkActivityTime = p.rcvTime
s.keepAlivePingSent = false

View file

@ -1782,6 +1782,39 @@ var _ = Describe("Client Session", func() {
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("changes the connection ID when receiving the first packet from the server", func() {
sess.version = protocol.VersionTLS
sess.packer.version = protocol.VersionTLS
unpacker := NewMockUnpacker(mockCtrl)
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)
sess.unpacker = unpacker
go func() {
defer GinkgoRecover()
sess.run()
}()
err := sess.handlePacketImpl(&receivedPacket{
header: &wire.Header{
Type: protocol.PacketTypeHandshake,
SrcConnectionID: protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7},
DestConnectionID: sess.srcConnID,
},
data: []byte{0},
})
Expect(err).ToNot(HaveOccurred())
// the session should have changed the dest connection ID now
sess.packer.hasSentPacket = true
sess.queueControlFrame(&wire.PingFrame{})
var packet []byte
Eventually(mconn.written).Should(Receive(&packet))
hdr, err := wire.ParseHeaderSentByClient(bytes.NewReader(packet))
Expect(err).ToNot(HaveOccurred())
Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}))
// make sure the go routine returns
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
Expect(sess.Close(nil)).To(Succeed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
Context("receiving packets", func() {
var hdr *wire.Header