close connections created by DialAddr when the session is closed

This commit is contained in:
Marten Seemann 2018-08-06 16:41:53 +07:00
parent 61fb67096f
commit cfa55f91bc
3 changed files with 82 additions and 21 deletions

View file

@ -19,8 +19,12 @@ import (
type client struct {
mutex sync.Mutex
pconn net.PacketConn
conn connection
pconn net.PacketConn
conn connection
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
hostname string
receivedRetry bool
@ -73,7 +77,6 @@ func DialAddrContext(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
config = populateClientConfig(config, false)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@ -82,7 +85,7 @@ func DialAddrContext(
if err != nil {
return nil, err
}
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config)
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
@ -107,8 +110,7 @@ func DialContext(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
config = populateClientConfig(config, true)
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config)
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
}
func dialContext(
@ -118,13 +120,15 @@ func dialContext(
host string,
tlsConf *tls.Config,
config *Config,
createdPacketConn bool,
) (Session, error) {
config = populateClientConfig(config, createdPacketConn)
multiplexer := getClientMultiplexer()
manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove, createdPacketConn)
if err != nil {
return nil, err
}
@ -149,6 +153,7 @@ func newClient(
tlsConf *tls.Config,
host string,
closeCallback func(protocol.ConnectionID),
createdPacketConn bool,
) (*client, error) {
var hostname string
if tlsConf != nil {
@ -175,22 +180,23 @@ func newClient(
onClose = closeCallback
}
c := &client{
pconn: pconn,
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
hostname: hostname,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
pconn: pconn,
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
hostname: hostname,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, c.generateConnectionIDs()
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, onPacketConn bool) *Config {
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
if config == nil {
config = &Config{}
}
@ -229,7 +235,7 @@ func populateClientConfig(config *Config, onPacketConn bool) *Config {
maxIncomingUniStreams = 0
}
connIDLen := config.ConnectionIDLength
if connIDLen == 0 && onPacketConn {
if connIDLen == 0 && !createdPacketConn {
connIDLen = protocol.DefaultConnectionIDLength
}
@ -342,6 +348,9 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
go func() {
err := c.session.run() // returns as soon as the session is closed
if err != handshake.ErrCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
c.conn.Close()
}
errorChan <- err
}()

View file

@ -290,6 +290,58 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
})
It("closes the connection when it was created by DialAddr", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
var conn connection
run := make(chan struct{})
sessionCreated := make(chan struct{})
sess := NewMockQuicSession(mockCtrl)
newClientSession = func(
connP connection,
_ sessionRunner,
_ string,
_ protocol.VersionNumber,
connID protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (quicSession, error) {
conn = connP
close(sessionCreated)
return sess, nil
}
sess.EXPECT().run().Do(func() {
<-run
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := DialAddr("quic.clemente.io:1337", nil, nil)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(sessionCreated).Should(BeClosed())
// check that the connection is not closed
Expect(conn.Write([]byte("foobar"))).To(Succeed())
close(run)
time.Sleep(50 * time.Millisecond)
// check that the connection is closed
err := conn.Write([]byte("foobar"))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
Eventually(done).Should(BeClosed())
})
Context("quic.Config", func() {
It("setups with the right values", func() {
config := &Config{
@ -340,13 +392,13 @@ var _ = Describe("Client", func() {
It("uses 0-byte connection IDs when dialing an address", func() {
config := &Config{}
c := populateClientConfig(config, false)
c := populateClientConfig(config, true)
Expect(c.ConnectionIDLength).To(BeZero())
})
It("doesn't use 0-byte connection IDs when dialing an address", func() {
config := &Config{}
c := populateClientConfig(config, true)
c := populateClientConfig(config, false)
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
})

View file

@ -81,7 +81,7 @@ var _ = Describe("Server", func() {
It("doesn't use 0-byte connection IDs", func() {
config := &Config{}
c := populateClientConfig(config, true)
c := populateServerConfig(config)
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
})
})