• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various operation folding utilities. These
10 // utilities are intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TRANSFORMS_FOLDUTILS_H
15 #define MLIR_TRANSFORMS_FOLDUTILS_H
16 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/DialectInterface.h"
20 #include "mlir/Interfaces/FoldInterfaces.h"
21 
22 namespace mlir {
23 class Operation;
24 class Value;
25 
26 
27 //===--------------------------------------------------------------------===//
28 // OperationFolder
29 //===--------------------------------------------------------------------===//
30 
31 /// A utility class for folding operations, and unifying duplicated constants
32 /// generated along the way.
33 class OperationFolder {
34 public:
OperationFolder(MLIRContext * ctx)35   OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
36 
37   /// Tries to perform folding on the given `op`, including unifying
38   /// deduplicated constants. If successful, replaces `op`'s uses with
39   /// folded results, and returns success. `preReplaceAction` is invoked on `op`
40   /// before it is replaced. 'processGeneratedConstants' is invoked for any new
41   /// operations generated when folding. If the op was completely folded it is
42   /// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
43   LogicalResult
44   tryToFold(Operation *op,
45             function_ref<void(Operation *)> processGeneratedConstants = nullptr,
46             function_ref<void(Operation *)> preReplaceAction = nullptr,
47             bool *inPlaceUpdate = nullptr);
48 
49   /// Notifies that the given constant `op` should be remove from this
50   /// OperationFolder's internal bookkeeping.
51   ///
52   /// Note: this method must be called if a constant op is to be deleted
53   /// externally to this OperationFolder. `op` must be a constant op.
54   void notifyRemoval(Operation *op);
55 
56   /// Create an operation of specific op type with the given builder,
57   /// and immediately try to fold it. This function populates 'results' with
58   /// the results after folding the operation.
59   template <typename OpTy, typename... Args>
create(OpBuilder & builder,SmallVectorImpl<Value> & results,Location location,Args &&...args)60   void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
61               Location location, Args &&... args) {
62     // The op needs to be inserted only if the fold (below) fails, or the number
63     // of results produced by the successful folding is zero (which is treated
64     // as an in-place fold). Using create methods of the builder will insert the
65     // op, so not using it here.
66     OperationState state(location, OpTy::getOperationName());
67     OpTy::build(builder, state, std::forward<Args>(args)...);
68     Operation *op = Operation::create(state);
69 
70     if (failed(tryToFold(builder, op, results)) || results.empty()) {
71       builder.insert(op);
72       results.assign(op->result_begin(), op->result_end());
73       return;
74     }
75     op->destroy();
76   }
77 
78   /// Overload to create or fold a single result operation.
79   template <typename OpTy, typename... Args>
80   typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
81                           Value>::type
create(OpBuilder & builder,Location location,Args &&...args)82   create(OpBuilder &builder, Location location, Args &&... args) {
83     SmallVector<Value, 1> results;
84     create<OpTy>(builder, results, location, std::forward<Args>(args)...);
85     return results.front();
86   }
87 
88   /// Overload to create or fold a zero result operation.
89   template <typename OpTy, typename... Args>
90   typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
91                           OpTy>::type
create(OpBuilder & builder,Location location,Args &&...args)92   create(OpBuilder &builder, Location location, Args &&... args) {
93     auto op = builder.create<OpTy>(location, std::forward<Args>(args)...);
94     SmallVector<Value, 0> unused;
95     (void)tryToFold(op.getOperation(), unused);
96 
97     // Folding cannot remove a zero-result operation, so for convenience we
98     // continue to return it.
99     return op;
100   }
101 
102   /// Clear out any constants cached inside of the folder.
103   void clear();
104 
105   /// Get or create a constant using the given builder. On success this returns
106   /// the constant operation, nullptr otherwise.
107   Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
108                             Attribute value, Type type, Location loc);
109 
110 private:
111   /// This map keeps track of uniqued constants by dialect, attribute, and type.
112   /// A constant operation materializes an attribute with a type. Dialects may
113   /// generate different constants with the same input attribute and type, so we
114   /// also need to track per-dialect.
115   using ConstantMap =
116       DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
117 
118   /// Tries to perform folding on the given `op`. If successful, populates
119   /// `results` with the results of the folding.
120   LogicalResult tryToFold(
121       OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
122       function_ref<void(Operation *)> processGeneratedConstants = nullptr);
123 
124   /// Try to get or create a new constant entry. On success this returns the
125   /// constant operation, nullptr otherwise.
126   Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
127                                     Dialect *dialect, OpBuilder &builder,
128                                     Attribute value, Type type, Location loc);
129 
130   /// A mapping between an insertion region and the constants that have been
131   /// created within it.
132   DenseMap<Region *, ConstantMap> foldScopes;
133 
134   /// This map tracks all of the dialects that an operation is referenced by;
135   /// given that many dialects may generate the same constant.
136   DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects;
137 
138   /// A collection of dialect folder interfaces.
139   DialectInterfaceCollection<DialectFoldInterface> interfaces;
140 };
141 
142 } // end namespace mlir
143 
144 #endif // MLIR_TRANSFORMS_FOLDUTILS_H
145