• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the 'dialect' abstraction.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECT_H
14 #define MLIR_IR_DIALECT_H
15 
16 #include "mlir/IR/OperationSupport.h"
17 #include "mlir/Support/TypeID.h"
18 
19 #include <map>
20 
21 namespace mlir {
22 class DialectAsmParser;
23 class DialectAsmPrinter;
24 class DialectInterface;
25 class OpBuilder;
26 class Type;
27 
28 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
29 
30 /// Dialects are groups of MLIR operations, types and attributes, as well as
31 /// behavior associated with the entire group.  For example, hooks into other
32 /// systems for constant folding, interfaces, default named types for asm
33 /// printing, etc.
34 ///
35 /// Instances of the dialect object are loaded in a specific MLIRContext.
36 ///
37 class Dialect {
38 public:
39   virtual ~Dialect();
40 
41   /// Utility function that returns if the given string is a valid dialect
42   /// namespace.
43   static bool isValidNamespace(StringRef str);
44 
getContext()45   MLIRContext *getContext() const { return context; }
46 
getNamespace()47   StringRef getNamespace() const { return name; }
48 
49   /// Returns the unique identifier that corresponds to this dialect.
getTypeID()50   TypeID getTypeID() const { return dialectID; }
51 
52   /// Returns true if this dialect allows for unregistered operations, i.e.
53   /// operations prefixed with the dialect namespace but not registered with
54   /// addOperation.
allowsUnknownOperations()55   bool allowsUnknownOperations() const { return unknownOpsAllowed; }
56 
57   /// Return true if this dialect allows for unregistered types, i.e., types
58   /// prefixed with the dialect namespace but not registered with addType.
59   /// These are represented with OpaqueType.
allowsUnknownTypes()60   bool allowsUnknownTypes() const { return unknownTypesAllowed; }
61 
62   /// Registered hook to materialize a single constant operation from a given
63   /// attribute value with the desired resultant type. This method should use
64   /// the provided builder to create the operation without changing the
65   /// insertion position. The generated operation is expected to be constant
66   /// like, i.e. single result, zero operands, non side-effecting, etc. On
67   /// success, this hook should return the value generated to represent the
68   /// constant value. Otherwise, it should return null on failure.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)69   virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
70                                          Type type, Location loc) {
71     return nullptr;
72   }
73 
74   //===--------------------------------------------------------------------===//
75   // Parsing Hooks
76   //===--------------------------------------------------------------------===//
77 
78   /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
79   /// refers to the expected type of the attribute.
80   virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
81 
82   /// Print an attribute registered to this dialect. Note: The type of the
83   /// attribute need not be printed by this method as it is always printed by
84   /// the caller.
printAttribute(Attribute,DialectAsmPrinter &)85   virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
86     llvm_unreachable("dialect has no registered attribute printing hook");
87   }
88 
89   /// Parse a type registered to this dialect.
90   virtual Type parseType(DialectAsmParser &parser) const;
91 
92   /// Print a type registered to this dialect.
printType(Type,DialectAsmPrinter &)93   virtual void printType(Type, DialectAsmPrinter &) const {
94     llvm_unreachable("dialect has no registered type printing hook");
95   }
96 
97   //===--------------------------------------------------------------------===//
98   // Verification Hooks
99   //===--------------------------------------------------------------------===//
100 
101   /// Verify an attribute from this dialect on the argument at 'argIndex' for
102   /// the region at 'regionIndex' on the given operation. Returns failure if
103   /// the verification failed, success otherwise. This hook may optionally be
104   /// invoked from any operation containing a region.
105   virtual LogicalResult verifyRegionArgAttribute(Operation *,
106                                                  unsigned regionIndex,
107                                                  unsigned argIndex,
108                                                  NamedAttribute);
109 
110   /// Verify an attribute from this dialect on the result at 'resultIndex' for
111   /// the region at 'regionIndex' on the given operation. Returns failure if
112   /// the verification failed, success otherwise. This hook may optionally be
113   /// invoked from any operation containing a region.
114   virtual LogicalResult verifyRegionResultAttribute(Operation *,
115                                                     unsigned regionIndex,
116                                                     unsigned resultIndex,
117                                                     NamedAttribute);
118 
119   /// Verify an attribute from this dialect on the given operation. Returns
120   /// failure if the verification failed, success otherwise.
verifyOperationAttribute(Operation *,NamedAttribute)121   virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
122     return success();
123   }
124 
125   //===--------------------------------------------------------------------===//
126   // Interfaces
127   //===--------------------------------------------------------------------===//
128 
129   /// Lookup an interface for the given ID if one is registered, otherwise
130   /// nullptr.
getRegisteredInterface(TypeID interfaceID)131   const DialectInterface *getRegisteredInterface(TypeID interfaceID) {
132     auto it = registeredInterfaces.find(interfaceID);
133     return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
134   }
getRegisteredInterface()135   template <typename InterfaceT> const InterfaceT *getRegisteredInterface() {
136     return static_cast<const InterfaceT *>(
137         getRegisteredInterface(InterfaceT::getInterfaceID()));
138   }
139 
140 protected:
141   /// The constructor takes a unique namespace for this dialect as well as the
142   /// context to bind to.
143   /// Note: The namespace must not contain '.' characters.
144   /// Note: All operations belonging to this dialect must have names starting
145   ///       with the namespace followed by '.'.
146   /// Example:
147   ///       - "tf" for the TensorFlow ops like "tf.add".
148   Dialect(StringRef name, MLIRContext *context, TypeID id);
149 
150   /// This method is used by derived classes to add their operations to the set.
151   ///
addOperations()152   template <typename... Args> void addOperations() {
153     (void)std::initializer_list<int>{
154         0, (AbstractOperation::insert<Args>(*this), 0)...};
155   }
156 
157   /// Register a set of type classes with this dialect.
addTypes()158   template <typename... Args> void addTypes() {
159     (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
160   }
161 
162   /// Register a set of attribute classes with this dialect.
addAttributes()163   template <typename... Args> void addAttributes() {
164     (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
165   }
166 
167   /// Enable support for unregistered operations.
168   void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
169 
170   /// Enable support for unregistered types.
171   void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
172 
173   /// Register a dialect interface with this dialect instance.
174   void addInterface(std::unique_ptr<DialectInterface> interface);
175 
176   /// Register a set of dialect interfaces with this dialect instance.
addInterfaces()177   template <typename... Args> void addInterfaces() {
178     (void)std::initializer_list<int>{
179         0, (addInterface(std::make_unique<Args>(this)), 0)...};
180   }
181 
182 private:
183   Dialect(const Dialect &) = delete;
184   void operator=(Dialect &) = delete;
185 
186   /// Register an attribute instance with this dialect.
addAttribute()187   template <typename T> void addAttribute() {
188     // Add this attribute to the dialect and register it with the uniquer.
189     addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
190     detail::AttributeUniquer::registerAttribute<T>(context);
191   }
192   void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
193 
194   /// Register a type instance with this dialect.
addType()195   template <typename T> void addType() {
196     // Add this type to the dialect and register it with the uniquer.
197     addType(T::getTypeID(), AbstractType::get<T>(*this));
198     detail::TypeUniquer::registerType<T>(context);
199   }
200   void addType(TypeID typeID, AbstractType &&typeInfo);
201 
202   /// The namespace of this dialect.
203   StringRef name;
204 
205   /// The unique identifier of the derived Op class, this is used in the context
206   /// to allow registering multiple times the same dialect.
207   TypeID dialectID;
208 
209   /// This is the context that owns this Dialect object.
210   MLIRContext *context;
211 
212   /// Flag that specifies whether this dialect supports unregistered operations,
213   /// i.e. operations prefixed with the dialect namespace but not registered
214   /// with addOperation.
215   bool unknownOpsAllowed = false;
216 
217   /// Flag that specifies whether this dialect allows unregistered types, i.e.
218   /// types prefixed with the dialect namespace but not registered with addType.
219   /// These types are represented with OpaqueType.
220   bool unknownTypesAllowed = false;
221 
222   /// A collection of registered dialect interfaces.
223   DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
224 
225   friend void registerDialect();
226   friend class MLIRContext;
227 };
228 
229 /// The DialectRegistry maps a dialect namespace to a constructor for the
230 /// matching dialect.
231 /// This allows for decoupling the list of dialects "available" from the
232 /// dialects loaded in the Context. The parser in particular will lazily load
233 /// dialects in the Context as operations are encountered.
234 class DialectRegistry {
235   using MapTy =
236       std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
237 
238 public:
239   template <typename ConcreteDialect>
insert()240   void insert() {
241     insert(TypeID::get<ConcreteDialect>(),
242            ConcreteDialect::getDialectNamespace(),
243            static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
244              // Just allocate the dialect, the context
245              // takes ownership of it.
246              return ctx->getOrLoadDialect<ConcreteDialect>();
247            })));
248   }
249 
250   template <typename ConcreteDialect, typename OtherDialect,
251             typename... MoreDialects>
insert()252   void insert() {
253     insert<ConcreteDialect>();
254     insert<OtherDialect, MoreDialects...>();
255   }
256 
257   /// Add a new dialect constructor to the registry.
258   void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor);
259 
260   /// Load a dialect for this namespace in the provided context.
261   Dialect *loadByName(StringRef name, MLIRContext *context);
262 
263   // Register all dialects available in the current registry with the registry
264   // in the provided context.
appendTo(DialectRegistry & destination)265   void appendTo(DialectRegistry &destination) {
266     for (const auto &nameAndRegistrationIt : registry)
267       destination.insert(nameAndRegistrationIt.second.first,
268                          nameAndRegistrationIt.first,
269                          nameAndRegistrationIt.second.second);
270   }
271   // Load all dialects available in the registry in the provided context.
loadAll(MLIRContext * context)272   void loadAll(MLIRContext *context) {
273     for (const auto &nameAndRegistrationIt : registry)
274       nameAndRegistrationIt.second.second(context);
275   }
276 
begin()277   MapTy::const_iterator begin() const { return registry.begin(); }
end()278   MapTy::const_iterator end() const { return registry.end(); }
279 
280 private:
281   MapTy registry;
282 };
283 
284 } // namespace mlir
285 
286 namespace llvm {
287 /// Provide isa functionality for Dialects.
288 template <typename T>
289 struct isa_impl<T, ::mlir::Dialect> {
290   static inline bool doit(const ::mlir::Dialect &dialect) {
291     return mlir::TypeID::get<T>() == dialect.getTypeID();
292   }
293 };
294 } // namespace llvm
295 
296 #endif
297