mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
fix race condition on concurrent use of Transport.Dial and Close (#4904)
This commit is contained in:
parent
5d4835e422
commit
eb70424fba
2 changed files with 45 additions and 22 deletions
21
transport.go
21
transport.go
|
@ -236,19 +236,13 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
|
func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
|
||||||
t.mutex.Lock()
|
if err := t.init(t.isSingleUse); err != nil {
|
||||||
if t.closeErr != nil {
|
return nil, err
|
||||||
t.mutex.Unlock()
|
|
||||||
return nil, t.closeErr
|
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
|
||||||
if err := validateConfig(conf); err != nil {
|
if err := validateConfig(conf); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
conf = populateConfig(conf)
|
conf = populateConfig(conf)
|
||||||
if err := t.init(t.isSingleUse); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tlsConf = tlsConf.Clone()
|
tlsConf = tlsConf.Clone()
|
||||||
setTLSConfigServerName(tlsConf, addr, host)
|
setTLSConfigServerName(tlsConf, addr, host)
|
||||||
return t.doDial(ctx,
|
return t.doDial(ctx,
|
||||||
|
@ -283,6 +277,13 @@ func (t *Transport) doDial(
|
||||||
|
|
||||||
tracingID := nextConnTracingID()
|
tracingID := nextConnTracingID()
|
||||||
ctx = context.WithValue(ctx, ConnectionTracingKey, tracingID)
|
ctx = context.WithValue(ctx, ConnectionTracingKey, tracingID)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
if t.closeErr != nil {
|
||||||
|
t.mutex.Unlock()
|
||||||
|
return nil, t.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
var tracer *logging.ConnectionTracer
|
var tracer *logging.ConnectionTracer
|
||||||
if config.Tracer != nil {
|
if config.Tracer != nil {
|
||||||
tracer = config.Tracer(ctx, protocol.PerspectiveClient, destConnID)
|
tracer = config.Tracer(ctx, protocol.PerspectiveClient, destConnID)
|
||||||
|
@ -312,6 +313,7 @@ func (t *Transport) doDial(
|
||||||
version,
|
version,
|
||||||
)
|
)
|
||||||
t.handlerMap.Add(srcConnID, conn)
|
t.handlerMap.Add(srcConnID, conn)
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
// The error channel needs to be buffered, as the run loop will continue running
|
// The error channel needs to be buffered, as the run loop will continue running
|
||||||
// after doDial returns (if the handshake is successful).
|
// after doDial returns (if the handshake is successful).
|
||||||
|
@ -452,6 +454,9 @@ func (t *Transport) runSendQueue() {
|
||||||
// If any listener was started, it will be closed as well.
|
// If any listener was started, it will be closed as well.
|
||||||
// It is invalid to start new listeners or connections after that.
|
// It is invalid to start new listeners or connections after that.
|
||||||
func (t *Transport) Close() error {
|
func (t *Transport) Close() error {
|
||||||
|
// avoid race condition if the transport is currently being initialized
|
||||||
|
t.init(false)
|
||||||
|
|
||||||
t.close(nil)
|
t.close(nil)
|
||||||
if t.createdConn {
|
if t.createdConn {
|
||||||
if err := t.Conn.Close(); err != nil {
|
if err := t.Conn.Close(); err != nil {
|
||||||
|
|
|
@ -120,21 +120,39 @@ func TestTransportPacketHandling(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransportAndListenerConcurrentClose(t *testing.T) {
|
func TestTransportAndListenerConcurrentClose(t *testing.T) {
|
||||||
// try 10 times to trigger race conditions
|
tr := &Transport{Conn: newUPDConnLocalhost(t)}
|
||||||
for i := 0; i < 10; i++ {
|
ln, err := tr.Listen(&tls.Config{}, nil)
|
||||||
tr := &Transport{Conn: newUPDConnLocalhost(t)}
|
require.NoError(t, err)
|
||||||
ln, err := tr.Listen(&tls.Config{}, nil)
|
// close transport and listener concurrently
|
||||||
|
lnErrChan := make(chan error, 1)
|
||||||
|
go func() { lnErrChan <- ln.Close() }()
|
||||||
|
require.NoError(t, tr.Close())
|
||||||
|
select {
|
||||||
|
case err := <-lnErrChan:
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// close transport and listener concurrently
|
case <-time.After(time.Second):
|
||||||
lnErrChan := make(chan error, 1)
|
t.Fatal("timeout")
|
||||||
go func() { lnErrChan <- ln.Close() }()
|
}
|
||||||
require.NoError(t, tr.Close())
|
}
|
||||||
select {
|
|
||||||
case err := <-lnErrChan:
|
func TestTransportAndDialConcurrentClose(t *testing.T) {
|
||||||
require.NoError(t, err)
|
server := newUPDConnLocalhost(t)
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("timeout")
|
tr := &Transport{Conn: newUPDConnLocalhost(t)}
|
||||||
}
|
// close transport and dial concurrently
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() { errChan <- tr.Close() }()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_, err := tr.Dial(ctx, server.LocalAddr(), &tls.Config{}, nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, ErrTransportClosed)
|
||||||
|
require.NotErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-errChan:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue