diff --git a/transport.go b/transport.go index 7590be32..34e0b21c 100644 --- a/transport.go +++ b/transport.go @@ -177,8 +177,6 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C func (t *Transport) init(isServer bool) error { t.initOnce.Do(func() { - getMultiplexer().AddConn(t.Conn) - var conn rawConn if c, ok := t.Conn.(rawConn); ok { conn = c @@ -212,6 +210,7 @@ func (t *Transport) init(isServer bool) error { t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} } + getMultiplexer().AddConn(t.Conn) go t.listen(conn) go t.runSendQueue() }) diff --git a/transport_test.go b/transport_test.go index c78742f8..f46affb3 100644 --- a/transport_test.go +++ b/transport_test.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "errors" "net" + "syscall" "time" mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" @@ -307,4 +308,27 @@ var _ = Describe("Transport", func() { pconn.EXPECT().Close() Expect(tr.Close()).To(Succeed()) }) + + It("doesn't add the PacketConn to the multiplexer if (*Transport).init fails", func() { + packetChan := make(chan packetToRead) + pconn := newMockPacketConn(packetChan) + syscallconn := &mockSyscallConn{pconn} + + tr := &Transport{ + Conn: syscallconn, + } + + err := tr.init(false) + Expect(err).To(HaveOccurred()) + conns := getMultiplexer().(*connMultiplexer).conns + Expect(len(conns)).To(BeZero()) + }) }) + +type mockSyscallConn struct { + net.PacketConn +} + +func (c *mockSyscallConn) SyscallConn() (syscall.RawConn, error) { + return nil, errors.New("mocked") +}