close connections created by DialAddr when the session is closed

This commit is contained in:
Marten Seemann 2018-08-06 16:41:53 +07:00
parent 61fb67096f
commit cfa55f91bc
3 changed files with 82 additions and 21 deletions

View file

@ -19,8 +19,12 @@ import (
type client struct {
mutex sync.Mutex
pconn net.PacketConn
conn connection
pconn net.PacketConn
conn connection
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
hostname string
receivedRetry bool
@ -73,7 +77,6 @@ func DialAddrContext(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
config = populateClientConfig(config, false)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@ -82,7 +85,7 @@ func DialAddrContext(
if err != nil {
return nil, err
}
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config)
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
@ -107,8 +110,7 @@ func DialContext(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
config = populateClientConfig(config, true)
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config)
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
}
func dialContext(
@ -118,13 +120,15 @@ func dialContext(
host string,
tlsConf *tls.Config,
config *Config,
createdPacketConn bool,
) (Session, error) {
config = populateClientConfig(config, createdPacketConn)
multiplexer := getClientMultiplexer()
manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove, createdPacketConn)
if err != nil {
return nil, err
}
@ -149,6 +153,7 @@ func newClient(
tlsConf *tls.Config,
host string,
closeCallback func(protocol.ConnectionID),
createdPacketConn bool,
) (*client, error) {
var hostname string
if tlsConf != nil {
@ -175,22 +180,23 @@ func newClient(
onClose = closeCallback
}
c := &client{
pconn: pconn,
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
hostname: hostname,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
pconn: pconn,
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
hostname: hostname,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, c.generateConnectionIDs()
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, onPacketConn bool) *Config {
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
if config == nil {
config = &Config{}
}
@ -229,7 +235,7 @@ func populateClientConfig(config *Config, onPacketConn bool) *Config {
maxIncomingUniStreams = 0
}
connIDLen := config.ConnectionIDLength
if connIDLen == 0 && onPacketConn {
if connIDLen == 0 && !createdPacketConn {
connIDLen = protocol.DefaultConnectionIDLength
}
@ -342,6 +348,9 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
go func() {
err := c.session.run() // returns as soon as the session is closed
if err != handshake.ErrCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
c.conn.Close()
}
errorChan <- err
}()