1 //===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file declares various operation folding utilities. These 10 // utilities are intended to be used by passes to unify and simply their logic. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_TRANSFORMS_FOLDUTILS_H 15 #define MLIR_TRANSFORMS_FOLDUTILS_H 16 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/Dialect.h" 19 #include "mlir/IR/DialectInterface.h" 20 #include "mlir/Interfaces/FoldInterfaces.h" 21 22 namespace mlir { 23 class Operation; 24 class Value; 25 26 27 //===--------------------------------------------------------------------===// 28 // OperationFolder 29 //===--------------------------------------------------------------------===// 30 31 /// A utility class for folding operations, and unifying duplicated constants 32 /// generated along the way. 33 class OperationFolder { 34 public: OperationFolder(MLIRContext * ctx)35 OperationFolder(MLIRContext *ctx) : interfaces(ctx) {} 36 37 /// Tries to perform folding on the given `op`, including unifying 38 /// deduplicated constants. If successful, replaces `op`'s uses with 39 /// folded results, and returns success. `preReplaceAction` is invoked on `op` 40 /// before it is replaced. 'processGeneratedConstants' is invoked for any new 41 /// operations generated when folding. If the op was completely folded it is 42 /// erased. If it is just updated in place, `inPlaceUpdate` is set to true. 43 LogicalResult 44 tryToFold(Operation *op, 45 function_ref<void(Operation *)> processGeneratedConstants = nullptr, 46 function_ref<void(Operation *)> preReplaceAction = nullptr, 47 bool *inPlaceUpdate = nullptr); 48 49 /// Notifies that the given constant `op` should be remove from this 50 /// OperationFolder's internal bookkeeping. 51 /// 52 /// Note: this method must be called if a constant op is to be deleted 53 /// externally to this OperationFolder. `op` must be a constant op. 54 void notifyRemoval(Operation *op); 55 56 /// Create an operation of specific op type with the given builder, 57 /// and immediately try to fold it. This function populates 'results' with 58 /// the results after folding the operation. 59 template <typename OpTy, typename... Args> create(OpBuilder & builder,SmallVectorImpl<Value> & results,Location location,Args &&...args)60 void create(OpBuilder &builder, SmallVectorImpl<Value> &results, 61 Location location, Args &&... args) { 62 // The op needs to be inserted only if the fold (below) fails, or the number 63 // of results produced by the successful folding is zero (which is treated 64 // as an in-place fold). Using create methods of the builder will insert the 65 // op, so not using it here. 66 OperationState state(location, OpTy::getOperationName()); 67 OpTy::build(builder, state, std::forward<Args>(args)...); 68 Operation *op = Operation::create(state); 69 70 if (failed(tryToFold(builder, op, results)) || results.empty()) { 71 builder.insert(op); 72 results.assign(op->result_begin(), op->result_end()); 73 return; 74 } 75 op->destroy(); 76 } 77 78 /// Overload to create or fold a single result operation. 79 template <typename OpTy, typename... Args> 80 typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(), 81 Value>::type create(OpBuilder & builder,Location location,Args &&...args)82 create(OpBuilder &builder, Location location, Args &&... args) { 83 SmallVector<Value, 1> results; 84 create<OpTy>(builder, results, location, std::forward<Args>(args)...); 85 return results.front(); 86 } 87 88 /// Overload to create or fold a zero result operation. 89 template <typename OpTy, typename... Args> 90 typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(), 91 OpTy>::type create(OpBuilder & builder,Location location,Args &&...args)92 create(OpBuilder &builder, Location location, Args &&... args) { 93 auto op = builder.create<OpTy>(location, std::forward<Args>(args)...); 94 SmallVector<Value, 0> unused; 95 (void)tryToFold(op.getOperation(), unused); 96 97 // Folding cannot remove a zero-result operation, so for convenience we 98 // continue to return it. 99 return op; 100 } 101 102 /// Clear out any constants cached inside of the folder. 103 void clear(); 104 105 /// Get or create a constant using the given builder. On success this returns 106 /// the constant operation, nullptr otherwise. 107 Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect, 108 Attribute value, Type type, Location loc); 109 110 private: 111 /// This map keeps track of uniqued constants by dialect, attribute, and type. 112 /// A constant operation materializes an attribute with a type. Dialects may 113 /// generate different constants with the same input attribute and type, so we 114 /// also need to track per-dialect. 115 using ConstantMap = 116 DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>; 117 118 /// Tries to perform folding on the given `op`. If successful, populates 119 /// `results` with the results of the folding. 120 LogicalResult tryToFold( 121 OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results, 122 function_ref<void(Operation *)> processGeneratedConstants = nullptr); 123 124 /// Try to get or create a new constant entry. On success this returns the 125 /// constant operation, nullptr otherwise. 126 Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants, 127 Dialect *dialect, OpBuilder &builder, 128 Attribute value, Type type, Location loc); 129 130 /// A mapping between an insertion region and the constants that have been 131 /// created within it. 132 DenseMap<Region *, ConstantMap> foldScopes; 133 134 /// This map tracks all of the dialects that an operation is referenced by; 135 /// given that many dialects may generate the same constant. 136 DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects; 137 138 /// A collection of dialect folder interfaces. 139 DialectInterfaceCollection<DialectFoldInterface> interfaces; 140 }; 141 142 } // end namespace mlir 143 144 #endif // MLIR_TRANSFORMS_FOLDUTILS_H 145