close connections created by DialAddr when the session is closed

This commit is contained in:
Marten Seemann 2018-08-06 16:41:53 +07:00
parent 61fb67096f
commit cfa55f91bc
3 changed files with 82 additions and 21 deletions

View file

@ -290,6 +290,58 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
})
It("closes the connection when it was created by DialAddr", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
var conn connection
run := make(chan struct{})
sessionCreated := make(chan struct{})
sess := NewMockQuicSession(mockCtrl)
newClientSession = func(
connP connection,
_ sessionRunner,
_ string,
_ protocol.VersionNumber,
connID protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (quicSession, error) {
conn = connP
close(sessionCreated)
return sess, nil
}
sess.EXPECT().run().Do(func() {
<-run
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := DialAddr("quic.clemente.io:1337", nil, nil)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(sessionCreated).Should(BeClosed())
// check that the connection is not closed
Expect(conn.Write([]byte("foobar"))).To(Succeed())
close(run)
time.Sleep(50 * time.Millisecond)
// check that the connection is closed
err := conn.Write([]byte("foobar"))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
Eventually(done).Should(BeClosed())
})
Context("quic.Config", func() {
It("setups with the right values", func() {
config := &Config{
@ -340,13 +392,13 @@ var _ = Describe("Client", func() {
It("uses 0-byte connection IDs when dialing an address", func() {
config := &Config{}
c := populateClientConfig(config, false)
c := populateClientConfig(config, true)
Expect(c.ConnectionIDLength).To(BeZero())
})
It("doesn't use 0-byte connection IDs when dialing an address", func() {
config := &Config{}
c := populateClientConfig(config, true)
c := populateClientConfig(config, false)
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
})