mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
close connections created by DialAddr when the session is closed
This commit is contained in:
parent
61fb67096f
commit
cfa55f91bc
3 changed files with 82 additions and 21 deletions
45
client.go
45
client.go
|
@ -19,8 +19,12 @@ import (
|
|||
type client struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
pconn net.PacketConn
|
||||
conn connection
|
||||
pconn net.PacketConn
|
||||
conn connection
|
||||
// If the client is created with DialAddr, we create a packet conn.
|
||||
// If it is started with Dial, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
|
||||
hostname string
|
||||
|
||||
receivedRetry bool
|
||||
|
@ -73,7 +77,6 @@ func DialAddrContext(
|
|||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, false)
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -82,7 +85,7 @@ func DialAddrContext(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config)
|
||||
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
|
@ -107,8 +110,7 @@ func DialContext(
|
|||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, true)
|
||||
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config)
|
||||
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
|
||||
}
|
||||
|
||||
func dialContext(
|
||||
|
@ -118,13 +120,15 @@ func dialContext(
|
|||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
createdPacketConn bool,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
multiplexer := getClientMultiplexer()
|
||||
manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove, createdPacketConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -149,6 +153,7 @@ func newClient(
|
|||
tlsConf *tls.Config,
|
||||
host string,
|
||||
closeCallback func(protocol.ConnectionID),
|
||||
createdPacketConn bool,
|
||||
) (*client, error) {
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
|
@ -175,22 +180,23 @@ func newClient(
|
|||
onClose = closeCallback
|
||||
}
|
||||
c := &client{
|
||||
pconn: pconn,
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
pconn: pconn,
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
createdPacketConn: createdPacketConn,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, c.generateConnectionIDs()
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config, onPacketConn bool) *Config {
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
|
@ -229,7 +235,7 @@ func populateClientConfig(config *Config, onPacketConn bool) *Config {
|
|||
maxIncomingUniStreams = 0
|
||||
}
|
||||
connIDLen := config.ConnectionIDLength
|
||||
if connIDLen == 0 && onPacketConn {
|
||||
if connIDLen == 0 && !createdPacketConn {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
|
||||
|
@ -342,6 +348,9 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
|
|||
|
||||
go func() {
|
||||
err := c.session.run() // returns as soon as the session is closed
|
||||
if err != handshake.ErrCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
|
||||
c.conn.Close()
|
||||
}
|
||||
errorChan <- err
|
||||
}()
|
||||
|
||||
|
|
|
@ -290,6 +290,58 @@ var _ = Describe("Client", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("closes the connection when it was created by DialAddr", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
|
||||
var conn connection
|
||||
run := make(chan struct{})
|
||||
sessionCreated := make(chan struct{})
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
newClientSession = func(
|
||||
connP connection,
|
||||
_ sessionRunner,
|
||||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
connID protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ utils.Logger,
|
||||
) (quicSession, error) {
|
||||
conn = connP
|
||||
close(sessionCreated)
|
||||
return sess, nil
|
||||
}
|
||||
sess.EXPECT().run().Do(func() {
|
||||
<-run
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := DialAddr("quic.clemente.io:1337", nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Eventually(sessionCreated).Should(BeClosed())
|
||||
|
||||
// check that the connection is not closed
|
||||
Expect(conn.Write([]byte("foobar"))).To(Succeed())
|
||||
|
||||
close(run)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// check that the connection is closed
|
||||
err := conn.Write([]byte("foobar"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
|
||||
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("quic.Config", func() {
|
||||
It("setups with the right values", func() {
|
||||
config := &Config{
|
||||
|
@ -340,13 +392,13 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("uses 0-byte connection IDs when dialing an address", func() {
|
||||
config := &Config{}
|
||||
c := populateClientConfig(config, false)
|
||||
c := populateClientConfig(config, true)
|
||||
Expect(c.ConnectionIDLength).To(BeZero())
|
||||
})
|
||||
|
||||
It("doesn't use 0-byte connection IDs when dialing an address", func() {
|
||||
config := &Config{}
|
||||
c := populateClientConfig(config, true)
|
||||
c := populateClientConfig(config, false)
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
})
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("doesn't use 0-byte connection IDs", func() {
|
||||
config := &Config{}
|
||||
c := populateClientConfig(config, true)
|
||||
c := populateServerConfig(config)
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue