mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-07 06:07:36 +03:00
close connections created by DialAddr when the session is closed
This commit is contained in:
parent
61fb67096f
commit
cfa55f91bc
3 changed files with 82 additions and 21 deletions
23
client.go
23
client.go
|
@ -21,6 +21,10 @@ type client struct {
|
||||||
|
|
||||||
pconn net.PacketConn
|
pconn net.PacketConn
|
||||||
conn connection
|
conn connection
|
||||||
|
// If the client is created with DialAddr, we create a packet conn.
|
||||||
|
// If it is started with Dial, we take a packet conn as a parameter.
|
||||||
|
createdPacketConn bool
|
||||||
|
|
||||||
hostname string
|
hostname string
|
||||||
|
|
||||||
receivedRetry bool
|
receivedRetry bool
|
||||||
|
@ -73,7 +77,6 @@ func DialAddrContext(
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
) (Session, error) {
|
) (Session, error) {
|
||||||
config = populateClientConfig(config, false)
|
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -82,7 +85,7 @@ func DialAddrContext(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config)
|
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||||
|
@ -107,8 +110,7 @@ func DialContext(
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
) (Session, error) {
|
) (Session, error) {
|
||||||
config = populateClientConfig(config, true)
|
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
|
||||||
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func dialContext(
|
func dialContext(
|
||||||
|
@ -118,13 +120,15 @@ func dialContext(
|
||||||
host string,
|
host string,
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
|
createdPacketConn bool,
|
||||||
) (Session, error) {
|
) (Session, error) {
|
||||||
|
config = populateClientConfig(config, createdPacketConn)
|
||||||
multiplexer := getClientMultiplexer()
|
multiplexer := getClientMultiplexer()
|
||||||
manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength)
|
manager, err := multiplexer.AddConn(pconn, config.ConnectionIDLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
|
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove, createdPacketConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -149,6 +153,7 @@ func newClient(
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
host string,
|
host string,
|
||||||
closeCallback func(protocol.ConnectionID),
|
closeCallback func(protocol.ConnectionID),
|
||||||
|
createdPacketConn bool,
|
||||||
) (*client, error) {
|
) (*client, error) {
|
||||||
var hostname string
|
var hostname string
|
||||||
if tlsConf != nil {
|
if tlsConf != nil {
|
||||||
|
@ -177,6 +182,7 @@ func newClient(
|
||||||
c := &client{
|
c := &client{
|
||||||
pconn: pconn,
|
pconn: pconn,
|
||||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||||
|
createdPacketConn: createdPacketConn,
|
||||||
hostname: hostname,
|
hostname: hostname,
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
config: config,
|
config: config,
|
||||||
|
@ -190,7 +196,7 @@ func newClient(
|
||||||
|
|
||||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||||
// it may be called with nil
|
// it may be called with nil
|
||||||
func populateClientConfig(config *Config, onPacketConn bool) *Config {
|
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
config = &Config{}
|
config = &Config{}
|
||||||
}
|
}
|
||||||
|
@ -229,7 +235,7 @@ func populateClientConfig(config *Config, onPacketConn bool) *Config {
|
||||||
maxIncomingUniStreams = 0
|
maxIncomingUniStreams = 0
|
||||||
}
|
}
|
||||||
connIDLen := config.ConnectionIDLength
|
connIDLen := config.ConnectionIDLength
|
||||||
if connIDLen == 0 && onPacketConn {
|
if connIDLen == 0 && !createdPacketConn {
|
||||||
connIDLen = protocol.DefaultConnectionIDLength
|
connIDLen = protocol.DefaultConnectionIDLength
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -342,6 +348,9 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := c.session.run() // returns as soon as the session is closed
|
err := c.session.run() // returns as soon as the session is closed
|
||||||
|
if err != handshake.ErrCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
errorChan <- err
|
errorChan <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -290,6 +290,58 @@ var _ = Describe("Client", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("closes the connection when it was created by DialAddr", func() {
|
||||||
|
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||||
|
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
|
|
||||||
|
var conn connection
|
||||||
|
run := make(chan struct{})
|
||||||
|
sessionCreated := make(chan struct{})
|
||||||
|
sess := NewMockQuicSession(mockCtrl)
|
||||||
|
newClientSession = func(
|
||||||
|
connP connection,
|
||||||
|
_ sessionRunner,
|
||||||
|
_ string,
|
||||||
|
_ protocol.VersionNumber,
|
||||||
|
connID protocol.ConnectionID,
|
||||||
|
_ *tls.Config,
|
||||||
|
_ *Config,
|
||||||
|
_ protocol.VersionNumber,
|
||||||
|
_ []protocol.VersionNumber,
|
||||||
|
_ utils.Logger,
|
||||||
|
) (quicSession, error) {
|
||||||
|
conn = connP
|
||||||
|
close(sessionCreated)
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
sess.EXPECT().run().Do(func() {
|
||||||
|
<-run
|
||||||
|
})
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
_, err := DialAddr("quic.clemente.io:1337", nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
Eventually(sessionCreated).Should(BeClosed())
|
||||||
|
|
||||||
|
// check that the connection is not closed
|
||||||
|
Expect(conn.Write([]byte("foobar"))).To(Succeed())
|
||||||
|
|
||||||
|
close(run)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
// check that the connection is closed
|
||||||
|
err := conn.Write([]byte("foobar"))
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
|
||||||
|
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
Context("quic.Config", func() {
|
Context("quic.Config", func() {
|
||||||
It("setups with the right values", func() {
|
It("setups with the right values", func() {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
|
@ -340,13 +392,13 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("uses 0-byte connection IDs when dialing an address", func() {
|
It("uses 0-byte connection IDs when dialing an address", func() {
|
||||||
config := &Config{}
|
config := &Config{}
|
||||||
c := populateClientConfig(config, false)
|
c := populateClientConfig(config, true)
|
||||||
Expect(c.ConnectionIDLength).To(BeZero())
|
Expect(c.ConnectionIDLength).To(BeZero())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't use 0-byte connection IDs when dialing an address", func() {
|
It("doesn't use 0-byte connection IDs when dialing an address", func() {
|
||||||
config := &Config{}
|
config := &Config{}
|
||||||
c := populateClientConfig(config, true)
|
c := populateClientConfig(config, false)
|
||||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,7 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
It("doesn't use 0-byte connection IDs", func() {
|
It("doesn't use 0-byte connection IDs", func() {
|
||||||
config := &Config{}
|
config := &Config{}
|
||||||
c := populateClientConfig(config, true)
|
c := populateServerConfig(config)
|
||||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue