From 07bbe8c383e1dfa3e634d3c2ab1f0b664ecec3eb Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 20 Jan 2025 21:54:40 -0800 Subject: [PATCH] add ErrTransportClosed and use it for Listen/Dial after transport close (#4883) * add ErrTransportClosed and use it for Listen/Dial after transport close * include the original error in the ErrTransportClosed error string --- connection.go | 2 + integrationtests/self/close_test.go | 105 ++++++++++++++++++++++++++++ transport.go | 47 +++++++++++-- transport_test.go | 4 +- 4 files changed, 150 insertions(+), 8 deletions(-) diff --git a/connection.go b/connection.go index f6ae8759..0325a32b 100644 --- a/connection.go +++ b/connection.go @@ -1657,6 +1657,8 @@ func (s *connection) handleCloseError(closeErr *closeError) { errors.As(e, &recreateErr), errors.As(e, &applicationErr), errors.As(e, &transportErr): + case closeErr.immediate: + e = closeErr.err default: e = &qerr.TransportError{ ErrorCode: qerr.InternalError, diff --git a/integrationtests/self/close_test.go b/integrationtests/self/close_test.go index 583a804d..fd6619b8 100644 --- a/integrationtests/self/close_test.go +++ b/integrationtests/self/close_test.go @@ -2,6 +2,8 @@ package self_test import ( "context" + "crypto/tls" + "errors" "fmt" "net" "sync/atomic" @@ -113,3 +115,106 @@ func TestDrainServerAcceptQueue(t *testing.T) { _, err = server.Accept(ctx) require.ErrorIs(t, err, quic.ErrServerClosed) } + +type brokenConn struct { + net.PacketConn + + broken chan struct{} + breakErr atomic.Pointer[error] +} + +func newBrokenConn(conn net.PacketConn) *brokenConn { + c := &brokenConn{ + PacketConn: conn, + broken: make(chan struct{}), + } + go func() { + <-c.broken + // make calls to ReadFrom return + c.PacketConn.SetDeadline(time.Now()) + }() + return c +} + +func (c *brokenConn) ReadFrom(b []byte) (int, net.Addr, error) { + if err := c.breakErr.Load(); err != nil { + return 0, nil, *err + } + n, addr, err := c.PacketConn.ReadFrom(b) + if err != nil { + select { + case <-c.broken: + err = *c.breakErr.Load() + default: + } + } + return n, addr, err +} + +func (c *brokenConn) Break(e error) { + c.breakErr.Store(&e) + close(c.broken) +} + +func TestTransportClose(t *testing.T) { + t.Run("Close", func(t *testing.T) { + conn := newUPDConnLocalhost(t) + testTransportClose(t, conn, func() { conn.Close() }, nil) + }) + + t.Run("connection error", func(t *testing.T) { + t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true") + + bc := newBrokenConn(newUPDConnLocalhost(t)) + testErr := errors.New("test error") + testTransportClose(t, bc, func() { bc.Break(testErr) }, testErr) + }) +} + +func testTransportClose(t *testing.T, conn net.PacketConn, closeFn func(), expectedErr error) { + server := newUPDConnLocalhost(t) + tr := &quic.Transport{Conn: conn} + + errChan := make(chan error, 1) + go func() { + _, err := tr.Dial(context.Background(), server.LocalAddr(), &tls.Config{}, getQuicConfig(nil)) + errChan <- err + }() + + select { + case <-errChan: + t.Fatal("didn't expect Dial to return yet") + case <-time.After(scaleDuration(10 * time.Millisecond)): + } + + closeFn() + + select { + case err := <-errChan: + require.Error(t, err) + require.ErrorIs(t, err, quic.ErrTransportClosed) + if expectedErr != nil { + require.ErrorIs(t, err, expectedErr) + } + case <-time.After(time.Second): + t.Fatal("timeout") + } + + // it's not possible to dial new connections + ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) + defer cancel() + _, err := tr.Dial(ctx, server.LocalAddr(), &tls.Config{}, getQuicConfig(nil)) + require.Error(t, err) + require.ErrorIs(t, err, quic.ErrTransportClosed) + if expectedErr != nil { + require.ErrorIs(t, err, expectedErr) + } + + // it's not possible to create new listeners + _, err = tr.Listen(&tls.Config{}, nil) + require.Error(t, err) + require.ErrorIs(t, err, quic.ErrTransportClosed) + if expectedErr != nil { + require.ErrorIs(t, err, expectedErr) + } +} diff --git a/transport.go b/transport.go index ab7a77fd..5d1b072e 100644 --- a/transport.go +++ b/transport.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/tls" "errors" + "fmt" "net" "sync" "sync/atomic" @@ -16,6 +17,27 @@ import ( "github.com/quic-go/quic-go/logging" ) +// ErrTransportClosed is returned by the Transport's Listen or Dial method after it was closed. +var ErrTransportClosed = &errTransportClosed{} + +type errTransportClosed struct { + err error +} + +func (e *errTransportClosed) Unwrap() []error { return []error{net.ErrClosed, e.err} } + +func (e *errTransportClosed) Error() string { + if e.err == nil { + return "quic: transport closed" + } + return fmt.Sprintf("quic: transport closed: %s", e.err) +} + +func (e *errTransportClosed) Is(target error) bool { + _, ok := target.(*errTransportClosed) + return ok +} + var errListenerAlreadySet = errors.New("listener already set") // The Transport is the central point to manage incoming and outgoing QUIC connections. @@ -126,7 +148,7 @@ type Transport struct { statelessResetQueue chan receivedPacket listening chan struct{} // is closed when listen returns - closed bool + closeErr error createdConn bool isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial @@ -169,6 +191,9 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo t.mutex.Lock() defer t.mutex.Unlock() + if t.closeErr != nil { + return nil, t.closeErr + } if t.server != nil { return nil, errListenerAlreadySet } @@ -211,6 +236,12 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C } func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) { + t.mutex.Lock() + if t.closeErr != nil { + t.mutex.Unlock() + return nil, t.closeErr + } + t.mutex.Unlock() if err := validateConfig(conf); err != nil { return nil, err } @@ -417,11 +448,11 @@ func (t *Transport) runSendQueue() { } } -// Close closes the underlying connection. +// Close stops listening for UDP datagrams on the Transport.Conn. // If any listener was started, it will be closed as well. // It is invalid to start new listeners or connections after that. func (t *Transport) Close() error { - t.close(errors.New("closing")) + t.close(nil) if t.createdConn { if err := t.Conn.Close(); err != nil { return err @@ -440,7 +471,7 @@ func (t *Transport) closeServer() { t.mutex.Lock() t.server = nil if t.isSingleUse { - t.closed = true + t.closeErr = ErrServerClosed } t.mutex.Unlock() if t.createdConn { @@ -456,10 +487,12 @@ func (t *Transport) closeServer() { func (t *Transport) close(e error) { t.mutex.Lock() defer t.mutex.Unlock() - if t.closed { + + if t.closeErr != nil { return } + e = &errTransportClosed{err: e} if t.handlerMap != nil { t.handlerMap.Close(e) } @@ -469,7 +502,7 @@ func (t *Transport) close(e error) { if t.Tracer != nil && t.Tracer.Close != nil { t.Tracer.Close() } - t.closed = true + t.closeErr = e } // only print warnings about the UDP receive buffer size once @@ -486,7 +519,7 @@ func (t *Transport) listen(conn rawConn) { // See https://github.com/quic-go/quic-go/issues/1737 for details. if nerr, ok := err.(net.Error); ok && nerr.Temporary() { t.mutex.Lock() - closed := t.closed + closed := t.closeErr != nil t.mutex.Unlock() if closed { return diff --git a/transport_test.go b/transport_test.go index 37ee580f..203d92fd 100644 --- a/transport_test.go +++ b/transport_test.go @@ -166,7 +166,9 @@ func TestTransportErrFromConn(t *testing.T) { t.Fatal("timeout") } - // TODO(#4778): test that it's not possible to listen after the transport is closed + _, err := tr.Listen(&tls.Config{}, nil) + require.Error(t, err) + require.ErrorIs(t, err, ErrTransportClosed) } func TestTransportStatelessResetReceiving(t *testing.T) {