mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 13:17:36 +03:00
close connections created by DialAddr when the session is closed
This commit is contained in:
parent
61fb67096f
commit
cfa55f91bc
3 changed files with 82 additions and 21 deletions
45
client.go
45
client.go
|
@ -19,8 +19,12 @@ import (
|
|||
type client struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
pconn net.PacketConn
|
||||
conn connection
|
||||
pconn net.PacketConn
|
||||
conn connection
|
||||
// If the client is created with DialAddr, we create a packet conn.
|
||||
// If it is started with Dial, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
|
||||
hostname string
|
||||
|
||||
receivedRetry bool
|
||||
|
@ -73,7 +77,6 @@ func DialAddrContext(
|
|||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, false)
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -82,7 +85,7 @@ func DialAddrContext(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config)
|
||||
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
|
@ -107,8 +110,7 @@ func DialContext(
|
|||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, true)
|
||||
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config)
|
||||
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
|
||||
}
|
||||
|
||||
func dialContext(
|
||||
|
@ -118,13 +120,15 @@ func dialContext(
|
|||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
createdPacketConn bool,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
multiplexer := getClientMultiplexer()
|
||||
manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove, createdPacketConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -149,6 +153,7 @@ func newClient(
|
|||
tlsConf *tls.Config,
|
||||
host string,
|
||||
closeCallback func(protocol.ConnectionID),
|
||||
createdPacketConn bool,
|
||||
) (*client, error) {
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
|
@ -175,22 +180,23 @@ func newClient(
|
|||
onClose = closeCallback
|
||||
}
|
||||
c := &client{
|
||||
pconn: pconn,
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
pconn: pconn,
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
createdPacketConn: createdPacketConn,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, c.generateConnectionIDs()
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config, onPacketConn bool) *Config {
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
|
@ -229,7 +235,7 @@ func populateClientConfig(config *Config, onPacketConn bool) *Config {
|
|||
maxIncomingUniStreams = 0
|
||||
}
|
||||
connIDLen := config.ConnectionIDLength
|
||||
if connIDLen == 0 && onPacketConn {
|
||||
if connIDLen == 0 && !createdPacketConn {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
|
||||
|
@ -342,6 +348,9 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
|
|||
|
||||
go func() {
|
||||
err := c.session.run() // returns as soon as the session is closed
|
||||
if err != handshake.ErrCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
|
||||
c.conn.Close()
|
||||
}
|
||||
errorChan <- err
|
||||
}()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue