From 5a834851a8206824cb839217b6fd37e9c9e5d2cf Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 25 Nov 2019 17:23:50 +0700 Subject: [PATCH] route retransmissions of the client's Initial to the right session --- conn_id_generator.go | 18 +++++++++++++++++- conn_id_generator_test.go | 14 ++++++++++++-- session.go | 15 +++++++++++---- session_test.go | 29 +++++++++++++++++++++-------- 4 files changed, 61 insertions(+), 15 deletions(-) diff --git a/conn_id_generator.go b/conn_id_generator.go index a95110af..61b8ed8a 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -13,7 +13,8 @@ type connIDGenerator struct { connIDLen int highestSeq uint64 - activeSrcConnIDs map[uint64]protocol.ConnectionID + activeSrcConnIDs map[uint64]protocol.ConnectionID + initialClientDestConnID protocol.ConnectionID addConnectionID func(protocol.ConnectionID) [16]byte removeConnectionID func(protocol.ConnectionID) @@ -24,6 +25,7 @@ type connIDGenerator struct { func newConnIDGenerator( initialConnectionID protocol.ConnectionID, + initialClientDestConnID protocol.ConnectionID, // nil for the client addConnectionID func(protocol.ConnectionID) [16]byte, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), @@ -40,6 +42,7 @@ func newConnIDGenerator( queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID + m.initialClientDestConnID = initialClientDestConnID return m } @@ -91,13 +94,26 @@ func (m *connIDGenerator) issueNewConnID() error { return nil } +func (m *connIDGenerator) SetHandshakeComplete() { + if m.initialClientDestConnID != nil { + m.retireConnectionID(m.initialClientDestConnID) + m.initialClientDestConnID = nil + } +} + func (m *connIDGenerator) RemoveAll() { + if m.initialClientDestConnID != nil { + m.removeConnectionID(m.initialClientDestConnID) + } for _, connID := range m.activeSrcConnIDs { m.removeConnectionID(connID) } } func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { + if m.initialClientDestConnID != nil { + m.replaceWithClosed(m.initialClientDestConnID, handler) + } for _, connID := range m.activeSrcConnIDs { m.replaceWithClosed(connID, handler) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 81b36306..1532139e 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -18,6 +18,7 @@ var _ = Describe("Connection ID Generator", func() { g *connIDGenerator ) initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} + initialClientDestConnID := protocol.ConnectionID{0xa, 0xb, 0xc, 0xd, 0xe} BeforeEach(func() { addedConnIDs = nil @@ -27,6 +28,7 @@ var _ = Describe("Connection ID Generator", func() { replacedWithClosed = make(map[string]packetHandler) g = newConnIDGenerator( initialConnID, + initialClientDestConnID, func(c protocol.ConnectionID) [16]byte { addedConnIDs = append(addedConnIDs, c) l := uint8(len(addedConnIDs)) @@ -101,12 +103,19 @@ var _ = Describe("Connection ID Generator", func() { Expect(queuedFrames).To(HaveLen(1)) }) + It("retires the client's initial destination connection ID when the handshake completes", func() { + g.SetHandshakeComplete() + Expect(retiredConnIDs).To(HaveLen(1)) + Expect(retiredConnIDs[0]).To(Equal(initialClientDestConnID)) + }) + It("removes all connection IDs", func() { Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) Expect(queuedFrames).To(HaveLen(5)) g.RemoveAll() - Expect(removedConnIDs).To(HaveLen(6)) // initial connection ID and newly issued ones + Expect(removedConnIDs).To(HaveLen(7)) // initial conn ID, initial client dest conn id, and newly issued ones Expect(removedConnIDs).To(ContainElement(initialConnID)) + Expect(removedConnIDs).To(ContainElement(initialClientDestConnID)) for _, f := range queuedFrames { nf := f.(*wire.NewConnectionIDFrame) Expect(removedConnIDs).To(ContainElement(nf.ConnectionID)) @@ -118,7 +127,8 @@ var _ = Describe("Connection ID Generator", func() { Expect(queuedFrames).To(HaveLen(5)) sess := NewMockPacketHandler(mockCtrl) g.ReplaceWithClosed(sess) - Expect(replacedWithClosed).To(HaveLen(6)) // initial connection ID and newly issued ones + Expect(replacedWithClosed).To(HaveLen(7)) // initial conn ID, initial client dest conn id, and newly issued ones + Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess)) Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess)) for _, f := range queuedFrames { nf := f.(*wire.NewConnectionIDFrame) diff --git a/session.go b/session.go index 4ad24df7..79c84a7d 100644 --- a/session.go +++ b/session.go @@ -227,6 +227,7 @@ var newSession = func( ) s.connIDGenerator = newConnIDGenerator( srcConnID, + clientDestConnID, func(connID protocol.ConnectionID) [16]byte { return runner.Add(connID, s) }, runner.Remove, runner.Retire, @@ -238,6 +239,7 @@ var newSession = func( initialStream := newCryptoStream() handshakeStream := newCryptoStream() oneRTTStream := newPostHandshakeCryptoStream(s.framer) + runner.Add(clientDestConnID, s) token := runner.Add(srcConnID, s) params := &handshake.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, @@ -262,10 +264,13 @@ var newSession = func( conn.RemoteAddr(), params, &handshakeRunner{ - onReceivedParams: s.processTransportParameters, - onError: s.closeLocal, - dropKeys: s.dropEncryptionLevel, - onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, + onReceivedParams: s.processTransportParameters, + onError: s.closeLocal, + dropKeys: s.dropEncryptionLevel, + onHandshakeComplete: func() { + runner.Retire(clientDestConnID) + close(s.handshakeCompleteChan) + }, }, tlsConf, s.rttStats, @@ -325,6 +330,7 @@ var newClientSession = func( ) s.connIDGenerator = newConnIDGenerator( srcConnID, + nil, func(connID protocol.ConnectionID) [16]byte { return runner.Add(connID, s) }, runner.Remove, runner.Retire, @@ -596,6 +602,7 @@ func (s *session) handleHandshakeComplete() { s.handshakeCompleteChan = nil // prevent this case from ever being selected again s.handshakeCtxCancel() + s.connIDGenerator.SetHandshakeComplete() s.sentPacketHandler.SetHandshakeComplete() // The client completes the handshake first (after sending the CFIN). // We need to make sure it learns about the server completing the handshake, diff --git a/session_test.go b/session_test.go index a12703d3..e942c078 100644 --- a/session_test.go +++ b/session_test.go @@ -82,6 +82,7 @@ var _ = Describe("Session", func() { ) srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} getPacket := func(pn protocol.PacketNumber) *packedPacket { buffer := getPacketBuffer() @@ -95,6 +96,7 @@ var _ = Describe("Session", func() { } expectReplaceWithClosed := func() { + sessionRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1) sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) Expect(s.Close()).To(Succeed()) @@ -106,7 +108,7 @@ var _ = Describe("Session", func() { Eventually(areSessionsRunning).Should(BeFalse()) sessionRunner = NewMockSessionRunner(mockCtrl) - sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any()) + sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any()).Times(2) mconn = newMockConnection() tokenGenerator, err := handshake.NewTokenGenerator() Expect(err).ToNot(HaveOccurred()) @@ -114,7 +116,7 @@ var _ = Describe("Session", func() { mconn, sessionRunner, nil, - protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + clientDestConnID, destConnID, srcConnID, populateServerConfig(&Config{}), @@ -360,6 +362,9 @@ var _ = Describe("Session", func() { sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) }) + sessionRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) + }) cryptoSetup.EXPECT().Close() go func() { @@ -381,6 +386,9 @@ var _ = Describe("Session", func() { sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) }) + sessionRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) + }) cryptoSetup.EXPECT().Close() go func() { @@ -508,7 +516,7 @@ var _ = Describe("Session", func() { It("closes the session in order to recreate it", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Remove(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() sess.closeForRecreating() Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent @@ -519,7 +527,7 @@ var _ = Describe("Session", func() { It("destroys the session", func() { testErr := errors.New("close") streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Remove(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() sess.destroy(testErr) Eventually(areSessionsRunning).Should(BeFalse()) @@ -1191,6 +1199,7 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode().AnyTimes() + sessionRunner.EXPECT().Retire(clientDestConnID) go func() { defer GinkgoRecover() <-finishHandshake @@ -1232,6 +1241,7 @@ var _ = Describe("Session", func() { It("sends a 1-RTT packet when the handshake completes", func() { done := make(chan struct{}) + sessionRunner.EXPECT().Retire(clientDestConnID) packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { defer close(done) return &packedPacket{ @@ -1335,7 +1345,7 @@ var _ = Describe("Session", func() { sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) Expect(s.Close()).To(Succeed()) - }).Times(4) + }).Times(5) // initial connection ID + initial client dest conn ID + 3 newly issued conn IDs packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() sess.Close() @@ -1422,7 +1432,7 @@ var _ = Describe("Session", func() { }) It("times out due to no network activity", func() { - sessionRunner.EXPECT().Remove(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()).Times(2) sess.handshakeComplete = true sess.lastPacketReceivedTime = time.Now().Add(-time.Hour) done := make(chan struct{}) @@ -1442,7 +1452,7 @@ var _ = Describe("Session", func() { It("times out due to non-completed handshake", func() { sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) - sessionRunner.EXPECT().Remove(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()).Times(2) cryptoSetup.EXPECT().Close() done := make(chan struct{}) go func() { @@ -1483,7 +1493,10 @@ var _ = Describe("Session", func() { It("closes the session due to the idle timeout after handshake", func() { packer.EXPECT().PackPacket().AnyTimes() - sessionRunner.EXPECT().Remove(gomock.Any()) + gomock.InOrder( + sessionRunner.EXPECT().Retire(clientDestConnID), + sessionRunner.EXPECT().Remove(gomock.Any()), + ) cryptoSetup.EXPECT().Close() sess.config.IdleTimeout = 0 done := make(chan struct{})