• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Analysis/userange_analysis.h"
17 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
18 #include "mlir-hlo/Transforms/PassDetail.h"
19 #include "mlir-hlo/Transforms/passes.h"
20 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
21 #include "mlir/Interfaces/CopyOpInterface.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "mlir/Pass/Pass.h"
24 
25 namespace mlir {
26 
27 namespace {
28 
29 class CopyRemoval : bufferization::BufferPlacementTransformationBase {
30  public:
CopyRemoval(Operation * op)31   explicit CopyRemoval(Operation *op)
32       : BufferPlacementTransformationBase(op),
33         userange(op, allocs, aliases),
34         dominators(op) {}
35 
removeCopy()36   void removeCopy() {
37     // A vector with the allocation value / copy operation pairs s to process.
38     llvm::SmallVector<CopyOpInterface> toProcess;
39     fillProcessSetAndResolveAliases(toProcess);
40 
41     DenseMap<Value, UseInterval::Vector> updatedUserange;
42     DenseMap<Value, UserangeAnalysis::UsePositionList> updatedUsepositions;
43 
44     // Lambda expression to update the userange interval.
45     auto lambdaUserangeUpdate = [&](Value v,
46                                     DenseMap<Value, UseInterval::Vector> &map)
47         -> UseInterval::Vector & { return insertUserangeInterval(v, map); };
48 
49     // A set containing copy operations that can be erased.
50     SmallPtrSet<Operation *, 16> toErase;
51     while (!toProcess.empty()) {
52       CopyOpInterface copyOp = toProcess.pop_back_val();
53 
54       // Cast the Operation and get the Source and Target.
55       Value copySource = copyOp.getSource();
56       Value copyTarget = copyOp.getTarget();
57 
58       // Only remove copies if they do not affect maps.
59       if (copySource.getType().cast<MemRefType>().getLayout() !=
60           copyTarget.getType().cast<MemRefType>().getLayout())
61         continue;
62 
63       // Get the UserangeIntervals.
64       auto sourceAlloc = aliasToAllocMap[copySource];
65       UseInterval::Vector sourceInterval =
66           getOrInsert(sourceAlloc, updatedUserange, lambdaUserangeUpdate);
67       auto targetAlloc = aliasToAllocMap[copyTarget];
68       UseInterval::Vector targetInterval =
69           getOrInsert(targetAlloc, updatedUserange, lambdaUserangeUpdate);
70 
71       UseInterval::Vector intersect = sourceInterval;
72 
73       // Compute the intersection.
74       UseInterval::intervalIntersect(intersect, targetInterval);
75 
76       // If the sourceInterval contains more than one UseInterval, there are
77       // multiple operations that intersect. The sourceInterval must have at
78       // least one UseInterval that contains the copyOp.
79       if (intersect.size() != 1) continue;
80 
81       // Check if all operations inside the intersection are benign, part of the
82       // copyOp or a dealloc.
83       if (!usesInIntervalAreSafe(copyOp, copySource, *intersect.begin()))
84         continue;
85 
86       // Check if the currentOp dominates all uses of the copyTarget.
87       if (!checkDominance(copyOp, copyTarget.getUsers(), toErase)) continue;
88 
89       // The last op in the intersection of the use ranges needs to be a
90       // dealloc, as it ended the original source range. If we do the reuse,
91       // we have to remove that dealloc to extend the liferange of the original
92       // value.
93       auto *lastOp = userange.getOperation(intersect.back().end);
94       if (!isDeallocOperationFor(lastOp, copySource)) continue;
95       toErase.insert(lastOp);
96 
97       // Merge the Useranges.
98       UseInterval::intervalMerge(sourceInterval, targetInterval);
99 
100       // Replace all uses of the target with the source.
101       copyTarget.replaceAllUsesWith(copySource);
102       toErase.insert(copyOp);
103     }
104     // Erase the copy operations.
105     for (auto *eraseOp : toErase) eraseOp->erase();
106 
107     // Erase all allocs without uses.
108     for (const bufferization::BufferPlacementAllocs::AllocEntry &entry :
109          allocs) {
110       Value alloc = std::get<0>(entry);
111       if (alloc.use_empty()) alloc.getDefiningOp()->erase();
112     }
113   }
114 
115  private:
116   /// Iterate over all allocs and their aliases and add their uses to the
117   /// process set that implement a CopyOpInterface, where the alloc or alias is
118   /// the source of the CopyOpInterface.
fillProcessSetAndResolveAliases(llvm::SmallVectorImpl<CopyOpInterface> & toProcess)119   void fillProcessSetAndResolveAliases(
120       llvm::SmallVectorImpl<CopyOpInterface> &toProcess) {
121     // A Set that contains the already processed aliases.
122     SmallPtrSet<Value, 16U> processedAliases;
123 
124     // Iterate over the allocs.
125     for (const bufferization::BufferPlacementAllocs::AllocEntry &entry :
126          allocs) {
127       Value allocValue = std::get<0>(entry);
128 
129       // Resolve the aliases of the current alloc and iterate over them.
130       // At the same time, merge the use ranges of aliases into the use range
131       // of the corresponding allocation.
132       const ValueSetT &aliasSet = aliases.resolve(allocValue);
133       for (Value alias : aliasSet) {
134         // If the alias is already processed, continue.
135         if (!processedAliases.insert(alias).second) continue;
136         // Union the use ranges.
137         userange.unionRanges(allocValue, alias);
138         // Remember the alias.
139         aliasToAllocMap.insert({alias, allocValue});
140         // If any of the uses are a copy, we have a canidate.
141         for (auto *user : alias.getUsers()) {
142           auto copyOp = dyn_cast<CopyOpInterface>(user);
143           if (!copyOp) continue;
144           if (copyOp.getSource() != alias) continue;
145           toProcess.push_back(copyOp);
146         }
147       }
148     }
149   }
150 
151   /// Find the given Value in the DenseMap and return the pointer. If the given
152   /// Value is not in the Map, insert a copy of the given original to the
153   /// DenseMap using the pased update function and return a pointer to that
154   /// element.
155   template <typename T, typename TFunc>
getOrInsert(Value v,DenseMap<Value,T> & updateMap,const TFunc & updateFunc)156   T &getOrInsert(Value v, DenseMap<Value, T> &updateMap,
157                  const TFunc &updateFunc) {
158     auto iter = updateMap.find(v);
159     if (iter != updateMap.end()) return iter->second;
160     return updateFunc(v, updateMap);
161   }
162 
163   /// Insert the original userange intervals of the operation in the map.
insertUserangeInterval(Value v,DenseMap<Value,UseInterval::Vector> & updateMap)164   UseInterval::Vector &insertUserangeInterval(
165       Value v, DenseMap<Value, UseInterval::Vector> &updateMap) {
166     const auto *original = userange.getUserangeInterval(v).value();
167     auto &entry = updateMap[v];
168     entry = *original;
169     return entry;
170   }
171 
172   /// Check if all users in the given range are dominated by given operation.
173   /// Note: The target has always at least one use which is the copy operation.
checkDominance(Operation * operation,const Value::user_range & userRange,SmallPtrSet<Operation *,16> & ignoreSet)174   bool checkDominance(Operation *operation, const Value::user_range &userRange,
175                       SmallPtrSet<Operation *, 16> &ignoreSet) {
176     // Check if any use of the target is not dominated by the useOp. Erased
177     // operations are ignored as uses.
178     return llvm::all_of(userRange, [=](Operation *user) {
179       return ignoreSet.count(user) || dominators.dominates(operation, user);
180     });
181   }
182 
183   /// Checks whether op is a dealloction operation for value.
184   /// This helper is aware of aliasing via the alias_to_alloc_map_.
isDeallocOperationFor(Operation * op,Value value)185   bool isDeallocOperationFor(Operation *op, Value value) {
186     auto effect = dyn_cast<MemoryEffectOpInterface>(op);
187     Value originalAlloc = aliasToAllocMap[value];
188     return effect && effect.hasEffect<MemoryEffects::Free>() &&
189            llvm::any_of(op->getOperands(), [&](Value operand) {
190              Value operandAlloc = aliasToAllocMap[operand];
191              return operandAlloc == originalAlloc;
192            });
193   }
194 
195   /// Checks whether all uses within the given interval are safe, i.e., there
196   /// are no conflicts.
197   /// This currently means that the interval may only contain non-sideeffecting
198   /// operations or a dealloc of the given source value.
usesInIntervalAreSafe(Operation * op,Value source,UseInterval & interval)199   bool usesInIntervalAreSafe(Operation *op, Value source,
200                              UseInterval &interval) {
201     // Divide the start and end by two to remove read/write properties.
202     for (int id = interval.start / 2, e = interval.end / 2; id <= e; ++id) {
203       // Get the operation from the id. Multiply the id by 2, because the
204       // userange operates on doubled ids. Return false if the operation is not
205       // an ancestor.
206       // TODO(herhut): This is a bit of a big hammer. Ideally this should only
207       //               look at use positions. Refactor to use those here.
208       Operation *opInInterval = userange.getOperation(id * 2);
209       if (op->isAncestor(opInInterval)) continue;
210       auto effect = dyn_cast<MemoryEffectOpInterface>(opInInterval);
211       // If we do not know about effects, fail.
212       if (!effect) return false;
213       // If it has no effect we are safe. It is OK if it gets the operand as
214       // it does not use it.
215       if (effect.hasNoEffect()) continue;
216       if (isDeallocOperationFor(opInInterval, source)) continue;
217       return false;
218     }
219     return true;
220   }
221 
222   /// The current userange info.
223   UserangeAnalysis userange;
224 
225   /// A map from aliases to their allocation value.
226   DenseMap<Value, Value> aliasToAllocMap;
227 
228   /// The current dominance info.
229   DominanceInfo dominators;
230 };
231 
232 struct CopyRemovalPass : public CopyRemovalBase<CopyRemovalPass> {
runOnOperationmlir::__anon7409b2050111::CopyRemovalPass233   void runOnOperation() override {
234     Operation *funcOp = getOperation();
235     CopyRemoval removal(funcOp);
236     removal.removeCopy();
237   }
238 };
239 
240 }  // namespace
241 
createCopyRemovalPass()242 std::unique_ptr<OperationPass<func::FuncOp>> createCopyRemovalPass() {
243   return std::make_unique<CopyRemovalPass>();
244 }
245 
246 }  // namespace mlir
247