• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <vector>
18 
19 #include "llvm/ADT/EquivalenceClasses.h"
20 #include "llvm/ADT/None.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/Analysis/BufferAliasAnalysis.h"  // from @llvm-project
24 #include "mlir/Analysis/Liveness.h"  // from @llvm-project
25 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"  // from @llvm-project
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/AffineMap.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
35 
36 // Needed to build `llvm::EquivalenceClasses` of `mlir::Value`s.
37 namespace mlir {
operator <(const Value & lhs,const Value & rhs)38 static bool operator<(const Value &lhs, const Value &rhs) {
39   return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
40 }
41 }  // namespace mlir
42 
43 constexpr llvm::StringRef
44     mlir::kernel_gen::tf_framework::TFAllocOp::kReuseOutputAttrName;
45 constexpr llvm::StringRef
46     mlir::kernel_gen::tf_framework::TFAllocOp::kReuseInputCandidatesAttrName;
47 constexpr llvm::StringRef
48     mlir::kernel_gen::tf_framework::TFFrameworkDialect::kTFEntryAttrName;
49 
50 namespace mlir {
51 namespace kernel_gen {
52 namespace transforms {
53 namespace {
54 
55 /// A temporary buffer size analysis that is correct but may be incomplete.
56 class BufferSizeAnalysis {
57  public:
BufferSizeAnalysis(FuncOp f,const BufferAliasAnalysis & aliases)58   BufferSizeAnalysis(FuncOp f, const BufferAliasAnalysis &aliases) {
59     build(f, aliases);
60   }
61 
is_same_size(Value a,Value b)62   bool is_same_size(Value a, Value b) { return ecs_.isEquivalent(a, b); }
63 
64  private:
build(FuncOp & f,const BufferAliasAnalysis & aliases)65   void build(FuncOp &f, const BufferAliasAnalysis &aliases) {
66     auto buffers = find_buffer_values(f);
67 
68     // Memrefs with statically known same shape and same symbol-free affine maps
69     // must be of the same size.
70     int n = buffers.size();
71     for (int i = 0; i < n; ++i) {
72       for (int j = i + 1; j < n; ++j) {
73         Value a = buffers[i];
74         Value b = buffers[j];
75         auto a_ty = a.getType().dyn_cast<MemRefType>();
76         auto b_ty = b.getType().dyn_cast<MemRefType>();
77         if (a_ty && b_ty && a_ty.hasStaticShape() && b_ty.hasStaticShape() &&
78             a_ty.getNumElements() == b_ty.getNumElements() &&
79             a_ty.getElementType() == b_ty.getElementType() &&
80             affine_maps_symbol_free_and_equal(a_ty.getAffineMaps(),
81                                               b_ty.getAffineMaps())) {
82           ecs_.unionSets(a, b);
83         }
84       }
85     }
86 
87     // Operands to `linalg.generic` with equal affine maps must be of same size.
88     f.walk([&](linalg::GenericOp genericOp) {
89       auto operand_buffers = genericOp.getShapedOperands();
90       int n = operand_buffers.size();
91       for (int i = 0; i < n; ++i) {
92         for (int j = i + 1; j < n; ++j) {
93           Value a = operand_buffers[i];
94           Value b = operand_buffers[j];
95           auto a_ty = a.getType().dyn_cast<MemRefType>();
96           auto b_ty = b.getType().dyn_cast<MemRefType>();
97           if (a_ty && b_ty && a_ty.getElementType() == b_ty.getElementType() &&
98               a_ty.getAffineMaps() == b_ty.getAffineMaps()) {
99             AffineMap map_i = genericOp.getIndexingMap(i);
100             AffineMap map_j = genericOp.getIndexingMap(j);
101             if (map_i == map_j && map_i.isPermutation()) ecs_.unionSets(a, b);
102           }
103         }
104       }
105     });
106 
107     // All aliases of a memref must be of the same underlying buffer size.
108     for (auto e : aliases) {
109       Value value = e.getFirst();
110       if (!value.getType().isa<BaseMemRefType>()) continue;
111       for (Value alias : e.getSecond()) {
112         assert(alias.getType().isa<BaseMemRefType>() &&
113                "Expected aliases of memref to be memrefs.");
114         ecs_.unionSets(value, alias);
115       }
116     }
117   }
118 
affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as,ArrayRef<AffineMap> bs)119   bool affine_maps_symbol_free_and_equal(ArrayRef<AffineMap> as,
120                                          ArrayRef<AffineMap> bs) {
121     auto is_symbol_free = [](AffineMap map) {
122       return map.getNumSymbols() == 0;
123     };
124     return llvm::all_of(as, is_symbol_free) &&
125            llvm::all_of(bs, is_symbol_free) && as == bs;
126   }
127 
find_buffer_values(FuncOp f)128   llvm::SmallVector<Value, 8> find_buffer_values(FuncOp f) {
129     llvm::SmallVector<Value, 8> buffers;
130     f.walk([&](Operation *op) {
131       for (Value val : op->getResults())
132         if (val.getType().isa<BaseMemRefType>()) buffers.push_back(val);
133     });
134     f.walk([&](Block *block) {
135       for (Value val : block->getArguments()) {
136         if (val.getType().isa<BaseMemRefType>()) buffers.push_back(val);
137       }
138     });
139     return buffers;
140   }
141 
142   llvm::EquivalenceClasses<Value> ecs_;
143 };
144 
145 class BufferReuseAnalysis {
146  public:
BufferReuseAnalysis(FuncOp f)147   explicit BufferReuseAnalysis(FuncOp f) { build(f); }
148 
149   static constexpr int32_t kIndexAmbiguous = -1;
150 
get_reuse_candiates(AllocOp op)151   Optional<SmallVector<int32_t, 2>> get_reuse_candiates(AllocOp op) {
152     auto it = reuse_candidates_.find(op);
153     if (it == reuse_candidates_.end()) return llvm::None;
154     return it->second;
155   }
156 
get_output_index(AllocOp op)157   Optional<int32_t> get_output_index(AllocOp op) {
158     auto it = output_indices_.find(op);
159     if (it == output_indices_.end()) return llvm::None;
160     return it->second;
161   }
162 
163  private:
build(FuncOp & f)164   void build(FuncOp &f) {
165     BufferAliasAnalysis aliases(f);
166     find_output_indices(f, aliases);
167     find_reuse_candiates(f, aliases);
168   }
169 
find_output_indices(FuncOp & f,BufferAliasAnalysis & aliases)170   void find_output_indices(FuncOp &f, BufferAliasAnalysis &aliases) {
171     f.walk([&](AllocOp alloc_op) {
172       int32_t output_index = kIndexAmbiguous;
173       int count_return_uses = 0;
174       auto buffer_aliases = aliases.resolve(alloc_op.getResult());
175       for (Value alias : buffer_aliases) {
176         for (auto &use : alias.getUses()) {
177           if (isa<ReturnOp>(use.getOwner())) {
178             int32_t index = use.getOperandNumber();
179             if (count_return_uses++ == 0)
180               output_index = index;
181             else if (output_index != index)
182               output_index = kIndexAmbiguous;
183           }
184         }
185       }
186       output_indices_[alloc_op] = output_index;
187     });
188   }
189 
find_reuse_candiates(FuncOp & f,BufferAliasAnalysis & aliases)190   void find_reuse_candiates(FuncOp &f, BufferAliasAnalysis &aliases) {
191     Liveness liveness(f);
192     BufferSizeAnalysis size_equivalences(f, aliases);
193     f.walk([&](Block *block) {
194       find_reuse_candiates(block, aliases, liveness.getLiveness(block),
195                            size_equivalences, f.getArguments());
196     });
197   }
198 
find_reuse_candiates(Block * block,BufferAliasAnalysis & aliases,const LivenessBlockInfo * liveness,BufferSizeAnalysis & size_equivalences,ArrayRef<BlockArgument> arguments)199   void find_reuse_candiates(Block *block, BufferAliasAnalysis &aliases,
200                             const LivenessBlockInfo *liveness,
201                             BufferSizeAnalysis &size_equivalences,
202                             ArrayRef<BlockArgument> arguments) {
203     for (Operation &op : *block) {
204       auto alloc_op = dyn_cast<AllocOp>(op);
205       if (!alloc_op) continue;
206 
207       // Find first use of the newly allocated buffer within this block.
208       Value new_buffer = alloc_op.getResult();
209       Operation *first_reuse = find_first_use_in_block(new_buffer, block);
210       assert((first_reuse == nullptr || first_reuse->getBlock() == block) &&
211              "Expected first use in same block if found.");
212 
213       // Find reuse candidates for the regarded allocation.
214       SmallVector<int32_t, 2> local_reuse_candidates;
215       for (BlockArgument old_buffer : arguments) {
216         if (!old_buffer.getType().isa<BaseMemRefType>()) continue;
217 
218         // Size criterion: Do not reuse buffers of different size as they may be
219         // too small.
220         if (!size_equivalences.is_same_size(new_buffer, old_buffer)) continue;
221 
222         // Lifetime criterion: Only reuse buffers that are no longer used on
223         // first reuse, i.e. they are no longer alive.
224         bool lifetimes_compatible = true;
225         for (Value old_buffer_alias : aliases.resolve(old_buffer)) {
226           if (first_reuse == nullptr) {
227             // If the first use is beyond the end of this block we look at the
228             // block end. An argument buffer that is already reusable there is
229             // certainly reusable at any later actual use. Otherwise, lifetimes
230             // are incompatible.
231             if (liveness->isLiveOut(old_buffer_alias)) {
232               lifetimes_compatible = false;
233               break;
234             }
235           } else {
236             // A buffer is reusable if
237             //   i)  its last use is before the point of reuse, or
238             //   ii) its last use is also its first reuse and the operation
239             //       allows for local reuse.
240             // Otherwise, lifetimes are incompatible.
241             Operation *last_use =
242                 liveness->getEndOperation(old_buffer_alias, &block->front());
243             assert(last_use != nullptr && last_use->getBlock() == block &&
244                    "Expected last use in same block.");
245             if (first_reuse->isBeforeInBlock(last_use)) {
246               lifetimes_compatible = false;
247               break;
248             }
249             if (first_reuse == last_use &&
250                 !can_reuse_locally(first_reuse, old_buffer_alias, new_buffer)) {
251               lifetimes_compatible = false;
252               break;
253             }
254           }
255         }
256 
257         if (lifetimes_compatible) {
258           // All criteria are fulfilled ��.
259           int32_t old_buffer_index = old_buffer.getArgNumber();
260           local_reuse_candidates.push_back(old_buffer_index);
261         }
262       }
263 
264       reuse_candidates_[&op] = local_reuse_candidates;
265     }
266   }
267 
find_first_use_in_block(Value value,Block * block)268   Operation *find_first_use_in_block(Value value, Block *block) {
269     Operation *first_use = nullptr;
270     for (Operation *op : value.getUsers()) {
271       Operation *ancestor_op = block->findAncestorOpInBlock(*op);
272       if (ancestor_op == nullptr) continue;
273       if (first_use == nullptr || ancestor_op->isBeforeInBlock(first_use))
274         first_use = ancestor_op;
275     }
276     return first_use;
277   }
278 
get_buffer_arguments(FuncOp & f)279   std::vector<Value> get_buffer_arguments(FuncOp &f) {
280     std::vector<Value> buffer_arguments;
281     for (BlockArgument arg : f.getArguments()) {
282       if (arg.getType().isa<BaseMemRefType>()) buffer_arguments.push_back(arg);
283     }
284     return buffer_arguments;
285   }
286 
can_reuse_locally(Operation * op,Value old_buffer,Value new_buffer)287   bool can_reuse_locally(Operation *op, Value old_buffer, Value new_buffer) {
288     // For now, we support only memrefs with the same memory layout.
289     auto old_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
290     auto new_buffer_ty = old_buffer.getType().dyn_cast<MemRefType>();
291     if (!old_buffer_ty || !new_buffer_ty ||
292         old_buffer_ty.getAffineMaps() != new_buffer_ty.getAffineMaps())
293       return false;
294 
295     if (auto generic_op = dyn_cast<linalg::GenericOp>(op)) {
296       assert(llvm::find(op->getOperands(), old_buffer) !=
297                  op->getOperands().end() &&
298              llvm::find(op->getOperands(), new_buffer) !=
299                  op->getOperands().end() &&
300              "Expect `old/new_buffer` to be operand of `op`.");
301 
302       // If `linalg.generic` indexing maps are the same for input and output
303       // buffer then the last use of the input buffer happens before its first
304       // reuse (per memory location).
305       auto operand_buffers = generic_op.getShapedOperands();
306       int old_index =
307           llvm::find(operand_buffers, old_buffer) - operand_buffers.begin();
308       int new_index =
309           llvm::find(operand_buffers, new_buffer) - operand_buffers.begin();
310       AffineMap old_indexing_map = generic_op.getIndexingMap(old_index);
311       AffineMap new_indexing_map = generic_op.getIndexingMap(new_index);
312       return old_indexing_map == new_indexing_map &&
313              old_indexing_map.isPermutation();
314     }
315     return false;
316   }
317 
318   DenseMap<Operation *, SmallVector<int32_t, 2>> reuse_candidates_;
319   DenseMap<Operation *, int32_t> output_indices_;
320 };
321 
322 #define GEN_PASS_CLASSES
323 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
324 
325 struct BufferReusePass : public BufferReusePassBase<BufferReusePass> {
runOnFunctionmlir::kernel_gen::transforms::__anon8364ced70111::BufferReusePass326   void runOnFunction() override {
327     if (!getFunction()->getAttrOfType<UnitAttr>(
328             tf_framework::TFFrameworkDialect::kTFEntryAttrName))
329       return;
330 
331     BufferReuseAnalysis analysis(getFunction());
332 
333     // Annotate IR with reuse candidates and output indices per allocation.
334     Builder builder(&getContext());
335     getFunction().walk([&](AllocOp op) {
336       if (auto output_index = analysis.get_output_index(op)) {
337         auto attr = builder.getI32IntegerAttr(*output_index);
338         op.getOperation()->setAttr(
339             tf_framework::TFAllocOp::kReuseOutputAttrName, attr);
340       }
341       if (auto reuse_candiates = analysis.get_reuse_candiates(op)) {
342         auto attr = builder.getI32ArrayAttr(*reuse_candiates);
343         op.getOperation()->setAttr(
344             tf_framework::TFAllocOp::kReuseInputCandidatesAttrName, attr);
345       }
346     });
347   }
348 };
349 
350 }  // namespace
351 
CreateBufferReusePass()352 std::unique_ptr<FunctionPass> CreateBufferReusePass() {
353   return std::make_unique<BufferReusePass>();
354 }
355 
356 }  // namespace transforms
357 }  // namespace kernel_gen
358 }  // namespace mlir
359