1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17 #include <vector>
18
19 #include "llvm/ADT/EquivalenceClasses.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/Analysis/BufferAliasAnalysis.h" // from @llvm-project
24 #include "mlir/Analysis/Liveness.h" // from @llvm-project
25 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/AffineMap.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
35
36 // Needed to build `llvm::EquivalenceClasses` of `mlir::Value`s.
37 namespace mlir {
operator <(const Value & lhs,const Value & rhs)38 static bool operator<(const Value &lhs, const Value &rhs) {
39 return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
40 }
41 } // namespace mlir
42
43 constexpr llvm::StringRef
44 mlir::kernel_gen::tf_framework::TFAllocOp::kReuseOutputAttrName;
45 constexpr llvm::StringRef
46 mlir::kernel_gen::tf_framework::TFAllocOp::kReuseInputCandidatesAttrName;
47 constexpr llvm::StringRef
48 mlir::kernel_gen::tf_framework::TFFrameworkDialect::kTFEntryAttrName;
49
50 namespace mlir {
51 namespace kernel_gen {
52 namespace transforms {
53 namespace {
54
55 /// A temporary buffer size analysis that is correct but may be incomplete.
56 class BufferSizeAnalysis {
57 public:
BufferSizeAnalysis(FuncOp f,const BufferAliasAnalysis & aliases)58 BufferSizeAnalysis(FuncOp f, const BufferAliasAnalysis &aliases) {
59 build(f, aliases);
60 }
61
is_same_size(Value a,Value b)62 bool is_same_size(Value a, Value b) { return ecs_.isEquivalent(a, b); }
63
64 private:
build(FuncOp & f,const BufferAliasAnalysis & aliases)65 void build(FuncOp &f, const BufferAliasAnalysis &aliases) {
66 auto buffers = find_buffer_values(f);
67
68 // Memrefs with statically known same shape and same symbol-free affine maps
69 // must be of the same size.
70 int n = buffers.size();
71 for (int i = 0; i < n; ++i) {
72 for (int j = i + 1; j < n; ++j) {
73 Value a = buffers[i];
74 Value b = buffers[j];
75 auto a_ty = a.getType().dyn_cast<MemRefType>();
76 auto b_ty = b.getType().dyn_cast<MemRefType>();
77 if (a_ty && b_ty && a_ty.hasStaticShape() && b_ty.hasStaticShape() &&
78 a_ty.getNumElements() == b_ty.getNumElements() &&
79 a_ty.getElementType() == b_ty.getElementType() &&
80 affine_maps_symbol_free_and_equal(a_ty.getAffineMaps(),
81 b_ty.getAffineMaps())) {
82 ecs_.unionSets(a, b);
83 }
84 }
85 }
86
87 // Operands to `linalg.generic` with equal affine maps must be of same size.
88 f.walk([&](linalg::GenericOp genericOp) {
89 auto operand_buffers = genericOp.getShapedOperands();
90 int n = operand_buffers.size();
91 for (int i = 0; i < n; ++i) {
92 for (int j = i + 1; j < n; ++j) {
93 Value a = operand_buffers[i];
94 Value b = operand_buffers[j];
95 auto a_ty = a.getType().dyn_cast<MemRefType>();
96 auto b_ty = b.getType().dyn_cast<MemRefType>();
97 if (a_ty && b_ty && a_ty.getElementType() == b_ty.getElementType() &&
98 a_ty.getAffineMaps() == b_ty.getAffineMaps()) {
99 AffineMap map_i = genericOp.getIndexingMap(i);
100 AffineMap map_j = genericOp.getIndexingMap(j);
101 if (map_i == map_j && map_i.isPermutation()) ecs_.unionSets(a, b);
102 }
103 }
104 }
105 });
106
107 // All aliases of a memref must be of the same underlying buffer size.
108 for (auto e : aliases) {
109 Value value = e.getFirst();
110 if (!value.getType().isa<BaseMemRefType>()) continue;
111 for (Value alias : e.getSecond()) {
112 assert(alias.getType().isa<BaseMemRefType>() &&
113 "Expected aliases of memref to be memrefs.");
114 ecs_.unionSets(value, alias);
115 }
116 }
117 }
118
affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as,ArrayRef<AffineMap> bs)119 bool affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as,
120 ArrayRef<AffineMap> bs) {
121 auto is_symbol_free = [](AffineMap map) {
122 return map.getNumSymbols() == 0;
123 };
124 return llvm::all_of(as, is_symbol_free) &&
125 llvm::all_of(bs, is_symbol_free) && as == bs;
126 }
127
find_buffer_values(FuncOp f)128 llvm::SmallVector<Value, 8> find_buffer_values(FuncOp f) {
129 llvm::SmallVector<Value, 8> buffers;
130 f.walk([&](Operation *op) {
131 for (Value val : op->getResults())
132 if (val.getType().isa<BaseMemRefType>()) buffers.push_back(val);
133 });
134 f.walk([&](Block *block) {
135 for (Value val : block->getArguments()) {
136 if (val.getType().isa<BaseMemRefType>()) buffers.push_back(val);
137 }
138 });
139 return buffers;
140 }
141
142 llvm::EquivalenceClasses<Value> ecs_;
143 };
144
145 class BufferReuseAnalysis {
146 public:
BufferReuseAnalysis(FuncOp f)147 explicit BufferReuseAnalysis(FuncOp f) { build(f); }
148
149 static constexpr int32_t kIndexAmbiguous = -1;
150
get_reuse_candiates(AllocOp op)151 Optional<SmallVector<int32_t, 2>> get_reuse_candiates(AllocOp op) {
152 auto it = reuse_candidates_.find(op);
153 if (it == reuse_candidates_.end()) return llvm::None;
154 return it->second;
155 }
156
get_output_index(AllocOp op)157 Optional<int32_t> get_output_index(AllocOp op) {
158 auto it = output_indices_.find(op);
159 if (it == output_indices_.end()) return llvm::None;
160 return it->second;
161 }
162
163 private:
build(FuncOp & f)164 void build(FuncOp &f) {
165 BufferAliasAnalysis aliases(f);
166 find_output_indices(f, aliases);
167 find_reuse_candiates(f, aliases);
168 }
169
find_output_indices(FuncOp & f,BufferAliasAnalysis & aliases)170 void find_output_indices(FuncOp &f, BufferAliasAnalysis &aliases) {
171 f.walk([&](AllocOp alloc_op) {
172 int32_t output_index = kIndexAmbiguous;
173 int count_return_uses = 0;
174 auto buffer_aliases = aliases.resolve(alloc_op.getResult());
175 for (Value alias : buffer_aliases) {
176 for (auto &use : alias.getUses()) {
177 if (isa<ReturnOp>(use.getOwner())) {
178 int32_t index = use.getOperandNumber();
179 if (count_return_uses++ == 0)
180 output_index = index;
181 else if (output_index != index)
182 output_index = kIndexAmbiguous;
183 }
184 }
185 }
186 output_indices_[alloc_op] = output_index;
187 });
188 }
189
find_reuse_candiates(FuncOp & f,BufferAliasAnalysis & aliases)190 void find_reuse_candiates(FuncOp &f, BufferAliasAnalysis &aliases) {
191 Liveness liveness(f);
192 BufferSizeAnalysis size_equivalences(f, aliases);
193 f.walk([&](Block *block) {
194 find_reuse_candiates(block, aliases, liveness.getLiveness(block),
195 size_equivalences, f.getArguments());
196 });
197 }
198
find_reuse_candiates(Block * block,BufferAliasAnalysis & aliases,const LivenessBlockInfo * liveness,BufferSizeAnalysis & size_equivalences,ArrayRef<BlockArgument> arguments)199 void find_reuse_candiates(Block *block, BufferAliasAnalysis &aliases,
200 const LivenessBlockInfo *liveness,
201 BufferSizeAnalysis &size_equivalences,
202 ArrayRef<BlockArgument> arguments) {
203 for (Operation &op : *block) {
204 auto alloc_op = dyn_cast<AllocOp>(op);
205 if (!alloc_op) continue;
206
207 // Find first use of the newly allocated buffer within this block.
208 Value new_buffer = alloc_op.getResult();
209 Operation *first_reuse = find_first_use_in_block(new_buffer, block);
210 assert((first_reuse == nullptr || first_reuse->getBlock() == block) &&
211 "Expected first use in same block if found.");
212
213 // Find reuse candidates for the regarded allocation.
214 SmallVector<int32_t, 2> local_reuse_candidates;
215 for (BlockArgument old_buffer : arguments) {
216 if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
217
218 // Size criterion: Do not reuse buffers of different size as they may be
219 // too small.
220 if (!size_equivalences.is_same_size(new_buffer, old_buffer)) continue;
221
222 // Lifetime criterion: Only reuse buffers that are no longer used on
223 // first reuse, i.e. they are no longer alive.
224 bool lifetimes_compatible = true;
225 for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
226 if (first_reuse == nullptr) {
227 // If the first use is beyond the end of this block we look at the
228 // block end. An argument buffer that is already reusable there is
229 // certainly reusable at any later actual use. Otherwise, lifetimes
230 // are incompatible.
231 if (liveness->isLiveOut(old_buffer_alias)) {
232 lifetimes_compatible = false;
233 break;
234 }
235 } else {
236 // A buffer is reusable if
237 // i) its last use is before the point of reuse, or
238 // ii) its last use is also its first reuse and the operation
239 // allows for local reuse.
240 // Otherwise, lifetimes are incompatible.
241 Operation *last_use =
242 liveness->getEndOperation(old_buffer_alias, &block->front());
243 assert(last_use != nullptr && last_use->getBlock() == block &&
244 "Expected last use in same block.");
245 if (first_reuse->isBeforeInBlock(last_use)) {
246 lifetimes_compatible = false;
247 break;
248 }
249 if (first_reuse == last_use &&
250 !can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
251 lifetimes_compatible = false;
252 break;
253 }
254 }
255 }
256
257 if (lifetimes_compatible) {
258 // All criteria are fulfilled .
259 int32_t old_buffer_index = old_buffer.getArgNumber();
260 local_reuse_candidates.push_back(old_buffer_index);
261 }
262 }
263
264 reuse_candidates_[&op] = local_reuse_candidates;
265 }
266 }
267
find_first_use_in_block(Value value,Block * block)268 Operation *find_first_use_in_block(Value value, Block *block) {
269 Operation *first_use = nullptr;
270 for (Operation *op : value.getUsers()) {
271 Operation *ancestor_op = block->findAncestorOpInBlock(*op);
272 if (ancestor_op == nullptr) continue;
273 if (first_use == nullptr || ancestor_op->isBeforeInBlock(first_use))
274 first_use = ancestor_op;
275 }
276 return first_use;
277 }
278
get_buffer_arguments(FuncOp & f)279 std::vector<Value> get_buffer_arguments(FuncOp &f) {
280 std::vector<Value> buffer_arguments;
281 for (BlockArgument arg : f.getArguments()) {
282 if (arg.getType().isa<BaseMemRefType>()) buffer_arguments.push_back(arg);
283 }
284 return buffer_arguments;
285 }
286
can_reuse_locally(Operation * op,Value old_buffer,Value new_buffer)287 bool can_reuse_locally(Operation *op, Value old_buffer, Value new_buffer) {
288 // For now, we support only memrefs with the same memory layout.
289 auto old_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
290 auto new_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
291 if (!old_buffer_ty || !new_buffer_ty ||
292 old_buffer_ty.getAffineMaps() != new_buffer_ty.getAffineMaps())
293 return false;
294
295 if (auto generic_op = dyn_cast<linalg::GenericOp>(op)) {
296 assert(llvm::find(op->getOperands(), old_buffer) !=
297 op->getOperands().end() &&
298 llvm::find(op->getOperands(), new_buffer) !=
299 op->getOperands().end() &&
300 "Expect `old/new_buffer` to be operand of `op`.");
301
302 // If `linalg.generic` indexing maps are the same for input and output
303 // buffer then the last use of the input buffer happens before its first
304 // reuse (per memory location).
305 auto operand_buffers = generic_op.getShapedOperands();
306 int old_index =
307 llvm::find(operand_buffers, old_buffer) - operand_buffers.begin();
308 int new_index =
309 llvm::find(operand_buffers, new_buffer) - operand_buffers.begin();
310 AffineMap old_indexing_map = generic_op.getIndexingMap(old_index);
311 AffineMap new_indexing_map = generic_op.getIndexingMap(new_index);
312 return old_indexing_map == new_indexing_map &&
313 old_indexing_map.isPermutation();
314 }
315 return false;
316 }
317
318 DenseMap<Operation *, SmallVector<int32_t, 2>> reuse_candidates_;
319 DenseMap<Operation *, int32_t> output_indices_;
320 };
321
322 #define GEN_PASS_CLASSES
323 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
324
325 struct BufferReusePass : public BufferReusePassBase<BufferReusePass> {
runOnFunctionmlir::kernel_gen::transforms::__anon8364ced70111::BufferReusePass326 void runOnFunction() override {
327 if (!getFunction()->getAttrOfType<UnitAttr>(
328 tf_framework::TFFrameworkDialect::kTFEntryAttrName))
329 return;
330
331 BufferReuseAnalysis analysis(getFunction());
332
333 // Annotate IR with reuse candidates and output indices per allocation.
334 Builder builder(&getContext());
335 getFunction().walk([&](AllocOp op) {
336 if (auto output_index = analysis.get_output_index(op)) {
337 auto attr = builder.getI32IntegerAttr(*output_index);
338 op.getOperation()->setAttr(
339 tf_framework::TFAllocOp::kReuseOutputAttrName, attr);
340 }
341 if (auto reuse_candiates = analysis.get_reuse_candiates(op)) {
342 auto attr = builder.getI32ArrayAttr(*reuse_candiates);
343 op.getOperation()->setAttr(
344 tf_framework::TFAllocOp::kReuseInputCandidatesAttrName, attr);
345 }
346 });
347 }
348 };
349
350 } // namespace
351
CreateBufferReusePass()352 std::unique_ptr<FunctionPass> CreateBufferReusePass() {
353 return std::make_unique<BufferReusePass>();
354 }
355
356 } // namespace transforms
357 } // namespace kernel_gen
358 } // namespace mlir
359