1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
18 #include "mlir-hlo/utils/cycle_detector.h"
19 #include "mlir/Dialect/Shape/IR/Shape.h" // TF:llvm-project
20 #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
21 #include "mlir/IR/MLIRContext.h" // TF:llvm-project
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/Pass/Pass.h" // TF:local_config_mlir
24 #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
25
26 // This pass has similar functionality of the fusion pass in XLA stack.
27 // However, unlike XLA, it targets the fully dynamic shape scenario.
28 // Currently, it implements the kLoop and kInput fusion templates.
29 // During conversion, it tries to greedily find kLoop/kInput fusion
30 // patterns.
31 //
32 // Similar to XLA, this pass supports fusion pattern having multiple outputs
33 // if all the shape of outputs are consistent. Following are some examples.
34 //
35 // kLoop kInput
36 // +----+ +----+ +----+ +----+ +----+ +----+
37 // |elem| |elem| |elem| |elem<----+elem+---->elem+----+
38 // +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
39 // | | | | | |
40 // | | | | |
41 // +-v--+ | +-v--+ +--v---+ +--v---+ |
42 // |elem+<---+----<+elem| |reduce| |reduce| |
43 // +-+--+ +-+--+ +--+---+ +--+---+ |
44 // | | | | |
45 // | | | | |
46 // v v v v v
47 //
48 // To this end, we also add an simple shape constraint analysis phase.
49 // For kLoop fusion template, it requires all the outputs of the fused
50 // pattern have the same shape. However, we don't know the actual value
51 // of the shape at the compile time in the dynamic shape world.
52 // Fortunately, we could still infer the relationship among different ops
53 // according to their shape constraint traits. Currently, We only consider
54 // shape equality propagation for elementwise ops (assuming that implicit
55 // shape broadcast is forbidden). The above process could be built on the
56 // shape dialect once it is ready.
57 //
58 // TODO(disc): This file implements fusion on buffer level, re-visit this after
59 // more shape inference/constraint infras are ready in mhlo level.
60 // TODO(disc): Not using fusibility interface a.t.m, re-visit this if necessary.
61
62 namespace mlir {
63 namespace lmhlo {
64 namespace {
65
66 struct FusionOptions {
67 // Maximum allowed number of arguments per fused kernel. Here arguments
68 // include both ready-only buffers and writable buffers.
69 int max_num_arguments_per_kernel;
70 };
71
72 // A fusion planner that can propose a fusion plan for a block of ops.
73 // The fusion plan is consisted of a group of fusion patterns.
74 //
75 // Currently all proposed patterns followed xla kLoop/kInput like fusion
76 // templates while are adapted to the fully dynamic shape world.
77 //
78 // kLoop fusion template satisfies:
79 // - all ops in the fusion pattern are element-wise.
80 // - all the shapes of outputs of fusion pattern are same or have same number
81 // of elements, and thus can fit into a same parallel loop.
82 //
83 // kInput fusion template satisfies:
84 // - any op in the fusion pattern is either element-wise or a reduction.
85 // - if a op is a reduction, its output cannot be consumed by other
86 // ops in the same fusion pattern.
87 // - all the effective shapes of outputs of fusion pattern are same.
88 // - For element-wise op, its effective shape is its output shape.
89 // - For reduction op, its effective shape is its operand shape.
90 // - currently our downstreaming codegen engine only support 2d -> 1d tensor
91 // reduction. TODO(disc): lift this limitation.
92 // - 2D row reduction: out[i] = sum({in[i][j] for all j})
93 // - 2D column reduction: out[j] = sum({in[i][j] for all i}
94 class FusionPlanner {
95 public:
FusionPlanner(const FusionOptions & options,Block * block)96 explicit FusionPlanner(const FusionOptions& options, Block* block)
97 : options_(options), block_(block) {
98 // Move up metadata-only ops (e.g. dim, shape_of) as far as possible.
99 MoveUpMetadataOnlyOpsForFusion();
100
101 for (Operation& op : *block) {
102 op_list_.push_back(&op);
103 }
104 shape_analysis_.reset(new ShapeConstraintAnalysis(op_list_));
105 cycle_detector_.reset(new GraphCycles(op_list_.size()));
106 BuildNodeMap();
107 }
108
109 // Returns a fusion plan if success, otherwise none.
Run()110 llvm::Optional<FusionPlan> Run() {
111 // Greedily search connected fusible pattern, and ops belonging to
112 // a same fusion pattern are grouped into a cluster.
113 RunEdgeContractionLoop();
114
115 // After doing edge contraction, each unique cluster having size
116 // more than one represents a potential fusion pattern.
117 // We collect all these clusters and construct a fusion plan.
118 FusionPlan plan;
119 DenseSet<Cluster*> seen_clusters;
120 for (Operation* op : op_list_) {
121 Cluster* cluster = GetClusterForNode(op);
122 if (!seen_clusters.insert(cluster).second) continue;
123 FusionPattern& fusion_pattern = cluster->fused_pattern();
124 // Make sure the ops in a fusion pattern are in topological ordering.
125 fusion_pattern.sortFusionOpListBy(op_to_node_id_);
126 if (!fusion_pattern.isFusible() || fusion_pattern.effectiveSize() <= 1) {
127 continue;
128 }
129 plan.emplace_back(fusion_pattern);
130 }
131
132 // Re-order ops inside the blocks to make sure all producers are placed
133 // before its consumers after fusion.
134 ReorderOperationsInsideBlock();
135 return plan;
136 }
137
138 // Returns the op_list this planner operates on.
op_list() const139 const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
140
141 private:
142 // Represent a (partial) fused pattern
143 class Cluster {
144 public:
Cluster(int node_id,FusionPlanner * planner)145 Cluster(int node_id, FusionPlanner* planner)
146 : node_id_(node_id), pattern_(planner->op_list()[node_id]) {}
147
148 // Merges `other` into this cluster, and clears `other`.
Merge(Cluster * other)149 void Merge(Cluster* other) {
150 pattern_.mergeInplace(other->fused_pattern());
151 }
152
153 // The number of nodes in this cluster.
cluster_size()154 int cluster_size() { return pattern_.size(); }
155
156 // The ID of the cluster as represented in `cycle_detector_`.
cycles_graph_node_id() const157 int cycles_graph_node_id() const { return node_id_; }
158
159 // Sets the ID of the cluster as represented in `cycle_detector_`.
set_cycles_graph_node_id(int cycles_graph_node_id)160 void set_cycles_graph_node_id(int cycles_graph_node_id) {
161 node_id_ = cycles_graph_node_id;
162 }
163
164 // Currently the fused pattern this cluster holds.
fused_pattern()165 FusionPattern& fused_pattern() { return pattern_; }
166
167 private:
168 // ID of the representative node of this cluster.
169 int node_id_;
170
171 // the fused pattern this cluster holds.
172 FusionPattern pattern_;
173 };
174
175 private:
176 // Returns a new cluster with specified `cycles_graph_node_id`
MakeCluster(int cycles_graph_node_id)177 Cluster* MakeCluster(int cycles_graph_node_id) {
178 cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
179 return cluster_storage_.back().get();
180 }
181
182 // Metadata ops (e.g. shapeOf, dimOp) don't change data thus we move forward
183 // them as far as possible inside the same block to enable more fusion
184 // opportunities.
MoveUpMetadataOnlyOpsForFusion()185 void MoveUpMetadataOnlyOpsForFusion() {
186 SmallVector<Operation*, 4> ops;
187 for (Operation& op : *block_) {
188 ops.push_back(&op);
189 }
190
191 auto inBlock = [&](Operation* op, Block* block) {
192 return op && op->getBlock() == block;
193 };
194
195 for (Operation* op : ops) {
196 Block* block = op->getBlock();
197 if (isa<shape::ShapeOfOp>(op)) {
198 Operation* definingOp = op->getOperand(0).getDefiningOp();
199 if (!inBlock(definingOp, block)) {
200 op->moveBefore(block, block->begin());
201 } else {
202 op->moveAfter(definingOp);
203 }
204 } else if (isa<memref::DimOp>(op)) {
205 Operation* firstOperandOp = op->getOperand(0).getDefiningOp();
206 Operation* secondOperandOp = op->getOperand(1).getDefiningOp();
207 if (!inBlock(firstOperandOp, block) &&
208 !inBlock(secondOperandOp, block)) {
209 op->moveBefore(block, block->begin());
210 } else if (!inBlock(firstOperandOp, block)) {
211 op->moveAfter(secondOperandOp);
212 } else if (!inBlock(secondOperandOp, block)) {
213 op->moveAfter(firstOperandOp);
214 } else if (firstOperandOp->isBeforeInBlock(secondOperandOp)) {
215 op->moveAfter(secondOperandOp);
216 } else {
217 op->moveAfter(firstOperandOp);
218 }
219 }
220 }
221 }
222
223 // Returns all the values touched by this op or its nested ops.
GetAllPossibleUsedValues(Operation * op)224 SmallVector<Value, 4> GetAllPossibleUsedValues(Operation* op) {
225 SmallVector<Value, 4> values;
226 op->walk([&](Operation* nest_op) {
227 for (Value v : nest_op->getOperands()) {
228 values.push_back(v);
229 }
230 });
231 return values;
232 }
233
234 // Builds the initial dependency graph.
BuildNodeMap()235 void BuildNodeMap() {
236 int num_nodes = op_list_.size();
237 for (int node_id = 0; node_id < num_nodes; ++node_id) {
238 Operation* op = op_list_[node_id];
239 MakeCluster(node_id);
240 op_to_node_id_[op] = node_id;
241 leader_for_node_.insert(node_id);
242 for (Value operand : GetAllPossibleUsedValues(op)) {
243 Operation* operand_op = FindLastWriter(operand);
244 // Only consider the operand_op inside the target block.
245 auto iter = op_to_node_id_.find(operand_op);
246 if (iter == op_to_node_id_.end()) {
247 continue;
248 }
249 // Add an edge to connect the last writer and the current consumer.
250 cycle_detector_->InsertEdge(iter->second, node_id);
251 }
252
253 // For some ops (e.g. lmhlo ops), some operands are the output memrefs
254 // Thus these operands are supposed to be updated.
255 // Suppose that a op (or its nested ops) can only write the buffers
256 // explicit passed in as operands of this op.
257 int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
258 for (Value v : op->getOperands().drop_front(num_input_operand)) {
259 auto it = last_writer_.try_emplace(v, op);
260 (void)it;
261 // Currently, a buffer is only supposed to be written once (as the
262 // output operand of one lmhlo op).
263 assert(it.second);
264 }
265 }
266 }
267
268 // Returns the cluster contains this op.
GetClusterForNode(Operation * n)269 Cluster* GetClusterForNode(Operation* n) {
270 int id = op_to_node_id_[n];
271 id = leader_for_node_.getLeaderValue(id);
272 return cluster_storage_[id].get();
273 }
274
275 // Returns the cluster contains the op having `node_id`.
GetClusterForCyclesGraphNode(int node_id)276 Cluster* GetClusterForCyclesGraphNode(int node_id) {
277 return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
278 }
279
280 // Merges the clusters `cluster_from` and `cluster_to`.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)281 bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
282 int from = cluster_from->cycles_graph_node_id();
283 int to = cluster_to->cycles_graph_node_id();
284
285 auto optional_merged_node = cycle_detector_->ContractEdge(from, to);
286 if (!optional_merged_node.hasValue()) {
287 llvm::dbgs() << "Could not contract " << from << " -> " << to
288 << " because contracting the edge would create a cycle.";
289 return false;
290 }
291
292 // Merge the clusters.
293 cluster_from->Merge(cluster_to);
294 cluster_from->set_cycles_graph_node_id(*optional_merged_node);
295
296 // Merge the UnionFind Set.
297 leader_for_node_.unionSets(from, to);
298 return true;
299 }
300
301 using FnTy = llvm::function_ref<bool(Cluster*, Cluster*)>;
ForEachEdgeInPostOrder(FnTy fn,bool enable_cross_fusion=false)302 bool ForEachEdgeInPostOrder(FnTy fn, bool enable_cross_fusion = false) {
303 bool changed = false;
304 for (int32_t node : cycle_detector_->AllNodesInPostOrder()) {
305 Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
306 // Make a copy of the set of successors because we may modify the graph in
307 // TryToContractEdge.
308 std::vector<int32_t> successors_copy =
309 cycle_detector_->SuccessorsCopy(cluster_from->cycles_graph_node_id());
310
311 for (int to : successors_copy) {
312 Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
313 bool contracted_edge = fn(cluster_from, cluster_to);
314 changed |= contracted_edge;
315 }
316 }
317
318 if (!enable_cross_fusion) return changed;
319
320 // To enable even more fusion opportunities (e.g. horizontal fusion)
321 for (int32_t lhs : cycle_detector_->AllNodesInPostOrder()) {
322 Cluster* cluster_lhs = GetClusterForCyclesGraphNode(lhs);
323 if (!cluster_lhs) {
324 continue;
325 }
326
327 for (int32_t rhs : cycle_detector_->AllNodesInPostOrder()) {
328 Cluster* cluster_rhs = GetClusterForCyclesGraphNode(rhs);
329 if (!cluster_rhs || cluster_lhs == cluster_rhs) {
330 continue;
331 }
332
333 bool contracted_edge = fn(cluster_lhs, cluster_rhs);
334 changed |= contracted_edge;
335 }
336 }
337
338 return changed;
339 }
340
341 // This function check if fusing `from` with `to` is valid and if so perform
342 // the merge. The validity is based on the operations in the clusters and
343 // the compatibility of the shapes of the outputs of the would-be fused
344 // clusters.
345 // Returns true is the merge was performed.
TryToContractEdge(Cluster * from,Cluster * to)346 bool TryToContractEdge(Cluster* from, Cluster* to) {
347 // Try merge and check if valid.
348 if (!from->fused_pattern().isMergeable(to->fused_pattern())) return false;
349 FusionPattern fused_pattern =
350 from->fused_pattern().merge(to->fused_pattern());
351 auto& op_list = fused_pattern.getOpList();
352 auto& operands = fused_pattern.getOperands();
353 auto& results = fused_pattern.getResults();
354
355 if (results.size() + operands.size() >
356 options_.max_num_arguments_per_kernel) {
357 // some backend devices (e.g. GPU) do not support a kernel with
358 // too many arguments.
359 return false;
360 }
361
362 // We currently do not support a constant op as final output of a fusion
363 // pattern.
364 // TODO(disc): copy small const in case necessary.
365 for (Value result : results) {
366 Operation* result_op = FindLastWriter(result);
367 assert(result_op);
368 if (isa<lmhlo::ConstOp>(result_op)) {
369 return false;
370 }
371 }
372
373 // ReduceOp can not have consumer within the fusion pattern.
374 for (Operation* op : op_list) {
375 if (!isa<lmhlo::ReduceOp>(op)) continue;
376 int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
377 for (Value v : op->getOperands().drop_front(num_input_operand)) {
378 for (Operation* user : getValueUsers(v)) {
379 if (user == op) continue;
380 if (std::find(op_list.begin(), op_list.end(), user) !=
381 op_list.end()) {
382 return false;
383 }
384 }
385 }
386 }
387
388 // All outputs of a fusion pattern should have compatible shape.
389 // Here `compatible` means:
390 // - if `to` and `from` are both kInput fusion, all output should have same
391 // shape.
392 // - otherwise, all output should have same number of elements.
393
394 // No outside users, these ops may be eliminated. We fused it here and let
395 // latter pass to do such DCE.
396 if (results.empty()) return true;
397
398 bool check_same_shape = (to->fused_pattern().isKInputFusion() &&
399 from->fused_pattern().isKInputFusion());
400 auto get_effective_shape = [&](Value v) {
401 auto result_op = FindLastWriter(v);
402 assert(result_op);
403 // effective shape of reduce op is its operand's shape.
404 return isa<lmhlo::ReduceOp>(result_op) ? result_op->getOperand(0) : v;
405 };
406
407 Value ref_shape = get_effective_shape(results[0]);
408 if (!llvm::all_of(results, [&](Value result) {
409 Value shape = get_effective_shape(result);
410 return check_same_shape
411 ? shape_analysis_->HasSameShape(ref_shape, shape)
412 : shape_analysis_->HasSameNumElements(ref_shape, shape);
413 })) {
414 return false;
415 }
416
417 return MergeClusters(from, to);
418 }
419
420 // Greedily fuse connected node.
RunEdgeContractionLoop()421 bool RunEdgeContractionLoop() {
422 using std::placeholders::_1;
423 using std::placeholders::_2;
424 bool changed = false;
425
426 // Run fusion pass repeatedly until nothing to be fused
427 while (ForEachEdgeInPostOrder(
428 std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2), false)) {
429 // empty statement by design
430 }
431 return changed;
432 }
433
434 // Here `value` is supported to be a pointer to buffer.
435 // Returns the defining op of `value `if no known op updates the buffer,
436 // otherwise returns the last op that updates the buffer pointed by the
437 // `value`.
FindLastWriter(Value value)438 Operation* FindLastWriter(Value value) {
439 auto it = last_writer_.find(value);
440 if (it != last_writer_.end()) {
441 return it->second;
442 }
443 return value.getDefiningOp();
444 }
445
446 // Re-order ops inside the block to make sure that producers are before
447 // consumers after fusion.
ReorderOperationsInsideBlock()448 void ReorderOperationsInsideBlock() {
449 auto reorder_func = [&](Cluster* from, Cluster* to) {
450 FusionPattern& from_pattern = from->fused_pattern();
451 FusionPattern& to_pattern = to->fused_pattern();
452
453 Operation* last_op_in_from = from_pattern.getOpList().back();
454 for (Operation* op : llvm::reverse(to_pattern.getOpList())) {
455 if (!last_op_in_from->isBeforeInBlock(op))
456 op->moveAfter(last_op_in_from);
457 }
458 return false;
459 };
460
461 ForEachEdgeInPostOrder(reorder_func);
462 }
463
464 // hyper-parameters that controls the behaviour of the fusion planner.
465 FusionOptions options_;
466
467 // The block that fusion planner works on.
468 Block* block_;
469
470 // Ops inside the block
471 SmallVector<Operation*, 4> op_list_;
472
473 // Shape equality checker
474 std::unique_ptr<ShapeConstraintAnalysis> shape_analysis_;
475
476 // op -> node_id
477 DenseMap<Operation*, int> op_to_node_id_;
478
479 // make sure not introduce cycle after fusion
480 std::unique_ptr<GraphCycles> cycle_detector_;
481 std::vector<std::unique_ptr<Cluster>> cluster_storage_;
482
483 // a UnionFind set. Each set represents a (partial) fused pattern
484 // and has a leader as representation.
485 EquivalenceClasses<int32_t> leader_for_node_;
486
487 // Here `value` is supported to be a pointer to buffer.
488 // Returns the defining op of `value `if no known op updates the buffer,
489 // otherwise returns the last op that updates the buffer pointed by the
490 // `value`.
491 DenseMap<Value, Operation*> last_writer_;
492 };
493
494 struct LhloFusionPass : public LhloFusionPassBase<LhloFusionPass> {
495 using LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase;
LhloFusionPassmlir::lmhlo::__anonb89c297e0111::LhloFusionPass496 explicit LhloFusionPass(int max_num_arguments_per_kernel)
497 : LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase() {
498 this->max_num_arguments_per_kernel_ = max_num_arguments_per_kernel;
499 }
500
runOnFunctionmlir::lmhlo::__anonb89c297e0111::LhloFusionPass501 void runOnFunction() override {
502 FuncOp func = getFunction();
503
504 // collect all blocks inside the function.
505 SmallVector<Block*, 4> blocks;
506 CollectBlocksInsideFunction(func, blocks);
507
508 // process each block and do fusion within a block.
509 FusionOptions options;
510 options.max_num_arguments_per_kernel = max_num_arguments_per_kernel_;
511 for (Block* block : blocks) {
512 FusionPlanner planner(options, block);
513 llvm::Optional<FusionPlan> plan = planner.Run();
514 if (!plan) {
515 emitError(func.getLoc(),
516 "an error occurs while trying to find fusion candidates");
517 signalPassFailure();
518 return;
519 }
520 if (!ApplyFusionPlan(*plan)) {
521 emitError(func.getLoc(), "apply fusion plan failed");
522 signalPassFailure();
523 return;
524 }
525 }
526 }
527
ApplyFusionPlanmlir::lmhlo::__anonb89c297e0111::LhloFusionPass528 bool ApplyFusionPlan(FusionPlan& plan) {
529 for (FusionPattern& pattern : plan) {
530 auto& op_list = pattern.getOpList();
531 OpBuilder b(op_list.back());
532
533 // Get the fused locations
534 SmallVector<Location, 4> locations;
535 locations.reserve(op_list.size());
536 for (Operation* op : op_list) {
537 locations.push_back(op->getLoc());
538 }
539 Location fused_loc =
540 FusedLoc::get(op_list.back()->getContext(), locations);
541
542 // Move ops inside fusion pattern to the region attached to the fusion op.
543 FusionOp fusion = b.create<lmhlo::FusionOp>(fused_loc);
544 Region& region = fusion.region();
545 Block& block = region.front();
546 for (Operation* op : llvm::reverse(op_list)) {
547 op->moveBefore(&block, block.begin());
548 }
549 }
550 return true;
551 }
552
CollectBlocksInsideFunctionmlir::lmhlo::__anonb89c297e0111::LhloFusionPass553 void CollectBlocksInsideFunction(FuncOp op, SmallVectorImpl<Block*>& blocks) {
554 op.walk([&](Block* block) {
555 // It does not make sense to fuse the region attached to these ops.
556 if (!isa<lmhlo::ReduceOp, lmhlo::FusionOp>(block->getParentOp()))
557 blocks.push_back(block);
558 });
559 }
560 };
561
562 } // namespace
563
createLhloFusionPass(int max_num_arguments_per_kernel)564 std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass(
565 int max_num_arguments_per_kernel) {
566 return std::make_unique<LhloFusionPass>(max_num_arguments_per_kernel);
567 }
568
569 } // namespace lmhlo
570 } // namespace mlir
571