From eb70424fbaddd2577b787aa51e3f961451a0d49e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 21 Jan 2025 01:52:16 -0800 Subject: [PATCH] fix race condition on concurrent use of Transport.Dial and Close (#4904) --- transport.go | 21 +++++++++++++-------- transport_test.go | 46 ++++++++++++++++++++++++++++++++-------------- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/transport.go b/transport.go index 5d1b072e..41dbc7ab 100644 --- a/transport.go +++ b/transport.go @@ -236,19 +236,13 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C } func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { - t.mutex.Lock() - if t.closeErr != nil { - t.mutex.Unlock() - return nil, t.closeErr + if err := t.init(t.isSingleUse); err != nil { + return nil, err } - t.mutex.Unlock() if err := validateConfig(conf); err != nil { return nil, err } conf = populateConfig(conf) - if err := t.init(t.isSingleUse); err != nil { - return nil, err - } tlsConf = tlsConf.Clone() setTLSConfigServerName(tlsConf, addr, host) return t.doDial(ctx, @@ -283,6 +277,13 @@ func (t *Transport) doDial( tracingID := nextConnTracingID() ctx = context.WithValue(ctx, ConnectionTracingKey, tracingID) + + t.mutex.Lock() + if t.closeErr != nil { + t.mutex.Unlock() + return nil, t.closeErr + } + var tracer *logging.ConnectionTracer if config.Tracer != nil { tracer = config.Tracer(ctx, protocol.PerspectiveClient, destConnID) @@ -312,6 +313,7 @@ func (t *Transport) doDial( version, ) t.handlerMap.Add(srcConnID, conn) + t.mutex.Unlock() // The error channel needs to be buffered, as the run loop will continue running // after doDial returns (if the handshake is successful). @@ -452,6 +454,9 @@ func (t *Transport) runSendQueue() { // If any listener was started, it will be closed as well. // It is invalid to start new listeners or connections after that. func (t *Transport) Close() error { + // avoid race condition if the transport is currently being initialized + t.init(false) + t.close(nil) if t.createdConn { if err := t.Conn.Close(); err != nil { diff --git a/transport_test.go b/transport_test.go index 203d92fd..b3af560b 100644 --- a/transport_test.go +++ b/transport_test.go @@ -120,21 +120,39 @@ func TestTransportPacketHandling(t *testing.T) { } func TestTransportAndListenerConcurrentClose(t *testing.T) { - // try 10 times to trigger race conditions - for i := 0; i < 10; i++ { - tr := &Transport{Conn: newUPDConnLocalhost(t)} - ln, err := tr.Listen(&tls.Config{}, nil) + tr := &Transport{Conn: newUPDConnLocalhost(t)} + ln, err := tr.Listen(&tls.Config{}, nil) + require.NoError(t, err) + // close transport and listener concurrently + lnErrChan := make(chan error, 1) + go func() { lnErrChan <- ln.Close() }() + require.NoError(t, tr.Close()) + select { + case err := <-lnErrChan: require.NoError(t, err) - // close transport and listener concurrently - lnErrChan := make(chan error, 1) - go func() { lnErrChan <- ln.Close() }() - require.NoError(t, tr.Close()) - select { - case err := <-lnErrChan: - require.NoError(t, err) - case <-time.After(time.Second): - t.Fatal("timeout") - } + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func TestTransportAndDialConcurrentClose(t *testing.T) { + server := newUPDConnLocalhost(t) + + tr := &Transport{Conn: newUPDConnLocalhost(t)} + // close transport and dial concurrently + errChan := make(chan error, 1) + go func() { errChan <- tr.Close() }() + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err := tr.Dial(ctx, server.LocalAddr(), &tls.Config{}, nil) + require.Error(t, err) + require.ErrorIs(t, err, ErrTransportClosed) + require.NotErrorIs(t, err, context.DeadlineExceeded) + + select { + case <-errChan: + case <-time.After(time.Second): + t.Fatal("timeout") } }