• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
18 #include "mlir-hlo/utils/cycle_detector.h"
19 #include "mlir/Dialect/Shape/IR/Shape.h"      // TF:llvm-project
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
21 #include "mlir/IR/MLIRContext.h"              // TF:llvm-project
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/Pass/Pass.h"               // TF:local_config_mlir
24 #include "mlir/Transforms/RegionUtils.h"  // TF:llvm-project
25 
26 // This pass has similar functionality of the fusion pass in XLA stack.
27 // However, unlike XLA, it targets the fully dynamic shape scenario.
28 // Currently, it implements the kLoop and kInput fusion templates.
29 // During conversion, it tries to greedily find kLoop/kInput fusion
30 // patterns.
31 //
32 // Similar to XLA, this pass supports fusion pattern having multiple outputs
33 // if all the shape of outputs are consistent. Following are some examples.
34 //
35 //        kLoop                          kInput
36 // +----+  +----+  +----+    +----+    +----+    +----+
37 // |elem|  |elem|  |elem|    |elem<----+elem+---->elem+----+
38 // +-+--+  +-+--+  +-+--+    +-+--+    +----+    +-+--+    |
39 //   |       |       |         |                   |       |
40 //   |               |         |                   |       |
41 // +-v--+    |     +-v--+   +--v---+            +--v---+   |
42 // |elem+<---+----<+elem|   |reduce|            |reduce|   |
43 // +-+--+          +-+--+   +--+---+            +--+---+   |
44 //   |               |         |                   |       |
45 //   |               |         |                   |       |
46 //   v               v         v                   v       v
47 //
48 // To this end, we also add an simple shape constraint analysis phase.
49 // For kLoop fusion template, it requires all the outputs of the fused
50 // pattern have the same shape. However, we don't know the actual value
51 // of the shape at the compile time in the dynamic shape world.
52 // Fortunately, we could still infer the relationship among different ops
53 // according to their shape constraint traits. Currently, We only consider
54 // shape equality propagation for elementwise ops (assuming that implicit
55 // shape broadcast is forbidden). The above process could be built on the
56 // shape dialect once it is ready.
57 //
58 // TODO(disc): This file implements fusion on buffer level, re-visit this after
59 // more shape inference/constraint infras are ready in mhlo level.
60 // TODO(disc): Not using fusibility interface a.t.m, re-visit this if necessary.
61 
62 namespace mlir {
63 namespace lmhlo {
64 namespace {
65 
66 struct FusionOptions {
67   // Maximum allowed number of arguments per fused kernel. Here arguments
68   // include both ready-only buffers and writable buffers.
69   int max_num_arguments_per_kernel;
70 };
71 
72 // A fusion planner that can propose a fusion plan for a block of ops.
73 // The fusion plan is consisted of a group of fusion patterns.
74 //
75 // Currently all proposed patterns followed xla kLoop/kInput like fusion
76 // templates while are adapted to the fully dynamic shape world.
77 //
78 // kLoop fusion template satisfies:
79 //   - all ops in the fusion pattern are element-wise.
80 //   - all the shapes of outputs of fusion pattern are same or have same number
81 //   of elements, and thus can fit into a same parallel loop.
82 //
83 // kInput fusion template satisfies:
84 //   - any op in the fusion pattern is either element-wise or a reduction.
85 //   - if a op is a reduction, its output cannot be consumed by other
86 //     ops in the same fusion pattern.
87 //   - all the effective shapes of outputs of fusion pattern are same.
88 //     - For element-wise op, its effective shape is its output shape.
89 //     - For reduction op, its effective shape is its operand shape.
90 //   - currently our downstreaming codegen engine only support 2d -> 1d tensor
91 //   reduction. TODO(disc): lift this limitation.
92 //     - 2D row reduction: out[i] = sum({in[i][j] for all j})
93 //     - 2D column reduction: out[j] = sum({in[i][j] for all i}
94 class FusionPlanner {
95  public:
FusionPlanner(const FusionOptions & options,Block * block)96   explicit FusionPlanner(const FusionOptions& options, Block* block)
97       : options_(options), block_(block) {
98     // Move up metadata-only ops (e.g. dim, shape_of) as far as possible.
99     MoveUpMetadataOnlyOpsForFusion();
100 
101     for (Operation& op : *block) {
102       op_list_.push_back(&op);
103     }
104     shape_analysis_.reset(new ShapeConstraintAnalysis(op_list_));
105     cycle_detector_.reset(new GraphCycles(op_list_.size()));
106     BuildNodeMap();
107   }
108 
109   // Returns a fusion plan if success, otherwise none.
Run()110   llvm::Optional<FusionPlan> Run() {
111     // Greedily search connected fusible pattern, and ops belonging to
112     // a same fusion pattern are grouped into a cluster.
113     RunEdgeContractionLoop();
114 
115     // After doing edge contraction, each unique cluster having size
116     // more than one represents a potential fusion pattern.
117     // We collect all these clusters and construct a fusion plan.
118     FusionPlan plan;
119     DenseSet<Cluster*> seen_clusters;
120     for (Operation* op : op_list_) {
121       Cluster* cluster = GetClusterForNode(op);
122       if (!seen_clusters.insert(cluster).second) continue;
123       FusionPattern& fusion_pattern = cluster->fused_pattern();
124       // Make sure the ops in a fusion pattern are in topological ordering.
125       fusion_pattern.sortFusionOpListBy(op_to_node_id_);
126       if (!fusion_pattern.isFusible() || fusion_pattern.effectiveSize() <= 1) {
127         continue;
128       }
129       plan.emplace_back(fusion_pattern);
130     }
131 
132     // Re-order ops inside the blocks to make sure all producers are placed
133     // before its consumers after fusion.
134     ReorderOperationsInsideBlock();
135     return plan;
136   }
137 
138   // Returns the op_list this planner operates on.
op_list() const139   const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
140 
141  private:
142   // Represent a (partial) fused pattern
143   class Cluster {
144    public:
Cluster(int node_id,FusionPlanner * planner)145     Cluster(int node_id, FusionPlanner* planner)
146         : node_id_(node_id), pattern_(planner->op_list()[node_id]) {}
147 
148     // Merges `other` into this cluster, and clears `other`.
Merge(Cluster * other)149     void Merge(Cluster* other) {
150       pattern_.mergeInplace(other->fused_pattern());
151     }
152 
153     // The number of nodes in this cluster.
cluster_size()154     int cluster_size() { return pattern_.size(); }
155 
156     // The ID of the cluster as represented in `cycle_detector_`.
cycles_graph_node_id() const157     int cycles_graph_node_id() const { return node_id_; }
158 
159     // Sets the ID of the cluster as represented in `cycle_detector_`.
set_cycles_graph_node_id(int cycles_graph_node_id)160     void set_cycles_graph_node_id(int cycles_graph_node_id) {
161       node_id_ = cycles_graph_node_id;
162     }
163 
164     // Currently the fused pattern this cluster holds.
fused_pattern()165     FusionPattern& fused_pattern() { return pattern_; }
166 
167    private:
168     // ID of the representative node of this cluster.
169     int node_id_;
170 
171     // the fused pattern this cluster holds.
172     FusionPattern pattern_;
173   };
174 
175  private:
176   // Returns a new cluster with specified `cycles_graph_node_id`
MakeCluster(int cycles_graph_node_id)177   Cluster* MakeCluster(int cycles_graph_node_id) {
178     cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
179     return cluster_storage_.back().get();
180   }
181 
182   // Metadata ops (e.g. shapeOf, dimOp) don't change data thus we move forward
183   // them as far as possible inside the same block to enable more fusion
184   // opportunities.
MoveUpMetadataOnlyOpsForFusion()185   void MoveUpMetadataOnlyOpsForFusion() {
186     SmallVector<Operation*, 4> ops;
187     for (Operation& op : *block_) {
188       ops.push_back(&op);
189     }
190 
191     auto inBlock = [&](Operation* op, Block* block) {
192       return op && op->getBlock() == block;
193     };
194 
195     for (Operation* op : ops) {
196       Block* block = op->getBlock();
197       if (isa<shape::ShapeOfOp>(op)) {
198         Operation* definingOp = op->getOperand(0).getDefiningOp();
199         if (!inBlock(definingOp, block)) {
200           op->moveBefore(block, block->begin());
201         } else {
202           op->moveAfter(definingOp);
203         }
204       } else if (isa<memref::DimOp>(op)) {
205         Operation* firstOperandOp = op->getOperand(0).getDefiningOp();
206         Operation* secondOperandOp = op->getOperand(1).getDefiningOp();
207         if (!inBlock(firstOperandOp, block) &&
208             !inBlock(secondOperandOp, block)) {
209           op->moveBefore(block, block->begin());
210         } else if (!inBlock(firstOperandOp, block)) {
211           op->moveAfter(secondOperandOp);
212         } else if (!inBlock(secondOperandOp, block)) {
213           op->moveAfter(firstOperandOp);
214         } else if (firstOperandOp->isBeforeInBlock(secondOperandOp)) {
215           op->moveAfter(secondOperandOp);
216         } else {
217           op->moveAfter(firstOperandOp);
218         }
219       }
220     }
221   }
222 
223   // Returns all the values touched by this op or its nested ops.
GetAllPossibleUsedValues(Operation * op)224   SmallVector<Value, 4> GetAllPossibleUsedValues(Operation* op) {
225     SmallVector<Value, 4> values;
226     op->walk([&](Operation* nest_op) {
227       for (Value v : nest_op->getOperands()) {
228         values.push_back(v);
229       }
230     });
231     return values;
232   }
233 
234   // Builds the initial dependency graph.
BuildNodeMap()235   void BuildNodeMap() {
236     int num_nodes = op_list_.size();
237     for (int node_id = 0; node_id < num_nodes; ++node_id) {
238       Operation* op = op_list_[node_id];
239       MakeCluster(node_id);
240       op_to_node_id_[op] = node_id;
241       leader_for_node_.insert(node_id);
242       for (Value operand : GetAllPossibleUsedValues(op)) {
243         Operation* operand_op = FindLastWriter(operand);
244         // Only consider the operand_op inside the target block.
245         auto iter = op_to_node_id_.find(operand_op);
246         if (iter == op_to_node_id_.end()) {
247           continue;
248         }
249         // Add an edge to connect the last writer and the current consumer.
250         cycle_detector_->InsertEdge(iter->second, node_id);
251       }
252 
253       // For some ops (e.g. lmhlo ops), some operands are the output memrefs
254       // Thus these operands are supposed to be updated.
255       // Suppose that a op (or its nested ops) can only write the buffers
256       // explicit passed in as operands of this op.
257       int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
258       for (Value v : op->getOperands().drop_front(num_input_operand)) {
259         auto it = last_writer_.try_emplace(v, op);
260         (void)it;
261         // Currently, a buffer is only supposed to be written once (as the
262         // output operand of one lmhlo op).
263         assert(it.second);
264       }
265     }
266   }
267 
268   // Returns the cluster contains this op.
GetClusterForNode(Operation * n)269   Cluster* GetClusterForNode(Operation* n) {
270     int id = op_to_node_id_[n];
271     id = leader_for_node_.getLeaderValue(id);
272     return cluster_storage_[id].get();
273   }
274 
275   // Returns the cluster contains the op having `node_id`.
GetClusterForCyclesGraphNode(int node_id)276   Cluster* GetClusterForCyclesGraphNode(int node_id) {
277     return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
278   }
279 
280   // Merges the clusters `cluster_from` and `cluster_to`.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)281   bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
282     int from = cluster_from->cycles_graph_node_id();
283     int to = cluster_to->cycles_graph_node_id();
284 
285     auto optional_merged_node = cycle_detector_->ContractEdge(from, to);
286     if (!optional_merged_node.hasValue()) {
287       llvm::dbgs() << "Could not contract " << from << " -> " << to
288                    << " because contracting the edge would create a cycle.";
289       return false;
290     }
291 
292     // Merge the clusters.
293     cluster_from->Merge(cluster_to);
294     cluster_from->set_cycles_graph_node_id(*optional_merged_node);
295 
296     // Merge the UnionFind Set.
297     leader_for_node_.unionSets(from, to);
298     return true;
299   }
300 
301   using FnTy = llvm::function_ref<bool(Cluster*, Cluster*)>;
ForEachEdgeInPostOrder(FnTy fn,bool enable_cross_fusion=false)302   bool ForEachEdgeInPostOrder(FnTy fn, bool enable_cross_fusion = false) {
303     bool changed = false;
304     for (int32_t node : cycle_detector_->AllNodesInPostOrder()) {
305       Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
306       // Make a copy of the set of successors because we may modify the graph in
307       // TryToContractEdge.
308       std::vector<int32_t> successors_copy =
309           cycle_detector_->SuccessorsCopy(cluster_from->cycles_graph_node_id());
310 
311       for (int to : successors_copy) {
312         Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
313         bool contracted_edge = fn(cluster_from, cluster_to);
314         changed |= contracted_edge;
315       }
316     }
317 
318     if (!enable_cross_fusion) return changed;
319 
320     // To enable even more fusion opportunities (e.g. horizontal fusion)
321     for (int32_t lhs : cycle_detector_->AllNodesInPostOrder()) {
322       Cluster* cluster_lhs = GetClusterForCyclesGraphNode(lhs);
323       if (!cluster_lhs) {
324         continue;
325       }
326 
327       for (int32_t rhs : cycle_detector_->AllNodesInPostOrder()) {
328         Cluster* cluster_rhs = GetClusterForCyclesGraphNode(rhs);
329         if (!cluster_rhs || cluster_lhs == cluster_rhs) {
330           continue;
331         }
332 
333         bool contracted_edge = fn(cluster_lhs, cluster_rhs);
334         changed |= contracted_edge;
335       }
336     }
337 
338     return changed;
339   }
340 
341   // This function check if fusing `from` with `to` is valid and if so perform
342   // the merge. The validity is based on the operations in the clusters and
343   // the compatibility of the shapes of the outputs of the would-be fused
344   // clusters.
345   // Returns true is the merge was performed.
TryToContractEdge(Cluster * from,Cluster * to)346   bool TryToContractEdge(Cluster* from, Cluster* to) {
347     // Try merge and check if valid.
348     if (!from->fused_pattern().isMergeable(to->fused_pattern())) return false;
349     FusionPattern fused_pattern =
350         from->fused_pattern().merge(to->fused_pattern());
351     auto& op_list = fused_pattern.getOpList();
352     auto& operands = fused_pattern.getOperands();
353     auto& results = fused_pattern.getResults();
354 
355     if (results.size() + operands.size() >
356         options_.max_num_arguments_per_kernel) {
357       // some backend devices (e.g. GPU) do not support a kernel with
358       // too many arguments.
359       return false;
360     }
361 
362     // We currently do not support a constant op as final output of a fusion
363     // pattern.
364     // TODO(disc): copy small const in case necessary.
365     for (Value result : results) {
366       Operation* result_op = FindLastWriter(result);
367       assert(result_op);
368       if (isa<lmhlo::ConstOp>(result_op)) {
369         return false;
370       }
371     }
372 
373     // ReduceOp can not have consumer within the fusion pattern.
374     for (Operation* op : op_list) {
375       if (!isa<lmhlo::ReduceOp>(op)) continue;
376       int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
377       for (Value v : op->getOperands().drop_front(num_input_operand)) {
378         for (Operation* user : getValueUsers(v)) {
379           if (user == op) continue;
380           if (std::find(op_list.begin(), op_list.end(), user) !=
381               op_list.end()) {
382             return false;
383           }
384         }
385       }
386     }
387 
388     // All outputs of a fusion pattern should have compatible shape.
389     // Here `compatible` means:
390     // - if `to` and `from` are both kInput fusion, all output should have same
391     // shape.
392     // - otherwise, all output should have same number of elements.
393 
394     // No outside users, these ops may be eliminated. We fused it here and let
395     // latter pass to do such DCE.
396     if (results.empty()) return true;
397 
398     bool check_same_shape = (to->fused_pattern().isKInputFusion() &&
399                              from->fused_pattern().isKInputFusion());
400     auto get_effective_shape = [&](Value v) {
401       auto result_op = FindLastWriter(v);
402       assert(result_op);
403       // effective shape of reduce op is its operand's shape.
404       return isa<lmhlo::ReduceOp>(result_op) ? result_op->getOperand(0) : v;
405     };
406 
407     Value ref_shape = get_effective_shape(results[0]);
408     if (!llvm::all_of(results, [&](Value result) {
409           Value shape = get_effective_shape(result);
410           return check_same_shape
411                      ? shape_analysis_->HasSameShape(ref_shape, shape)
412                      : shape_analysis_->HasSameNumElements(ref_shape, shape);
413         })) {
414       return false;
415     }
416 
417     return MergeClusters(from, to);
418   }
419 
420   // Greedily fuse connected node.
RunEdgeContractionLoop()421   bool RunEdgeContractionLoop() {
422     using std::placeholders::_1;
423     using std::placeholders::_2;
424     bool changed = false;
425 
426     // Run fusion pass repeatedly until nothing to be fused
427     while (ForEachEdgeInPostOrder(
428         std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2), false)) {
429       // empty statement by design
430     }
431     return changed;
432   }
433 
434   // Here `value` is supported to be a pointer to buffer.
435   // Returns the defining op of `value `if no known op updates the buffer,
436   // otherwise returns the last op that updates the buffer pointed by the
437   // `value`.
FindLastWriter(Value value)438   Operation* FindLastWriter(Value value) {
439     auto it = last_writer_.find(value);
440     if (it != last_writer_.end()) {
441       return it->second;
442     }
443     return value.getDefiningOp();
444   }
445 
446   // Re-order ops inside the block to make sure that producers are before
447   // consumers after fusion.
ReorderOperationsInsideBlock()448   void ReorderOperationsInsideBlock() {
449     auto reorder_func = [&](Cluster* from, Cluster* to) {
450       FusionPattern& from_pattern = from->fused_pattern();
451       FusionPattern& to_pattern = to->fused_pattern();
452 
453       Operation* last_op_in_from = from_pattern.getOpList().back();
454       for (Operation* op : llvm::reverse(to_pattern.getOpList())) {
455         if (!last_op_in_from->isBeforeInBlock(op))
456           op->moveAfter(last_op_in_from);
457       }
458       return false;
459     };
460 
461     ForEachEdgeInPostOrder(reorder_func);
462   }
463 
464   // hyper-parameters that controls the behaviour of the fusion planner.
465   FusionOptions options_;
466 
467   // The block that fusion planner works on.
468   Block* block_;
469 
470   // Ops inside the block
471   SmallVector<Operation*, 4> op_list_;
472 
473   // Shape equality checker
474   std::unique_ptr<ShapeConstraintAnalysis> shape_analysis_;
475 
476   // op -> node_id
477   DenseMap<Operation*, int> op_to_node_id_;
478 
479   // make sure not introduce cycle after fusion
480   std::unique_ptr<GraphCycles> cycle_detector_;
481   std::vector<std::unique_ptr<Cluster>> cluster_storage_;
482 
483   // a UnionFind set. Each set represents a (partial) fused pattern
484   // and has a leader as representation.
485   EquivalenceClasses<int32_t> leader_for_node_;
486 
487   // Here `value` is supported to be a pointer to buffer.
488   // Returns the defining op of `value `if no known op updates the buffer,
489   // otherwise returns the last op that updates the buffer pointed by the
490   // `value`.
491   DenseMap<Value, Operation*> last_writer_;
492 };
493 
494 struct LhloFusionPass : public LhloFusionPassBase<LhloFusionPass> {
495   using LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase;
LhloFusionPassmlir::lmhlo::__anonb89c297e0111::LhloFusionPass496   explicit LhloFusionPass(int max_num_arguments_per_kernel)
497       : LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase() {
498     this->max_num_arguments_per_kernel_ = max_num_arguments_per_kernel;
499   }
500 
runOnFunctionmlir::lmhlo::__anonb89c297e0111::LhloFusionPass501   void runOnFunction() override {
502     FuncOp func = getFunction();
503 
504     // collect all blocks inside the function.
505     SmallVector<Block*, 4> blocks;
506     CollectBlocksInsideFunction(func, blocks);
507 
508     // process each block and do fusion within a block.
509     FusionOptions options;
510     options.max_num_arguments_per_kernel = max_num_arguments_per_kernel_;
511     for (Block* block : blocks) {
512       FusionPlanner planner(options, block);
513       llvm::Optional<FusionPlan> plan = planner.Run();
514       if (!plan) {
515         emitError(func.getLoc(),
516                   "an error occurs while trying to find fusion candidates");
517         signalPassFailure();
518         return;
519       }
520       if (!ApplyFusionPlan(*plan)) {
521         emitError(func.getLoc(), "apply fusion plan failed");
522         signalPassFailure();
523         return;
524       }
525     }
526   }
527 
ApplyFusionPlanmlir::lmhlo::__anonb89c297e0111::LhloFusionPass528   bool ApplyFusionPlan(FusionPlan& plan) {
529     for (FusionPattern& pattern : plan) {
530       auto& op_list = pattern.getOpList();
531       OpBuilder b(op_list.back());
532 
533       // Get the fused locations
534       SmallVector<Location, 4> locations;
535       locations.reserve(op_list.size());
536       for (Operation* op : op_list) {
537         locations.push_back(op->getLoc());
538       }
539       Location fused_loc =
540           FusedLoc::get(op_list.back()->getContext(), locations);
541 
542       // Move ops inside fusion pattern to the region attached to the fusion op.
543       FusionOp fusion = b.create<lmhlo::FusionOp>(fused_loc);
544       Region& region = fusion.region();
545       Block& block = region.front();
546       for (Operation* op : llvm::reverse(op_list)) {
547         op->moveBefore(&block, block.begin());
548       }
549     }
550     return true;
551   }
552 
CollectBlocksInsideFunctionmlir::lmhlo::__anonb89c297e0111::LhloFusionPass553   void CollectBlocksInsideFunction(FuncOp op, SmallVectorImpl<Block*>& blocks) {
554     op.walk([&](Block* block) {
555       // It does not make sense to fuse the region attached to these ops.
556       if (!isa<lmhlo::ReduceOp, lmhlo::FusionOp>(block->getParentOp()))
557         blocks.push_back(block);
558     });
559   }
560 };
561 
562 }  // namespace
563 
createLhloFusionPass(int max_num_arguments_per_kernel)564 std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass(
565     int max_num_arguments_per_kernel) {
566   return std::make_unique<LhloFusionPass>(max_num_arguments_per_kernel);
567 }
568 
569 }  // namespace lmhlo
570 }  // namespace mlir
571