1 //===- OperationSupport.h ---------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a number of support types that Operation and related
10 // classes build on top of.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef MLIR_IR_OPERATION_SUPPORT_H
15 #define MLIR_IR_OPERATION_SUPPORT_H
16
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BlockSupport.h"
19 #include "mlir/IR/Identifier.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/TypeRange.h"
22 #include "mlir/IR/Types.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/InterfaceSupport.h"
25 #include "llvm/ADT/BitmaskEnum.h"
26 #include "llvm/ADT/PointerUnion.h"
27 #include "llvm/Support/PointerLikeTypeTraits.h"
28 #include "llvm/Support/TrailingObjects.h"
29 #include <memory>
30
31 namespace mlir {
32 class Dialect;
33 class DictionaryAttr;
34 class ElementsAttr;
35 class MutableDictionaryAttr;
36 class Operation;
37 struct OperationState;
38 class OpAsmParser;
39 class OpAsmParserResult;
40 class OpAsmPrinter;
41 class OperandRange;
42 class OpFoldResult;
43 class ParseResult;
44 class Pattern;
45 class Region;
46 class ResultRange;
47 class RewritePattern;
48 class Type;
49 class Value;
50 class ValueRange;
51 template <typename ValueRangeT> class ValueTypeRange;
52
53 class OwningRewritePatternList;
54
55 //===----------------------------------------------------------------------===//
56 // AbstractOperation
57 //===----------------------------------------------------------------------===//
58
59 enum class OperationProperty {
60 /// This bit is set for an operation if it is a commutative
61 /// operation: that is an operator where order of operands does not
62 /// change the result of the operation. For example, in a binary
63 /// commutative operation, "a op b" and "b op a" produce the same
64 /// results.
65 Commutative = 0x1,
66
67 /// This bit is set for an operation if it is a terminator: that means
68 /// an operation at the end of a block.
69 Terminator = 0x2,
70
71 /// This bit is set for operations that are completely isolated from above.
72 /// This is used for operations whose regions are explicit capture only, i.e.
73 /// they are never allowed to implicitly reference values defined above the
74 /// parent operation.
75 IsolatedFromAbove = 0x4,
76 };
77
78 /// This is a "type erased" representation of a registered operation. This
79 /// should only be used by things like the AsmPrinter and other things that need
80 /// to be parameterized by generic operation hooks. Most user code should use
81 /// the concrete operation types.
82 class AbstractOperation {
83 public:
84 using OperationProperties = uint32_t;
85
86 using GetCanonicalizationPatternsFn = void (*)(OwningRewritePatternList &,
87 MLIRContext *);
88 using FoldHookFn = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
89 SmallVectorImpl<OpFoldResult> &);
90 using HasTraitFn = bool (*)(TypeID);
91 using ParseAssemblyFn = ParseResult (*)(OpAsmParser &, OperationState &);
92 using PrintAssemblyFn = void (*)(Operation *, OpAsmPrinter &);
93 using VerifyInvariantsFn = LogicalResult (*)(Operation *);
94
95 /// This is the name of the operation.
96 const Identifier name;
97
98 /// This is the dialect that this operation belongs to.
99 Dialect &dialect;
100
101 /// The unique identifier of the derived Op class.
102 TypeID typeID;
103
104 /// Use the specified object to parse this ops custom assembly format.
105 ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
106
107 /// This hook implements the AsmPrinter for this operation.
printAssembly(Operation * op,OpAsmPrinter & p)108 void printAssembly(Operation *op, OpAsmPrinter &p) const {
109 return printAssemblyFn(op, p);
110 }
111
112 /// This hook implements the verifier for this operation. It should emits an
113 /// error message and returns failure if a problem is detected, or returns
114 /// success if everything is ok.
verifyInvariants(Operation * op)115 LogicalResult verifyInvariants(Operation *op) const {
116 return verifyInvariantsFn(op);
117 }
118
119 /// This hook implements a generalized folder for this operation. Operations
120 /// can implement this to provide simplifications rules that are applied by
121 /// the Builder::createOrFold API and the canonicalization pass.
122 ///
123 /// This is an intentionally limited interface - implementations of this hook
124 /// can only perform the following changes to the operation:
125 ///
126 /// 1. They can leave the operation alone and without changing the IR, and
127 /// return failure.
128 /// 2. They can mutate the operation in place, without changing anything else
129 /// in the IR. In this case, return success.
130 /// 3. They can return a list of existing values that can be used instead of
131 /// the operation. In this case, fill in the results list and return
132 /// success. The caller will remove the operation and use those results
133 /// instead.
134 ///
135 /// This allows expression of some simple in-place canonicalizations (e.g.
136 /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
137 /// generalized constant folding.
foldHook(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)138 LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
139 SmallVectorImpl<OpFoldResult> &results) const {
140 return foldHookFn(op, operands, results);
141 }
142
143 /// This hook returns any canonicalization pattern rewrites that the operation
144 /// supports, for use by the canonicalization pass.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)145 void getCanonicalizationPatterns(OwningRewritePatternList &results,
146 MLIRContext *context) const {
147 return getCanonicalizationPatternsFn(results, context);
148 }
149
150 /// Returns whether the operation has a particular property.
hasProperty(OperationProperty property)151 bool hasProperty(OperationProperty property) const {
152 return opProperties & static_cast<OperationProperties>(property);
153 }
154
155 /// Returns an instance of the concept object for the given interface if it
156 /// was registered to this operation, null otherwise. This should not be used
157 /// directly.
getInterface()158 template <typename T> typename T::Concept *getInterface() const {
159 return interfaceMap.lookup<T>();
160 }
161
162 /// Returns true if the operation has a particular trait.
hasTrait()163 template <template <typename T> class Trait> bool hasTrait() const {
164 return hasTraitFn(TypeID::get<Trait>());
165 }
166
167 /// Look up the specified operation in the specified MLIRContext and return a
168 /// pointer to it if present. Otherwise, return a null pointer.
169 static const AbstractOperation *lookup(StringRef opName,
170 MLIRContext *context);
171
172 /// This constructor is used by Dialect objects when they register the list of
173 /// operations they contain.
insert(Dialect & dialect)174 template <typename T> static void insert(Dialect &dialect) {
175 insert(T::getOperationName(), dialect, T::getOperationProperties(),
176 TypeID::get<T>(), T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
177 T::getVerifyInvariantsFn(), T::getFoldHookFn(),
178 T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
179 T::getHasTraitFn());
180 }
181
182 private:
183 static void insert(StringRef name, Dialect &dialect,
184 OperationProperties opProperties, TypeID typeID,
185 ParseAssemblyFn parseAssembly,
186 PrintAssemblyFn printAssembly,
187 VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
188 GetCanonicalizationPatternsFn getCanonicalizationPatterns,
189 detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
190
191 AbstractOperation(StringRef name, Dialect &dialect,
192 OperationProperties opProperties, TypeID typeID,
193 ParseAssemblyFn parseAssembly,
194 PrintAssemblyFn printAssembly,
195 VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook,
196 GetCanonicalizationPatternsFn getCanonicalizationPatterns,
197 detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait);
198
199 /// The properties of the operation.
200 const OperationProperties opProperties;
201
202 /// A map of interfaces that were registered to this operation.
203 detail::InterfaceMap interfaceMap;
204
205 /// Internal callback hooks provided by the op implementation.
206 FoldHookFn foldHookFn;
207 GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
208 HasTraitFn hasTraitFn;
209 ParseAssemblyFn parseAssemblyFn;
210 PrintAssemblyFn printAssemblyFn;
211 VerifyInvariantsFn verifyInvariantsFn;
212 };
213
214 //===----------------------------------------------------------------------===//
215 // NamedAttrList
216 //===----------------------------------------------------------------------===//
217
218 /// NamedAttrList is array of NamedAttributes that tracks whether it is sorted
219 /// and does some basic work to remain sorted.
220 class NamedAttrList {
221 public:
222 using const_iterator = SmallVectorImpl<NamedAttribute>::const_iterator;
223 using const_reference = const NamedAttribute &;
224 using reference = NamedAttribute &;
225 using size_type = size_t;
226
NamedAttrList()227 NamedAttrList() : dictionarySorted({}, true) {}
228 NamedAttrList(ArrayRef<NamedAttribute> attributes);
229 NamedAttrList(const_iterator in_start, const_iterator in_end);
230
231 bool operator!=(const NamedAttrList &other) const {
232 return !(*this == other);
233 }
234 bool operator==(const NamedAttrList &other) const {
235 return attrs == other.attrs;
236 }
237
238 /// Add an attribute with the specified name.
239 void append(StringRef name, Attribute attr);
240
241 /// Add an attribute with the specified name.
242 void append(Identifier name, Attribute attr);
243
244 /// Add an array of named attributes.
245 void append(ArrayRef<NamedAttribute> newAttributes);
246
247 /// Add a range of named attributes.
248 void append(const_iterator in_start, const_iterator in_end);
249
250 /// Replaces the attributes with new list of attributes.
251 void assign(const_iterator in_start, const_iterator in_end);
252
253 /// Replaces the attributes with new list of attributes.
assign(ArrayRef<NamedAttribute> range)254 void assign(ArrayRef<NamedAttribute> range) {
255 append(range.begin(), range.end());
256 }
257
empty()258 bool empty() const { return attrs.empty(); }
259
reserve(size_type N)260 void reserve(size_type N) { attrs.reserve(N); }
261
262 /// Add an attribute with the specified name.
263 void push_back(NamedAttribute newAttribute);
264
265 /// Pop last element from list.
pop_back()266 void pop_back() { attrs.pop_back(); }
267
268 /// Returns an entry with a duplicate name the list, if it exists, else
269 /// returns llvm::None.
270 Optional<NamedAttribute> findDuplicate() const;
271
272 /// Return a dictionary attribute for the underlying dictionary. This will
273 /// return an empty dictionary attribute if empty rather than null.
274 DictionaryAttr getDictionary(MLIRContext *context) const;
275
276 /// Return all of the attributes on this operation.
277 ArrayRef<NamedAttribute> getAttrs() const;
278
279 /// Return the specified attribute if present, null otherwise.
280 Attribute get(Identifier name) const;
281 Attribute get(StringRef name) const;
282
283 /// Return the specified named attribute if present, None otherwise.
284 Optional<NamedAttribute> getNamed(StringRef name) const;
285 Optional<NamedAttribute> getNamed(Identifier name) const;
286
287 /// If the an attribute exists with the specified name, change it to the new
288 /// value. Otherwise, add a new attribute with the specified name/value.
289 void set(Identifier name, Attribute value);
290 void set(StringRef name, Attribute value);
291
292 /// Erase the attribute with the given name from the list. Return the
293 /// attribute that was erased, or nullptr if there was no attribute with such
294 /// name.
295 Attribute erase(Identifier name);
296 Attribute erase(StringRef name);
297
begin()298 const_iterator begin() const { return attrs.begin(); }
end()299 const_iterator end() const { return attrs.end(); }
300
301 NamedAttrList &operator=(const SmallVectorImpl<NamedAttribute> &rhs);
302 operator ArrayRef<NamedAttribute>() const;
303 operator MutableDictionaryAttr() const;
304
305 private:
306 /// Return whether the attributes are sorted.
isSorted()307 bool isSorted() const { return dictionarySorted.getInt(); }
308
309 /// Erase the attribute at the given iterator position.
310 Attribute eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it);
311
312 // These are marked mutable as they may be modified (e.g., sorted)
313 mutable SmallVector<NamedAttribute, 4> attrs;
314 // Pair with cached DictionaryAttr and status of whether attrs is sorted.
315 // Note: just because sorted does not mean a DictionaryAttr has been created
316 // but the case where there is a DictionaryAttr but attrs isn't sorted should
317 // not occur.
318 mutable llvm::PointerIntPair<Attribute, 1, bool> dictionarySorted;
319 };
320
321 //===----------------------------------------------------------------------===//
322 // OperationName
323 //===----------------------------------------------------------------------===//
324
325 class OperationName {
326 public:
327 using RepresentationUnion =
328 PointerUnion<Identifier, const AbstractOperation *>;
329
OperationName(AbstractOperation * op)330 OperationName(AbstractOperation *op) : representation(op) {}
331 OperationName(StringRef name, MLIRContext *context);
332
333 /// Return the name of the dialect this operation is registered to.
334 StringRef getDialect() const;
335
336 /// Return the operation name with dialect name stripped, if it has one.
337 StringRef stripDialect() const;
338
339 /// Return the name of this operation. This always succeeds.
340 StringRef getStringRef() const;
341
342 /// Return the name of this operation as an identifier. This always succeeds.
343 Identifier getIdentifier() const;
344
345 /// If this operation has a registered operation description, return it.
346 /// Otherwise return null.
347 const AbstractOperation *getAbstractOperation() const;
348
349 void print(raw_ostream &os) const;
350 void dump() const;
351
getAsOpaquePointer()352 void *getAsOpaquePointer() const {
353 return static_cast<void *>(representation.getOpaqueValue());
354 }
355 static OperationName getFromOpaquePointer(const void *pointer);
356
357 private:
358 RepresentationUnion representation;
OperationName(RepresentationUnion representation)359 OperationName(RepresentationUnion representation)
360 : representation(representation) {}
361 };
362
363 inline raw_ostream &operator<<(raw_ostream &os, OperationName identifier) {
364 identifier.print(os);
365 return os;
366 }
367
368 inline bool operator==(OperationName lhs, OperationName rhs) {
369 return lhs.getAsOpaquePointer() == rhs.getAsOpaquePointer();
370 }
371
372 inline bool operator!=(OperationName lhs, OperationName rhs) {
373 return lhs.getAsOpaquePointer() != rhs.getAsOpaquePointer();
374 }
375
376 // Make operation names hashable.
hash_value(OperationName arg)377 inline llvm::hash_code hash_value(OperationName arg) {
378 return llvm::hash_value(arg.getAsOpaquePointer());
379 }
380
381 //===----------------------------------------------------------------------===//
382 // OperationState
383 //===----------------------------------------------------------------------===//
384
385 /// This represents an operation in an abstracted form, suitable for use with
386 /// the builder APIs. This object is a large and heavy weight object meant to
387 /// be used as a temporary object on the stack. It is generally unwise to put
388 /// this in a collection.
389 struct OperationState {
390 Location location;
391 OperationName name;
392 SmallVector<Value, 4> operands;
393 /// Types of the results of this operation.
394 SmallVector<Type, 4> types;
395 NamedAttrList attributes;
396 /// Successors of this operation and their respective operands.
397 SmallVector<Block *, 1> successors;
398 /// Regions that the op will hold.
399 SmallVector<std::unique_ptr<Region>, 1> regions;
400
401 public:
402 OperationState(Location location, StringRef name);
403
404 OperationState(Location location, OperationName name);
405
406 OperationState(Location location, StringRef name, ValueRange operands,
407 TypeRange types, ArrayRef<NamedAttribute> attributes,
408 BlockRange successors = {},
409 MutableArrayRef<std::unique_ptr<Region>> regions = {});
410
411 void addOperands(ValueRange newOperands);
412
addTypesOperationState413 void addTypes(ArrayRef<Type> newTypes) {
414 types.append(newTypes.begin(), newTypes.end());
415 }
416 template <typename RangeT>
417 std::enable_if_t<!std::is_convertible<RangeT, ArrayRef<Type>>::value>
addTypesOperationState418 addTypes(RangeT &&newTypes) {
419 types.append(newTypes.begin(), newTypes.end());
420 }
421
422 /// Add an attribute with the specified name.
addAttributeOperationState423 void addAttribute(StringRef name, Attribute attr) {
424 addAttribute(Identifier::get(name, getContext()), attr);
425 }
426
427 /// Add an attribute with the specified name.
addAttributeOperationState428 void addAttribute(Identifier name, Attribute attr) {
429 attributes.append(name, attr);
430 }
431
432 /// Add an array of named attributes.
addAttributesOperationState433 void addAttributes(ArrayRef<NamedAttribute> newAttributes) {
434 attributes.append(newAttributes);
435 }
436
addSuccessorsOperationState437 void addSuccessors(Block *successor) { successors.push_back(successor); }
438 void addSuccessors(BlockRange newSuccessors);
439
440 /// Create a region that should be attached to the operation. These regions
441 /// can be filled in immediately without waiting for Operation to be
442 /// created. When it is, the region bodies will be transferred.
443 Region *addRegion();
444
445 /// Take a region that should be attached to the Operation. The body of the
446 /// region will be transferred when the Operation is constructed. If the
447 /// region is null, a new empty region will be attached to the Operation.
448 void addRegion(std::unique_ptr<Region> &®ion);
449
450 /// Take ownership of a set of regions that should be attached to the
451 /// Operation.
452 void addRegions(MutableArrayRef<std::unique_ptr<Region>> regions);
453
454 /// Get the context held by this operation state.
getContextOperationState455 MLIRContext *getContext() const { return location->getContext(); }
456 };
457
458 //===----------------------------------------------------------------------===//
459 // OperandStorage
460 //===----------------------------------------------------------------------===//
461
462 namespace detail {
463 /// This class contains the information for a trailing operand storage.
464 struct TrailingOperandStorage final
465 : public llvm::TrailingObjects<TrailingOperandStorage, OpOperand> {
~TrailingOperandStoragefinal466 ~TrailingOperandStorage() {
467 for (auto &operand : getOperands())
468 operand.~OpOperand();
469 }
470
471 /// Return the operands held by this storage.
getOperandsfinal472 MutableArrayRef<OpOperand> getOperands() {
473 return {getTrailingObjects<OpOperand>(), numOperands};
474 }
475
476 /// The number of operands within the storage.
477 unsigned numOperands;
478 /// The total capacity number of operands that the storage can hold.
479 unsigned capacity : 31;
480 /// We reserve a range of bits for use by the operand storage.
481 unsigned reserved : 1;
482 };
483
484 /// This class handles the management of operation operands. Operands are
485 /// stored either in a trailing array, or a dynamically resizable vector.
486 class OperandStorage final
487 : private llvm::TrailingObjects<OperandStorage, OpOperand> {
488 public:
489 OperandStorage(Operation *owner, ValueRange values);
490 ~OperandStorage();
491
492 /// Replace the operands contained in the storage with the ones provided in
493 /// 'values'.
494 void setOperands(Operation *owner, ValueRange values);
495
496 /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
497 /// with the ones provided in 'operands'. 'operands' may be smaller or larger
498 /// than the range pointed to by 'start'+'length'.
499 void setOperands(Operation *owner, unsigned start, unsigned length,
500 ValueRange operands);
501
502 /// Erase the operands held by the storage within the given range.
503 void eraseOperands(unsigned start, unsigned length);
504
505 /// Get the operation operands held by the storage.
getOperands()506 MutableArrayRef<OpOperand> getOperands() {
507 return getStorage().getOperands();
508 }
509
510 /// Return the number of operands held in the storage.
size()511 unsigned size() { return getStorage().numOperands; }
512
513 /// Returns the additional size necessary for allocating this object.
additionalAllocSize(unsigned numOperands)514 static size_t additionalAllocSize(unsigned numOperands) {
515 return additionalSizeToAlloc<OpOperand>(numOperands);
516 }
517
518 private:
519 enum : uint64_t {
520 /// The bit used to mark the storage as dynamic.
521 DynamicStorageBit = 1ull << 63ull
522 };
523
524 /// Resize the storage to the given size. Returns the array containing the new
525 /// operands.
526 MutableArrayRef<OpOperand> resize(Operation *owner, unsigned newSize);
527
528 /// Returns the current internal storage instance.
getStorage()529 TrailingOperandStorage &getStorage() {
530 return LLVM_UNLIKELY(isDynamicStorage()) ? getDynamicStorage()
531 : getInlineStorage();
532 }
533
534 /// Returns the storage container if the storage is inline.
getInlineStorage()535 TrailingOperandStorage &getInlineStorage() {
536 assert(!isDynamicStorage() && "expected storage to be inline");
537 static_assert(sizeof(TrailingOperandStorage) == sizeof(uint64_t),
538 "inline storage representation must match the opaque "
539 "representation");
540 return inlineStorage;
541 }
542
543 /// Returns the storage container if this storage is dynamic.
getDynamicStorage()544 TrailingOperandStorage &getDynamicStorage() {
545 assert(isDynamicStorage() && "expected dynamic storage");
546 uint64_t maskedRepresentation = representation & ~DynamicStorageBit;
547 return *reinterpret_cast<TrailingOperandStorage *>(maskedRepresentation);
548 }
549
550 /// Returns true if the storage is currently dynamic.
isDynamicStorage()551 bool isDynamicStorage() const { return representation & DynamicStorageBit; }
552
553 /// The current representation of the storage. This is either a
554 /// InlineOperandStorage, or a pointer to a InlineOperandStorage.
555 union {
556 TrailingOperandStorage inlineStorage;
557 uint64_t representation;
558 };
559
560 /// This stuff is used by the TrailingObjects template.
561 friend llvm::TrailingObjects<OperandStorage, OpOperand>;
562 };
563 } // end namespace detail
564
565 //===----------------------------------------------------------------------===//
566 // ResultStorage
567 //===----------------------------------------------------------------------===//
568
569 namespace detail {
570 /// This class provides the implementation for an in-line operation result. This
571 /// is an operation result whose number can be stored inline inside of the bits
572 /// of an Operation*.
573 struct alignas(8) InLineOpResult : public IRObjectWithUseList<OpOperand> {};
574 /// This class provides the implementation for an out-of-line operation result.
575 /// This is an operation result whose number cannot be stored inline inside of
576 /// the bits of an Operation*.
577 struct alignas(8) TrailingOpResult : public IRObjectWithUseList<OpOperand> {
TrailingOpResultTrailingOpResult578 TrailingOpResult(uint64_t trailingResultNumber)
579 : trailingResultNumber(trailingResultNumber) {}
580
581 /// Returns the parent operation of this trailing result.
582 Operation *getOwner();
583
584 /// Return the proper result number of this op result.
getResultNumberTrailingOpResult585 unsigned getResultNumber() {
586 return trailingResultNumber + OpResult::getMaxInlineResults();
587 }
588
589 /// The trailing result number, or the offset from the beginning of the
590 /// trailing array.
591 uint64_t trailingResultNumber;
592 };
593 } // end namespace detail
594
595 //===----------------------------------------------------------------------===//
596 // OpPrintingFlags
597 //===----------------------------------------------------------------------===//
598
599 /// Set of flags used to control the behavior of the various IR print methods
600 /// (e.g. Operation::Print).
601 class OpPrintingFlags {
602 public:
603 OpPrintingFlags();
OpPrintingFlags(llvm::NoneType)604 OpPrintingFlags(llvm::NoneType) : OpPrintingFlags() {}
605
606 /// Enables the elision of large elements attributes by printing a lexically
607 /// valid but otherwise meaningless form instead of the element data. The
608 /// `largeElementLimit` is used to configure what is considered to be a
609 /// "large" ElementsAttr by providing an upper limit to the number of
610 /// elements.
611 OpPrintingFlags &elideLargeElementsAttrs(int64_t largeElementLimit = 16);
612
613 /// Enable printing of debug information. If 'prettyForm' is set to true,
614 /// debug information is printed in a more readable 'pretty' form. Note: The
615 /// IR generated with 'prettyForm' is not parsable.
616 OpPrintingFlags &enableDebugInfo(bool prettyForm = false);
617
618 /// Always print operations in the generic form.
619 OpPrintingFlags &printGenericOpForm();
620
621 /// Use local scope when printing the operation. This allows for using the
622 /// printer in a more localized and thread-safe setting, but may not
623 /// necessarily be identical to what the IR will look like when dumping
624 /// the full module.
625 OpPrintingFlags &useLocalScope();
626
627 /// Return if the given ElementsAttr should be elided.
628 bool shouldElideElementsAttr(ElementsAttr attr) const;
629
630 /// Return the size limit for printing large ElementsAttr.
631 Optional<int64_t> getLargeElementsAttrLimit() const;
632
633 /// Return if debug information should be printed.
634 bool shouldPrintDebugInfo() const;
635
636 /// Return if debug information should be printed in the pretty form.
637 bool shouldPrintDebugInfoPrettyForm() const;
638
639 /// Return if operations should be printed in the generic form.
640 bool shouldPrintGenericOpForm() const;
641
642 /// Return if the printer should use local scope when dumping the IR.
643 bool shouldUseLocalScope() const;
644
645 private:
646 /// Elide large elements attributes if the number of elements is larger than
647 /// the upper limit.
648 Optional<int64_t> elementsAttrElementLimit;
649
650 /// Print debug information.
651 bool printDebugInfoFlag : 1;
652 bool printDebugInfoPrettyFormFlag : 1;
653
654 /// Print operations in the generic form.
655 bool printGenericOpFormFlag : 1;
656
657 /// Print operations with numberings local to the current operation.
658 bool printLocalScope : 1;
659 };
660
661 //===----------------------------------------------------------------------===//
662 // Operation Value-Iterators
663 //===----------------------------------------------------------------------===//
664
665 //===----------------------------------------------------------------------===//
666 // OperandRange
667
668 /// This class implements the operand iterators for the Operation class.
669 class OperandRange final : public llvm::detail::indexed_accessor_range_base<
670 OperandRange, OpOperand *, Value, Value, Value> {
671 public:
672 using RangeBaseT::RangeBaseT;
673 OperandRange(Operation *op);
674
675 /// Returns the types of the values within this range.
676 using type_iterator = ValueTypeIterator<iterator>;
677 using type_range = ValueTypeRange<OperandRange>;
getTypes()678 type_range getTypes() const { return {begin(), end()}; }
getType()679 auto getType() const { return getTypes(); }
680
681 /// Return the operand index of the first element of this range. The range
682 /// must not be empty.
683 unsigned getBeginOperandIndex() const;
684
685 private:
686 /// See `llvm::detail::indexed_accessor_range_base` for details.
offset_base(OpOperand * object,ptrdiff_t index)687 static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) {
688 return object + index;
689 }
690 /// See `llvm::detail::indexed_accessor_range_base` for details.
dereference_iterator(OpOperand * object,ptrdiff_t index)691 static Value dereference_iterator(OpOperand *object, ptrdiff_t index) {
692 return object[index].get();
693 }
694
695 /// Allow access to `offset_base` and `dereference_iterator`.
696 friend RangeBaseT;
697 };
698
699 //===----------------------------------------------------------------------===//
700 // MutableOperandRange
701
702 /// This class provides a mutable adaptor for a range of operands. It allows for
703 /// setting, inserting, and erasing operands from the given range.
704 class MutableOperandRange {
705 public:
706 /// A pair of a named attribute corresponding to an operand segment attribute,
707 /// and the index within that attribute. The attribute should correspond to an
708 /// i32 DenseElementsAttr.
709 using OperandSegment = std::pair<unsigned, NamedAttribute>;
710
711 /// Construct a new mutable range from the given operand, operand start index,
712 /// and range length. `operandSegments` is an optional set of operand segments
713 /// to be updated when mutating the operand list.
714 MutableOperandRange(Operation *owner, unsigned start, unsigned length,
715 ArrayRef<OperandSegment> operandSegments = llvm::None);
716 MutableOperandRange(Operation *owner);
717
718 /// Slice this range into a sub range, with the additional operand segment.
719 MutableOperandRange slice(unsigned subStart, unsigned subLen,
720 Optional<OperandSegment> segment = llvm::None);
721
722 /// Append the given values to the range.
723 void append(ValueRange values);
724
725 /// Assign this range to the given values.
726 void assign(ValueRange values);
727
728 /// Assign the range to the given value.
729 void assign(Value value);
730
731 /// Erase the operands within the given sub-range.
732 void erase(unsigned subStart, unsigned subLen = 1);
733
734 /// Clear this range and erase all of the operands.
735 void clear();
736
737 /// Returns the current size of the range.
size()738 unsigned size() const { return length; }
739
740 /// Allow implicit conversion to an OperandRange.
741 operator OperandRange() const;
742
743 /// Returns the owning operation.
getOwner()744 Operation *getOwner() const { return owner; }
745
746 private:
747 /// Update the length of this range to the one provided.
748 void updateLength(unsigned newLength);
749
750 /// The owning operation of this range.
751 Operation *owner;
752
753 /// The start index of the operand range within the owner operand list, and
754 /// the length starting from `start`.
755 unsigned start, length;
756
757 /// Optional set of operand segments that should be updated when mutating the
758 /// length of this range.
759 SmallVector<std::pair<unsigned, NamedAttribute>, 1> operandSegments;
760 };
761
762 //===----------------------------------------------------------------------===//
763 // ResultRange
764
765 /// This class implements the result iterators for the Operation class.
766 class ResultRange final
767 : public llvm::indexed_accessor_range<ResultRange, Operation *, OpResult,
768 OpResult, OpResult> {
769 public:
770 using indexed_accessor_range<ResultRange, Operation *, OpResult, OpResult,
771 OpResult>::indexed_accessor_range;
772 ResultRange(Operation *op);
773
774 /// Returns the types of the values within this range.
775 using type_iterator = ArrayRef<Type>::iterator;
776 using type_range = ArrayRef<Type>;
777 type_range getTypes() const;
getType()778 auto getType() const { return getTypes(); }
779
780 private:
781 /// See `llvm::indexed_accessor_range` for details.
782 static OpResult dereference(Operation *op, ptrdiff_t index);
783
784 /// Allow access to `dereference_iterator`.
785 friend llvm::indexed_accessor_range<ResultRange, Operation *, OpResult,
786 OpResult, OpResult>;
787 };
788
789 //===----------------------------------------------------------------------===//
790 // ValueRange
791
792 namespace detail {
793 /// The type representing the owner of a ValueRange. This is either a list of
794 /// values, operands, or an Operation+start index for results.
795 struct ValueRangeOwner {
ValueRangeOwnerValueRangeOwner796 ValueRangeOwner(const Value *owner) : ptr(owner), startIndex(0) {}
ValueRangeOwnerValueRangeOwner797 ValueRangeOwner(OpOperand *owner) : ptr(owner), startIndex(0) {}
ValueRangeOwnerValueRangeOwner798 ValueRangeOwner(Operation *owner, unsigned startIndex)
799 : ptr(owner), startIndex(startIndex) {}
800 bool operator==(const ValueRangeOwner &rhs) const { return ptr == rhs.ptr; }
801
802 /// The owner pointer of the range. The owner has represents three distinct
803 /// states:
804 /// const Value *: The owner is the base to a contiguous array of Value.
805 /// OpOperand * : The owner is the base to a contiguous array of operands.
806 /// void* : This owner is an Operation*. It is marked as void* here
807 /// because the definition of Operation is not visible here.
808 PointerUnion<const Value *, OpOperand *, void *> ptr;
809
810 /// Ths start index into the range. This is only used for Operation* owners.
811 unsigned startIndex;
812 };
813 } // end namespace detail
814
815 /// This class provides an abstraction over the different types of ranges over
816 /// Values. In many cases, this prevents the need to explicitly materialize a
817 /// SmallVector/std::vector. This class should be used in places that are not
818 /// suitable for a more derived type (e.g. ArrayRef) or a template range
819 /// parameter.
820 class ValueRange final
821 : public llvm::detail::indexed_accessor_range_base<
822 ValueRange, detail::ValueRangeOwner, Value, Value, Value> {
823 public:
824 using RangeBaseT::RangeBaseT;
825
826 template <typename Arg,
827 typename = typename std::enable_if_t<
828 std::is_constructible<ArrayRef<Value>, Arg>::value &&
829 !std::is_convertible<Arg, Value>::value>>
ValueRange(Arg && arg)830 ValueRange(Arg &&arg) : ValueRange(ArrayRef<Value>(std::forward<Arg>(arg))) {}
ValueRange(const Value & value)831 ValueRange(const Value &value) : ValueRange(&value, /*count=*/1) {}
ValueRange(const std::initializer_list<Value> & values)832 ValueRange(const std::initializer_list<Value> &values)
833 : ValueRange(ArrayRef<Value>(values)) {}
ValueRange(iterator_range<OperandRange::iterator> values)834 ValueRange(iterator_range<OperandRange::iterator> values)
835 : ValueRange(OperandRange(values)) {}
ValueRange(iterator_range<ResultRange::iterator> values)836 ValueRange(iterator_range<ResultRange::iterator> values)
837 : ValueRange(ResultRange(values)) {}
ValueRange(ArrayRef<BlockArgument> values)838 ValueRange(ArrayRef<BlockArgument> values)
839 : ValueRange(ArrayRef<Value>(values.data(), values.size())) {}
840 ValueRange(ArrayRef<Value> values = llvm::None);
841 ValueRange(OperandRange values);
842 ValueRange(ResultRange values);
843
844 /// Returns the types of the values within this range.
845 using type_iterator = ValueTypeIterator<iterator>;
846 using type_range = ValueTypeRange<ValueRange>;
getTypes()847 type_range getTypes() const { return {begin(), end()}; }
getType()848 auto getType() const { return getTypes(); }
849
850 private:
851 using OwnerT = detail::ValueRangeOwner;
852
853 /// See `llvm::detail::indexed_accessor_range_base` for details.
854 static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
855 /// See `llvm::detail::indexed_accessor_range_base` for details.
856 static Value dereference_iterator(const OwnerT &owner, ptrdiff_t index);
857
858 /// Allow access to `offset_base` and `dereference_iterator`.
859 friend RangeBaseT;
860 };
861
862 //===----------------------------------------------------------------------===//
863 // Operation Equivalency
864 //===----------------------------------------------------------------------===//
865
866 /// This class provides utilities for computing if two operations are
867 /// equivalent.
868 struct OperationEquivalence {
869 enum Flags {
870 None = 0,
871
872 /// This flag signals that operands should not be considered when checking
873 /// for equivalence. This allows for users to implement there own
874 /// equivalence schemes for operand values. The number of operands are still
875 /// checked, just not the operands themselves.
876 IgnoreOperands = 1,
877
878 LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreOperands)
879 };
880
881 /// Compute a hash for the given operation.
882 static llvm::hash_code computeHash(Operation *op, Flags flags = Flags::None);
883
884 /// Compare two operations and return if they are equivalent.
885 static bool isEquivalentTo(Operation *lhs, Operation *rhs,
886 Flags flags = Flags::None);
887 };
888
889 /// Enable Bitmask enums for OperationEquivalence::Flags.
890 LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
891
892 } // end namespace mlir
893
894 namespace llvm {
895 // Identifiers hash just like pointers, there is no need to hash the bytes.
896 template <> struct DenseMapInfo<mlir::OperationName> {
897 static mlir::OperationName getEmptyKey() {
898 auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
899 return mlir::OperationName::getFromOpaquePointer(pointer);
900 }
901 static mlir::OperationName getTombstoneKey() {
902 auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
903 return mlir::OperationName::getFromOpaquePointer(pointer);
904 }
905 static unsigned getHashValue(mlir::OperationName Val) {
906 return DenseMapInfo<void *>::getHashValue(Val.getAsOpaquePointer());
907 }
908 static bool isEqual(mlir::OperationName LHS, mlir::OperationName RHS) {
909 return LHS == RHS;
910 }
911 };
912
913 /// The pointer inside of an identifier comes from a StringMap, so its alignment
914 /// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to
915 /// steal the low bits.
916 template <> struct PointerLikeTypeTraits<mlir::OperationName> {
917 public:
918 static inline void *getAsVoidPointer(mlir::OperationName I) {
919 return const_cast<void *>(I.getAsOpaquePointer());
920 }
921 static inline mlir::OperationName getFromVoidPointer(void *P) {
922 return mlir::OperationName::getFromOpaquePointer(P);
923 }
924 static constexpr int NumLowBitsAvailable = PointerLikeTypeTraits<
925 mlir::OperationName::RepresentationUnion>::NumLowBitsAvailable;
926 };
927
928 } // end namespace llvm
929
930 #endif
931