1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Analysis/userange_analysis.h"
17 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
18 #include "mlir-hlo/Transforms/PassDetail.h"
19 #include "mlir-hlo/Transforms/passes.h"
20 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
21 #include "mlir/Interfaces/CopyOpInterface.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "mlir/Pass/Pass.h"
24
25 namespace mlir {
26
27 namespace {
28
29 class CopyRemoval : bufferization::BufferPlacementTransformationBase {
30 public:
CopyRemoval(Operation * op)31 explicit CopyRemoval(Operation *op)
32 : BufferPlacementTransformationBase(op),
33 userange(op, allocs, aliases),
34 dominators(op) {}
35
removeCopy()36 void removeCopy() {
37 // A vector with the allocation value / copy operation pairs s to process.
38 llvm::SmallVector<CopyOpInterface> toProcess;
39 fillProcessSetAndResolveAliases(toProcess);
40
41 DenseMap<Value, UseInterval::Vector> updatedUserange;
42 DenseMap<Value, UserangeAnalysis::UsePositionList> updatedUsepositions;
43
44 // Lambda expression to update the userange interval.
45 auto lambdaUserangeUpdate = [&](Value v,
46 DenseMap<Value, UseInterval::Vector> &map)
47 -> UseInterval::Vector & { return insertUserangeInterval(v, map); };
48
49 // A set containing copy operations that can be erased.
50 SmallPtrSet<Operation *, 16> toErase;
51 while (!toProcess.empty()) {
52 CopyOpInterface copyOp = toProcess.pop_back_val();
53
54 // Cast the Operation and get the Source and Target.
55 Value copySource = copyOp.getSource();
56 Value copyTarget = copyOp.getTarget();
57
58 // Only remove copies if they do not affect maps.
59 if (copySource.getType().cast<MemRefType>().getLayout() !=
60 copyTarget.getType().cast<MemRefType>().getLayout())
61 continue;
62
63 // Get the UserangeIntervals.
64 auto sourceAlloc = aliasToAllocMap[copySource];
65 UseInterval::Vector sourceInterval =
66 getOrInsert(sourceAlloc, updatedUserange, lambdaUserangeUpdate);
67 auto targetAlloc = aliasToAllocMap[copyTarget];
68 UseInterval::Vector targetInterval =
69 getOrInsert(targetAlloc, updatedUserange, lambdaUserangeUpdate);
70
71 UseInterval::Vector intersect = sourceInterval;
72
73 // Compute the intersection.
74 UseInterval::intervalIntersect(intersect, targetInterval);
75
76 // If the sourceInterval contains more than one UseInterval, there are
77 // multiple operations that intersect. The sourceInterval must have at
78 // least one UseInterval that contains the copyOp.
79 if (intersect.size() != 1) continue;
80
81 // Check if all operations inside the intersection are benign, part of the
82 // copyOp or a dealloc.
83 if (!usesInIntervalAreSafe(copyOp, copySource, *intersect.begin()))
84 continue;
85
86 // Check if the currentOp dominates all uses of the copyTarget.
87 if (!checkDominance(copyOp, copyTarget.getUsers(), toErase)) continue;
88
89 // The last op in the intersection of the use ranges needs to be a
90 // dealloc, as it ended the original source range. If we do the reuse,
91 // we have to remove that dealloc to extend the liferange of the original
92 // value.
93 auto *lastOp = userange.getOperation(intersect.back().end);
94 if (!isDeallocOperationFor(lastOp, copySource)) continue;
95 toErase.insert(lastOp);
96
97 // Merge the Useranges.
98 UseInterval::intervalMerge(sourceInterval, targetInterval);
99
100 // Replace all uses of the target with the source.
101 copyTarget.replaceAllUsesWith(copySource);
102 toErase.insert(copyOp);
103 }
104 // Erase the copy operations.
105 for (auto *eraseOp : toErase) eraseOp->erase();
106
107 // Erase all allocs without uses.
108 for (const bufferization::BufferPlacementAllocs::AllocEntry &entry :
109 allocs) {
110 Value alloc = std::get<0>(entry);
111 if (alloc.use_empty()) alloc.getDefiningOp()->erase();
112 }
113 }
114
115 private:
116 /// Iterate over all allocs and their aliases and add their uses to the
117 /// process set that implement a CopyOpInterface, where the alloc or alias is
118 /// the source of the CopyOpInterface.
fillProcessSetAndResolveAliases(llvm::SmallVectorImpl<CopyOpInterface> & toProcess)119 void fillProcessSetAndResolveAliases(
120 llvm::SmallVectorImpl<CopyOpInterface> &toProcess) {
121 // A Set that contains the already processed aliases.
122 SmallPtrSet<Value, 16U> processedAliases;
123
124 // Iterate over the allocs.
125 for (const bufferization::BufferPlacementAllocs::AllocEntry &entry :
126 allocs) {
127 Value allocValue = std::get<0>(entry);
128
129 // Resolve the aliases of the current alloc and iterate over them.
130 // At the same time, merge the use ranges of aliases into the use range
131 // of the corresponding allocation.
132 const ValueSetT &aliasSet = aliases.resolve(allocValue);
133 for (Value alias : aliasSet) {
134 // If the alias is already processed, continue.
135 if (!processedAliases.insert(alias).second) continue;
136 // Union the use ranges.
137 userange.unionRanges(allocValue, alias);
138 // Remember the alias.
139 aliasToAllocMap.insert({alias, allocValue});
140 // If any of the uses are a copy, we have a canidate.
141 for (auto *user : alias.getUsers()) {
142 auto copyOp = dyn_cast<CopyOpInterface>(user);
143 if (!copyOp) continue;
144 if (copyOp.getSource() != alias) continue;
145 toProcess.push_back(copyOp);
146 }
147 }
148 }
149 }
150
151 /// Find the given Value in the DenseMap and return the pointer. If the given
152 /// Value is not in the Map, insert a copy of the given original to the
153 /// DenseMap using the pased update function and return a pointer to that
154 /// element.
155 template <typename T, typename TFunc>
getOrInsert(Value v,DenseMap<Value,T> & updateMap,const TFunc & updateFunc)156 T &getOrInsert(Value v, DenseMap<Value, T> &updateMap,
157 const TFunc &updateFunc) {
158 auto iter = updateMap.find(v);
159 if (iter != updateMap.end()) return iter->second;
160 return updateFunc(v, updateMap);
161 }
162
163 /// Insert the original userange intervals of the operation in the map.
insertUserangeInterval(Value v,DenseMap<Value,UseInterval::Vector> & updateMap)164 UseInterval::Vector &insertUserangeInterval(
165 Value v, DenseMap<Value, UseInterval::Vector> &updateMap) {
166 const auto *original = userange.getUserangeInterval(v).value();
167 auto &entry = updateMap[v];
168 entry = *original;
169 return entry;
170 }
171
172 /// Check if all users in the given range are dominated by given operation.
173 /// Note: The target has always at least one use which is the copy operation.
checkDominance(Operation * operation,const Value::user_range & userRange,SmallPtrSet<Operation *,16> & ignoreSet)174 bool checkDominance(Operation *operation, const Value::user_range &userRange,
175 SmallPtrSet<Operation *, 16> &ignoreSet) {
176 // Check if any use of the target is not dominated by the useOp. Erased
177 // operations are ignored as uses.
178 return llvm::all_of(userRange, [=](Operation *user) {
179 return ignoreSet.count(user) || dominators.dominates(operation, user);
180 });
181 }
182
183 /// Checks whether op is a dealloction operation for value.
184 /// This helper is aware of aliasing via the alias_to_alloc_map_.
isDeallocOperationFor(Operation * op,Value value)185 bool isDeallocOperationFor(Operation *op, Value value) {
186 auto effect = dyn_cast<MemoryEffectOpInterface>(op);
187 Value originalAlloc = aliasToAllocMap[value];
188 return effect && effect.hasEffect<MemoryEffects::Free>() &&
189 llvm::any_of(op->getOperands(), [&](Value operand) {
190 Value operandAlloc = aliasToAllocMap[operand];
191 return operandAlloc == originalAlloc;
192 });
193 }
194
195 /// Checks whether all uses within the given interval are safe, i.e., there
196 /// are no conflicts.
197 /// This currently means that the interval may only contain non-sideeffecting
198 /// operations or a dealloc of the given source value.
usesInIntervalAreSafe(Operation * op,Value source,UseInterval & interval)199 bool usesInIntervalAreSafe(Operation *op, Value source,
200 UseInterval &interval) {
201 // Divide the start and end by two to remove read/write properties.
202 for (int id = interval.start / 2, e = interval.end / 2; id <= e; ++id) {
203 // Get the operation from the id. Multiply the id by 2, because the
204 // userange operates on doubled ids. Return false if the operation is not
205 // an ancestor.
206 // TODO(herhut): This is a bit of a big hammer. Ideally this should only
207 // look at use positions. Refactor to use those here.
208 Operation *opInInterval = userange.getOperation(id * 2);
209 if (op->isAncestor(opInInterval)) continue;
210 auto effect = dyn_cast<MemoryEffectOpInterface>(opInInterval);
211 // If we do not know about effects, fail.
212 if (!effect) return false;
213 // If it has no effect we are safe. It is OK if it gets the operand as
214 // it does not use it.
215 if (effect.hasNoEffect()) continue;
216 if (isDeallocOperationFor(opInInterval, source)) continue;
217 return false;
218 }
219 return true;
220 }
221
222 /// The current userange info.
223 UserangeAnalysis userange;
224
225 /// A map from aliases to their allocation value.
226 DenseMap<Value, Value> aliasToAllocMap;
227
228 /// The current dominance info.
229 DominanceInfo dominators;
230 };
231
232 struct CopyRemovalPass : public CopyRemovalBase<CopyRemovalPass> {
runOnOperationmlir::__anon7409b2050111::CopyRemovalPass233 void runOnOperation() override {
234 Operation *funcOp = getOperation();
235 CopyRemoval removal(funcOp);
236 removal.removeCopy();
237 }
238 };
239
240 } // namespace
241
createCopyRemovalPass()242 std::unique_ptr<OperationPass<func::FuncOp>> createCopyRemovalPass() {
243 return std::make_unique<CopyRemovalPass>();
244 }
245
246 } // namespace mlir
247