• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "llvm/ADT/EquivalenceClasses.h"
22 #include "llvm/Support/Debug.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/utils/cycle_detector.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
27 #include "mlir/IR/MLIRContext.h"              // TF:llvm-project
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/Pass/Pass.h"               // TF:local_config_mlir
30 #include "mlir/Transforms/RegionUtils.h"  // TF:llvm-project
31 
32 // This pass has similar functionality of the fusion pass in XLA stack.
33 // However, unlike XLA, it targets the fully dynamic shape scenario.
34 // Currently, it implements the kLoop and kInput fusion templates.
35 // During conversion, it tries to greedily find kLoop/kInput fusion
36 // patterns.
37 //
38 // Similar to XLA, this pass supports fusion pattern having multiple outputs
39 // if all the shape of outputs are consistent. Following are some examples.
40 //
41 //        kLoop                          kInput
42 // +----+  +----+  +----+    +----+    +----+    +----+
43 // |elem|  |elem|  |elem|    |elem<----+elem+---->elem+----+
44 // +-+--+  +-+--+  +-+--+    +-+--+    +----+    +-+--+    |
45 //   |       |       |         |                   |       |
46 //   |               |         |                   |       |
47 // +-v--+    |     +-v--+   +--v---+            +--v---+   |
48 // |elem+<---+----<+elem|   |reduce|            |reduce|   |
49 // +-+--+          +-+--+   +--+---+            +--+---+   |
50 //   |               |         |                   |       |
51 //   |               |         |                   |       |
52 //   v               v         v                   v       v
53 //
54 // To this end, we also add an simple shape constraint analysis phase.
55 // For kLoop fusion template, it requires all the outputs of the fused
56 // pattern have the same shape. However, we don't know the actual value
57 // of the shape at the compile time in the dynamic shape world.
58 // Fortunately, we could still infer the relationship among different ops
59 // according to their shape constrain traits. Currently, We only consider
60 // shape equality propagation for elementwise ops (assuming that implicit
61 // shape broadcast is forbidden). The above process could be built on the
62 // shape dialect once it is ready.
63 
64 namespace mlir {
65 namespace mhlo {
66 namespace {
67 
68 using llvm::EquivalenceClasses;
69 using FusionPattern = std::vector<Operation*>;
70 using FusionPlan = std::vector<FusionPattern>;
71 
72 // To support using EquivalenceClasses for Value
73 class ValueWrapper {
74  public:
ValueWrapper(Value value)75   explicit ValueWrapper(Value value) : value_(std::move(value)) {}
76 
getValue() const77   Value getValue() const { return value_; }
78 
operator ==(const ValueWrapper & rhs) const79   bool operator==(const ValueWrapper& rhs) const {
80     return getValue() == rhs.getValue();
81   }
82 
83  private:
84   Value value_;
85 };
86 
operator <(const ValueWrapper & lhs,const ValueWrapper & rhs)87 bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
88   auto lhs_value = lhs.getValue().getAsOpaquePointer();
89   auto rhs_value = rhs.getValue().getAsOpaquePointer();
90   return lhs_value < rhs_value;
91 }
92 
IsFusible(Operation * op)93 bool IsFusible(Operation* op) {
94   if (matchPattern(op, m_Constant())) {
95     return true;
96   }
97   auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
98   return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
99                            op_fusibility.isFusibleWithConsumer());
100 }
101 
GetInputsOfFusionPattern(const FusionPattern & pattern)102 SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
103   SmallVector<Value, 4> inputs;
104   DenseSet<Value> input_set;
105   DenseSet<Operation*> op_set;
106   for (Operation* op : pattern) {
107     bool inserted = op_set.insert(op).second;
108     (void)inserted;
109     assert(inserted && "FusionPattern contains duplicate operations");
110   }
111 
112   for (Operation* op : pattern) {
113     for (Value operand : op->getOperands()) {
114       Operation* operand_op = operand.getDefiningOp();
115       if (op_set.find(operand_op) != op_set.end()) {
116         // skip if defining op is in the pattern
117         continue;
118       }
119       if (input_set.insert(operand).second) {
120         inputs.push_back(operand);
121       }
122     }
123   }
124   return inputs;
125 }
126 
GetOutputsOfFusionPattern(const FusionPattern & pattern)127 SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
128   SmallVector<Value, 4> outputs;
129   DenseSet<Operation*> op_set;
130   for (Operation* op : pattern) {
131     bool inserted = op_set.insert(op).second;
132     (void)inserted;
133     assert(inserted && "FusionPattern contains duplicate operations");
134   }
135 
136   for (Operation* op : pattern) {
137     for (Value result : op->getResults()) {
138       bool has_external_user = llvm::any_of(
139           result.getUses(),
140           [&](OpOperand& use) { return !op_set.count(use.getOwner()); });
141       if (has_external_user) {
142         outputs.push_back(result);
143       }
144     }
145   }
146   return outputs;
147 }
148 
MergeFusionPattern(const FusionPattern & lhs,const FusionPattern & rhs)149 FusionPattern MergeFusionPattern(const FusionPattern& lhs,
150                                  const FusionPattern& rhs) {
151   FusionPattern pattern(lhs);
152   pattern.insert(pattern.end(), rhs.begin(), rhs.end());
153   return pattern;
154 }
155 
EffectiveSize(const FusionPattern & pattern)156 inline int EffectiveSize(const FusionPattern& pattern) {
157   return llvm::count_if(
158       pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
159 }
160 
161 // This is an simple shape constraint analysis, which is used to
162 // guide fusion decision (e.g. we only fuse shape-compatible ops).
163 //
164 // Currently, We only consider shape equality propagation based
165 // on the shape constrain traits of elementwise ops (assuming that
166 // implicit shape broadcast is forbidden).
167 class ShapeConstraintAnalysis {
168  public:
ShapeConstraintAnalysis(const SmallVectorImpl<Operation * > & op_list)169   explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
170     PropagateEquality(op_list);
171   }
172 
173   // Returns true is `lhs` and `rhs` are supposed to have same shape.
HasSameShape(Value lhs,Value rhs)174   bool HasSameShape(Value lhs, Value rhs) {
175     return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
176   }
177 
178  private:
179   // shape equality propagation based on the shape constrains of
180   // elementwise ops.
PropagateEquality(const SmallVectorImpl<Operation * > & op_list)181   void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
182     bool converged = true;
183     do {
184       converged = true;
185       auto update = [&](Value lhs, Value rhs) {
186         if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
187           converged = false;
188           impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
189         }
190       };
191       for (Operation* op : op_list) {
192         auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
193         if (!op_fusibility) continue;
194         int numInput = op->getNumOperands();
195         int numOutput = op->getNumResults();
196         // shape equality propagation between inputs.
197         for (int input1 = 0; input1 < numInput; ++input1)
198           for (int input2 = input1 + 1; input2 < numInput; ++input2)
199             if (op_fusibility.inferInputsShapeEquality(input1, input2))
200               update(op->getOperand(input1), op->getOperand(input2));
201 
202         // shape equality propagation between outputs.
203         for (int output1 = 0; output1 < numOutput; ++output1)
204           for (int output2 = output1 + 1; output2 < numOutput; ++output2)
205             if (op_fusibility.inferOutputsShapeEquality(output1, output2))
206               update(op->getResult(output1), op->getResult(output2));
207 
208         // shape equality propagation between input and output.
209         for (int input = 0; input < numInput; ++input)
210           for (int output = 0; output < numOutput; ++output)
211             if (op_fusibility.inferInputOutputShapeEquality(input, output))
212               update(op->getOperand(input), op->getResult(output));
213       }
214     } while (!converged);
215   }
216 
217   // a UnionFind set
218   EquivalenceClasses<ValueWrapper> impl_;
219 };
220 
221 // A fusion planner that can propose a fusion plan for a block of ops.
222 // The fusion plan is consisted of a group of fusion patterns.
223 //
224 // Currently all proposed patterns followed xla kLoop/kInput like fusion
225 // templates while are adapted to the fully dynamic shape world.
226 //
227 // kLoop fusion template satifies:
228 //   - all ops in the fusion pattern are element-wise.
229 //   - all the shapes of outputs of fusion pattern are same, and thus can
230 //     fit into a same parallel loop.
231 //
232 // kInput fusion template satifies:
233 //   - any op in the fusion pattern is either element-wise or a reduction.
234 //   - if a op is a reduction, its output cannot be consumered by other
235 //     ops in the same fusion pattern.
236 //   - all the effective shapes of outputs of fusion pattern are same.
237 //     - For element-wise op, its effective shape is its output shape.
238 //     - For reduction op, its effective shape is its operand shape.
239 class FusionPlanner {
240  public:
FusionPlanner(const SmallVectorImpl<Operation * > & op_list)241   explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
242       : op_list_(op_list),
243         shape_analysis_(op_list),
244         cycle_detector_(op_list.size()) {
245     BuildNodeMap();
246   }
247 
248   // Returns a fusion plan if success, otherwise none.
Run()249   llvm::Optional<FusionPlan> Run() {
250     // Greedily search connected fusible pattern, and ops belonging to
251     // a same fusion pattern are grouped into a cluster.
252     RunEdgeContractionLoop();
253 
254     // After doing edge contraction, each unique cluster having size
255     // more than one represents a potential fusion pattern.
256     // We collect all these clusters and construct a fusion plan.
257     //
258     // Note that the ops in a fusion pattern are in topological ordering.
259     FusionPlan plan;
260     DenseMap<int, int> pattern_ids;
261     for (Operation* op : op_list_) {
262       Cluster* cluster = GetClusterForNode(op);
263       int node_id = cluster->cycles_graph_node_id();
264       if (!IsFusible(op_list_[node_id]) ||
265           EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
266         continue;
267       }
268       if (!pattern_ids.count(node_id)) {
269         int pattern_id = pattern_ids.size();
270         pattern_ids[node_id] = pattern_id;
271         plan.emplace_back();
272       }
273       plan[pattern_ids[node_id]].push_back(op);
274     }
275     return plan;
276   }
277 
278   // Returns the op_list this planner operates on.
op_list() const279   const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
280 
281  private:
282   // Represent a (partial) fused pattern
283   class Cluster {
284    public:
Cluster(int node_id,FusionPlanner * planner)285     Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
286       const SmallVectorImpl<Operation*>& op_list = planner->op_list();
287       pattern_.push_back(op_list[node_id]);
288     }
289 
290     // Merges `other` into this cluster, and clears `other`.
Merge(Cluster * other)291     void Merge(Cluster* other) {
292       pattern_.insert(pattern_.end(), other->pattern_.begin(),
293                       other->pattern_.end());
294       other->pattern_.clear();
295     }
296 
297     // The number of nodes in this cluster.
cluster_size() const298     int cluster_size() const { return pattern_.size(); }
299 
300     // The ID of the cluster as represented in `cycle_detector_`.
cycles_graph_node_id() const301     int cycles_graph_node_id() const { return node_id_; }
302 
303     // Sets the ID of the cluster as represented in `cycle_detector_`.
set_cycles_graph_node_id(int cycles_graph_node_id)304     void set_cycles_graph_node_id(int cycles_graph_node_id) {
305       node_id_ = cycles_graph_node_id;
306     }
307 
308     // Currently the fused pattern this cluster holds.
fused_pattern()309     const FusionPattern& fused_pattern() { return pattern_; }
310 
311    private:
312     // ID of the representative node of this cluster.
313     int node_id_;
314 
315     // the fused pattern this cluster holds.
316     FusionPattern pattern_;
317   };
318 
319  private:
MakeCluster(int cycles_graph_node_id)320   Cluster* MakeCluster(int cycles_graph_node_id) {
321     cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
322     return cluster_storage_.back().get();
323   }
324 
BuildNodeMap()325   void BuildNodeMap() {
326     int num_nodes = op_list_.size();
327     for (int node_id = 0; node_id < num_nodes; ++node_id) {
328       Operation* op = op_list_[node_id];
329       MakeCluster(node_id);
330       op_to_node_id_[op] = node_id;
331       leader_for_node_.insert(node_id);
332       for (Value operand : op->getOperands()) {
333         Operation* operand_op = operand.getDefiningOp();
334         if (operand_op == nullptr) {
335           // skip block argument
336           continue;
337         }
338         auto iter = op_to_node_id_.find(operand_op);
339         assert(iter != op_to_node_id_.end());
340         cycle_detector_.InsertEdge(iter->second, node_id);
341       }
342     }
343   }
344 
345   // Returns the cluster contains this op.
GetClusterForNode(Operation * n)346   Cluster* GetClusterForNode(Operation* n) {
347     int id = op_to_node_id_.at(n);
348     id = leader_for_node_.getLeaderValue(id);
349     return cluster_storage_[id].get();
350   }
351 
352   // Returns the cluster contains the op having `node_id`.
GetClusterForCyclesGraphNode(int node_id)353   Cluster* GetClusterForCyclesGraphNode(int node_id) {
354     return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
355   }
356 
357   // Merges the clusters `cluster_from` and `cluster_to`.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)358   bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
359     int from = cluster_from->cycles_graph_node_id();
360     int to = cluster_to->cycles_graph_node_id();
361 
362     auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
363     if (!optional_merged_node.hasValue()) {
364       llvm::dbgs() << "Could not contract " << from << " -> " << to
365                    << " because contracting the edge would create a cycle.";
366       return false;
367     }
368 
369     // Merge the clusters.
370     cluster_from->Merge(cluster_to);
371     cluster_from->set_cycles_graph_node_id(*optional_merged_node);
372 
373     // Merge the UnionFind Set.
374     leader_for_node_.unionSets(from, to);
375     return true;
376   }
377 
378   template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)379   bool ForEachEdgeInPostOrder(FnTy fn) {
380     bool changed = false;
381     for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
382       Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
383       // Make a copy of the set of successors because we may modify the graph in
384       // TryToContractEdge.
385       std::vector<int32_t> successors_copy =
386           cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
387 
388       for (int to : successors_copy) {
389         Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
390         bool contracted_edge = fn(cluster_from, cluster_to);
391         changed |= contracted_edge;
392       }
393     }
394 
395     return changed;
396   }
397 
398   // returns the outputs if two cluster were merged
GetResultsOfFusedPattern(Cluster * from,Cluster * to)399   SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
400     FusionPattern fused_pattern =
401         MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
402     return GetOutputsOfFusionPattern(fused_pattern);
403   }
404 
405   // This function check if fusing `from` with `to` is valid and if so perform
406   // the merge. The validity is based on the operations in the clusters and
407   // the compatibility of the shapes of the outputs of the would-be fused
408   // clusters.
409   // Returns true is the merge was performed.
TryToContractEdge(Cluster * from,Cluster * to)410   bool TryToContractEdge(Cluster* from, Cluster* to) {
411     int node_to = to->cycles_graph_node_id();
412     int node_from = from->cycles_graph_node_id();
413 
414     // Both node_to and node_from should be fusible
415     if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
416       return false;
417     }
418 
419     auto op_from_fusibility =
420         dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
421     if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
422       // This op cannot be fused with its consumers.
423       return false;
424     }
425 
426     auto op_to_fusibility =
427         dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
428     if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
429       // This op cannot be fused with its operands.
430       return false;
431     }
432 
433     // Output shapes of a fusion pattern should be compatible as described in
434     // the document of this class.
435     SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
436     auto get_workload_shape = [](Value v) {
437       Operation* op = v.getDefiningOp();
438       // Block argument
439       if (!op) return v;
440       auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
441       // Const value
442       if (!op_fusibility) return v;
443       llvm::Optional<Value> workload =
444           op_fusibility.inferEffectiveWorkloadShape();
445       return workload.hasValue() ? *workload : v;
446     };
447 
448     Value ref = get_workload_shape(results[0]);
449     if (!llvm::all_of(results, [&](Value result) {
450           Value val = get_workload_shape(result);
451           return shape_analysis_.HasSameShape(ref, val);
452         })) {
453       return false;
454     }
455 
456     return MergeClusters(from, to);
457   }
458 
459   // Greedily fuse connected node.
RunEdgeContractionLoop()460   bool RunEdgeContractionLoop() {
461     using std::placeholders::_1;
462     using std::placeholders::_2;
463     return ForEachEdgeInPostOrder(
464         std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
465   }
466 
467   const SmallVectorImpl<Operation*>& op_list_;
468 
469   // Shape equality checker
470   ShapeConstraintAnalysis shape_analysis_;
471 
472   // op -> node_id
473   std::unordered_map<Operation*, int> op_to_node_id_;
474 
475   // make sure not introduce cycle after fusion
476   GraphCycles cycle_detector_;
477   std::vector<std::unique_ptr<Cluster>> cluster_storage_;
478 
479   // a UnionFind set. Each set represents a (partial) fused pattern
480   // and has a leader as representation.
481   EquivalenceClasses<int32_t> leader_for_node_;
482 };
483 
484 struct MhloFusionPass : public MhloFusionPassBase<MhloFusionPass> {
runOnFunctionmlir::mhlo::__anonf3dc631f0111::MhloFusionPass485   void runOnFunction() override {
486     FuncOp func = getFunction();
487     if (!IsTargetFunc(func)) {
488       return;
489     }
490 
491     // process each block and do fusion within a block.
492     for (Block& block : func) {
493       SmallVector<Operation*, 4> op_list;
494       for (Operation& op : block) {
495         op_list.push_back(&op);
496       }
497 
498       FusionPlanner planner(op_list);
499       llvm::Optional<FusionPlan> plan = planner.Run();
500       if (!plan) {
501         emitError(func.getLoc(), "can't find a fusion plan");
502         signalPassFailure();
503         return;
504       }
505       if (!ApplyFusionPlan(*plan)) {
506         emitError(func.getLoc(), "apply fusion plan failed");
507         signalPassFailure();
508         return;
509       }
510     }
511   }
512 
IsTargetFuncmlir::mhlo::__anonf3dc631f0111::MhloFusionPass513   bool IsTargetFunc(FuncOp func) {
514     int num_fusible_ops = 0;
515     bool is_target_func = false;
516     // We only process the function having enough candidates
517     func.walk([&](Operation* op) {
518       num_fusible_ops +=
519           static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
520       is_target_func = (num_fusible_ops > 1);
521       // early stop
522       if (is_target_func) return WalkResult::interrupt();
523       return WalkResult::advance();
524     });
525     return is_target_func;
526   }
527 
ApplyFusionPlanmlir::mhlo::__anonf3dc631f0111::MhloFusionPass528   bool ApplyFusionPlan(const FusionPlan& plan) {
529     for (const FusionPattern& pattern : plan) {
530       OpBuilder b(pattern.back());
531 
532       SmallVector<Location, 4> locations;
533       locations.reserve(pattern.size());
534       for (Operation* op : pattern) {
535         locations.push_back(op->getLoc());
536       }
537       Location fused_loc =
538           FusedLoc::get(pattern.back()->getContext(), locations);
539 
540       SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
541       SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
542       SmallVector<Type, 4> output_types;
543       output_types.reserve(outputs.size());
544       for (Value v : outputs) {
545         output_types.push_back(v.getType());
546       }
547 
548       FusionOp fusion =
549           b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
550       Region& region = fusion.fused_computation();
551       region.push_back(new Block);
552       Block& block = region.front();
553       for (Operation* op : pattern) {
554         op->moveBefore(&block, block.end());
555       }
556       b.setInsertionPoint(&block, block.end());
557       b.create<mhlo::ReturnOp>(fused_loc, outputs);
558 
559       for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
560         Value output = std::get<0>(output_and_result);
561         Value fusion_result = std::get<1>(output_and_result);
562         for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
563           if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
564         }
565       }
566     }
567     return true;
568   }
569 };
570 
571 }  // namespace
572 
createMhloFusionPass()573 std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass() {
574   return std::make_unique<MhloFusionPass>();
575 }
576 
577 }  // namespace mhlo
578 }  // namespace mlir
579