only create a single session for duplicate Initials

This commit is contained in:
Marten Seemann 2020-05-29 15:54:25 +07:00
parent 85c19fbb5a
commit dad30e7c33
6 changed files with 123 additions and 58 deletions

View file

@ -403,10 +403,7 @@ var _ = Describe("Server", func() {
var token [16]byte
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
@ -439,9 +436,13 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool {
Expect(c).To(Equal(newConnID))
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
fn()
return true
})
@ -502,14 +503,20 @@ var _ = Describe("Server", func() {
Expect(serv.handlePacketImpl(zeroRTTPacket)).To(BeTrue())
// Then receive the Initial packet.
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
fn()
return true
})
Expect(serv.handlePacketImpl(initialPacket)).To(BeTrue())
Expect(createdSession).To(BeTrue())
})
It("drops packets if the receive queue is full", func() {
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
phm.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
}).AnyTimes()
serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
acceptSession := make(chan struct{})
@ -532,7 +539,12 @@ var _ = Describe("Server", func() {
) quicSession {
<-acceptSession
atomic.AddUint32(&counter, 1)
return nil
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().run()
sess.EXPECT().Context().Return(context.Background())
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess
}
serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}))
@ -577,10 +589,9 @@ var _ = Describe("Server", func() {
}
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false)
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdSession).To(BeTrue())
Expect(createdSession).To(BeFalse())
})
It("rejects new connection attempts if the accept queue is full", func() {
@ -612,8 +623,11 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
}).Times(protocol.MaxAcceptQueueSize)
var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize)
@ -673,8 +687,11 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
@ -772,8 +789,11 @@ var _ = Describe("Server", func() {
sess.EXPECT().Context().Return(context.Background())
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake
@ -836,8 +856,11 @@ var _ = Describe("Server", func() {
sess.EXPECT().Context().Return(context.Background())
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed())
close(ready)
@ -874,8 +897,11 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
}).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
serv.handlePacket(getInitialWithRandomDestConnID())
}
@ -927,8 +953,11 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
Eventually(sessionCreated).Should(BeClosed())