From 0c54d416eeb10c489deae30a853dde4553e94e96 Mon Sep 17 00:00:00 2001 From: kelmenhorst <45046038+kelmenhorst@users.noreply.github.com> Date: Thu, 29 Jun 2023 19:35:16 +0200 Subject: [PATCH] transport: don't add connection to multiplexer if init fails (#3931) * Remove conn from multiplexer when (*Transport).init fails * Transport: AddConn to multiplexer directly before start listening * Update transport_test.go --------- Co-authored-by: Marten Seemann --- transport.go | 3 +-- transport_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/transport.go b/transport.go index 7590be32..34e0b21c 100644 --- a/transport.go +++ b/transport.go @@ -177,8 +177,6 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C func (t *Transport) init(isServer bool) error { t.initOnce.Do(func() { - getMultiplexer().AddConn(t.Conn) - var conn rawConn if c, ok := t.Conn.(rawConn); ok { conn = c @@ -212,6 +210,7 @@ func (t *Transport) init(isServer bool) error { t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} } + getMultiplexer().AddConn(t.Conn) go t.listen(conn) go t.runSendQueue() }) diff --git a/transport_test.go b/transport_test.go index c78742f8..f46affb3 100644 --- a/transport_test.go +++ b/transport_test.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "errors" "net" + "syscall" "time" mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" @@ -307,4 +308,27 @@ var _ = Describe("Transport", func() { pconn.EXPECT().Close() Expect(tr.Close()).To(Succeed()) }) + + It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() { + packetChan := make(chan packetToRead) + pconn := newMockPacketConn(packetChan) + syscallconn := &mockSyscallConn{pconn} + + tr := &Transport{ + Conn: syscallconn, + } + + err := tr.init(false) + Expect(err).To(HaveOccurred()) + conns := getMultiplexer().(*connMultiplexer).conns + Expect(len(conns)).To(BeZero()) + }) }) + +type mockSyscallConn struct { + net.PacketConn +} + +func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) { + return nil, errors.New("mocked") +}