route retransmissions of the client's Initial to the right session

This commit is contained in:
Marten Seemann 2019-11-25 17:23:50 +07:00
parent 7445bde357
commit 5a834851a8
4 changed files with 61 additions and 15 deletions

View file

@ -13,7 +13,8 @@ type connIDGenerator struct {
connIDLen int
highestSeq uint64
activeSrcConnIDs map[uint64]protocol.ConnectionID
activeSrcConnIDs map[uint64]protocol.ConnectionID
initialClientDestConnID protocol.ConnectionID
addConnectionID func(protocol.ConnectionID) [16]byte
removeConnectionID func(protocol.ConnectionID)
@ -24,6 +25,7 @@ type connIDGenerator struct {
func newConnIDGenerator(
initialConnectionID protocol.ConnectionID,
initialClientDestConnID protocol.ConnectionID, // nil for the client
addConnectionID func(protocol.ConnectionID) [16]byte,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
@ -40,6 +42,7 @@ func newConnIDGenerator(
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
m.initialClientDestConnID = initialClientDestConnID
return m
}
@ -91,13 +94,26 @@ func (m *connIDGenerator) issueNewConnID() error {
return nil
}
func (m *connIDGenerator) SetHandshakeComplete() {
if m.initialClientDestConnID != nil {
m.retireConnectionID(m.initialClientDestConnID)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.removeConnectionID(m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.removeConnectionID(connID)
}
}
func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
if m.initialClientDestConnID != nil {
m.replaceWithClosed(m.initialClientDestConnID, handler)
}
for _, connID := range m.activeSrcConnIDs {
m.replaceWithClosed(connID, handler)
}

View file

@ -18,6 +18,7 @@ var _ = Describe("Connection ID Generator", func() {
g *connIDGenerator
)
initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}
initialClientDestConnID := protocol.ConnectionID{0xa, 0xb, 0xc, 0xd, 0xe}
BeforeEach(func() {
addedConnIDs = nil
@ -27,6 +28,7 @@ var _ = Describe("Connection ID Generator", func() {
replacedWithClosed = make(map[string]packetHandler)
g = newConnIDGenerator(
initialConnID,
initialClientDestConnID,
func(c protocol.ConnectionID) [16]byte {
addedConnIDs = append(addedConnIDs, c)
l := uint8(len(addedConnIDs))
@ -101,12 +103,19 @@ var _ = Describe("Connection ID Generator", func() {
Expect(queuedFrames).To(HaveLen(1))
})
It("retires the client's initial destination connection ID when the handshake completes", func() {
g.SetHandshakeComplete()
Expect(retiredConnIDs).To(HaveLen(1))
Expect(retiredConnIDs[0]).To(Equal(initialClientDestConnID))
})
It("removes all connection IDs", func() {
Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
Expect(queuedFrames).To(HaveLen(5))
g.RemoveAll()
Expect(removedConnIDs).To(HaveLen(6)) // initial connection ID and newly issued ones
Expect(removedConnIDs).To(HaveLen(7)) // initial conn ID, initial client dest conn id, and newly issued ones
Expect(removedConnIDs).To(ContainElement(initialConnID))
Expect(removedConnIDs).To(ContainElement(initialClientDestConnID))
for _, f := range queuedFrames {
nf := f.(*wire.NewConnectionIDFrame)
Expect(removedConnIDs).To(ContainElement(nf.ConnectionID))
@ -118,7 +127,8 @@ var _ = Describe("Connection ID Generator", func() {
Expect(queuedFrames).To(HaveLen(5))
sess := NewMockPacketHandler(mockCtrl)
g.ReplaceWithClosed(sess)
Expect(replacedWithClosed).To(HaveLen(6)) // initial connection ID and newly issued ones
Expect(replacedWithClosed).To(HaveLen(7)) // initial conn ID, initial client dest conn id, and newly issued ones
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess))
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess))
for _, f := range queuedFrames {
nf := f.(*wire.NewConnectionIDFrame)

View file

@ -227,6 +227,7 @@ var newSession = func(
)
s.connIDGenerator = newConnIDGenerator(
srcConnID,
clientDestConnID,
func(connID protocol.ConnectionID) [16]byte { return runner.Add(connID, s) },
runner.Remove,
runner.Retire,
@ -238,6 +239,7 @@ var newSession = func(
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
runner.Add(clientDestConnID, s)
token := runner.Add(srcConnID, s)
params := &handshake.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
@ -262,10 +264,13 @@ var newSession = func(
conn.RemoteAddr(),
params,
&handshakeRunner{
onReceivedParams: s.processTransportParameters,
onError: s.closeLocal,
dropKeys: s.dropEncryptionLevel,
onHandshakeComplete: func() { close(s.handshakeCompleteChan) },
onReceivedParams: s.processTransportParameters,
onError: s.closeLocal,
dropKeys: s.dropEncryptionLevel,
onHandshakeComplete: func() {
runner.Retire(clientDestConnID)
close(s.handshakeCompleteChan)
},
},
tlsConf,
s.rttStats,
@ -325,6 +330,7 @@ var newClientSession = func(
)
s.connIDGenerator = newConnIDGenerator(
srcConnID,
nil,
func(connID protocol.ConnectionID) [16]byte { return runner.Add(connID, s) },
runner.Remove,
runner.Retire,
@ -596,6 +602,7 @@ func (s *session) handleHandshakeComplete() {
s.handshakeCompleteChan = nil // prevent this case from ever being selected again
s.handshakeCtxCancel()
s.connIDGenerator.SetHandshakeComplete()
s.sentPacketHandler.SetHandshakeComplete()
// The client completes the handshake first (after sending the CFIN).
// We need to make sure it learns about the server completing the handshake,

View file

@ -82,6 +82,7 @@ var _ = Describe("Session", func() {
)
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
getPacket := func(pn protocol.PacketNumber) *packedPacket {
buffer := getPacketBuffer()
@ -95,6 +96,7 @@ var _ = Describe("Session", func() {
}
expectReplaceWithClosed := func() {
sessionRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1)
sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{}))
Expect(s.Close()).To(Succeed())
@ -106,7 +108,7 @@ var _ = Describe("Session", func() {
Eventually(areSessionsRunning).Should(BeFalse())
sessionRunner = NewMockSessionRunner(mockCtrl)
sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any())
sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any()).Times(2)
mconn = newMockConnection()
tokenGenerator, err := handshake.NewTokenGenerator()
Expect(err).ToNot(HaveOccurred())
@ -114,7 +116,7 @@ var _ = Describe("Session", func() {
mconn,
sessionRunner,
nil,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
clientDestConnID,
destConnID,
srcConnID,
populateServerConfig(&Config{}),
@ -360,6 +362,9 @@ var _ = Describe("Session", func() {
sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{}))
})
sessionRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{}))
})
cryptoSetup.EXPECT().Close()
go func() {
@ -381,6 +386,9 @@ var _ = Describe("Session", func() {
sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{}))
})
sessionRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{}))
})
cryptoSetup.EXPECT().Close()
go func() {
@ -508,7 +516,7 @@ var _ = Describe("Session", func() {
It("closes the session in order to recreate it", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
cryptoSetup.EXPECT().Close()
sess.closeForRecreating()
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
@ -519,7 +527,7 @@ var _ = Describe("Session", func() {
It("destroys the session", func() {
testErr := errors.New("close")
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
cryptoSetup.EXPECT().Close()
sess.destroy(testErr)
Eventually(areSessionsRunning).Should(BeFalse())
@ -1191,6 +1199,7 @@ var _ = Describe("Session", func() {
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode().AnyTimes()
sessionRunner.EXPECT().Retire(clientDestConnID)
go func() {
defer GinkgoRecover()
<-finishHandshake
@ -1232,6 +1241,7 @@ var _ = Describe("Session", func() {
It("sends a 1-RTT packet when the handshake completes", func() {
done := make(chan struct{})
sessionRunner.EXPECT().Retire(clientDestConnID)
packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) {
defer close(done)
return &packedPacket{
@ -1335,7 +1345,7 @@ var _ = Describe("Session", func() {
sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{}))
Expect(s.Close()).To(Succeed())
}).Times(4)
}).Times(5) // initial connection ID + initial client dest conn ID + 3 newly issued conn IDs
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
cryptoSetup.EXPECT().Close()
sess.Close()
@ -1422,7 +1432,7 @@ var _ = Describe("Session", func() {
})
It("times out due to no network activity", func() {
sessionRunner.EXPECT().Remove(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any()).Times(2)
sess.handshakeComplete = true
sess.lastPacketReceivedTime = time.Now().Add(-time.Hour)
done := make(chan struct{})
@ -1442,7 +1452,7 @@ var _ = Describe("Session", func() {
It("times out due to non-completed handshake", func() {
sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second)
sessionRunner.EXPECT().Remove(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any()).Times(2)
cryptoSetup.EXPECT().Close()
done := make(chan struct{})
go func() {
@ -1483,7 +1493,10 @@ var _ = Describe("Session", func() {
It("closes the session due to the idle timeout after handshake", func() {
packer.EXPECT().PackPacket().AnyTimes()
sessionRunner.EXPECT().Remove(gomock.Any())
gomock.InOrder(
sessionRunner.EXPECT().Retire(clientDestConnID),
sessionRunner.EXPECT().Remove(gomock.Any()),
)
cryptoSetup.EXPECT().Close()
sess.config.IdleTimeout = 0
done := make(chan struct{})