• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "LocationDetail.h"
15 #include "TypeDetail.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/Identifier.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/Types.h"
27 #include "mlir/Support/ThreadLocalCache.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/ADT/Twine.h"
33 #include "llvm/Support/Allocator.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/RWMutex.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <memory>
39 
40 #define DEBUG_TYPE "mlircontext"
41 
42 using namespace mlir;
43 using namespace mlir::detail;
44 
45 using llvm::hash_combine;
46 using llvm::hash_combine_range;
47 
48 //===----------------------------------------------------------------------===//
49 // MLIRContext CommandLine Options
50 //===----------------------------------------------------------------------===//
51 
52 namespace {
53 /// This struct contains command line options that can be used to initialize
54 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
55 /// for global command line options.
56 struct MLIRContextOptions {
57   llvm::cl::opt<bool> disableThreading{
58       "mlir-disable-threading",
59       llvm::cl::desc("Disabling multi-threading within MLIR")};
60 
61   llvm::cl::opt<bool> printOpOnDiagnostic{
62       "mlir-print-op-on-diagnostic",
63       llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
64                      "the operation as an attached note"),
65       llvm::cl::init(true)};
66 
67   llvm::cl::opt<bool> printStackTraceOnDiagnostic{
68       "mlir-print-stacktrace-on-diagnostic",
69       llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
70                      "as an attached note")};
71 };
72 } // end anonymous namespace
73 
74 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
75 
76 /// Register a set of useful command-line options that can be used to configure
77 /// various flags within the MLIRContext. These flags are used when constructing
78 /// an MLIR context for initialization.
registerMLIRContextCLOptions()79 void mlir::registerMLIRContextCLOptions() {
80   // Make sure that the options struct has been initialized.
81   *clOptions;
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // Locking Utilities
86 //===----------------------------------------------------------------------===//
87 
88 namespace {
89 /// Utility reader lock that takes a runtime flag that specifies if we really
90 /// need to lock.
91 struct ScopedReaderLock {
ScopedReaderLock__anon5fe09d030211::ScopedReaderLock92   ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
93       : mutex(shouldLock ? &mutexParam : nullptr) {
94     if (mutex)
95       mutex->lock_shared();
96   }
~ScopedReaderLock__anon5fe09d030211::ScopedReaderLock97   ~ScopedReaderLock() {
98     if (mutex)
99       mutex->unlock_shared();
100   }
101   llvm::sys::SmartRWMutex<true> *mutex;
102 };
103 /// Utility writer lock that takes a runtime flag that specifies if we really
104 /// need to lock.
105 struct ScopedWriterLock {
ScopedWriterLock__anon5fe09d030211::ScopedWriterLock106   ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
107       : mutex(shouldLock ? &mutexParam : nullptr) {
108     if (mutex)
109       mutex->lock();
110   }
~ScopedWriterLock__anon5fe09d030211::ScopedWriterLock111   ~ScopedWriterLock() {
112     if (mutex)
113       mutex->unlock();
114   }
115   llvm::sys::SmartRWMutex<true> *mutex;
116 };
117 } // end anonymous namespace.
118 
119 //===----------------------------------------------------------------------===//
120 // AffineMap and IntegerSet hashing
121 //===----------------------------------------------------------------------===//
122 
123 /// A utility function to safely get or create a uniqued instance within the
124 /// given set container.
125 template <typename ValueT, typename DenseInfoT, typename KeyT,
126           typename ConstructorFn>
safeGetOrCreate(DenseSet<ValueT,DenseInfoT> & container,KeyT && key,llvm::sys::SmartRWMutex<true> & mutex,bool threadingIsEnabled,ConstructorFn && constructorFn)127 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
128                               KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex,
129                               bool threadingIsEnabled,
130                               ConstructorFn &&constructorFn) {
131   // Check for an existing instance in read-only mode.
132   if (threadingIsEnabled) {
133     llvm::sys::SmartScopedReader<true> instanceLock(mutex);
134     auto it = container.find_as(key);
135     if (it != container.end())
136       return *it;
137   }
138 
139   // Acquire a writer-lock so that we can safely create the new instance.
140   ScopedWriterLock instanceLock(mutex, threadingIsEnabled);
141 
142   // Check for an existing instance again here, because another writer thread
143   // may have already created one. Otherwise, construct a new instance.
144   auto existing = container.insert_as(ValueT(), key);
145   if (existing.second)
146     return *existing.first = constructorFn();
147   return *existing.first;
148 }
149 
150 namespace {
151 struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
152   // Affine maps are uniqued based on their dim/symbol counts and affine
153   // expressions.
154   using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>;
155   using DenseMapInfo<AffineMap>::isEqual;
156 
getHashValue__anon5fe09d030311::AffineMapKeyInfo157   static unsigned getHashValue(const AffineMap &key) {
158     return getHashValue(
159         KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults()));
160   }
161 
getHashValue__anon5fe09d030311::AffineMapKeyInfo162   static unsigned getHashValue(KeyTy key) {
163     return hash_combine(
164         std::get<0>(key), std::get<1>(key),
165         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
166   }
167 
isEqual__anon5fe09d030311::AffineMapKeyInfo168   static bool isEqual(const KeyTy &lhs, AffineMap rhs) {
169     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
170       return false;
171     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
172                                   rhs.getResults());
173   }
174 };
175 
176 struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
177   // Integer sets are uniqued based on their dim/symbol counts, affine
178   // expressions appearing in the LHS of constraints, and eqFlags.
179   using KeyTy =
180       std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>;
181   using DenseMapInfo<IntegerSet>::isEqual;
182 
getHashValue__anon5fe09d030311::IntegerSetKeyInfo183   static unsigned getHashValue(const IntegerSet &key) {
184     return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(),
185                               key.getConstraints(), key.getEqFlags()));
186   }
187 
getHashValue__anon5fe09d030311::IntegerSetKeyInfo188   static unsigned getHashValue(KeyTy key) {
189     return hash_combine(
190         std::get<0>(key), std::get<1>(key),
191         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
192         hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
193   }
194 
isEqual__anon5fe09d030311::IntegerSetKeyInfo195   static bool isEqual(const KeyTy &lhs, IntegerSet rhs) {
196     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
197       return false;
198     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
199                                   rhs.getConstraints(), rhs.getEqFlags());
200   }
201 };
202 } // end anonymous namespace.
203 
204 //===----------------------------------------------------------------------===//
205 // MLIRContextImpl
206 //===----------------------------------------------------------------------===//
207 
208 namespace mlir {
209 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
210 /// This class is completely private to this file, so everything is public.
211 class MLIRContextImpl {
212 public:
213   //===--------------------------------------------------------------------===//
214   // Identifier uniquing
215   //===--------------------------------------------------------------------===//
216 
217   // Identifier allocator and mutex for thread safety.
218   llvm::BumpPtrAllocator identifierAllocator;
219   llvm::sys::SmartRWMutex<true> identifierMutex;
220 
221   //===--------------------------------------------------------------------===//
222   // Diagnostics
223   //===--------------------------------------------------------------------===//
224   DiagnosticEngine diagEngine;
225 
226   //===--------------------------------------------------------------------===//
227   // Options
228   //===--------------------------------------------------------------------===//
229 
230   /// In most cases, creating operation in unregistered dialect is not desired
231   /// and indicate a misconfiguration of the compiler. This option enables to
232   /// detect such use cases
233   bool allowUnregisteredDialects = false;
234 
235   /// Enable support for multi-threading within MLIR.
236   bool threadingIsEnabled = true;
237 
238   /// Track if we are currently executing in a threaded execution environment
239   /// (like the pass-manager): this is only a debugging feature to help reducing
240   /// the chances of data races one some context APIs.
241 #ifndef NDEBUG
242   std::atomic<int> multiThreadedExecutionContext{0};
243 #endif
244 
245   /// If the operation should be attached to diagnostics printed via the
246   /// Operation::emit methods.
247   bool printOpOnDiagnostic = true;
248 
249   /// If the current stack trace should be attached when emitting diagnostics.
250   bool printStackTraceOnDiagnostic = false;
251 
252   //===--------------------------------------------------------------------===//
253   // Other
254   //===--------------------------------------------------------------------===//
255 
256   /// This is a list of dialects that are created referring to this context.
257   /// The MLIRContext owns the objects.
258   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
259   DialectRegistry dialectsRegistry;
260 
261   /// This is a mapping from operation name to AbstractOperation for registered
262   /// operations.
263   llvm::StringMap<AbstractOperation> registeredOperations;
264 
265   /// Identifiers are uniqued by string value and use the internal string set
266   /// for storage.
267   llvm::StringSet<llvm::BumpPtrAllocator &> identifiers;
268   /// A thread local cache of identifiers to reduce lock contention.
269   ThreadLocalCache<llvm::StringMap<llvm::StringMapEntry<llvm::NoneType> *>>
270       localIdentifierCache;
271 
272   /// An allocator used for AbstractAttribute and AbstractType objects.
273   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
274 
275   //===--------------------------------------------------------------------===//
276   // Affine uniquing
277   //===--------------------------------------------------------------------===//
278 
279   // Affine allocator and mutex for thread safety.
280   llvm::BumpPtrAllocator affineAllocator;
281   llvm::sys::SmartRWMutex<true> affineMutex;
282 
283   // Affine map uniquing.
284   using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
285   AffineMapSet affineMaps;
286 
287   // Integer set uniquing.
288   using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>;
289   IntegerSets integerSets;
290 
291   // Affine expression uniquing.
292   StorageUniquer affineUniquer;
293 
294   //===--------------------------------------------------------------------===//
295   // Type uniquing
296   //===--------------------------------------------------------------------===//
297 
298   DenseMap<TypeID, const AbstractType *> registeredTypes;
299   StorageUniquer typeUniquer;
300 
301   /// Cached Type Instances.
302   BFloat16Type bf16Ty;
303   Float16Type f16Ty;
304   Float32Type f32Ty;
305   Float64Type f64Ty;
306   IndexType indexTy;
307   IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
308   NoneType noneType;
309 
310   //===--------------------------------------------------------------------===//
311   // Attribute uniquing
312   //===--------------------------------------------------------------------===//
313 
314   DenseMap<TypeID, const AbstractAttribute *> registeredAttributes;
315   StorageUniquer attributeUniquer;
316 
317   /// Cached Attribute Instances.
318   BoolAttr falseAttr, trueAttr;
319   UnitAttr unitAttr;
320   UnknownLoc unknownLocAttr;
321   DictionaryAttr emptyDictionaryAttr;
322 
323 public:
MLIRContextImpl()324   MLIRContextImpl() : identifiers(identifierAllocator) {}
~MLIRContextImpl()325   ~MLIRContextImpl() {
326     for (auto typeMapping : registeredTypes)
327       typeMapping.second->~AbstractType();
328     for (auto attrMapping : registeredAttributes)
329       attrMapping.second->~AbstractAttribute();
330   }
331 };
332 } // end namespace mlir
333 
MLIRContext()334 MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
335   // Initialize values based on the command line flags if they were provided.
336   if (clOptions.isConstructed()) {
337     disableMultithreading(clOptions->disableThreading);
338     printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
339     printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
340   }
341 
342   // Ensure the builtin dialect is always pre-loaded.
343   getOrLoadDialect<BuiltinDialect>();
344 
345   // Initialize several common attributes and types to avoid the need to lock
346   // the context when accessing them.
347 
348   //// Types.
349   /// Floating-point Types.
350   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
351   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
352   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
353   impl->f64Ty = TypeUniquer::get<Float64Type>(this);
354   /// Index Type.
355   impl->indexTy = TypeUniquer::get<IndexType>(this);
356   /// Integer Types.
357   impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
358   impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
359   impl->int16Ty =
360       TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
361   impl->int32Ty =
362       TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
363   impl->int64Ty =
364       TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
365   impl->int128Ty =
366       TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
367   /// None Type.
368   impl->noneType = TypeUniquer::get<NoneType>(this);
369 
370   //// Attributes.
371   //// Note: These must be registered after the types as they may generate one
372   //// of the above types internally.
373   /// Bool Attributes.
374   impl->falseAttr = AttributeUniquer::get<IntegerAttr>(
375                         this, impl->int1Ty, APInt(/*numBits=*/1, false))
376                         .cast<BoolAttr>();
377   impl->trueAttr = AttributeUniquer::get<IntegerAttr>(
378                        this, impl->int1Ty, APInt(/*numBits=*/1, true))
379                        .cast<BoolAttr>();
380   /// Unit Attribute.
381   impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
382   /// Unknown Location Attribute.
383   impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
384   /// The empty dictionary attribute.
385   impl->emptyDictionaryAttr =
386       AttributeUniquer::get<DictionaryAttr>(this, ArrayRef<NamedAttribute>());
387 
388   // Register the affine storage objects with the uniquer.
389   impl->affineUniquer
390       .registerParametricStorageType<AffineBinaryOpExprStorage>();
391   impl->affineUniquer
392       .registerParametricStorageType<AffineConstantExprStorage>();
393   impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
394 }
395 
~MLIRContext()396 MLIRContext::~MLIRContext() {}
397 
398 /// Copy the specified array of elements into memory managed by the provided
399 /// bump pointer allocator.  This assumes the elements are all PODs.
400 template <typename T>
copyArrayRefInto(llvm::BumpPtrAllocator & allocator,ArrayRef<T> elements)401 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
402                                     ArrayRef<T> elements) {
403   auto result = allocator.Allocate<T>(elements.size());
404   std::uninitialized_copy(elements.begin(), elements.end(), result);
405   return ArrayRef<T>(result, elements.size());
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // Diagnostic Handlers
410 //===----------------------------------------------------------------------===//
411 
412 /// Returns the diagnostic engine for this context.
getDiagEngine()413 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
414 
415 //===----------------------------------------------------------------------===//
416 // Dialect and Operation Registration
417 //===----------------------------------------------------------------------===//
418 
getDialectRegistry()419 DialectRegistry &MLIRContext::getDialectRegistry() {
420   return impl->dialectsRegistry;
421 }
422 
423 /// Return information about all registered IR dialects.
getLoadedDialects()424 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
425   std::vector<Dialect *> result;
426   result.reserve(impl->loadedDialects.size());
427   for (auto &dialect : impl->loadedDialects)
428     result.push_back(dialect.second.get());
429   llvm::array_pod_sort(result.begin(), result.end(),
430                        [](Dialect *const *lhs, Dialect *const *rhs) -> int {
431                          return (*lhs)->getNamespace() < (*rhs)->getNamespace();
432                        });
433   return result;
434 }
getAvailableDialects()435 std::vector<StringRef> MLIRContext::getAvailableDialects() {
436   std::vector<StringRef> result;
437   for (auto &dialect : impl->dialectsRegistry)
438     result.push_back(dialect.first);
439   return result;
440 }
441 
442 /// Get a registered IR dialect with the given namespace. If none is found,
443 /// then return nullptr.
getLoadedDialect(StringRef name)444 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
445   // Dialects are sorted by name, so we can use binary search for lookup.
446   auto it = impl->loadedDialects.find(name);
447   return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
448 }
449 
getOrLoadDialect(StringRef name)450 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
451   Dialect *dialect = getLoadedDialect(name);
452   if (dialect)
453     return dialect;
454   return impl->dialectsRegistry.loadByName(name, this);
455 }
456 
457 /// Get a dialect for the provided namespace and TypeID: abort the program if a
458 /// dialect exist for this namespace with different TypeID. Returns a pointer to
459 /// the dialect owned by the context.
460 Dialect *
getOrLoadDialect(StringRef dialectNamespace,TypeID dialectID,function_ref<std::unique_ptr<Dialect> ()> ctor)461 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
462                               function_ref<std::unique_ptr<Dialect>()> ctor) {
463   auto &impl = getImpl();
464   // Get the correct insertion position sorted by namespace.
465   std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
466 
467   if (!dialect) {
468     LLVM_DEBUG(llvm::dbgs()
469                << "Load new dialect in Context " << dialectNamespace << "\n");
470 #ifndef NDEBUG
471     if (impl.multiThreadedExecutionContext != 0)
472       llvm::report_fatal_error(
473           "Loading a dialect (" + dialectNamespace +
474           ") while in a multi-threaded execution context (maybe "
475           "the PassManager): this can indicate a "
476           "missing `dependentDialects` in a pass for example.");
477 #endif
478     dialect = ctor();
479     assert(dialect && "dialect ctor failed");
480     return dialect.get();
481   }
482 
483   // Abort if dialect with namespace has already been registered.
484   if (dialect->getTypeID() != dialectID)
485     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
486                              "' has already been registered");
487 
488   return dialect.get();
489 }
490 
allowsUnregisteredDialects()491 bool MLIRContext::allowsUnregisteredDialects() {
492   return impl->allowUnregisteredDialects;
493 }
494 
allowUnregisteredDialects(bool allowing)495 void MLIRContext::allowUnregisteredDialects(bool allowing) {
496   impl->allowUnregisteredDialects = allowing;
497 }
498 
499 /// Return true if multi-threading is disabled by the context.
isMultithreadingEnabled()500 bool MLIRContext::isMultithreadingEnabled() {
501   return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
502 }
503 
504 /// Set the flag specifying if multi-threading is disabled by the context.
disableMultithreading(bool disable)505 void MLIRContext::disableMultithreading(bool disable) {
506   impl->threadingIsEnabled = !disable;
507 
508   // Update the threading mode for each of the uniquers.
509   impl->affineUniquer.disableMultithreading(disable);
510   impl->attributeUniquer.disableMultithreading(disable);
511   impl->typeUniquer.disableMultithreading(disable);
512 }
513 
enterMultiThreadedExecution()514 void MLIRContext::enterMultiThreadedExecution() {
515 #ifndef NDEBUG
516   ++impl->multiThreadedExecutionContext;
517 #endif
518 }
exitMultiThreadedExecution()519 void MLIRContext::exitMultiThreadedExecution() {
520 #ifndef NDEBUG
521   --impl->multiThreadedExecutionContext;
522 #endif
523 }
524 
525 /// Return true if we should attach the operation to diagnostics emitted via
526 /// Operation::emit.
shouldPrintOpOnDiagnostic()527 bool MLIRContext::shouldPrintOpOnDiagnostic() {
528   return impl->printOpOnDiagnostic;
529 }
530 
531 /// Set the flag specifying if we should attach the operation to diagnostics
532 /// emitted via Operation::emit.
printOpOnDiagnostic(bool enable)533 void MLIRContext::printOpOnDiagnostic(bool enable) {
534   impl->printOpOnDiagnostic = enable;
535 }
536 
537 /// Return true if we should attach the current stacktrace to diagnostics when
538 /// emitted.
shouldPrintStackTraceOnDiagnostic()539 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
540   return impl->printStackTraceOnDiagnostic;
541 }
542 
543 /// Set the flag specifying if we should attach the current stacktrace when
544 /// emitting diagnostics.
printStackTraceOnDiagnostic(bool enable)545 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
546   impl->printStackTraceOnDiagnostic = enable;
547 }
548 
549 /// Return information about all registered operations.  This isn't very
550 /// efficient, typically you should ask the operations about their properties
551 /// directly.
getRegisteredOperations()552 std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
553   // We just have the operations in a non-deterministic hash table order. Dump
554   // into a temporary array, then sort it by operation name to get a stable
555   // ordering.
556   llvm::StringMap<AbstractOperation> &registeredOps =
557       impl->registeredOperations;
558 
559   std::vector<AbstractOperation *> result;
560   result.reserve(registeredOps.size());
561   for (auto &elt : registeredOps)
562     result.push_back(&elt.second);
563   llvm::array_pod_sort(
564       result.begin(), result.end(),
565       [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) {
566         return (*lhs)->name.compare((*rhs)->name);
567       });
568 
569   return result;
570 }
571 
isOperationRegistered(StringRef name)572 bool MLIRContext::isOperationRegistered(StringRef name) {
573   return impl->registeredOperations.count(name);
574 }
575 
addType(TypeID typeID,AbstractType && typeInfo)576 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
577   auto &impl = context->getImpl();
578   assert(impl.multiThreadedExecutionContext == 0 &&
579          "Registering a new type kind while in a multi-threaded execution "
580          "context");
581   auto *newInfo =
582       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
583           AbstractType(std::move(typeInfo));
584   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
585     llvm::report_fatal_error("Dialect Type already registered.");
586 }
587 
addAttribute(TypeID typeID,AbstractAttribute && attrInfo)588 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
589   auto &impl = context->getImpl();
590   assert(impl.multiThreadedExecutionContext == 0 &&
591          "Registering a new attribute kind while in a multi-threaded execution "
592          "context");
593   auto *newInfo =
594       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
595           AbstractAttribute(std::move(attrInfo));
596   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
597     llvm::report_fatal_error("Dialect Attribute already registered.");
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // AbstractAttribute
602 //===----------------------------------------------------------------------===//
603 
604 /// Get the dialect that registered the attribute with the provided typeid.
lookup(TypeID typeID,MLIRContext * context)605 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
606                                                    MLIRContext *context) {
607   auto &impl = context->getImpl();
608   auto it = impl.registeredAttributes.find(typeID);
609   if (it == impl.registeredAttributes.end())
610     llvm::report_fatal_error("Trying to create an Attribute that was not "
611                              "registered in this MLIRContext.");
612   return *it->second;
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // AbstractOperation
617 //===----------------------------------------------------------------------===//
618 
parseAssembly(OpAsmParser & parser,OperationState & result) const619 ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser,
620                                              OperationState &result) const {
621   return parseAssemblyFn(parser, result);
622 }
623 
624 /// Look up the specified operation in the operation set and return a pointer
625 /// to it if present. Otherwise, return a null pointer.
lookup(StringRef opName,MLIRContext * context)626 const AbstractOperation *AbstractOperation::lookup(StringRef opName,
627                                                    MLIRContext *context) {
628   auto &impl = context->getImpl();
629   auto it = impl.registeredOperations.find(opName);
630   if (it != impl.registeredOperations.end())
631     return &it->second;
632   return nullptr;
633 }
634 
insert(StringRef name,Dialect & dialect,OperationProperties opProperties,TypeID typeID,ParseAssemblyFn parseAssembly,PrintAssemblyFn printAssembly,VerifyInvariantsFn verifyInvariants,FoldHookFn foldHook,GetCanonicalizationPatternsFn getCanonicalizationPatterns,detail::InterfaceMap && interfaceMap,HasTraitFn hasTrait)635 void AbstractOperation::insert(
636     StringRef name, Dialect &dialect, OperationProperties opProperties,
637     TypeID typeID, ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
638     VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
639     GetCanonicalizationPatternsFn getCanonicalizationPatterns,
640     detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait) {
641   AbstractOperation opInfo(name, dialect, opProperties, typeID, parseAssembly,
642                            printAssembly, verifyInvariants, foldHook,
643                            getCanonicalizationPatterns, std::move(interfaceMap),
644                            hasTrait);
645 
646   auto &impl = dialect.getContext()->getImpl();
647   assert(impl.multiThreadedExecutionContext == 0 &&
648          "Registering a new operation kind while in a multi-threaded execution "
649          "context");
650   if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) {
651     llvm::errs() << "error: operation named '" << name
652                  << "' is already registered.\n";
653     abort();
654   }
655 }
656 
AbstractOperation(StringRef name,Dialect & dialect,OperationProperties opProperties,TypeID typeID,ParseAssemblyFn parseAssembly,PrintAssemblyFn printAssembly,VerifyInvariantsFn verifyInvariants,FoldHookFn foldHook,GetCanonicalizationPatternsFn getCanonicalizationPatterns,detail::InterfaceMap && interfaceMap,HasTraitFn hasTrait)657 AbstractOperation::AbstractOperation(
658     StringRef name, Dialect &dialect, OperationProperties opProperties,
659     TypeID typeID, ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly,
660     VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
661     GetCanonicalizationPatternsFn getCanonicalizationPatterns,
662     detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait)
663     : name(Identifier::get(name, dialect.getContext())), dialect(dialect),
664       typeID(typeID), opProperties(opProperties),
665       interfaceMap(std::move(interfaceMap)), foldHookFn(foldHook),
666       getCanonicalizationPatternsFn(getCanonicalizationPatterns),
667       hasTraitFn(hasTrait), parseAssemblyFn(parseAssembly),
668       printAssemblyFn(printAssembly), verifyInvariantsFn(verifyInvariants) {}
669 
670 //===----------------------------------------------------------------------===//
671 // AbstractType
672 //===----------------------------------------------------------------------===//
673 
lookup(TypeID typeID,MLIRContext * context)674 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
675   auto &impl = context->getImpl();
676   auto it = impl.registeredTypes.find(typeID);
677   if (it == impl.registeredTypes.end())
678     llvm::report_fatal_error(
679         "Trying to create a Type that was not registered in this MLIRContext.");
680   return *it->second;
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // Identifier uniquing
685 //===----------------------------------------------------------------------===//
686 
687 /// Return an identifier for the specified string.
get(StringRef str,MLIRContext * context)688 Identifier Identifier::get(StringRef str, MLIRContext *context) {
689   // Check invariants after seeing if we already have something in the
690   // identifier table - if we already had it in the table, then it already
691   // passed invariant checks.
692   assert(!str.empty() && "Cannot create an empty identifier");
693   assert(str.find('\0') == StringRef::npos &&
694          "Cannot create an identifier with a nul character");
695 
696   auto &impl = context->getImpl();
697   if (!context->isMultithreadingEnabled())
698     return Identifier(&*impl.identifiers.insert(str).first);
699 
700   // Check for an existing instance in the local cache.
701   auto *&localEntry = (*impl.localIdentifierCache)[str];
702   if (localEntry)
703     return Identifier(localEntry);
704 
705   // Check for an existing identifier in read-only mode.
706   {
707     llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
708     auto it = impl.identifiers.find(str);
709     if (it != impl.identifiers.end()) {
710       localEntry = &*it;
711       return Identifier(localEntry);
712     }
713   }
714 
715   // Acquire a writer-lock so that we can safely create the new instance.
716   llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
717   auto it = impl.identifiers.insert(str).first;
718   localEntry = &*it;
719   return Identifier(localEntry);
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // Type uniquing
724 //===----------------------------------------------------------------------===//
725 
726 /// Returns the storage uniquer used for constructing type storage instances.
727 /// This should not be used directly.
getTypeUniquer()728 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
729 
get(MLIRContext * context)730 BFloat16Type BFloat16Type::get(MLIRContext *context) {
731   return context->getImpl().bf16Ty;
732 }
get(MLIRContext * context)733 Float16Type Float16Type::get(MLIRContext *context) {
734   return context->getImpl().f16Ty;
735 }
get(MLIRContext * context)736 Float32Type Float32Type::get(MLIRContext *context) {
737   return context->getImpl().f32Ty;
738 }
get(MLIRContext * context)739 Float64Type Float64Type::get(MLIRContext *context) {
740   return context->getImpl().f64Ty;
741 }
742 
743 /// Get an instance of the IndexType.
get(MLIRContext * context)744 IndexType IndexType::get(MLIRContext *context) {
745   return context->getImpl().indexTy;
746 }
747 
748 /// Return an existing integer type instance if one is cached within the
749 /// context.
750 static IntegerType
getCachedIntegerType(unsigned width,IntegerType::SignednessSemantics signedness,MLIRContext * context)751 getCachedIntegerType(unsigned width,
752                      IntegerType::SignednessSemantics signedness,
753                      MLIRContext *context) {
754   if (signedness != IntegerType::Signless)
755     return IntegerType();
756 
757   switch (width) {
758   case 1:
759     return context->getImpl().int1Ty;
760   case 8:
761     return context->getImpl().int8Ty;
762   case 16:
763     return context->getImpl().int16Ty;
764   case 32:
765     return context->getImpl().int32Ty;
766   case 64:
767     return context->getImpl().int64Ty;
768   case 128:
769     return context->getImpl().int128Ty;
770   default:
771     return IntegerType();
772   }
773 }
774 
get(unsigned width,MLIRContext * context)775 IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
776   return get(width, IntegerType::Signless, context);
777 }
778 
get(unsigned width,IntegerType::SignednessSemantics signedness,MLIRContext * context)779 IntegerType IntegerType::get(unsigned width,
780                              IntegerType::SignednessSemantics signedness,
781                              MLIRContext *context) {
782   if (auto cached = getCachedIntegerType(width, signedness, context))
783     return cached;
784   return Base::get(context, width, signedness);
785 }
786 
getChecked(unsigned width,Location location)787 IntegerType IntegerType::getChecked(unsigned width, Location location) {
788   return getChecked(width, IntegerType::Signless, location);
789 }
790 
getChecked(unsigned width,SignednessSemantics signedness,Location location)791 IntegerType IntegerType::getChecked(unsigned width,
792                                     SignednessSemantics signedness,
793                                     Location location) {
794   if (auto cached =
795           getCachedIntegerType(width, signedness, location->getContext()))
796     return cached;
797   return Base::getChecked(location, width, signedness);
798 }
799 
800 /// Get an instance of the NoneType.
get(MLIRContext * context)801 NoneType NoneType::get(MLIRContext *context) {
802   if (NoneType cachedInst = context->getImpl().noneType)
803     return cachedInst;
804   // Note: May happen when initializing the singleton attributes of the builtin
805   // dialect.
806   return Base::get(context);
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // Attribute uniquing
811 //===----------------------------------------------------------------------===//
812 
813 /// Returns the storage uniquer used for constructing attribute storage
814 /// instances. This should not be used directly.
getAttributeUniquer()815 StorageUniquer &MLIRContext::getAttributeUniquer() {
816   return getImpl().attributeUniquer;
817 }
818 
819 /// Initialize the given attribute storage instance.
initializeAttributeStorage(AttributeStorage * storage,MLIRContext * ctx,TypeID attrID)820 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
821                                                   MLIRContext *ctx,
822                                                   TypeID attrID) {
823   storage->initialize(AbstractAttribute::lookup(attrID, ctx));
824 
825   // If the attribute did not provide a type, then default to NoneType.
826   if (!storage->getType())
827     storage->setType(NoneType::get(ctx));
828 }
829 
get(bool value,MLIRContext * context)830 BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
831   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
832 }
833 
get(MLIRContext * context)834 UnitAttr UnitAttr::get(MLIRContext *context) {
835   return context->getImpl().unitAttr;
836 }
837 
get(MLIRContext * context)838 Location UnknownLoc::get(MLIRContext *context) {
839   return context->getImpl().unknownLocAttr;
840 }
841 
842 /// Return empty dictionary.
getEmpty(MLIRContext * context)843 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
844   return context->getImpl().emptyDictionaryAttr;
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // AffineMap uniquing
849 //===----------------------------------------------------------------------===//
850 
getAffineUniquer()851 StorageUniquer &MLIRContext::getAffineUniquer() {
852   return getImpl().affineUniquer;
853 }
854 
getImpl(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> results,MLIRContext * context)855 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
856                              ArrayRef<AffineExpr> results,
857                              MLIRContext *context) {
858   auto &impl = context->getImpl();
859   auto key = std::make_tuple(dimCount, symbolCount, results);
860 
861   // Safely get or create an AffineMap instance.
862   return safeGetOrCreate(
863       impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] {
864         auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
865 
866         // Copy the results into the bump pointer.
867         results = copyArrayRefInto(impl.affineAllocator, results);
868 
869         // Initialize the memory using placement new.
870         new (res)
871             detail::AffineMapStorage{dimCount, symbolCount, results, context};
872         return AffineMap(res);
873       });
874 }
875 
get(MLIRContext * context)876 AffineMap AffineMap::get(MLIRContext *context) {
877   return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
878 }
879 
get(unsigned dimCount,unsigned symbolCount,MLIRContext * context)880 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
881                          MLIRContext *context) {
882   return getImpl(dimCount, symbolCount, /*results=*/{}, context);
883 }
884 
get(unsigned dimCount,unsigned symbolCount,AffineExpr result)885 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
886                          AffineExpr result) {
887   return getImpl(dimCount, symbolCount, {result}, result.getContext());
888 }
889 
get(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> results,MLIRContext * context)890 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
891                          ArrayRef<AffineExpr> results, MLIRContext *context) {
892   return getImpl(dimCount, symbolCount, results, context);
893 }
894 
895 //===----------------------------------------------------------------------===//
896 // Integer Sets: these are allocated into the bump pointer, and are immutable.
897 // Unlike AffineMap's, these are uniqued only if they are small.
898 //===----------------------------------------------------------------------===//
899 
get(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> constraints,ArrayRef<bool> eqFlags)900 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
901                            ArrayRef<AffineExpr> constraints,
902                            ArrayRef<bool> eqFlags) {
903   // The number of constraints can't be zero.
904   assert(!constraints.empty());
905   assert(constraints.size() == eqFlags.size());
906 
907   auto &impl = constraints[0].getContext()->getImpl();
908 
909   // A utility function to construct a new IntegerSetStorage instance.
910   auto constructorFn = [&] {
911     auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>();
912 
913     // Copy the results and equality flags into the bump pointer.
914     constraints = copyArrayRefInto(impl.affineAllocator, constraints);
915     eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags);
916 
917     // Initialize the memory using placement new.
918     new (res)
919         detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags};
920     return IntegerSet(res);
921   };
922 
923   // If this instance is uniqued, then we handle it separately so that multiple
924   // threads may simultaneously access existing instances.
925   if (constraints.size() < IntegerSet::kUniquingThreshold) {
926     auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags);
927     return safeGetOrCreate(impl.integerSets, key, impl.affineMutex,
928                            impl.threadingIsEnabled, constructorFn);
929   }
930 
931   // Otherwise, acquire a writer-lock so that we can safely create the new
932   // instance.
933   ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled);
934   return constructorFn();
935 }
936 
937 //===----------------------------------------------------------------------===//
938 // StorageUniquerSupport
939 //===----------------------------------------------------------------------===//
940 
941 /// Utility method to generate a default location for use when checking the
942 /// construction invariants of a storage object. This is defined out-of-line to
943 /// avoid the need to include Location.h.
944 const AttributeStorage *
generateUnknownStorageLocation(MLIRContext * ctx)945 mlir::detail::generateUnknownStorageLocation(MLIRContext *ctx) {
946   return reinterpret_cast<const AttributeStorage *>(
947       ctx->getImpl().unknownLocAttr.getAsOpaquePointer());
948 }
949