1 //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the 'dialect' abstraction. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_DIALECT_H 14 #define MLIR_IR_DIALECT_H 15 16 #include "mlir/IR/OperationSupport.h" 17 #include "mlir/Support/TypeID.h" 18 19 #include <map> 20 21 namespace mlir { 22 class DialectAsmParser; 23 class DialectAsmPrinter; 24 class DialectInterface; 25 class OpBuilder; 26 class Type; 27 28 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>; 29 30 /// Dialects are groups of MLIR operations, types and attributes, as well as 31 /// behavior associated with the entire group. For example, hooks into other 32 /// systems for constant folding, interfaces, default named types for asm 33 /// printing, etc. 34 /// 35 /// Instances of the dialect object are loaded in a specific MLIRContext. 36 /// 37 class Dialect { 38 public: 39 virtual ~Dialect(); 40 41 /// Utility function that returns if the given string is a valid dialect 42 /// namespace. 43 static bool isValidNamespace(StringRef str); 44 getContext()45 MLIRContext *getContext() const { return context; } 46 getNamespace()47 StringRef getNamespace() const { return name; } 48 49 /// Returns the unique identifier that corresponds to this dialect. getTypeID()50 TypeID getTypeID() const { return dialectID; } 51 52 /// Returns true if this dialect allows for unregistered operations, i.e. 53 /// operations prefixed with the dialect namespace but not registered with 54 /// addOperation. allowsUnknownOperations()55 bool allowsUnknownOperations() const { return unknownOpsAllowed; } 56 57 /// Return true if this dialect allows for unregistered types, i.e., types 58 /// prefixed with the dialect namespace but not registered with addType. 59 /// These are represented with OpaqueType. allowsUnknownTypes()60 bool allowsUnknownTypes() const { return unknownTypesAllowed; } 61 62 /// Registered hook to materialize a single constant operation from a given 63 /// attribute value with the desired resultant type. This method should use 64 /// the provided builder to create the operation without changing the 65 /// insertion position. The generated operation is expected to be constant 66 /// like, i.e. single result, zero operands, non side-effecting, etc. On 67 /// success, this hook should return the value generated to represent the 68 /// constant value. Otherwise, it should return null on failure. materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)69 virtual Operation *materializeConstant(OpBuilder &builder, Attribute value, 70 Type type, Location loc) { 71 return nullptr; 72 } 73 74 //===--------------------------------------------------------------------===// 75 // Parsing Hooks 76 //===--------------------------------------------------------------------===// 77 78 /// Parse an attribute registered to this dialect. If 'type' is nonnull, it 79 /// refers to the expected type of the attribute. 80 virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const; 81 82 /// Print an attribute registered to this dialect. Note: The type of the 83 /// attribute need not be printed by this method as it is always printed by 84 /// the caller. printAttribute(Attribute,DialectAsmPrinter &)85 virtual void printAttribute(Attribute, DialectAsmPrinter &) const { 86 llvm_unreachable("dialect has no registered attribute printing hook"); 87 } 88 89 /// Parse a type registered to this dialect. 90 virtual Type parseType(DialectAsmParser &parser) const; 91 92 /// Print a type registered to this dialect. printType(Type,DialectAsmPrinter &)93 virtual void printType(Type, DialectAsmPrinter &) const { 94 llvm_unreachable("dialect has no registered type printing hook"); 95 } 96 97 //===--------------------------------------------------------------------===// 98 // Verification Hooks 99 //===--------------------------------------------------------------------===// 100 101 /// Verify an attribute from this dialect on the argument at 'argIndex' for 102 /// the region at 'regionIndex' on the given operation. Returns failure if 103 /// the verification failed, success otherwise. This hook may optionally be 104 /// invoked from any operation containing a region. 105 virtual LogicalResult verifyRegionArgAttribute(Operation *, 106 unsigned regionIndex, 107 unsigned argIndex, 108 NamedAttribute); 109 110 /// Verify an attribute from this dialect on the result at 'resultIndex' for 111 /// the region at 'regionIndex' on the given operation. Returns failure if 112 /// the verification failed, success otherwise. This hook may optionally be 113 /// invoked from any operation containing a region. 114 virtual LogicalResult verifyRegionResultAttribute(Operation *, 115 unsigned regionIndex, 116 unsigned resultIndex, 117 NamedAttribute); 118 119 /// Verify an attribute from this dialect on the given operation. Returns 120 /// failure if the verification failed, success otherwise. verifyOperationAttribute(Operation *,NamedAttribute)121 virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) { 122 return success(); 123 } 124 125 //===--------------------------------------------------------------------===// 126 // Interfaces 127 //===--------------------------------------------------------------------===// 128 129 /// Lookup an interface for the given ID if one is registered, otherwise 130 /// nullptr. getRegisteredInterface(TypeID interfaceID)131 const DialectInterface *getRegisteredInterface(TypeID interfaceID) { 132 auto it = registeredInterfaces.find(interfaceID); 133 return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; 134 } getRegisteredInterface()135 template <typename InterfaceT> const InterfaceT *getRegisteredInterface() { 136 return static_cast<const InterfaceT *>( 137 getRegisteredInterface(InterfaceT::getInterfaceID())); 138 } 139 140 protected: 141 /// The constructor takes a unique namespace for this dialect as well as the 142 /// context to bind to. 143 /// Note: The namespace must not contain '.' characters. 144 /// Note: All operations belonging to this dialect must have names starting 145 /// with the namespace followed by '.'. 146 /// Example: 147 /// - "tf" for the TensorFlow ops like "tf.add". 148 Dialect(StringRef name, MLIRContext *context, TypeID id); 149 150 /// This method is used by derived classes to add their operations to the set. 151 /// addOperations()152 template <typename... Args> void addOperations() { 153 (void)std::initializer_list<int>{ 154 0, (AbstractOperation::insert<Args>(*this), 0)...}; 155 } 156 157 /// Register a set of type classes with this dialect. addTypes()158 template <typename... Args> void addTypes() { 159 (void)std::initializer_list<int>{0, (addType<Args>(), 0)...}; 160 } 161 162 /// Register a set of attribute classes with this dialect. addAttributes()163 template <typename... Args> void addAttributes() { 164 (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...}; 165 } 166 167 /// Enable support for unregistered operations. 168 void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } 169 170 /// Enable support for unregistered types. 171 void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; } 172 173 /// Register a dialect interface with this dialect instance. 174 void addInterface(std::unique_ptr<DialectInterface> interface); 175 176 /// Register a set of dialect interfaces with this dialect instance. addInterfaces()177 template <typename... Args> void addInterfaces() { 178 (void)std::initializer_list<int>{ 179 0, (addInterface(std::make_unique<Args>(this)), 0)...}; 180 } 181 182 private: 183 Dialect(const Dialect &) = delete; 184 void operator=(Dialect &) = delete; 185 186 /// Register an attribute instance with this dialect. addAttribute()187 template <typename T> void addAttribute() { 188 // Add this attribute to the dialect and register it with the uniquer. 189 addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this)); 190 detail::AttributeUniquer::registerAttribute<T>(context); 191 } 192 void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); 193 194 /// Register a type instance with this dialect. addType()195 template <typename T> void addType() { 196 // Add this type to the dialect and register it with the uniquer. 197 addType(T::getTypeID(), AbstractType::get<T>(*this)); 198 detail::TypeUniquer::registerType<T>(context); 199 } 200 void addType(TypeID typeID, AbstractType &&typeInfo); 201 202 /// The namespace of this dialect. 203 StringRef name; 204 205 /// The unique identifier of the derived Op class, this is used in the context 206 /// to allow registering multiple times the same dialect. 207 TypeID dialectID; 208 209 /// This is the context that owns this Dialect object. 210 MLIRContext *context; 211 212 /// Flag that specifies whether this dialect supports unregistered operations, 213 /// i.e. operations prefixed with the dialect namespace but not registered 214 /// with addOperation. 215 bool unknownOpsAllowed = false; 216 217 /// Flag that specifies whether this dialect allows unregistered types, i.e. 218 /// types prefixed with the dialect namespace but not registered with addType. 219 /// These types are represented with OpaqueType. 220 bool unknownTypesAllowed = false; 221 222 /// A collection of registered dialect interfaces. 223 DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces; 224 225 friend void registerDialect(); 226 friend class MLIRContext; 227 }; 228 229 /// The DialectRegistry maps a dialect namespace to a constructor for the 230 /// matching dialect. 231 /// This allows for decoupling the list of dialects "available" from the 232 /// dialects loaded in the Context. The parser in particular will lazily load 233 /// dialects in the Context as operations are encountered. 234 class DialectRegistry { 235 using MapTy = 236 std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>; 237 238 public: 239 template <typename ConcreteDialect> insert()240 void insert() { 241 insert(TypeID::get<ConcreteDialect>(), 242 ConcreteDialect::getDialectNamespace(), 243 static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) { 244 // Just allocate the dialect, the context 245 // takes ownership of it. 246 return ctx->getOrLoadDialect<ConcreteDialect>(); 247 }))); 248 } 249 250 template <typename ConcreteDialect, typename OtherDialect, 251 typename... MoreDialects> insert()252 void insert() { 253 insert<ConcreteDialect>(); 254 insert<OtherDialect, MoreDialects...>(); 255 } 256 257 /// Add a new dialect constructor to the registry. 258 void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor); 259 260 /// Load a dialect for this namespace in the provided context. 261 Dialect *loadByName(StringRef name, MLIRContext *context); 262 263 // Register all dialects available in the current registry with the registry 264 // in the provided context. appendTo(DialectRegistry & destination)265 void appendTo(DialectRegistry &destination) { 266 for (const auto &nameAndRegistrationIt : registry) 267 destination.insert(nameAndRegistrationIt.second.first, 268 nameAndRegistrationIt.first, 269 nameAndRegistrationIt.second.second); 270 } 271 // Load all dialects available in the registry in the provided context. loadAll(MLIRContext * context)272 void loadAll(MLIRContext *context) { 273 for (const auto &nameAndRegistrationIt : registry) 274 nameAndRegistrationIt.second.second(context); 275 } 276 begin()277 MapTy::const_iterator begin() const { return registry.begin(); } end()278 MapTy::const_iterator end() const { return registry.end(); } 279 280 private: 281 MapTy registry; 282 }; 283 284 } // namespace mlir 285 286 namespace llvm { 287 /// Provide isa functionality for Dialects. 288 template <typename T> 289 struct isa_impl<T, ::mlir::Dialect> { 290 static inline bool doit(const ::mlir::Dialect &dialect) { 291 return mlir::TypeID::get<T>() == dialect.getTypeID(); 292 } 293 }; 294 } // namespace llvm 295 296 #endif 297