diff --git a/client.go b/client.go index 84dad27d..4e157885 100644 --- a/client.go +++ b/client.go @@ -19,8 +19,12 @@ import ( type client struct { mutex sync.Mutex - pconn net.PacketConn - conn connection + pconn net.PacketConn + conn connection + // If the client is created with DialAddr, we create a packet conn. + // If it is started with Dial, we take a packet conn as a parameter. + createdPacketConn bool + hostname string receivedRetry bool @@ -73,7 +77,6 @@ func DialAddrContext( tlsConf *tls.Config, config *Config, ) (Session, error) { - config = populateClientConfig(config, false) udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -82,7 +85,7 @@ func DialAddrContext( if err != nil { return nil, err } - return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config) + return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true) } // Dial establishes a new QUIC connection to a server using a net.PacketConn. @@ -107,8 +110,7 @@ func DialContext( tlsConf *tls.Config, config *Config, ) (Session, error) { - config = populateClientConfig(config, true) - return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config) + return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false) } func dialContext( @@ -118,13 +120,15 @@ func dialContext( host string, tlsConf *tls.Config, config *Config, + createdPacketConn bool, ) (Session, error) { + config = populateClientConfig(config, createdPacketConn) multiplexer := getClientMultiplexer() manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength) if err != nil { return nil, err } - c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove) + c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove, createdPacketConn) if err != nil { return nil, err } @@ -149,6 +153,7 @@ func newClient( tlsConf *tls.Config, host string, closeCallback func(protocol.ConnectionID), + createdPacketConn bool, ) (*client, error) { var hostname string if tlsConf != nil { @@ -175,22 +180,23 @@ func newClient( onClose = closeCallback } c := &client{ - pconn: pconn, - conn: &conn{pconn: pconn, currentAddr: remoteAddr}, - hostname: hostname, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - closeCallback: onClose, - logger: utils.DefaultLogger.WithPrefix("client"), + pconn: pconn, + conn: &conn{pconn: pconn, currentAddr: remoteAddr}, + createdPacketConn: createdPacketConn, + hostname: hostname, + tlsConf: tlsConf, + config: config, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + closeCallback: onClose, + logger: utils.DefaultLogger.WithPrefix("client"), } return c, c.generateConnectionIDs() } // populateClientConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil -func populateClientConfig(config *Config, onPacketConn bool) *Config { +func populateClientConfig(config *Config, createdPacketConn bool) *Config { if config == nil { config = &Config{} } @@ -229,7 +235,7 @@ func populateClientConfig(config *Config, onPacketConn bool) *Config { maxIncomingUniStreams = 0 } connIDLen := config.ConnectionIDLength - if connIDLen == 0 && onPacketConn { + if connIDLen == 0 && !createdPacketConn { connIDLen = protocol.DefaultConnectionIDLength } @@ -342,6 +348,9 @@ func (c *client) establishSecureConnection(ctx context.Context) error { go func() { err := c.session.run() // returns as soon as the session is closed + if err != handshake.ErrCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn { + c.conn.Close() + } errorChan <- err }() diff --git a/client_test.go b/client_test.go index e79bfc06..9af37f73 100644 --- a/client_test.go +++ b/client_test.go @@ -290,6 +290,58 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) }) + It("closes the connection when it was created by DialAddr", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any()) + + var conn connection + run := make(chan struct{}) + sessionCreated := make(chan struct{}) + sess := NewMockQuicSession(mockCtrl) + newClientSession = func( + connP connection, + _ sessionRunner, + _ string, + _ protocol.VersionNumber, + connID protocol.ConnectionID, + _ *tls.Config, + _ *Config, + _ protocol.VersionNumber, + _ []protocol.VersionNumber, + _ utils.Logger, + ) (quicSession, error) { + conn = connP + close(sessionCreated) + return sess, nil + } + sess.EXPECT().run().Do(func() { + <-run + }) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := DialAddr("quic.clemente.io:1337", nil, nil) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + Eventually(sessionCreated).Should(BeClosed()) + + // check that the connection is not closed + Expect(conn.Write([]byte("foobar"))).To(Succeed()) + + close(run) + time.Sleep(50 * time.Millisecond) + // check that the connection is closed + err := conn.Write([]byte("foobar")) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("use of closed network connection")) + + Eventually(done).Should(BeClosed()) + }) + Context("quic.Config", func() { It("setups with the right values", func() { config := &Config{ @@ -340,13 +392,13 @@ var _ = Describe("Client", func() { It("uses 0-byte connection IDs when dialing an address", func() { config := &Config{} - c := populateClientConfig(config, false) + c := populateClientConfig(config, true) Expect(c.ConnectionIDLength).To(BeZero()) }) It("doesn't use 0-byte connection IDs when dialing an address", func() { config := &Config{} - c := populateClientConfig(config, true) + c := populateClientConfig(config, false) Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) }) diff --git a/server_test.go b/server_test.go index 7c86014d..248039df 100644 --- a/server_test.go +++ b/server_test.go @@ -81,7 +81,7 @@ var _ = Describe("Server", func() { It("doesn't use 0-byte connection IDs", func() { config := &Config{} - c := populateClientConfig(config, true) + c := populateServerConfig(config) Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) }) })