• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_MLIRCONTEXT_H
10 #define MLIR_IR_MLIRCONTEXT_H
11 
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Support/TypeID.h"
14 #include <functional>
15 #include <memory>
16 #include <vector>
17 
18 namespace mlir {
19 class AbstractOperation;
20 class DiagnosticEngine;
21 class Dialect;
22 class DialectRegistry;
23 class InFlightDiagnostic;
24 class Location;
25 class MLIRContextImpl;
26 class StorageUniquer;
27 
28 /// MLIRContext is the top-level object for a collection of MLIR operations. It
29 /// holds immortal uniqued objects like types, and the tables used to unique
30 /// them.
31 ///
32 /// MLIRContext gets a redundant "MLIR" prefix because otherwise it ends up with
33 /// a very generic name ("Context") and because it is uncommon for clients to
34 /// interact with it.
35 ///
36 class MLIRContext {
37 public:
38   /// Create a new Context.
39   /// The loadAllDialects parameters allows to load all dialects from the global
40   /// registry on Context construction. It is deprecated and will be removed
41   /// soon.
42   explicit MLIRContext();
43   ~MLIRContext();
44 
45   /// Return information about all IR dialects loaded in the context.
46   std::vector<Dialect *> getLoadedDialects();
47 
48   /// Return the dialect registry associated with this context.
49   DialectRegistry &getDialectRegistry();
50 
51   /// Return information about all available dialects in the registry in this
52   /// context.
53   std::vector<StringRef> getAvailableDialects();
54 
55   /// Get a registered IR dialect with the given namespace. If an exact match is
56   /// not found, then return nullptr.
57   Dialect *getLoadedDialect(StringRef name);
58 
59   /// Get a registered IR dialect for the given derived dialect type. The
60   /// derived type must provide a static 'getDialectNamespace' method.
61   template <typename T>
getLoadedDialect()62   T *getLoadedDialect() {
63     return static_cast<T *>(getLoadedDialect(T::getDialectNamespace()));
64   }
65 
66   /// Get (or create) a dialect for the given derived dialect type. The derived
67   /// type must provide a static 'getDialectNamespace' method.
68   template <typename T>
getOrLoadDialect()69   T *getOrLoadDialect() {
70     return static_cast<T *>(
71         getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
72           std::unique_ptr<T> dialect(new T(this));
73           return dialect;
74         }));
75   }
76 
77   /// Load a dialect in the context.
78   template <typename Dialect>
loadDialect()79   void loadDialect() {
80     getOrLoadDialect<Dialect>();
81   }
82 
83   /// Load a list dialects in the context.
84   template <typename Dialect, typename OtherDialect, typename... MoreDialects>
loadDialect()85   void loadDialect() {
86     getOrLoadDialect<Dialect>();
87     loadDialect<OtherDialect, MoreDialects...>();
88   }
89 
90   /// Get (or create) a dialect for the given derived dialect name.
91   /// The dialect will be loaded from the registry if no dialect is found.
92   /// If no dialect is loaded for this name and none is available in the
93   /// registry, returns nullptr.
94   Dialect *getOrLoadDialect(StringRef name);
95 
96   /// Return true if we allow to create operation for unregistered dialects.
97   bool allowsUnregisteredDialects();
98 
99   /// Enables creating operations in unregistered dialects.
100   void allowUnregisteredDialects(bool allow = true);
101 
102   /// Return true if multi-threading is enabled by the context.
103   bool isMultithreadingEnabled();
104 
105   /// Set the flag specifying if multi-threading is disabled by the context.
106   void disableMultithreading(bool disable = true);
107   void enableMultithreading(bool enable = true) {
108     disableMultithreading(!enable);
109   }
110 
111   /// Return true if we should attach the operation to diagnostics emitted via
112   /// Operation::emit.
113   bool shouldPrintOpOnDiagnostic();
114 
115   /// Set the flag specifying if we should attach the operation to diagnostics
116   /// emitted via Operation::emit.
117   void printOpOnDiagnostic(bool enable);
118 
119   /// Return true if we should attach the current stacktrace to diagnostics when
120   /// emitted.
121   bool shouldPrintStackTraceOnDiagnostic();
122 
123   /// Set the flag specifying if we should attach the current stacktrace when
124   /// emitting diagnostics.
125   void printStackTraceOnDiagnostic(bool enable);
126 
127   /// Return information about all registered operations.  This isn't very
128   /// efficient: typically you should ask the operations about their properties
129   /// directly.
130   std::vector<AbstractOperation *> getRegisteredOperations();
131 
132   /// Return true if this operation name is registered in this context.
133   bool isOperationRegistered(StringRef name);
134 
135   // This is effectively private given that only MLIRContext.cpp can see the
136   // MLIRContextImpl type.
getImpl()137   MLIRContextImpl &getImpl() { return *impl; }
138 
139   /// Returns the diagnostic engine for this context.
140   DiagnosticEngine &getDiagEngine();
141 
142   /// Returns the storage uniquer used for creating affine constructs.
143   StorageUniquer &getAffineUniquer();
144 
145   /// Returns the storage uniquer used for constructing type storage instances.
146   /// This should not be used directly.
147   StorageUniquer &getTypeUniquer();
148 
149   /// Returns the storage uniquer used for constructing attribute storage
150   /// instances. This should not be used directly.
151   StorageUniquer &getAttributeUniquer();
152 
153   /// These APIs are tracking whether the context will be used in a
154   /// multithreading environment: this has no effect other than enabling
155   /// assertions on misuses of some APIs.
156   void enterMultiThreadedExecution();
157   void exitMultiThreadedExecution();
158 
159 private:
160   const std::unique_ptr<MLIRContextImpl> impl;
161 
162   /// Get a dialect for the provided namespace and TypeID: abort the program if
163   /// a dialect exist for this namespace with different TypeID. If a dialect has
164   /// not been loaded for this namespace/TypeID yet, use the provided ctor to
165   /// create one on the fly and load it. Returns a pointer to the dialect owned
166   /// by the context.
167   Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
168                             function_ref<std::unique_ptr<Dialect>()> ctor);
169 
170   MLIRContext(const MLIRContext &) = delete;
171   void operator=(const MLIRContext &) = delete;
172 };
173 
174 //===----------------------------------------------------------------------===//
175 // MLIRContext CommandLine Options
176 //===----------------------------------------------------------------------===//
177 
178 /// Register a set of useful command-line options that can be used to configure
179 /// various flags within the MLIRContext. These flags are used when constructing
180 /// an MLIR context for initialization.
181 void registerMLIRContextCLOptions();
182 
183 } // end namespace mlir
184 
185 #endif // MLIR_IR_MLIRCONTEXT_H
186